#!/usr/bin/env python
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.

from __future__ import print_function, division

# on retina screens, the default plots are way too small
# by using Qt5 and setting QT_AUTO_SCREEN_SCALE_FACTOR=1
# Qt5 will scale everything using the dpi in ~/.Xresources
import matplotlib
matplotlib.use("Qt5Agg")

if __name__ == '__main__':
    import argparse
    import matplotlib.pyplot as plt
    import numpy as np
    import h5py
    import pandas as pd
    from sddm import IDP_E_MINUS, IDP_MU_MINUS, SNOMAN_MASS
    from sddm.plot import plot_hist, plot_legend, get_stats

    parser = argparse.ArgumentParser("plot fit results")
    parser.add_argument("filenames", nargs='+', help="input files")
    args = parser.parse_args()

    for filename in args.filenames:
        print(filename)
        
        with h5py.File(filename) as f:
            ev = pd.read_hdf(filename, "ev")
            mcgn = pd.read_hdf(filename, "mcgn")
            fits = pd.read_hdf(filename, "fits")

            # get rid of 2nd events like Michel electrons
            ev = ev.sort_values(['run','gtid']).groupby(['evn'],as_index=False).nth(0)

            # Now, we merge all three datasets together to produce a single
            # dataframe. To do so, we join the ev dataframe with the mcgn frame
            # on the evn column, and then join with the fits on the run and
            # gtid columns.
            #
            # At the end we will have a single dataframe with one row for each
            # fit, i.e. it will look like:
            #
            # >>> data
            #   run   gtid nhit, ... mcgn_x, mcgn_y, mcgn_z, ..., fit_id1, fit_x, fit_y, fit_z, ...
            #
            # Before merging, we prefix the primary seed track table with mcgn_
            # and the fit table with fit_ just to make things easier.

            # Prefix track and fit frames
            mcgn = mcgn.add_prefix("mcgn_")
            fits = fits.add_prefix("fit_")

            # merge ev and mcgn on evn
            data = ev.merge(mcgn,left_on=['evn'],right_on=['mcgn_evn'])
            # merge data and fits on run and gtid
            data = data.merge(fits,left_on=['run','gtid'],right_on=['fit_run','fit_gtid'])

            # For this script, we only want the single particle fit results
            data = data[(data.fit_id2 == 0) & (data.fit_id3 == 0)]

            # Select only the best fit for a given run, gtid, and particle
            # combo
            data = data.sort_values('fit_fmin').groupby(['run','gtid','fit_id1','fit_id2','fit_id3'],as_index=False).nth(0).reset_index(level=0,drop=True)

            # calculate true kinetic energy
            mass = [SNOMAN_MASS[id] for id in data['mcgn_id'].values]
            data['T'] = data['mcgn_energy'].values - mass
            data['dx'] = data['fit_x'].values - data['mcgn_x'].values
            data['dy'] = data['fit_y'].values - data['mcgn_y'].values
            data['dz'] = data['fit_z'].values - data['mcgn_z'].values
            data['dT'] = data['fit_energy1'].values - data['T'].values

            true_dir = np.dstack((data['mcgn_dirx'],data['mcgn_diry'],data['mcgn_dirz'])).squeeze()
            dir = np.dstack((np.sin(data['fit_theta1'])*np.cos(data['fit_phi1']),
                             np.sin(data['fit_theta1'])*np.sin(data['fit_phi1']),
                             np.cos(data['fit_theta1']))).squeeze()

            data['theta'] = np.degrees(np.arccos((true_dir*dir).sum(axis=-1)))

            # only select fits which have at least 2 fits
            data = data.groupby(['run','gtid']).filter(lambda x: len(x) > 1)
            data_true = data[data['fit_id1'] == data['mcgn_id']]
            data_e = data[data['fit_id1'] == IDP_E_MINUS]
            data_mu = data[data['fit_id1'] == IDP_MU_MINUS]

            data_true = data_true.set_index(['run','gtid'])
            data_e = data_e.set_index(['run','gtid'])
            data_mu = data_mu.set_index(['run','gtid'])

            data_true['ratio'] = data_mu['fit_fmin']-data_e['fit_fmin']
            data_true['te'] = data_e['fit_time']
            data_true['tm'] = data_mu['fit_time']
            data_true['Te'] = data_e['fit_energy1']

        if len(data_true) < 2:
            continue

        mean, mean_error, std, std_error = get_stats(data_true.dT)
        print("dT      = %.2g +/- %.2g" % (mean, mean_error))
        print("std(dT) = %.2g +/- %.2g" % (std, std_error))
        mean, mean_error, std, std_error = get_stats(data_true.dx)
        print("dx      = %4.2g +/- %.2g" % (mean, mean_error))
        print("std(dx) = %4.2g +/- %.2g" % (std, std_error))
        mean, mean_error, std, std_error = get_stats(data_true.dy)
        print("dy      = %4.2g +/- %.2g" % (mean, mean_error))
        print("std(dy) = %4.2g +/- %.2g" % (std, std_error))
        mean, mean_error, std, std_error = get_stats(data_true.dz)
        print("dz      = %4.2g +/- %.2g" % (mean, mean_error))
        print("std(dz) = %4.2g +/- %.2g" % (std, std_error))
        mean, mean_error, std, std_error = get_stats(data_true.theta)
        print("std(theta) = %4.2g +/- %.2g" % (std, std_error))

        plt.figure(1)
        plot_hist(data_true.dT, label=filename)
        plt.xlabel("Kinetic Energy difference (MeV)")
        plt.figure(2)
        plot_hist(data_true.dx, label=filename)
        plt.xlabel("X Position difference (cm)")
        plt.figure(3)
        plot_hist(data_true.dy, label=filename)
        plt.xlabel("Y Position difference (cm)")
        plt.figure(4)
        plot_hist(data_true.dz, label=filename)
        plt.xlabel("Z Position difference (cm)")
        plt.figure(5)
        plot_hist(data_true.theta, label=filename)
        plt.xlabel(r"$\theta$ (deg)")
        plt.figure(6)
        plot_hist(data_true.ratio, label=filename)
        plt.xlabel(r"Log Likelihood Ratio ($e/\mu$)")
        plt.figure(7)
        plot_hist(data_true.te/1e3/60.0, label=filename)
        plt.xlabel(r"Electron Fit time (minutes)")
        plt.figure(8)
        plot_hist(data_true.tm/1e3/60.0, label=filename)
        plt.xlabel(r"Muon Fit time (minutes)")
        plt.figure(9)
        plot_hist(data_true.fit_psi/data_true.nhit, label=filename)
        plt.xlabel(r"$\Psi$/Nhit")

    plot_legend(1)
    plot_legend(2)
    plot_legend(3)
    plot_legend(4)
    plot_legend(5)
    plot_legend(6)
    plot_legend(7)
    plot_legend(8)
    plot_legend(9)
    plt.show()
