#!/usr/bin/env python

import os
import sys
import time

from subprocess import Popen, PIPE
from socket import gethostname
from os.path import join as pj, abspath as absp, exists as pexists, basename, isfile

from subprocesswithtimeout import SubProcessWithTimeout

try: 
  from ConfigParser import SafeConfigParser, NoOptionError
except ImportError: # The ConfigParser module has been renamed to configparser in Python 3
  from configparser import SafeConfigParser, NoOptionError

__version__ = "0.1"
__author__  = "Matteo Giantomassi"

__all__ = [
"JobRunner",
"TimeBomb",
"OMPEnvironment",
]

_debug = False
#_debug = True
if _debug:
  def dbg_print(*args): print(args)
else:
  def dbg_print(*args): pass

#############################################################################################################

CFG_KEYWORDS = {
# keyword             : (parser, default value i.e. NO MPI, section, description)
# [mpi]
"mpi_flavor"         : (str, "", "mpi", "Name of the MPI implementation e.g. openmpi, mpich2 ..."), 
"mpi_version"        : (str, "", "mpi", "Version of the MPI implementation e.g. 1.2.4"), 
"mpi_prefix"         : (str, "", "mpi", "Top directory of the MPI library. e.g. /shared/openmpi-ifc10"), 
"mpirun_np"          : (str, "", "mpi", "String specifying how to execute a binary with N processors"), 
#"mpirun_extra_args"  : (str, "", "mpi", "Options passed after the 'mpirun -np 3' command"), 
#"np_option"         : (str, "", "-np", "")
"poe"                : (str, "", "mpi", "poe location"),
"poe_args"           : (str, "", "mpi", "arguments passed to poe"),
#"info"              : (str, "", "mpi", "String with optional information"),
}

def mpicfg_parser(fname, defaults=None):
  """Parse the configuration file that describes the test."""
  dbg_print("Parsing [MPI] section in file : " + str(fname))

  parser = SafeConfigParser(defaults)
  parser.read(fname)

  ############################################
  # Read variables needed to handle the job.
  ############################################
  d = dict()
                                                                                         
  for key, tup in CFG_KEYWORDS.items():
    line_parser = tup[0]
    section = tup[2]
    if section in parser.sections():
      try:
        d[key] = parser.get(section, key)
      except NoOptionError:
        d[key] = tup[1] # Section exists but option is not specified. Use default value.
    else:
      d[key] = tup[1] # Section does not exist. Use default value.
                                                                                         
    # Process the line
    try:
      d[key] = line_parser( d[key] )
    except:
      raise ValueError("Wrong line: key = " + str(key) + " d[key] = " + str(d[key]) )

  return d

#############################################################################################################
class JobRunnerError(Exception):

  def __init__(self, return_code, cmd, run_etime, prev_errmsg=None):
    self.return_code = return_code
    self.cmd = cmd
    self.run_etime = run_etime
    self.prev_errmsg = prev_errmsg

  def __str__(self):
     #return "\n".join( [str(k) + " : " + str(v) for (k,v) in self.__dict__.items()] )
     string = "Command %s\n returned exit_code: %s\n" % (self.cmd, self.return_code)
     if self.prev_errmsg: 
       string += "Previous exception: %s" % self.prev_errmsg
     return string

#############################################################################################################

class JobRunner(object):
  """Base Class used to manage the execution of jobs in an MPI environment."""

  @classmethod
  def fromdict(cls, kwargs, ompenv=None, timebomb=None):
    d = dict()
    d["ompenv"] = ompenv
    d["timebomb"] = timebomb
    d.update(kwargs)
    return cls(d)

  @classmethod
  def fromfile(cls, fname, timebomb=None):
    d = mpicfg_parser(fname)
    d["ompenv"] = OMPEnvironment.from_file(fname, allow_empty=True)
    d["timebomb"] = timebomb
    return cls(d)

  @classmethod
  def sequential(cls, ompenv=None, timebomb=None):
    d = dict()
    d["ompenv"] = ompenv
    d["timebomb"] = timebomb
    return cls(d)

  @classmethod
  def generic_mpi(cls, ompenv=None, timebomb=None):
    # It should work, provided that the shell environment is properly defined.
    d = dict()
    d["mpirun_np"] = "mpirun -np"
    #d["mpirun_extra_args"] = ""
    d["ompenv"] = ompenv
    d["timebomb"] = timebomb
    return cls(d)

  def __init__(self, dic):
    self.__debug = _debug

    self.exceptions = []

    for (k, v) in dic.items():
      if k not in self.__dict__:
        self.__dict__[k] = v
      else:
        err_msg = " key %s is already in self.__dict__, cannot overwrite" % k
        raise ValueError(err_msg)

    if (self.has_poe and self.has_mpirun):
      raise ValueError("poe and mpirun are mutually exclusive")

  def __str__(self):
    if self.__debug: 
      return "\n".join([str(k) + " : " + str(v) for (k, v) in self.__dict__.items()] )

    string = ""
    for key in CFG_KEYWORDS:
       attr = getattr(self, str(key), None)
       if attr: string += "%s = %s\n" % (key, attr)
    if string: string = "[MPI setup]\n" + string

    if self.has_ompenv: string += "[OpenMP]\n" + str(self.ompenv)
    return string

  #def _startup(self):
    #pass
    #if self.mpi_flavor in ["mpich2",]: # start mpd demon.
    #   mpdtrace = pj(self.mpi_prefix, "bin", "mpdtrace")
    #   output = Popen(mpdtrace, stdout=PIPE).communicate()[0]
    #   print output
    #   hostname = gethostname()
    #   if not output.startswith(hostname):
    #     mpd = pj(self.mpi_prefix, "bin", "mpd")
    #     p = Popen([mpd, "&"], shell=True)
		# 		time.sleep(3)

  def set_timebomb(self, timebomb):
     if self.has_timebomb:
       raise ValueError("timebomb is already defined")
     else:
       self.timebomb = timebomb

  @property
  def has_mpirun(self):
    return ( hasattr(self, "mpirun_np") and 
             bool(getattr(self, "mpirun_np")) )

  @property
  def has_poe(self):
    return ( hasattr(self, "poe") and 
             bool(getattr(self, "poe")) )

  @property
  def has_timebomb(self):
    return ( hasattr(self, "timebomb") and 
             bool(getattr(self, "timebomb")) )

  @property
  def has_ompenv(self):
    return ( hasattr(self, "ompenv") and 
             bool(getattr(self, "ompenv")) )

  def set_ompenv(self, ompenv):
    if self.has_ompenv:
      raise ValueError("ompenv is already defined")
    else:
      self.ompenv = ompenv

  def run(self, mpi_nprocs, bin_path, stdin_fname, stdout_fname, stderr_fname, cwd=None):

    env = os.environ.copy()
    if self.has_ompenv: env.update(self.ompenv)

    if self.has_mpirun:
      args = [self.mpirun_np, str(mpi_nprocs),  
              bin_path, "<", stdin_fname, ">", stdout_fname, "2>", stderr_fname]

    elif self.has_poe:
      # example ${poe} abinit ${poe_args} -procs 4
      args = [self.poe, bin_path, self.poe_args, " -procs "+str(mpi_nprocs),  
              " <", stdin_fname, ">", stdout_fname, "2>", stderr_fname]
    else:
      assert mpi_nprocs == 1
      args = [bin_path, "<", stdin_fname, ">", stdout_fname, "2>",stderr_fname]

    cmd = " ".join(args)
    if self.__debug: print "About to execute command:\n" + cmd

    start_time = time.time()
    ret_code = None

    try:
      if self.has_timebomb:
         (p, ret_code) = self.timebomb.run(cmd, shell=True, cwd=cwd, env=env)
      else:
        p = Popen(cmd, shell=True, cwd=cwd, env=env)
        ret_code = p.wait()

      run_etime = time.time() - start_time

      if ret_code != 0: 
        exc = JobRunnerError(ret_code, " ".join(args), run_etime)
        self.exceptions.append(exc)

    except:
      run_etime = time.time() - start_time
      prev_errmsg = str(sys.exc_info()[1])
      exc = JobRunnerError(ret_code, " ".join(args), run_etime, prev_errmsg)
      self.exceptions.append(exc)

    return run_etime

#############################################################################################################
class TimeBomb(object):

  def __init__(self, timeout, delay=.05, exec_path=None):
    self.timeout = float(timeout)
    self.delay = float(delay)
    self.exec_path = exec_path

  def run(self, args, 
          bufsize=0, executable=None, stdin=None, stdout=None, stderr=None, preexec_fn=None, 
          close_fds=False, shell=False, cwd=None, env=None, universal_newlines=False, startupinfo=None, creationflags=0):
    "Same interface as Popen"

    try:

      if self.exec_path: 
        #
        # timeout exec is available.
        #
        if self.timeout > 0.:
          dbg_print("Using timeout function: " + self.exec_path)
          if isinstance(args, str):
            args = " ".join([self.exec_path, str(self.timeout), args])
          else:
            args = [self.exec_path, str(self.timeout)] + args

        p = Popen(args,
                  bufsize=bufsize, executable=executable, stdin=stdin, stdout=stdout, stderr=stderr, preexec_fn=preexec_fn, 
                  close_fds=close_fds, shell=shell, cwd=cwd, env=env, universal_newlines=universal_newlines, startupinfo=startupinfo,
                  creationflags=creationflags)

        ret_code = p.wait()

      else: 
        #
        # timeout exec is NOT available.
        #
        if self.timeout > 0.0:
          dbg_print("Using SubprocesswithTimeout and timeout_time : "+str(self.timeout))
          p = SubProcessWithTimeout(self.timeout, delay=self.delay)

          (p, ret_code) = p.run(args, 
              bufsize=bufsize, executable=executable, stdin=stdin, stdout=stdout, stderr=stderr, preexec_fn=preexec_fn, 
              close_fds=close_fds, shell=shell, cwd=cwd, env=env, universal_newlines=universal_newlines, startupinfo=startupinfo,
              creationflags=creationflags)
        else:
          dbg_print("Using Popen (no timeout_time)")
          p = Popen(args,
                    bufsize=bufsize, executable=executable, stdin=stdin, stdout=stdout, stderr=stderr, preexec_fn=preexec_fn, 
                    close_fds=close_fds, shell=shell, cwd=cwd, env=env, universal_newlines=universal_newlines, startupinfo=startupinfo,
                    creationflags=creationflags)

          ret_code = p.wait()

      return (p, ret_code)

    except:
      raise 

#############################################################################################################
class OMPEnvironment(dict):
  #: OMP3 variables. see https://computing.llnl.gov/tutorials/openMP/#EnvironmentVariables"
  _keys = [
     "OMP_SCHEDULE",
     "OMP_NUM_THREADS",
     "OMP_DYNAMIC",
     "OMP_PROC_BIND",
     "OMP_NESTED",
     "OMP_STACKSIZE",
     "OMP_WAIT_POLICY",
     "OMP_MAX_ACTIVE_LEVELS",
     "OMP_THREAD_LIMIT",
     "OMP_STACKSIZE",
     "OMP_PROC_BIND",
  ]

  def __init__(self, *args, **kwargs):
    """ 
    Constructor method inherited from dictionary:
                                               
    >>> OMPEnvironment(OMP_NUM_THREADS=1)
    {'OMP_NUM_THREADS': '1'}
                                               
    To create an instance from the INI file fname, use:
       OMPEnvironment.from_file(fname)
    """
    self.update(*args, **kwargs)

    err_msg = ""
    for key, value in self.items():
      self[key] = str(value)
      if key not in OMPEnvironment._keys:
        err_msg += "unknown option %s" % key
    if err_msg: raise ValueError(err_msg)

  @classmethod
  def from_file(cls, fname, allow_empty=False):
    parser = SafeConfigParser()
    parser.read(fname)

    inst = OMPEnvironment()

    # Consistency check. Note that we only check if the option name is correct, 
    # we do not check whether the value is correct or not.
    if "openmp" not in parser.sections():
      if not allow_empty:
        raise ValueError("%s does not contain any [openmp] section" % fname) 
      return inst

    err_msg = ""
    for key in parser.options("openmp"):
      if key.upper() not in OMPEnvironment._keys:
        err_msg += "unknown option %s, maybe a typo" % key
    if err_msg: raise ValueError(err_msg)

    for key in OMPEnvironment._keys:
      try:
        inst[key] = str(parser.get("openmp", key))
      except NoOptionError:
        try:
          inst[key] = str(parser.get("openmp", key.lower()))
        except NoOptionError:
          pass

    if not allow_empty and not inst:
        raise ValueError("Refusing to return with an empty dict") 

    return inst

#############################################################################################################
#if __name__ == "__main__":
