#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# MEDELLER : template-based membrane protein structure prediction
#
# Author: Sebastian Kelm
# Created: 10/09/2009
#
#



import sys, os.path
#import re
from array import array

if __name__ == "__main__":
  # Make sure we can import stuff from this file's directory
  sys.path.append(os.path.abspath(os.path.dirname(sys.argv[0])))
  sys.path.append(os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), "..", "lib-python"))
  sys.path.extend(os.environ['PATH'].split(':'))

from prosci.common import *
from prosci.util.pdb import Pdb, intersectPdb, diffPdb
from prosci.util.pdb3d import compare_structures
from prosci.util.residue import ResidueList
from prosci.util.ali import Ali
from prosci.util.gaps import deGappify
from prosci.medeller.fread import mapCore2TargetSeq



def check_sequences(models):
    firstres = models[0][0].ires
    lastres = models[0][-1].ires
    for m in models:
      if m[0].ires < firstres:
        firstres = m[0].ires
      if m[-1].ires > lastres:
        lastres = m[-1].ires
    seqs = []
    for m in models:
      seqs.append(m.get_seq(gapped=True, firstres=firstres, lastres=lastres))
    for s in seqs:
      if len(s) != len(seqs[0]):
        raise ValueError("Model sequence lengths don't match\n"+"\n".join(seqs)+"\n")
    for i in xrange(len(seqs[0])):
      letters = set([s[i] for s in seqs])
      if '-' in letters:
        letters.remove('-')
      if len(letters) > 1:
        raise ValueError("Model sequences don't match: "+str(letters)+"\n"+"\n".join(seqs)+"\n")
    return firstres, lastres, seqs



def basename_noext(fname):
  return os.path.splitext(os.path.basename(fname))[0]


if __name__ == "__main__":
    # We are running on the command line
    #
   
    import prosci.shell
    
    ################################
    # Command line option handling #
    ################################
    
    params = prosci.shell.Params(
       allowed=['help', 'singlechain', 'renumber', 'writemodels', 'fullpath'],
       required=[],
       withargument=['imembrane', 'completeseq', 'completeseqid'],
       helpoption='help',
       usage="""
USAGE:
  
  $0 [OPTIONS] <native> <model1> [model2 ...]
      \n""",
       
       help="""
OPTIONS:
    --help
      Display this message and exit.
      
    --renumber
      Renumber the atoms in the native structure
    
    --singlechain
      Assume only a single chain is present in each model.
      Rename all chains to "A".
      Use this when some modelling programs don't name their chains properly!
    
    --writemodels
      Write cut-down models to files, adding the suffix ".subset" to the end of
      the file name.
    
    --imembrane TEMFILE
      Provide iMembrane annotation, in order to cut down the models to the
      transmembrane domain, instead of the common residues.
    
    --completeseq ALIFILE
      Provide the complete sequence of the protein, which can be used to map
      all the models and native structure onto a consistent residue numbering
      scheme.
    
    --completeseqid ID
      Provide the ID that identifies the complete sequence in the above ALIFILE.
    
    --fullpath
      Output scores labelled with the full model path rather than just the meaningful part of the filename.
      \n""")
    
    if (not len(params.args) >= 2):
      sys.stderr.write("""
Incorrect usage. Must provide at least 2 pdb files.
      """.strip())
      sys.stderr.write("\n")
      params.writeUsage()
      sys.exit(0)
    
    #print >>sys.stderr, params.args
    models = [Pdb(x) for x in params.args]
    if params.isOpt("fullpath"):
      for m, x in zip(models, params.args):
        m.code = x
    
    if params.isOpt("renumber"):
      models[0].renumber()
      if params.isOpt("writemodels"):
        f = open(params.args[0]+".renumbered", "w")
        f.write(str(models[0]))
        f.close()
    
    if params.isOpt("singlechain"):
      for m in models:
        for atm in m:
          atm.chain = "A"
    
    if params.isOpt("completeseq"):
      complete_ali = Ali(params.getOpt("completeseq"))
      if params.isOpt("completeseqid"):
        complete_seq = deGappify(complete_ali[params.getOpt("completeseqid")].master.seq)
      else:
        complete_seq = deGappify(complete_ali[0].master.seq)
      del complete_ali
      
      mapCore2TargetSeq(complete_seq, models[0], targetid="complete", modelid="native", renumber=True)
      for i, m in enumerate(models[1:]):
        mapCore2TargetSeq(complete_seq, m, targetid="complete", modelid="model%d"%i, renumber=True)
    
    full_firstres, full_lastres, full_seqs_aligned = check_sequences(models)
    #print >>sys.stderr, join("\n", full_seqs_aligned)
    print >>sys.stderr, "Full model sequences ok"
    
    #print models[0]
    
    # Cut down the models to the relevant residues
    if params.isOpt("imembrane"):
      # We have iMembrane annotation - cut down models to the TM domain
      #
      native_residues = ResidueList(models[0])
      annotation = Ali(params.getOpt("imembrane"))
      annotation.remove_gaps()
      if annotation[0].master.seq != native_residues.get_seq():
        raise RuntimeError("Sequence in iMembrane TEM file does not match sequence in native PDB file\n%s\n%s\n" % (annotation[0].master.seq, native_residues.get_seq()))
      membrane_layer = annotation[0]['membrane layer'].seq
      assert len(native_residues) == len(membrane_layer)
      
      # Add residues on either end
      all_tmdomains = []
      tmdomain = []
      tails = 0
      first = -1
      last = -1
      for i, (res, lay) in enumerate(zip(native_residues, membrane_layer)):
        if lay in "TH":
          if last >= 0:
            if i-last < 30:
              # Gap is short - include all of it
              tmdomain.extend(native_residues[last:i])
            else:
              # Gap is long - include only the first and last 15 residues
              tmdomain.extend(native_residues[last:last+15])
              all_tmdomains.append([tmdomain, tails])
              tmdomain = []
              tails = 0
              tmdomain.extend(native_residues[i-15:i])
          tmdomain.append(res)
          if lay == "T":
            tails += 1
          last = i+1
          if first < 0:
            first = i
      
      assert first >= 0
      
      if tmdomain:
        all_tmdomains.append([tmdomain, tails])
      
      # Add 15 residues on either side
      all_tmdomains[0][0] = native_residues[max(0,first-15):first] + all_tmdomains[0][0]
      all_tmdomains[-1][0].extend(native_residues[last:last+15])
      
      # Accept only domains with enough Tail layer residues and with a minimum length
      tmdomain = []
      for residues, tails in all_tmdomains:
        if tails > 15 and len(residues) >= 40:
          tmdomain.extend(residues)
      
      tmdomain = ResidueList(tmdomain).to_pdb()
      
      check_sequences([models[0], tmdomain])
      
      # Keep only those regions common between each model and the cut-down native structure
      reduced_models = [tmdomain]
      for m in models[1:]:
        reduced_tmdomain, reduced_m = intersectPdb([tmdomain, m])
        #if reduced_tmdomain.rescount() != tmdomain.rescount():
        #  print "model does not contain these residues:"
        #  print diffPdb(tmdomain, reduced_tmdomain)
        #  print "native does not contain these residues:"
        #  print diffPdb(m, reduced_m)
        check_sequences([models[0], tmdomain, m, reduced_m])
        reduced_models.append(reduced_m)
      #reduced_models = intersectPdb(reduced_models)
    else:
      # We don't have iMembrane annotation - cut down models to overlapping region
      #
      reduced_models = intersectPdb(models)
    
    #print >>sys.stderr, "Reduced models:"
    #print >>sys.stderr, reduced_models
    
    reduced_firstres, reduced_lastres, reduced_seqs_aligned = check_sequences(reduced_models)
    #print >>sys.stderr, join("\n", reduced_seqs_aligned)
    print >>sys.stderr, "Reduced model sequences ok"
    
    native = reduced_models.pop(0)
    native_seq = native.get_seq(gapped=True)
    complete_rescount = native.rescount()
    #complete_rescount = models[0].rescount()
    
    if not params.isOpt("imembrane"):
      print "all\tcommon_residues\t%g"%(native.rescount())
      print "all\tcommon_coverage\t%g"%(float(native.rescount())/models[0].rescount())
    
    print "all\tcomplete_length\t%g"%(models[0].rescount())
    print "all\tcropped_length\t%g"%(native.rescount())
    
    #~ f = open("native.pdb", "w")
    #~ f.write(str(native))
    #~ f.close()
    
    if params.isOpt("writemodels"):
      for model, fname in zip([native]+reduced_models, params.args):
        f = open(fname+".subset", "w")
        f.write(str(model))
        f.close()
    
    for m, fname in zip(reduced_models, params.args[1:]):
      #print "Comparing %s to native (%s)" % (fname, params.args[0])
      m_seq = m.get_seq(gapped=True, firstres=native[0].ires, lastres=native[-1].ires)
      #assert native_seq == m_seq
      if len(native_seq) != len(m_seq):
        raise ValueError("Native and model sequences differ in length.\nNative:%s\nModel :%s\n"%(native_seq, m_seq))
      scores = compare_structures(native, m, native_seq, m_seq, normalise_by=complete_rescount)
      print "%s\t%s\t%g" %(m.code, "length", m.rescount())
      print "%s\t%s\t%g" %(m.code, "coverage", float(m.rescount())/complete_rescount)
      for k in sorted(scores):
        print "%s\t%s\t%g" %(m.code, k, scores[k])
