#!/usr/bin/env python
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.
"""
Script to combine the fit results from jobs submitted to the grid. It's
expected to be run from a cron job:

    PATH=/usr/bin:$HOME/local/bin
    SDDM_DATA=$HOME/sddm/src
    DQXX_DIR=$HOME/dqxx

    0 * * * * module load hdf5; module load py-h5py; module load zlib; cat-grid-jobs --loglevel debug --logfile cat.log --output-dir $HOME/fit_results

The script will loop through all entries in the database and try to combine the
fit results into a single output file.
"""

from __future__ import print_function, division
import os
import sys
import numpy as np
from datetime import datetime
import h5py
from os.path import join, split
from subprocess import check_call
from sddm import splitext, which
from sddm.logger import Logger
import subprocess

log = Logger()

def cat_grid_jobs(conn, output_dir, zdab_dir=None):
    zdab_cat = which("zdab-cat")

    if zdab_cat is None:
        log.warn("couldn't find zdab-cat in path!",file=sys.stderr)
        return

    c = conn.cursor()

    results = c.execute('SELECT filename, uuid FROM state').fetchall()

    unique_results = set(results)

    for filename, uuid in unique_results:
        head, tail = split(filename)
        root, ext = splitext(tail)

        # First, find all hdf5 result files
        fit_results = []
        for row in c.execute("SELECT gtid, particle_id FROM state WHERE state = 'SUCCESS' AND filename = ? AND uuid = ?", (filename, uuid)).fetchall():
            # all output files are prefixed with FILENAME_GTID_UUID
            prefix = "%s_%08i_%i_%s" % (root,row['gtid'],row['particle_id'],uuid)
            new_dir = "%s_%s" % (root,uuid)
            # Note: We assume here that the output directory is the same as the
            # directory where the fit results are stored.
            fit_results.append(join(output_dir, new_dir, "%s.hdf5" % prefix))

        if len(fit_results) == 0:
            log.verbose("No fit results found for %s (%s)" % (tail, uuid))
            continue

        output = join(output_dir,"%s_%s_fit_results.hdf5" % (root,uuid))

        if 'reduced' in root:
            directories = [head]
            if zdab_dir is not None:
                directories += [zdab_dir]
            for directory in directories:
                for extension in [ext, '.zdab', '.zdab.gz']:
                    # Use the reprocessed version of the file if possible
                    reprocessed_filename = join(directory,root.replace('reduced','reprocessed')) + extension

                    if os.path.exists(reprocessed_filename):
                        log.verbose("Found reprocessed file '%s'. Using that instead of '%s'" % (reprocessed_filename,tail))
                        filename = reprocessed_filename

        if os.path.exists(output):
            total_fits = 0
            for fit_result_filename in fit_results:
                fit_result_head, fit_result_tail = split(fit_result_filename)

                if not os.path.exists(fit_result_filename):
                    log.warn("File '%s' does not exist!" % fit_result_filename)
                    continue

                with h5py.File(fit_result_filename,'r') as f:
                    if 'git_sha1' not in f.attrs:
                        log.warn("No git sha1 found for '%s'. Skipping..." % fit_result_filename)
                        continue
                    total_fits += f['fits'].shape[0]

            with h5py.File(output,'r') as fout:
                if 'version' not in fout.attrs or fout.attrs['version'] < 3:
                    pass
                elif 'reprocessed' in filename and 'reprocessed' not in fout.attrs:
                    pass
                elif 'fits' in fout and fout['fits'].shape[0] >= total_fits:
                    log.verbose("skipping %s because there are already %i fit results" % (tail,total_fits))
                    continue

        if not os.path.exists(filename):
            log.warn("File '%s' does not exist!" % filename)
            continue

        # First we get the full event list along with the data cleaning word, FTP
        # position, FTK, and RSP energy from the original zdab and then add the fit
        # results.
        #
        # Note: We send stderr to /dev/null since there can be a lot of warnings
        # about PMT types and fit results
        with open(os.devnull, 'w') as f:
            log.debug("zdab-cat %s -o %s" % (filename,output))
            try:
                check_call([zdab_cat,filename,"-o",output],stderr=f)
            except subprocess.CalledProcessError as e:
                log.warn(str(e))
                continue

        total_events = 0
        events_with_fit = 0
        total_fits = 0

        with h5py.File(output,"a") as fout:
            # Mark a version in case we need to reprocess all the files
            fout.attrs['version'] = 3

            # Mark the file as being reprocessed so we know in the future if we
            # already used the reprocessed version instead of the reduced
            # version
            if 'reprocessed' in filename:
                fout.attrs['reprocessed'] = 1

            fits = []

            total_events = fout['ev'].shape[0]
            for fit_result_filename in fit_results:
                fit_result_head, fit_result_tail = split(fit_result_filename)

                if not os.path.exists(fit_result_filename):
                    log.warn("File '%s' does not exist!" % fit_result_filename)
                    continue

                with h5py.File(fit_result_filename) as f:
                    if 'git_sha1' not in f.attrs:
                        log.warn("No git sha1 found for %s. Skipping..." % fit_result_tail)
                        continue

                    # Check to see if the git sha1 match
                    if fout.attrs['git_sha1'] != f.attrs['git_sha1']:
                        log.debug("git_sha1 is %s for current version but %s for %s" % (fout.attrs['git_sha1'],f.attrs['git_sha1'],fit_result_tail))

                    fits.append(f['fits'][:])

                    events_with_fit += len(np.unique(fits[-1][['run','gtid']]))
                    total_fits += fits[-1].shape[0]

            if len(fits):
                del fout['fits']
                fout.create_dataset('fits',data=np.concatenate(fits))

        log.notice("%s (%s): added %i fit results from %i events to a total of %i events" % (tail, uuid, total_fits, events_with_fit, total_events))

if __name__ == '__main__':
    import argparse
    import sqlite3

    parser = argparse.ArgumentParser("concatenate fit results from grid jobs into a single file")
    parser.add_argument("--db", type=str, help="database file", default=None)
    parser.add_argument('--loglevel',
                        help="logging level (debug, verbose, notice, warning)",
                        default='notice')
    parser.add_argument('--logfile', default=None,
                        help="filename for log file")
    parser.add_argument('--output-dir', default=None,
                        help="output directory for fit results")
    parser.add_argument('--zdab-dir', default=None,
                        help="extra directory to search for zdab files")
    args = parser.parse_args()

    log.set_verbosity(args.loglevel)

    if args.logfile:
        log.set_logfile(args.logfile)

    home = os.path.expanduser("~")

    if args.db is None:
        args.db = join(home,'state.db')

    if args.output_dir is None:
        args.output_dir = home
    else:
        if not os.path.exists(args.output_dir):
            log.debug("mkdir %s" % args.output_dir)
            os.mkdir(args.output_dir)

    conn = sqlite3.connect(args.db)

    conn.row_factory = sqlite3.Row

    cat_grid_jobs(conn, args.output_dir, args.zdab_dir)
    conn.close()
