#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os.path
import re

if __name__ == "__main__":
  # Make sure we can import stuff from this file's directory
  sys.path.append(os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), "..", "lib-python"))

from prosci.util.pdb import Pdb
from prosci.common import splitpath



re_default = re.compile("^(?:HEADER|TITLE|OBSLTE|COMPND|SOURCE|KEYWDS|EXPDTA|AUTHOR|REVDAT|SPRSDE|JRNL|REMARK|DBREF|SEQADV|FTNOTE|FORMUL|MODEL|ENDMDL|END   )")
re_atom    = re.compile("^(?:ATOM|HETATM|SHEET|TER)") # chainid at position 21
re_helix   = re.compile("^HELIX") # chainid at position 19
re_seqres  = re.compile("^SEQRES") # chainid at position 11
re_het     = re.compile("^HET    ") # chainid at position 12

def splitstructure(infile, prefix, chaincodes, extension=".pdb"):
  chaincodes = tuple(set(chaincodes))
  
  chainnum = range(len(chaincodes))
  
  if not (chaincodes and infile):
    return []
  
  outfiles = []
  outfnames = []
  for c in chaincodes:
    fname = prefix+c+extension
    outfiles.append(open(fname, 'w'))
    outfnames.append(fname)
  
  c = ' '
  m = None
  for line in infile:
    if re_default.match(line):
      for f in outfiles:
        f.write(line)
    else:
      m = re_atom.match(line)
      if m:
        c = line[21]
      else:
        m = re_helix.match(line)
        if m:
          c = line[19]
        else:
          m = re_seqres.match(line)
          if m:
            c = line[11]
          else:
            m = re_het.match(line)
            if m:
              c = line[12]
            else:
              continue
      for i in chainnum:
        if c == chaincodes[i]:
          outfiles[i].write(line)
          break
  
  for f in outfiles:
    f.close()
  
  return outfnames



def writechainseq(a_chain, basename, pdb_code, chaincode, options):
    c = chaincode
    outfnames = []
    
    for x in sorted(set("ajmf") & set(options)):
      outfname = basename+c+".ali"
      
      if 'a' == x:
        text = ">%s\n%s\n%s\n" % (pdb_code+c, a_chain.get_structure_lign(), a_chain.get_seq())
      elif 'j' == x:
        text = ">P1;%s\n%s\n%s*\n" % (pdb_code, a_chain.get_structure_lign(), a_chain.get_seq())
      elif 'm' == x:
        text = ">P1;%s\nstructure:%s:FIRST:@:END:::::\n%s*\n" % (pdb_code+c, pdb_code+c, a_chain.get_seq())
      else: #if 'f' == x:
        text = ">%s\n%s\n" % (pdb_code+c, a_chain.get_seq())
        outfname = basename+c+".fasta"
      
      f = open(outfname, 'w')
      f.write(text)
      f.close()
      
      outfnames.append(outfname)
    
    return outfnames



def splitchains(files, options, doprint=False):
  for pdb_file in files:
    path, basename, ext = splitpath(pdb_file)
    pdb_code = basename
    
    a = Pdb(pdb_code, file(pdb_file), allowLigands=("allowLigands" in options))
    chains = a.get_chain_codes()
    
    if "list" in options:
      for c in chains:
        print c
      return
    
    if "seqtitle" in options:
      pdb_code = options["seqtitle"]
    if "prefix" in options:
      basename = options["prefix"]
    
    outfiles = []
    
    if not chains or (len(chains)==1 and (not chains[0] or chains[0] == " ")) or (len(chains)==1 and "noappend" in options):
        #sys.stderr.write("No chain information found in PDB file '%s'. Not splitting it.\n" % (pdb_file))
        if ('F' in options) and (set("ajmf") & set(options)):
            outfiles += writechainseq(a, basename, pdb_code, "", options)
    else:
        if 'p' in options:
            outfiles += splitstructure(file(pdb_file), basename, chains, ext)
        
        if set("ajmf") & set(options):
            for c in chains:
            #if 'p' in options:
                #os.system("cutchain %s %s > %s" % (c, pdb_file, basename+c+ext))
                #if os.path.isfile(basename+c+ext):
                    #if doprint:
                        #print basename+c+ext
                #else:
                    #sys.stderr.write("ERROR creating file: %s", basename+c+ext)
                
                a_chain = a.get_chain(c)
                #print "\n\n\n", str(a_chain), "\n\n\n"
                
                outfiles += writechainseq(a_chain, basename, pdb_code, c, options)
                
    for f in outfiles:
      if os.path.isfile(f):
        if doprint:
          print f
      else:
        sys.stderr.write("ERROR creating file: %s\n", f)



if __name__ == "__main__":
    
    import prosci.shell

    ################################
    # Command line option handling #
    ################################
    
    params = prosci.shell.Params(allowed=['p','a','j', 'm','f','F',"noappend", "quiet", "list", "allowLigands"], withargument=["prefix", "seqtitle"])
    
    
    if len(params.args) < 1 or (not (set("pajmf") & set(params.opts)) and not "list" in params.opts):
        sys.stdout.write("""
Split a PDB or ATM file into chains.
Output PDB structure files and/or ALI sequence files.

USAGE:
    """+params.scriptname+""" OPTIONS pdb_file...

SHORT OPTIONS:
    OUTPUT FORMAT
    You must specify at least one of the following:
    
    -p     Write chains into PDB files (file extension of input files will
           be used)
    -a     Write chains into iMembrane-style ALI files (.ali)
    -j     Write chains into JOY-style       ALI files (.ali)
    -m     Write chains into MODELLER-style  ALI files (.ali)
    -f     Write chains into FASTA files (.fasta)
    
    -F     Force writing .ali files, even when only a single, unnamed chain
           is present. PDB output will not be written, as it is identical
           to the input.
    
    --allowLigands
           Allow ligands in output. By default all ligands (HETATM) are removed
    
    --list
           List chains only. Don't produce any output.
    
    --prefix PREFIX
           Set filename prefix to PREFIX instead of the extensionless basename
           of the input file
    
    --seqtitle TITLE
           Set title of any sequence file entries to TITLE, instead of the
           extensionless basename of the input file.
    
    --noappend
           If only a single, named chain is present, write .ali file with
           same name as PDB file. Do not append chain identifier to name.
           PDB output will not be written, as it is identical to the input.
    
    --quiet
           Don't print the names of output files.
    
        \n""")
        sys.exit(1)
    
    if len(set("ajm") & set(params.opts)) > 1:
        sys.stderr.write("""
The following options are mutually exclusive: -a, -j, -m
Please pick one of them.

For a list of all options, type:
    """+params.scriptname+"""
        \n""")
        sys.exit(2)
    
    ###############
    # File output #
    ###############
    
    splitchains(params.args, params.opts, doprint=(not params.isOpt("quiet")))
