
import sys
import os
import math
import subprocess as sub
from glob import glob
#for timestampingfile
from time import time
from datetime import datetime
##for writing a csv file
import csv
import numpy
import argparse
 
from prosci.util.pdb import Pdb
from prosci.util.residue import ResidueList
from prosci.util.ali import Ali
from prosci.util.pdb3d import fit_line, dist
from prosci.util.gaps import map_gaps
from prosci.util.seq import pid

#############################################################
# Some definitions
#############################################################
def isGap(c):
  return c in "-/"

def isLoop(c):
  #return c in "CPE"
  return c != "H"

def angle(crds1, crds2):
  return math.acos(numpy.dot(crds1, crds2)/(numpy.linalg.norm(crds1)*
                                            numpy.linalg.norm(crds2)))



def cylinder_fit_c(n_linepoints, n_unitvector, fragment, num_atoms):
  #make the c command - this has the funciton name, then the number of atoms
  # then the linepoint (x, y, z) the vector(xyz) from the fit_line and then 
  #the helix coordinates (x, y, z) for however many atoms we have 
  cylinder = numpy.zeros([9])
  command_string = []
  command_string.append("./cylinder")
  command_string.append(str(num_atoms))
  for k in xrange(3):
    #print str(n_linepoints[0][k])
    command_string.append(str(n_linepoints[k]))
  for k in xrange(3):
    #linepoints then vector!!!!!!
    command_string.append(str(n_unitvector[k]))
    #print str(n_unitvector[k])
  for k in xrange(num_atoms):
    command_string.append(str(fragment.get_coords()[k][0]))
    command_string.append(str(fragment.get_coords()[k][1]))
    command_string.append(str(fragment.get_coords()[k][2]))
  
  #run the command string, putting the stdout (which is the linepoints then 
  # the vector, followed by the rmsd) into p - a string
  p = sub.check_output(command_string)
  #p in format %f %f %f %f %f %f %f, so need to parse this, by splitting 
  # this at every space
  k=0 # counter for position in p
  if p[0] == '*':
    print 'fit failed for vector ' + str(pdb[start+i].CA.ires) + \
    ' in helix ' + \
    str(pdb[start].CA.ires) + ' , ' + str(pdb[end-1].CA.ires) + \
    '. reverting to least squares fit'
    
    cylinder[3:6] = n_unitvector
    cylinder[0:3] = n_linepoints[0]
  else:  
    for j in xrange(0,9): #for each of the numbers
      number_string = '' #initialize string
      while (p[k] !=' '): #loop until we reach a space
        number_string = number_string + p[k]
        k = k+1
        # put this number into the cylinder object
      cylinder[j] = (float(number_string))
      # jump the space
      k=k+1
    #check that vectors are going in the same direction
    if (numpy.dot(cylinder[3:6], n_unitvector) < 0):
      cylinder[3:6] = -cylinder[3:6]
      print('##############inverted vector #######################')
     
      
    #take the squared deviations, and return rmsds: 8 is the end result, 9 the 
    # rmsd of the initial guess
    cylinder[7:9] = numpy.sqrt(cylinder[7:9]/num_atoms)  
    return cylinder 


max_loop_length =3
helix_vector_length = 6
pdb_file = '3EMLA.atm'
tem_file = '3EMLA.annotation.tem'

soluble = False
pdbfile = '3EMLA.atm'
temfile = '3EMLA.annotation.tem'

pdb = ResidueList(Pdb(pdbfile).get_backbone())
pdb_no_c_alpha = ResidueList(Pdb(pdbfile).get_backbone_no_O())
tem = Ali(temfile)

sequence = tem[0][0].seq
sstruc = tem[0]["DSSP"].seq
if not soluble:
    layers = tem[0]["membrane layer"].seq
else:
    layers = 'N' * len(sequence)


hbmmc = tem[0]["Mainchain to mainchain hydrogen bonds (carbonyl)"].seq
hbmma = tem[0]["mainchain to mainchain hydrogen bonds (amide)"].seq
hbh = tem[0]["hydrogen bond to heterogen"].seq
hbmc = tem[0]["hydrogen bond to mainchain CO"].seq
hbmn = tem[0]["hydrogen bond to mainchain NH"].seq
hbsh = tem[0]["hydrogen bond to other sidechain/heterogen"].seq

sstruc_bounds = []
for i in xrange(1,len(sstruc)):
    if not isGap(sequence[i]):
      if not isLoop(sstruc[i]):
        if layers[i] != "N" or soluble == True:  
          # Either: There was a break in the protein chain
          if dist(pdb[i-1].CA, pdb[i].CA) > 4:
            sstruc_bounds.append([i, i+1])
          
          # Or: We're starting a new secondary structure element
          # (not part for first iteration)
          elif not sstruc_bounds or sstruc_bounds[-1][1] < i:
            
            # Jumping over a small loop 
            #again, first part is for when no sstruc exists
            if sstruc_bounds and \
            i-sstruc_bounds[-1][1] < max_loop_length and \
            (soluble == True or "H" not in layers[sstruc_bounds[-1][1]:i]):
              # We've just jumped a tiny loop. Merge this SSE with the 
              #previous one and delete the tiny loop.
              sstruc_bounds[-1][1] = i+1
              
            # Really a new secondary structure element
            else:
              sstruc_bounds.append([i, i+1])
          
          # Or: Extending SSE
          else:
            sstruc_bounds[-1][1] = i+1
       



start = 4
end = 31

helix=pdb[start:end]
helix_no_c_alpha = pdb_no_c_alpha[start:end]
print start,end # so we know where we are 

#lets initialise some things
maxangle=0.0 #the biggest angle
maxangle2=0.0 #the wobble at the biggest angle
maxpos=-1 #position of the kinks
kinkpos_helix = -1 #position relative to helix
angles = [] #angles of the helix
angles2 = [] #wobble angles
distance_to_H=[] #distance to TH for each of the residues
#initialize some arrays for unitvectors and cylinder results
unitvector= numpy.zeros([end-start-helix_vector_length+1,3])
cylinder = numpy.zeros([end-start-helix_vector_length+1,9])
cylinder_all = numpy.zeros([end-start-helix_vector_length+1,9])
##########################################################################
#loop over the helix, calculating vectors
for i in xrange(end-start-helix_vector_length+1):
    #choose our fragment to fit
    fragment = helix[i:i+helix_vector_length]
    fragment_no_c_alpha = helix_no_c_alpha[i:i+helix_vector_length]
    #simple least squares to get a sensible starting point       
    n_unitvector, n_linepoints = fit_line(fragment.get_coords()[0:21])
    #now cylinder fit, using the least squares 
    cylinder[i] = cylinder_fit_c(n_linepoints[0], n_unitvector,  fragment_no_c_alpha, 18)
    cylinder_all[i]= cylinder_fit_c(n_linepoints[0], n_unitvector,  fragment, 24)
    unitvector[i] = cylinder[i][3:6] #transfer to unitvector
 



