#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# iMembrane : Annotation database constructor
#
# Author: Sebastian Kelm
# Created: 10/05/2010
#


import sys
import os
import traceback
import subprocess
from glob import glob

sys.path = [os.path.join(os.path.abspath(os.path.dirname(sys.argv[0]))), 
            os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), "..", "lib-python"),
            os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), "..", "..", "lib-python")] + sys.path

from prosci.common import NotFoundError
from prosci.util.pdb import Pdb
from prosci.util.ali import Ali
from prosci.util.seq import pid as sequence_identity
from prosci.shell import Params

from prosci.imembrane.projector import Projector, splitid, joinid
from prosci.util.getpdb import getpdb




def choose_representatives(dbdir, outstream):
    import os
    from glob import glob
    
    from prosci.common import average
    from prosci.util.ali import Ali
    
    
    entries = glob(os.path.join(dbdir, "entries", "*"))
    results = {}
    
    for edir in entries:
      e = os.path.basename(edir)
      efields = e.split(".")
      pdb = efields[0]
      sim = efields[1]
      
      tmscore = -1.0
      with file(os.path.join(edir, "scores.table")) as f:
        for line in f:
          if line.startswith("TM-score"):
            tmscore = float(line.split()[1])
            break
      assert tmscore >= 0
      
      key = pdb
      value = ("%s.%s"%(pdb, sim), e, tmscore)
      
      if key in results:
        results[key].append(value)
      else:
        results[key] = [value]
    
    
    for pdb in sorted(results):
      scores = {}
      
      for pdbsim, e, tmscore in results[pdb]:
        if pdbsim in scores:
          x = scores[pdbsim]
        else:
          x = scores[pdbsim] = [0.0, 0]
        x[0] += tmscore
        x[1] += 1
      
      bestsim = None
      bestscore = -1.0
      
      for pdbsim in scores:
        tmscore, count = scores[pdbsim]
        if tmscore/count > bestscore:
          bestscore = tmscore/count
          bestsim = pdbsim
      assert bestsim is not None
      
      outstream.write("%s\t%s\t%f\n" % (pdb, bestsim, bestscore))




params = Params()

if params.opts or len(params.args) != 3:
    sys.stderr.write("""
iMembrane - database projection component

Input: iMembrane database with CGDB structures
Output: iMembrane database with PDB structures

JOY-style annotation types:
"membrane contact"   Residues are either in contact with the lipid tails
                 (T) or the polar head groups (H) or neither (N).
"membrane layer"     Residues are either in the lipid tail layer (T) or
                 in a peripheral polar head group layer (H) or not
                 within the membrane at all (N).

Residues that cannot be annotated are marked with question marks (?).


USAGE:
"""+params.scriptname+""" <imembranedb_directory> <projecteddb_directory> <PDB_mirror_directory>

  <imembranedb_directory>   Directory created by imem_builddb
  <projecteddb_directory>   Directory created by this script
  <PDB_mirror_directory>    Local mirror of the PDB, created by typing:

  rsync -a --port=33444 rsync.wwpdb.org::ftp_data/structures/divided/pdb/ <PDB_mirror_directory>
\n""")
    sys.exit(1)


#DBDIR   = os.path.abspath(params.args[0])
DBDIR   = params.args[0]
OUTDIR  = os.path.abspath(params.args[1])
PDB_DIR = os.path.abspath(params.args[2])

ENTRYDIR = "%s/entries"%(OUTDIR)
SEQDIR   = "%s/seq"%(OUTDIR)


if not os.path.exists(DBDIR):
    sys.stderr.write("ERROR: Database directory does not exist: '%s'\n" % (DBDIR))
    sys.exit(2)

if not os.path.exists(PDB_DIR):
    sys.stderr.write("WARNING: PDB directory does not exist: '%s'. Will attempt to fetch PDB files online.\n" % (PDB_DIR))
    #sys.exit(3)

if not os.path.exists("%s"%(OUTDIR)):
    os.mkdir("%s"%(OUTDIR))



#~ allseqs  = []
#~ allcodes = []
missing_entries = []
failed_entries = []

if not os.path.exists(ENTRYDIR):
    os.mkdir(ENTRYDIR)

wholestruc = None
entrylist = sorted(glob("%s/entries/*.*"%(DBDIR)))
obsolete = None
for i, entrycode in enumerate(entrylist):
    entrycode = os.path.basename(entrycode)
    
    print "%d / %d : %s" % (i+1, len(entrylist), entrycode)
    
    idfields = splitid(entrycode)
    
    projectioncode = idfields[:]
    projectioncode[4] = str(int(projectioncode[4])+1)
    projectioncode = joinid(projectioncode)
    
    if os.path.exists("%s/%s/annotation.tem"%(ENTRYDIR, projectioncode)):
      continue
    
    if not (wholestruc and wholestruc.code == idfields[0]):
      if obsolete == idfields[0]:
        continue
      
      try:
        wholestruc = getpdb(idfields[0], PDB_DIR, online=True)
      except NotFoundError:
        obsolete = idfields[0]
        sys.stderr.write("WARNING: PDB entry '%s' not found. Structure obsolete?\n"%(idfields[0]))
        continue
    
    if not wholestruc:
      missing_entries.append(entrycode)
      sys.stderr.write("WARNING: PDB entry '%s' seems to be empty\n"%(idfields[0]))
      continue
    
    struc = wholestruc.get_chain(idfields[2])
    if not struc:
      if idfields[2] != " ":
        missing_entries.append(entrycode)
        sys.stderr.write("WARNING: PDB entry '%s' does not contain chain '%s'.\n"%(idfields[0], idfields[2]))
        continue
      else:
        # CGDB has not labelled this chain properly.
        # Align CGDB chain to all PDB chains and find the one with the best sequence identity.
        #
        with file("%s/entries/%s/contact.atm"%(DBDIR, entrycode)) as f:
          dbseq = Pdb(entrycode, f).get_seq()
        best_seq_identity = 0.0
        best_covereage = 0.0
        best_chain = None
        for c in wholestruc.xchains():
          s = c.get_seq()
          alignment = Ali(">cgdb\nsequence\n%s\n>pdbchain\nsequence\n%s\n"%(dbseq, s))
          alignment.align()
          #print alignment
          
          # Calculate sequence identity, normalised by length of the PDB chain.
          # This will catch cases where CGDB has merged multiple chains into one unlabelled chain.
          # Just in case, calculate coverage of cgdb chain too and use as a second discriminator of the best match.
          #
          seq_identity = sequence_identity(alignment["cgdb"][0].seq, alignment["pdbchain"][0].seq, mode="second")[0]
          cgdb_coverage = sequence_identity(alignment["cgdb"][0].seq, alignment["pdbchain"][0].seq, mode="first")[1]
          if seq_identity > best_seq_identity or (seq_identity == best_seq_identity and cgdb_coverage > best_covereage):
            best_seq_identity = seq_identity
            best_covereage = cgdb_coverage
            best_chain = c
        
        if best_seq_identity < 0.9:
          # Crap, nothing with a sequence identity above 80% found. Just drop this entry.
          missing_entries.append(entrycode)
          sys.stderr.write("WARNING: PDB entry '%s' does not contain chain '%s' and no chain with >90%% sequence identity.\n"%(idfields[0], idfields[2]))
          continue
        
        struc = best_chain
        
        # Fix the chain code in the new database
        projectioncode = splitid(projectioncode)
        projectioncode[2] = struc[0].chain
        projectioncode = joinid(projectioncode)
        
        # Re-check if we should skip this entry because we already have the result
        if os.path.exists("%s/%s"%(ENTRYDIR, projectioncode)):
          continue
        
        sys.stderr.write("NOTICE: projection of entry '%s' saved as '%s'.\n"%(entrycode, projectioncode))
        
    
    try:
      p = Projector(DBDIR, entrycode, projectioncode, None, struc)
      p.annotate()
    except:
      failed_entries.append(entrycode)
      
      # Print exception and traceback
      exc_type, exc_value, exc_traceback = sys.exc_info()
      traceback.print_tb(exc_traceback)
      traceback.print_exception(exc_type, exc_value, exc_traceback)
      
      break
    
    
    
    p.write_output(ENTRYDIR, link_parent=True, copy_parent=False, write_fasta=True)
    
    #~ allseqs.append(p.ali[projectioncode].toFastaString())
    #~ allcodes.append(projectioncode)



if not os.path.exists(SEQDIR):
    os.mkdir(SEQDIR)

f=open("%s/missing.list"%(SEQDIR), 'w')
f.write("\n".join(missing_entries))
f.write("\n")
f.close()

f=open("%s/failed.list"%(SEQDIR), 'w')
f.write("\n".join(failed_entries))
f.write("\n")
f.close()

# Choose a single simulation per PDB complex as a representative
# The one with the lowest average TM-score is chosen
#
f = open(os.path.join(OUTDIR, "representatives.table"), "w")
choose_representatives(OUTDIR, f)
f.close()

os.chdir(SEQDIR)
subprocess.call('cat "%s/"*/sequence.fasta > chains.fasta'%(ENTRYDIR), shell=True)
subprocess.call("grep '>' chains.fasta > chains.list", shell=True)
subprocess.call("makeblastdb -in chains.fasta -out chains -dbtype prot", shell=True)
subprocess.call('chmod a+rx "%s" "%s" "%s"' % (OUTDIR, ENTRYDIR, SEQDIR), shell=True)
subprocess.call('chmod a+r "%s/"*' % (SEQDIR), shell=True)
