import sys
import os
import imp
import platform
import re
import cPickle as cpkl

from os.path import join as pj, abspath as absp, exists as pexists, basename
from socket import gethostname

from tests.pymods.testsuite import TestSuite, ChainOfTests
from tests.pymods.tools import pprint_table
from tests.pymods.devtools import FileLock, FileLockException

__all__ = [
]

#############################################################################################################

class AbinitEnvironment(object):
  """
  Container with information on the abinit source tree. 
  Provide helper functions to construct the absolute path of directories.
  """
  def __init__(self):
    self.uname = platform.uname()
    self.hostname = gethostname()
    try:
      self.username = os.getlogin()
    except:
      self.username = "No_username"

    self.tests_dir, tail = os.path.split(absp(__file__))
    self.home_dir, tail = os.path.split(self.tests_dir) 

    self.psps_dir = pj(self.tests_dir, "Psps_for_tests")
    self.fldiff_path = pj(self.tests_dir, "Scripts", "fldiff.pl")

  def __str__(self):
    return "\n".join( [str(k) + " : " + str(v) for (k, v) in self.__dict__.items()] )

  def apath_of(self, *p):
    """
    Return the absolute path of p where p is one or more pathname components
    relative to the top level directory of the package. 
    """
    return pj(self.home_dir, *p)

  def isbuild(self):
    "True if we have built the code in the top level directory"
    configh_path = pj(self.home_dir, "config.h")
    abinit_path = pj(self.home_dir,"src","98_main","abinit")
    return (isfile(configh_path) and isfile(abinit_path))

abenv = AbinitEnvironment()

database_path = pj(abenv.tests_dir, "test_suite.cpkl")

#############################################################################################################

_tsuite_dirs = [
#"abirules",
"atompaw",
"bigdft",
#"buildsys",
"built-in",
#"cpu",       # Disabled
"etsf_io",
"fast",
"fox",
"gpu",
"libxc",
"mpiio",
"paral",
#"physics",   # Disabled !  
"seq",
#"tutoparal", # TODO
"tutoplugs",
"tutorespfn",
"tutorial",
"unitary",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v67mbpt",
"v7",
"vdwxc",
"wannier90",
]
_tsuite_dirs.sort()
_tsuite_dirs = tuple( [pj(abenv.tests_dir, dir_name) for dir_name in _tsuite_dirs] )

class Suite(object):
  "Information on one test suite"

  def __init__(self, suite_path):

    suite_path = absp(suite_path)

    self.suite_path = absp(suite_path)
    self.name = basename(suite_path)

    module_name = pj(suite_path, "__init__.py")
    module = imp.load_source(module_name, pj(suite_path, "__init__.py") )

    self.keywords      = set(module.keywords)
    self.need_cpp_vars = set(module.need_cpp_vars)

    # Divide tests into (active|disabled).
    self.inp_paths = [p for p in module.inp_files if not p.startswith("-")]
    self.disabled_inp_paths = [p[1:] for p in module.inp_files if p.startswith("-")]

    # Use absolute paths
    self.inp_paths = [pj(suite_path, "Input", p) for p in self.inp_paths]
    self.disabled_inp_paths = [pj(suite_path, "Input", p) for p in self.disabled_inp_paths]

    # True if the suite contains tests that should be executed with 
    # different numbers of MPI processes
    self.is_multi_parallel = False
    if hasattr(module, "is_multi_parallel"):
      self.is_multi_parallel = module.is_multi_parallel

    self.subsuites = {}
    if hasattr(module, "subsuites"):
      subsuite_names = module.subsuites
      for k in subsuite_names: self.subsuites[k] = [] 

      for name in subsuite_names:
        pattern = re.compile("-?t" + name + "_\d+\.in")
        for inp in module.inp_files:
          if pattern.match(inp):
            #print inp, "--> subsuite: ", name
            inp_path = pj(suite_path, "Input", inp)
            self.subsuites[name].append(inp_path)

      nfound = sum([len(paths) for paths in self.subsuites.values()] )
      if nfound != len(module.inp_files):
        err_msg = "At least one input_file does not belong to a subsuite, nfound = %s, __init__.nfiles = %s\n" % (
          nfound, len(module.inp_files))

        for inp in module.inp_files:
          for paths in self.subsuites.values():
            fnames = [basename(p) for p in paths]
            if inp in fnames: break
          else:
            err_msg += "%s not found\n" % inp
        raise ValueError(err_msg)
      #
      # Remove disabled tests (if any).
      for (sub_name, paths) in self.subsuites.items():
        inp_paths = [p for p in paths if not p.endswith("-")]
        self.subsuites[sub_name] = inp_paths

  def has_subsuite(self, subsuite_name):
    return subsuite_name in self.subsuites

  def inputs_of_subsuite(self, subsuite_name):
    "Return the absolut path of the input files in the subsuite."
    return self.subsuites[subsuite_name]

  def __str__(self):
    return "\n".join( [str(k) + " : " + str(v) for k,v in self.__dict__.items()] )

#############################################################################################################
class TestsDatabase(dict):
  "Database of tests indexed by the name of the abinit suite"

  def __init__(self, suites):
    dict.__init__(self)
    self._suites = suites

  def iter_tests(self):
    "Iterate over all the tests"
    for suite in self.values():
      for test in suite: yield test

  @property
  def suite_names(self):
    "List of suite names"
    return [suite.name for suite in self._suites.values()]

  @property
  def authors_snames(self):
    "List of authors' second names extracted from the tests"
    all_snames = []
    for test in self.iter_tests():
      all_snames.extend(test._authors_snames)
    return set(all_snames)

  def test_chains(self):
    "Return a list with all the chained tests."
    chains = []
    for test in self.iter_tests():
      if isinstance(test, ChainOfTests): chains.append(test)
    return chains

  def add_test_suite(self, suite_name, test_suite):
    "Add test_suite to the database using the key suite_name."
    if suite_name not in self.suite_names:
      raise ValueError("%s is not a valid suite name" % suite_name)

    if suite_name in self:
      raise ValueError("%s is already in the database" % suite_name)

    self[suite_name] = test_suite

  def get_test_suite(self, suite_name, subsuite_name=None, slice_obj=None):
    """
    Return a list of tests belonging to the suite suite_name.
    If subsuite_name is not None, only the tests in suite_name/subsuite_name are returned.
    slice_obj is a slice object that can be used to specify the initial and final number of the test.
    if slice_obj is None, all the tests in suite_name/subsuite_name are returned.
    """
    test_suite = self[suite_name]

    if subsuite_name is not None:
      # Build the tests in the subsuite from the input files.
      suite = self._suites[suite_name]

      if not suite.has_subsuite(subsuite_name):
        err_msg = "suite %s does not have subsuite %s" % (suite_name, subsuite_name)
        raise ValueError(err_msg)

      sub_inputs = suite.inputs_of_subsuite(subsuite_name)

      abenv = test_suite.abenv
      test_suite = TestSuite(abenv, inp_files=sub_inputs, 
                             keywords=suite.keywords,
                             need_cpp_vars=suite.need_cpp_vars)
    if slice_obj is None:  
      return test_suite
    else:
      return test_suite[slice_obj]

#############################################################################################################
class AbinitTests(object):
   """
   Object describing the collection of automatic tests
   """
   def __init__(self):
     self.suite_names = tuple( [basename(d) for d in _tsuite_dirs] )
     self.suite_paths = tuple( [pj(abenv.tests_dir, dir_name) for dir_name in _tsuite_dirs] )

     self._suites = dict()
     for (suite_name, suite_path) in self.walk_suites():
       self._suites[suite_name] = Suite(suite_path)

     # Check suite_names and subsuite_names
     all_subsuite_names = self.all_subsuite_names
     for (suite_name, suite_path) in self.walk_suites():
       if suite_name in all_subsuite_names:
         print "Found suite and subsuite with the same name: %s" % suite_name

   def walk_suites(self):
     "return list of (suite_name, suite_paths)"
     return zip(self.suite_names, self.suite_paths)

   def __str__(self):
     return "\n".join( [str(k) + " : " + str(v) for (k,v) in self.__dict__.items()] )

   def get_suite(self, suite_name):
     return self._suites[suite_name]

   @property
   def suites(self):
     return self._suites.values()

   def multi_parallel_suites(self):
     suites = self.suites
     return [s for s in suites if s.is_multi_parallel]

   def suite_of_subsuite(self, subsuite_name):
     "Return the suite corresponding to the given subsuite."
     for suite in self.suites:
       if suite.has_subsuite(subsuite_name): return suite
     else:
       raise ValueError("subsuite %s not found" % subsuite_name)

   @property
   def all_subsuite_names(self):
     "List with the names of all the registered subsuites."
     all_subnames = []
     for suite in self.suites:
       all_subnames.extend(suite.subsuites.keys())

     assert len(all_subnames) == len(set(all_subnames))
     return all_subnames

   def keywords_of_suite(self, suite_name):
     return self._suites[suite_name].keywords

   def cpp_vars_of_suite(self, suite_name):
     return self._suites[suite_name].need_cpp_vars

   def inputs_of_suite(self, suite_name, active=True):
     if active:
       return self._suites[suite_name].inp_paths
     else:
       return self._suites[suite_name].disabled_inp_paths

   def build_database(self, with_disabled=False):
     """
     Return an instance of TestsDatabase containing all the ABINIT automatic tests.
     with_disabled specifies whether disabled tests should be included in the database.
     """
     database = TestsDatabase(self._suites)
     for suite_name in self.suite_names:

        inp_files = self.inputs_of_suite(suite_name, active=True)
        if with_disabled:
          inp_files.extend(self.inputs_of_suite(suite_name, active=False))

        test_suite = TestSuite(abenv, 
          inp_files     = inp_files,
          keywords      = self.keywords_of_suite(suite_name), 
          need_cpp_vars = self.cpp_vars_of_suite(suite_name) )

        database.add_test_suite(suite_name, test_suite)

     return database

   def get_database(self, regenerate=False):
     """
     Return an instance of TestsDatabase initialized from an external Cpickle file.
     regenerate: Set it to True to force the regeneration of the database 
                 and the writing of a new Cpickle file.
     """
     if regenerate or not pexists(database_path):
       print "Regenerating database..."
       database = self.build_database()
                                                                
       # Save the database in the cpickle file.
       # Use file locking mechanism to prevent IO from other processes.
       lock = FileLock(database_path)
       lock.acquire()
                                                                
       fh = open(database_path,"w")
       cpkl.dump(database, fh, protocol=-1)
       fh.close()
                                                                
       lock.release()
                                                                
     else:
       print "Loading database from file: " + database_path

       # Read the database from the cpickle file.
       # Use file locking mechanism to prevent IO from other processes.
       lock = FileLock(database_path)
       lock.acquire()
                                                                
       fh = open(database_path,"r") 
       database = cpkl.load(fh)
       fh.close()
                                                                
       lock.release()
                                                                
     return database

   def _suite_args_parser(self, args=None):
     """
     Parse script arguments. Return a mapping suite_name --> [slice objects]
     Three forms are possible
     0)                               Run all tests.
     1) v2[34:35] v1[12:] v5 v6[:45]  Select slices in the suites 
     2) v4- v5-                       Exclude suites
     """
     #from tests.pymods.mynamedtuple import namedtuple

     # Mapping (suite_name, subsuite_name) --> slice_obj
     def all_tests():
        tuples = [ (name, None) for name in self.suite_names ]
        return dict.fromkeys(tuples, [slice(None),])                                                                                          

     #def all_tests_in_suite(suite_name):
     #  pass

     # Run all tests.
     if args is None or not args: return all_tests() 
                                                                             
     args = [s.replace(" ","") for s in args]
                                                                             
     exclude_mode = False
     for arg in args:
       if arg.endswith("-"): exclude_mode = True
                                                                             
     if exclude_mode:
       d = all_tests()
       for arg in args:
         if arg.endswith("-"): 
           arg = arg[:-1]
           if arg in self.suite_names:
             d.pop((arg,None))  # v4- --> Remove v4
           else:
             # TODO 
             raise NotImplementedError("exclude_mode does not support subsuites")
             suite = self.suite_of_subsuite(arg)
             d.pop((suite.name, arg))  # gw1- --> skip tutorial/gw1

     else:
       import re 
       re_slice  = re.compile("^\[(\d*):(\d*)\]$")
       re_single = re.compile("^\[(\d*)\]$")
       d = {}
       for arg in args:
         start_stop = slice(None)
         idx = arg.find("[")
         if idx != -1: 
            arg, string = arg[:idx], arg[idx:]
            match = re_slice.search(string)
            if match:
              (start, stop) = match.group(1), match.group(2)
              if not start: start = 1
              if not stop: stop = None
              if stop is None:
                start_stop = slice(int(start), stop)
              else:
                start_stop = slice(int(start), int(stop))
            else: 
              match = re_single.search(string)
              if match:
                start = int(match.group(1))
                start_stop = slice(start, start+1)
              else:
                raise ValueError("Wrong or unknow argument: %s" % arg)

         if arg in self.suite_names:
           tp = (arg, None)
         elif arg in self.all_subsuite_names: 
           suite = self.suite_of_subsuite(arg)
           tp = (suite.name, arg)
         else:
            raise ValueError("Wrong (suite_name|subsuite_name) : %s" % arg)

         if tp not in d:
           d[tp] = [start_stop]
         else:
           d[tp].append(start_stop)

     return d

   def select_tests(self, suite_args, regenerate=False, keys=None, authors=None):
     """
     Main entry point for client code.
     Return an instance of TestSuite
     """
     try:
       tests_todo = self._suite_args_parser(suite_args)
     except:
       raise

     # Load the full database.
     database = self.get_database(regenerate=regenerate)

     # Extract the tests to run as specified by suite_args i.e by the string "v1[1:4] v3 ..." 
     # TODO waiting for changes in the naming scheme
     suites_without_slicing = ["paral", "mpiio", "built-in", "seq"] 
                                                                                                                              
     tests = TestSuite(abenv, test_list=[])

     tuples = tests_todo.keys()
     tuples.sort() # FIXME Not the sorting algorith we would like to have!

     for t in tuples:
       suite_name, subsuite_name = t
       for slice_obj in tests_todo[t]:
         #print "Extracting suite_name: %s, subsuite_name: %s, slice_obj: %s" % (suite_name, subsuite_name, slice_obj)

         # FIXME
         if suite_name in suites_without_slicing: slice_obj = None

         tests = tests + database.get_test_suite(suite_name, subsuite_name=subsuite_name, slice_obj=slice_obj)

     if keys or authors:
       # Create new suite whose tests contain the specified keywords.
       with_keys, exclude_keys, with_authors, exclude_authors = 4 * (None,)

       if keys:
         with_keys    = [k for k in keys if not k.endswith("-")]
         exclude_keys = [k[:-1] for k in keys if k.endswith("-")]
         print "Extracting tests with keywords = %s, without keywords %s" % (with_keys, exclude_keys)

       if authors:
         with_authors    = [a for a in authors if not a.endswith("-")]
         exclude_authors = [a[:-1] for a in authors if a.endswith("-")]
         print " Exctracting tests with authors = %s, without authors %s" % (with_authors, exclude_authors)

       tests = tests.select_tests(with_keys    = with_keys, 
                                  exclude_keys = exclude_keys, 
                                  with_authors = with_authors,
                                  exclude_authors = exclude_authors)
     return tests

   def generate_html_readme(self):
      "Generate the HTML readme files" 
      database = self.get_database(regenerate=True)

      for (suite_name, suite_path) in self.walk_suites():
        suite = database.get_test_suite(suite_name)
        readme_fname = pj(suite_path, "readme.html")
        print "  Writing readme file: ",readme_fname
        fh = open(fname,"w")
        fh.write(suite.make_readme(width=160, html=True))
        fh.close()

   def show_info(self):
     table = []
     table.append( ["Suite",  "#Active Tests",  "#Disabled Tests"] )
     for suite_name in self.suite_names:
        active_tests   = self.inputs_of_suite(suite_name, active=True)
        disabled_tests = self.inputs_of_suite(suite_name, active=False)
        table.append( [suite_name, str(len(active_tests)), str(len(disabled_tests))] )

     pprint_table(table)

     skeys = KNOWN_KEYWORDS.keys() 
     skeys.sort()
     print 8*"=" + " KEYWORDS " + 8*"="
     width = max( [len(k) for k in KNOWN_KEYWORDS] ) + 5
     for skey in skeys:
        desc = KNOWN_KEYWORDS[skey]
        print skey.ljust(width), desc

     #for suite_name in self.suite_names:
     #  suite = self.get_suite(suite_name)
     #  print suite_name, suite.need_cpp_vars

     #for subsuite_name in self.all_subsuite_names:
     #  print subsuite_name

     #for suite in self.multi_parallel_suites():
     #  print suite.name
     #sys.exit(0)

     from pprint import pprint

     # list authors
     database = self.get_database(regenerate=True)
     pprint(database.authors_snames)

     chains = database.test_chains()

     for c in chains:
       (string, nlinks) = c.info_on_chain()
       if nlinks == 0: print 15*"*" + " Warning: found 0 explicit links " + 15*"*"
       print string

abitests = AbinitTests()

#############################################################################################################

#: dictionary with test keywords and their brief description.
KNOWN_KEYWORDS = {
# Main executables
'abinit'    : "Test abinit code.",
'anaddb'    : "Test anaddb code", 
"atompaw"   : "Atompaw tests",
'band2eps'  : "Test band2eps code", 
'cut3d'     : "Test cut3d code", 
'mrgscr'    : "Test mrgscr code",
'mrggkk'    : "Test mrggkk code",
'mrgddb'    : "Test mrgddb code", 
'lwf'       : "Test lwf code",
'optic'     : "Test optic code",
'ujdet'     : "Test ujdet code",
'aim'       : "Test aim code",
'conducti'  : "Test conducti code",  
'fftprof'   : "Test fftprof code and low-level FFT routines", 
'wannier90' : "Test the interface with Wannier90", 
'macroave'  : "Test macroave code",  
'bigdft'    : "The the interface with Bigdft", 
'unitary'   : "Unitary tests", 
# keywords describing the test
'GW'        : "GW calculations", 
'NC'        : "Norm-conserving calculations", 
'DDK'       : "DDK calculations", 
"PAW"       : "PAW calculations",
"BSE"       : "Bethe-Salpeter calculations",
"LDAU"      : "LDA+U calculations",
"EPH"       : "Electron-phonon calculations",
"LEXX"      : "Local exact exchange calculations",
"DMFT"      : "Dynamical mean-field theory",
"TDDFT"     : "Time-dependent Density Functional Theory",
"STRING"    : "String method",
"CML"       : "Chemical Markup Language",
"WanT"      : "Interface with Want code",
"XML"       : "XML output (testing prtxml)",
"CIF"       : "Tests producing Crystallographic Information Files",
"positron"  : "positron calculations",
}

#? DOS, LDA, GGA ... GS, RF, SOC
  
#############################################################################################################
