#!/usr/bin/env python
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.
"""
This is a short script to help manage submitting jobs to the grid. To submit
jobs first run the submit-grid-jobs script passing the filename of the zdab you
want to fit:

    $ submit-grid-jobs ~/zdabs/SNOCR_0000010000_000_p4_reduced.xzdab.gz

This will add a database entry for each gtid and particle id combo. To then
actually submit the jobs to the grid run:

    $ submit-grid-jobs-queue

which will loop through the database entries and create a submit file for all
jobs marked as "NEW" or "RETRY" in the database and then submit it.

This script is also meant to be run as part of a cron job to monitor the status of the jobs. To do so, add something like the following to your crontab:

    PATH=/usr/bin:$HOME/local/bin
    SDDM_DATA=$HOME/sddm/src
    DQXX_DIR=$HOME/dqxx

    0 0 * * * submit-grid-jobs-queue --auto --logfile ~/submit.log

Currently this script will *not* automatically resubmit any jobs but will instead mark any jobs on hold as failed.
"""

from __future__ import print_function, division
import string
from os.path import split, join, abspath
import uuid
from subprocess import check_call, check_output
import os
import sys
from datetime import datetime
import subprocess
import json
import h5py
import numpy as np
from sddm.logger import Logger
from sddm import which, splitext

log = Logger()

CONDOR_TEMPLATE = \
"""
# We need the job to run our executable script, with the
#  input.txt filename as an argument, and to transfer the
#  relevant input and output files:
executable = @executable
arguments = $(zdab) -o $(output_filename) --gtid $(gtid) -p $(particle_combo) --max-time $(max_time)
transfer_input_files = $(input_filename), @transfer_input_files, $(dqxx_filename)

error = $(prefix).error
output = $(prefix).output
log = $(prefix).log

initialdir = $(initial_dir)

# The below are good base requirements for first testing jobs on OSG, 
#  if you don't have a good idea of memory and disk usage.
requirements = (HAS_MODULES == True) && (OSGVO_OS_STRING == "RHEL 7") && (OpSys == "LINUX")
request_cpus = 1
request_memory = 1 GB
request_disk = 1 GB

priority = $(priority)

max_retries = 5
on_exit_hold = ( ExitCode == 1 || ExitCode == 134 ) || (NumJobCompletions > 4 && ExitCode =!= 0)
max_idle = 1000

# Queue one job with the above specifications.
queue input_filename, prefix, zdab, output_filename, gtid, particle_combo, max_time, dqxx_filename, initial_dir, priority from (
@queue
)

+ProjectName = "SNOplus"
""".strip()

# all files required to run the fitter (except the DQXX files)
INPUT_FILES = ["muE_water_liquid.txt","pmt_response_qoca_d2o_20060216.dat","pmt_response_qoca_salt_20060420.dat","rsp_rayleigh.dat","e_water_liquid.txt","pmt_pcath_response.dat","pmt.txt","muE_deuterium_oxide_liquid.txt","pmt_response.dat","proton_water_liquid.txt"]

class MyTemplate(string.Template):
    delimiter = '@'

def create_submit_file(results, sddm_data, dqxx_dir):
    """
    Creates a submit file and returns the file as a string.
    """
    # set up the arguments for the template
    executable = which("fit")
    wrapper = which("fit-wrapper")

    queue = []
    for row in results:
        head, tail = split(row['filename'])
        root, ext = splitext(tail)

        # all output files are prefixed with FILENAME_GTID_UUID
        prefix = "%s_%08i_%i_%s" % (root,row['gtid'],row['particle_id'],row['uuid'])

        # fit output filename
        output = "%s.hdf5" % prefix

        if executable is None:
            log.warn("Couldn't find fit in path!",file=sys.stderr)
            sys.exit(1)

        if wrapper is None:
            log.warn("Couldn't find fit-wrapper in path!",file=sys.stderr)
            sys.exit(1)

        dqxx_filename = join(dqxx_dir,"DQXX_%010i.dat" % row['run'])

        new_dir = "%s_%s" % (root,row['uuid'])

        home_dir = os.getcwd()

        if not os.path.isdir(new_dir):
            log.debug("mkdir %s" % new_dir)
            os.mkdir(new_dir)

        queue.append(",".join(map(str,[row['filename'],prefix,tail,output,row['gtid'],row['particle_id'],"%f" % row['max_time'],dqxx_filename,new_dir,row['priority']])))

    template = MyTemplate(CONDOR_TEMPLATE)

    transfer_input_files = ",".join([executable] + [join(sddm_data,filename) for filename in INPUT_FILES])

    submit_string = template.safe_substitute(
            executable=wrapper,
            transfer_input_files=transfer_input_files,
            queue='\n'.join(queue))

    return submit_string

def remove_job(row):
    """
    Remove a particular job from the job queue. Returns 0 on success, -1 on
    failure.
    """
    entry = get_entry(row)

    if entry == -1:
        return -1

    try:
        log.debug("condor_rm %s" % entry['ClusterId'])
        check_call(['condor_rm',str(entry['ClusterId'])])
    except subprocess.CalledProcessError:
        return -1

    return 0

def get_entry(row):
    """
    Returns a entry from the condor_q -json output for a given row.
    """
    head, tail = split(row['filename'])
    root, ext = splitext(tail)

    new_dir = "%s_%s" % (root,row['uuid'])

    # all output files are prefixed with FILENAME_GTID_UUID
    prefix = "%s_%08i_%i_%s" % (root,row['gtid'],row['particle_id'],row['uuid'])

    out = "%s.output" % prefix

    log.debug('condor_q -json --attributes Out,JobStatus --constraint \'Out == "%s"\'' % out)
    output = check_output(["condor_q","-json","--attributes","Out,JobStatus","--constraint",'Out == "%s"' % out])

    if not output:
        return -1

    data = json.loads(output)

    for entry in data:
        if entry['Out'] == out:
            return entry

    return -1

def release_job(row):
    """
    Release a particular job. Returns 0 on success, -1 on failure.
    """
    entry = get_entry(row)

    if entry == -1:
        return -1

    try:
        log.debug("condor_release %s.%s" % (entry['ClusterId'],entry['ProcId']))
        
        check_call(['condor_release',"%s.%s" % (entry['ClusterId'],entry['ProcId'])])
    except subprocess.CalledProcessError:
        return -1
    return 0

def get_job_status(row, data=None):
    """
    Check to see if a given grid job is finished. Returns the following statuses:

        0    Unexpanded
        1    Idle
        2    Running
        3    Removed
        4    Completed
        5    Held
        6    Submission_err
        7    Job failed
        8    Success

    These come from the JobStatus entry in condor_q. The values here come from
    http://pages.cs.wisc.edu/~adesmet/status.html.
    """
    head, tail = split(row['filename'])
    root, ext = splitext(tail)

    new_dir = "%s_%s" % (root,row['uuid'])

    # all output files are prefixed with FILENAME_GTID_UUID
    prefix = "%s_%08i_%i_%s" % (root,row['gtid'],row['particle_id'],row['uuid'])

    out = "%s.output" % prefix

    if data is None:
        log.debug('condor_q -json --attributes Out,JobStatus --constraint \'Out == "%s"\'' % out)
        output = check_output(["condor_q","-json","--attributes","Out,JobStatus","--constraint",'Out == "%s"' % out])

        if output:
            data = json.loads(output)
        else:
            data = []

    for entry in data:
        if entry['Out'] == out:
            return entry['JobStatus']

    # If there's no entry from condor_q the job is done. Now, we check to see
    # if it completed successfully. Note: Jobs often don't complete
    # successfully because I've noticed that even though I have specified in my
    # submit file that the node should have modules, many of them don't!
    #
    # Update: With the new queue statement, I have no way of knowing if a job
    # hasn't been submitted yet, or if it is done. Therefore, we assume here
    # that if the log file doesn't exist, it hasn't run yet.

    log_file = join(new_dir,"%s.log" % prefix)

    try:
        with open(log_file) as f:
            if "return value 0" in f.read():
                # Job completed successfully
                pass
            else:
                log.warn("Log file '%s' doesn't contain the string 'return value 0'. Assuming job failed." % log_file)
                return 7
    except IOError:
        log.debug("Log file '%s' doesn't exist. Assuming job hasn't started running." % log_file)
        return 2

    hdf5_file = join(new_dir,"%s.hdf5" % prefix)

    try:
        with h5py.File(hdf5_file) as f:
            if 'git_sha1' in f.attrs:
                # Job completed successfully
                return 8
            else:
                log.warn("No git_sha1 attribute in HDF5 file '%s'. Assuming job failed." % hdf5_file)
                return 7
    except IOError:
        log.warn("HDF5 file '%s' doesn't exist. Assuming job failed." % hdf5_file)
        return 7

    return 7

def main(conn):
    c = conn.cursor()

    results = c.execute('SELECT id, filename, run, uuid, gtid, particle_id, max_time, state, nretry, submit_file, priority FROM state ORDER BY priority DESC, timestamp ASC')

    stats = {}

    log.debug("condor_q -json --attributes Out,JobStatus")
    output = check_output(["condor_q","-json","--attributes","Out,JobStatus"])

    if output:
        data = json.loads(output)
    else:
        data = []

    for row in results.fetchall():
        id, filename, run, uuid, gtid, particle_id, max_time, state, nretry, submit_file, priority = row

        if state not in stats:
            stats[state] = 1
        else:
            stats[state] += 1

        if state == 'NEW':
            pass
        elif state == 'RUNNING':
            # check to see if it's completed
            job_status = get_job_status(row, data=data)

            if job_status in (0,1,2,4):
                # nothing to do!
                log.verbose("Still waiting for job %i to finish" % id)
            elif job_status == 3:
                c.execute("UPDATE state SET state = 'FAILED', message = ? WHERE id = ?", ("job was removed",id))
            elif job_status == 8:
                # Success!
                log.notice("Job %i completed successfully!" % id)
                c.execute("UPDATE state SET state = 'SUCCESS' WHERE id = ?", (id,))
            elif job_status == 5:
                # For now, I don't do anything for held jobs. I can mark them
                # manually as failed or retry in the database.
                pass
            elif job_status == 7:
                c.execute("UPDATE state SET state = 'FAILED', message = ? WHERE id = ?", ("job failed", id))
            else:
                # Don't know what to do here for Removed or Submission_err
                log.warn("Job %i is in the state %i. Don't know what to do." % (id, job_status))
        elif state == 'RETRY':
            pass
        elif state in ('SUCCESS','FAILED'):
            # Nothing to do here
            pass
        else:
            log.warn("Job %i is in the unknown state '%s'." % (id,state))
            
        conn.commit()

    log.notice("Stats on jobs in the database:")
    for state, value in stats.iteritems():
        log.notice("    %s: %i" % (state,value))

if __name__ == '__main__':
    import argparse
    import os
    import sqlite3
    import traceback
    import datetime

    parser = argparse.ArgumentParser("submit grid jobs", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--db", type=str, help="database file", default=None)
    parser.add_argument('--loglevel',
                        help="logging level (debug, verbose, notice, warning)",
                        default='notice')
    parser.add_argument('--logfile', default=None,
                        help="filename for log file")
    parser.add_argument('--max-retries', type=int, default=2, help="maximum number of times to try and resubmit a grid job")
    parser.add_argument('-n', type=int, default=None, help="number of jobs to create submit file for")
    parser.add_argument('--dry-run', action='store_true', default=False, help="create the submit file but don't submit it")
    parser.add_argument('--auto', action='store_true', default=False, help="automatically loop over database entries and submit grid jobs")
    args = parser.parse_args()

    log.set_verbosity(args.loglevel)

    if args.logfile:
        log.set_logfile(args.logfile)

    if args.db is None:
        home = os.path.expanduser("~")
        args.db = join(home,'state.db')

    conn = sqlite3.connect(args.db)

    conn.row_factory = sqlite3.Row

    c = conn.cursor()

    if 'SDDM_DATA' not in os.environ:
        log.warn("Please set the SDDM_DATA environment variable to point to the fitter source code location", file=sys.stderr)
        sys.exit(1)

    sddm_data = os.environ['SDDM_DATA']

    if 'DQXX_DIR' not in os.environ:
        log.warn("Please set the DQXX_DIR environment variable to point to the directory with the DQXX files", file=sys.stderr)
        sys.exit(1)

    dqxx_dir = os.environ['DQXX_DIR']

    # get absolute paths since we are going to create a new directory
    sddm_data = abspath(sddm_data)
    dqxx_dir = abspath(dqxx_dir)

    zdab_cat = which("zdab-cat")

    if zdab_cat is None:
        log.warn("Couldn't find zdab-cat in path!",file=sys.stderr)
        sys.exit(1)

    if args.auto:
        try:
            main(conn)
            conn.commit()
            conn.close()
        except Exception as e:
            log.warn(traceback.format_exc())
            sys.exit(1)
        sys.exit(0)

    cmd = 'SELECT * FROM state WHERE state in ("NEW","RETRY") ORDER BY priority DESC, timestamp ASC'

    if args.n:
        cmd += ' LIMIT %i' % args.n

    results = c.execute(cmd).fetchall()

    if len(results) == 0:
        print("No more jobs!")
        sys.exit(0)

    submit_string = create_submit_file(results, sddm_data, dqxx_dir)

    date_string = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

    submit_filename = "condor_submit_%s.submit" % date_string

    print("Writing %s" % submit_filename)
    with open(submit_filename, "w") as f:
        f.write(submit_string)

    if not args.dry_run:
        print("Submitting %s" % submit_filename)
        try:
            # Send stdout and stderr to /dev/null
            log.debug("condor_submit %s" % submit_filename)
            check_call(["condor_submit",submit_filename])
        except subprocess.CalledProcessError:
            raise
        else:
            for row in results:
                c.execute("UPDATE state SET state = 'RUNNING', nretry = COALESCE(nretry + 1,0) WHERE id = ?", (row['id'],))
            conn.commit()
