#!/usr/bin/env python2
import sys
from termcolor import colored
import shlex
from subprocess import Popen, PIPE
import itertools
import re
import numpy as np
import os
from shutil import copy2
import matplotlib.pyplot as plt
import json
from math import *
from collections import OrderedDict
import csv
import argparse
def GetDuckDir():
    return os.path.dirname(os.path.realpath(__file__))

def nNucl(molbaselines):
      return float(molbaselines[1].split()[0])

def isMononucle(molbaselines):
    return nNucl(molbaselines)==1

def openfileindir(path,readwrite):
    mydir=os.path.dirname(path)
    if not os.path.exists(mydir) and mydir!="": 
        os.makedirs(mydir)
    return open(path,readwrite)
def outfile(Outdic,item,index=None):
    itemdata=Outdic[item]
    if itemdata["Enabled"]:
        fmt=itemdata["Format"]
        if index is not None:
            filename=fmt.format(index)
        else:
            filename=fmt
        if "Parent" in Outdic:
            path=os.path.join(Outdic["Parent"],filename)
        else:
            path=filename
        return openfileindir(path,'w')
    else:
        return

def runDuck(mol,basis,x,molbaselines,molbase,basisbase):
            #gennerate molecule file
        currdir=os.getcwd()
        os.chdir(GetDuckDir())
        molname='.'.join([mol,str(x)])
        lstw=list()
        for i,line in enumerate(molbaselines):
            if i<3:
                lstw.append(line)
            else:
                if isMononucle(molbaselines):
                   if i==3:
                      lstw.append(' '.join([str(x)]+line.split()[1:]))
                else:
                    v=[float(abs(x))/float(2),float(-abs(x)/float(2))]
                    val=v[i-3]
                    lstw.append(' '.join([line.split()[0],'0.','0.',str(val)]))
        junkfiles=list()
        with open(molbase+molname,'w') as n:
            junkfiles.append(n.name)
            n.write(os.linesep.join(lstw))
        #Copy basis
        basisfile=basisbase+'.'.join([mol,basis])
        newbasisfile=basisbase+'.'.join([molname,basis])
        copy2(basisfile,newbasisfile)
        junkfiles.append(newbasisfile)
        #start child process Goduck
        cmd=" ".join(["./GoDuck",molname, basis])
        Duck=Popen(shlex.split(cmd),stdout=PIPE)
        (DuckOut, DuckErr) = Duck.communicate()
        excode=Duck.wait()
        for junk in junkfiles:
            os.remove(junk)
        os.chdir(currdir)
        return (excode,DuckOut,DuckErr)

def addvalue(dic,key,x,y):
    if key not in dic:
        dic[key]=list()
    dic[key].append(y)
    print(key)
    print(x,y)

def main(mol):
    #get basepath for files
    molbase='examples/molecule.'
    basisbase=molbase.replace('molecule','basis')
    with open('PyOptions.json','r') as jfile:
        options=json.loads(jfile.read())
    basis=str(options['Basis'])
    #Get mehtod to analyse
    methodsdic=options['Methods']
    #Get datas to analyse in this method
    scandic=options['Scan']
    scan=np.arange(scandic['Start'],scandic['Stop']+scandic['Step'],scandic['Step'])
    print(scan)
    mymethods=dict()
    alllabels=list()
    for method,methoddatas in methodsdic.iteritems():
        if methoddatas['Enabled']:
            mymethods[method]=methoddatas
            for label,labeldatas in methoddatas['Labels'].iteritems():
                if type(labeldatas) is dict:
                    enabled=labeldatas['Enabled']
                else: 
                    enabled=labeldatas
                if enabled and label not in alllabels:
                    alllabels.append(label)
    graphdic=dict()
    errorconvstring="Convergence failed"
    with open(os.path.join(GetDuckDir(),molbase+mol),'r') as b:
        molbaselines=b.read().splitlines()
    if isMononucle(molbaselines):
        print('monoatomic system: variation of the nuclear charge')
    else:
        print('polyatomic system: variation is on the distance')  
    for x in scan:
        (DuckExit,DuckOut,DuckErr)=runDuck(mol,basis,x,molbaselines,molbase,basisbase)
        #print DuckOut on file or not
        if "Outputs" in options:
            outdat=options["Outputs"]
            if 'DuckOutput' in outdat:
                outopt=outdat["DuckOutput"] 
                if outopt['Enabled']:
                    if outopt['Multiple']:
                        duckoutf=outfile(outopt,"DuckOutput",x)
                    else:
                        if x==scan[0]:
                            duckoutf=outfile(outdat,"DuckOutput")
                        duckoutf.write('Z' if isMononucle(molbaselines) else 'Distance'+' '+str(x)+os.linesep+os.linesep)
                    duckoutf.write(DuckOut)
                    if outopt['Multiple']:
                        duckoutf.close()
        print("GoDuk exit code " + str(DuckExit))
        if DuckExit !=0:
        #if GoDuck is not happy
            print(DuckErr)
            sys.exit(-1)
        #get all data for the method
        for method,methoddatas in mymethods.iteritems():
            isnan=False
            if '{0}' in  method:
                if "index" in methoddatas:
                    methodheaders=[method.format(str(x)) for x in methoddatas['Index']]
                else:
                    try:
                        print(method)
                        reglist=re.findall('(\d+)'.join([re.escape(s) for s in method.split('{0}')]),DuckOut)
                        print(reglist)
                        final=max([(int(i[0]) if type(i) is tuple else int(i))  for i in reglist])
                        print(final)
                        methodheaders=[method.format(str(final))]
                    except:
                        isnan=True
                        methodheaders=[None]
                    method=method.replace('{0}','')
            else:
                methodheaders=list([method])
            for methodheader in methodheaders:
                if len(methodheaders)!=1:
                    method=methodheader
                lbldic=methoddatas['Labels']
                print(methodheader)
                if methodheader is None:
                    methodtxt=''
                else:
                    it=itertools.dropwhile(lambda line: methodheader + ' calculation' not in line , DuckOut.splitlines())
                    it=itertools.takewhile(lambda line: 'Total CPU time for ' not in line, it)
                    methodtxt=os.linesep.join(it)
                if errorconvstring in methodtxt:
                    print(colored(' '.join([method, errorconvstring, '!!!!!']),'red'))
                    isnan=True
                if methodtxt=='':
                    print(colored('No data' +os.linesep+  'RHF scf not converged or method not enabled','red'))
                    isnan=True
                #find the expected values
                for label,labeldatas in lbldic.iteritems():
                    if type(labeldatas) is dict:
                        indexed=('Index' in labeldatas)
                        enabled=labeldatas['Enabled']
                        graph=labeldatas['Graph'] if 'Graph' in labeldatas else 1
                    else:
                        enabled=labeldatas
                        graph=1
                        indexed=False
                    if enabled:
                        if graph not in graphdic:
                            graphdic[graph]=OrderedDict()
                        y=graphdic[graph]
                        if not indexed:
                            v=np.nan
                            print(method)
                            print(label)
                            if not isnan:
                                try:
                                    m=re.search('\s+'.join([re.escape(w) for w in label.split()]) + "\s+(?:"+re.escape("(eV):")+"\s+)?(?:=\s+)?(-?\d+.?\d*)",methodtxt)
                                    v=m.group(1)
                                except:
                                    v=np.nan
                            addvalue(y,(method,label),x,v)
                        else:
                            startindex=-1
                            columnindex=-1
                            linedtxt=methodtxt.split(os.linesep)
                            for n,line in enumerate(linedtxt):
                                if all(x in line for x in ['|',' '+label+' ','#']):
                                    startindex=n+2
                                    columnindex=[s.strip() for s in line.split('|')].index(label)
                                    break
                            with open(os.path.join(GetDuckDir(),'input','molecule'),'r') as molfile:
                                    molfile.readline()
                                    line=molfile.readline()
                                    nel=int(line.split()[1])
                                    print(nel)
                            HOMO=int(nel/2)
                            HO=HOMO
                            LUMO=HOMO+1
                            BV=LUMO
                            for i in labeldatas['Index']:
                                v=np.nan
                                if type(i) is str or type(i) is unicode:
                                    ival=eval(i)
                                    if type(ival) is not int:
                                        print('Index '+ str(i) + 'must be integer')
                                        sys.exit(-2)
                                else:
                                    ival=i
                                v=np.nan
                                if not isnan:
                                    try:
                                        if startindex!=-1 and columnindex!=-1:
                                            line=linedtxt[startindex+ival-1]
                                            v=float(line.split('|')[columnindex].split()[0])
                                            print(method)
                                            print(label)
                                            print(i)
                                        else:
                                            v=np.nan
                                    except:
                                        v=np.nan
                                key=(method,label,i)
                                addvalue(y,key,x,v)
                                tpl=(x,scan.tolist().index(x)+1,len(y[key]))
                                print(tpl)
                                if tpl[1]-tpl[2]:
                                    sys.exit()
    #define graph grid
    maxgraph=max(graphdic.keys())
    maxrow=int(round(sqrt(maxgraph)))
    maxcol=int(ceil(float(maxgraph)/float(maxrow)))
    #define label ls
    for graph,y in graphdic.iteritems():
        datas=list()
        datas.append(["#x"]+scan.tolist())
        if len(y.keys())!=0:
            plt.subplot(maxrow,maxcol,graph)
            plt.xlabel('Z' if isMononucle(molbaselines) else 'Distance '+mol)
            ylbls=list([basis])
            for i in range(0,2):
                lst=list(set([key[i] for key in y.keys()]))
                if len(lst)==1:
                    ylbls.append(lst[0])
            plt.ylabel(' '.join(ylbls))
            print('Legend')
            print(list(y.keys()))
            for key,values in y.iteritems():
                legend=list()
                for el in key[0:2]:
                    if el not in ylbls:
                        legend.append(el)
                if len(key)>2:
                    legend.append(str(key[2]))
                #plot curves
                lbl=' '.join(legend)
                plt.plot(scan,y[key],'-o',label=lbl)
                #print("min",x[y.index(min(y))]/2)
                #generate legends
                plt.legend()
                dataout=False
                if "Outputs" in options:
                    outputs=options['Outputs']
                    if "DataOutput" in outputs:
                        DataOutput=outputs['DataOutput']
                        dataout=DataOutput['Enabled']
                if dataout:
                    fmtlegendf='{0}({1})'
                    datas.append([fmtlegendf.format("y",lbl)]+y[key])
            if dataout:
                csvdatas=zip(*datas)
                with outfile(outputs,"DataOutput",graph) as csvf:
                    writer = csv.writer(csvf, delimiter=' ')
                    writer.writerow(['#']+ylbls)
                    writer.writerows(csvdatas)
    #show graph
    if "Outputs" in options:
       outputs=options['Outputs']
       if "FigureOutput" in outputs:
           figout=outputs["FigureOutput"]
           if figout["Enabled"]:
              plt.savefig(figout['Path'])
    plt.show()
if __name__ == '__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument("mol",nargs='?', help="molecule to compute",type=str)
    parser.add_argument("-c,--copy", help="Copy sample option file",action="store_true",dest="copy")
    args = parser.parse_args()
    if len(sys.argv)==1:
        parser.print_help()
    else:
        if args.copy:
            copy2(os.path.join(GetDuckDir(),"PyOptions.template.json"),"PyOptions.json")
            if args.mol is not None:
                os.system("vim PyOptions.json")
        if args.mol is not None:
            main(args.mol)