#!/usr/bin/env python
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.
"""
Script to plot the energy and time difference distribution for neutrons. To run
it just run:

    $ ./plot-neutrons [list of fit results]
"""
from __future__ import print_function, division
import numpy as np
from scipy.stats import iqr, poisson
from scipy.stats import iqr, norm, beta
from sddm.stats import *
import emcee
from sddm.dc import estimate_errors, EPSILON
import nlopt
from sddm import printoptions

particle_id = {20: 'e', 22: 'u'}

if __name__ == '__main__':
    import argparse
    import numpy as np
    import pandas as pd
    import sys
    import h5py
    from sddm.plot_energy import *
    from sddm.plot import *
    from sddm import setup_matplotlib
    from sddm.utils import correct_energy_bias

    parser = argparse.ArgumentParser("plot fit results")
    parser.add_argument("filenames", nargs='+', help="input files")
    parser.add_argument("--save", action='store_true', default=False, help="save corner plots for backgrounds")
    parser.add_argument("--mc", nargs='+', required=True, help="atmospheric MC files")
    args = parser.parse_args()

    setup_matplotlib(args.save)

    import matplotlib.pyplot as plt

    # Loop over runs to prevent using too much memory
    evs = []
    rhdr = pd.concat([read_hdf(filename, "rhdr").assign(filename=filename) for filename in args.filenames],ignore_index=True)
    for run, df in rhdr.groupby('run'):
        evs.append(get_events(df.filename.values, merge_fits=True))
    ev = pd.concat(evs)
    ev = correct_energy_bias(ev)

    # Note: We loop over the MC filenames here instead of just passing the
    # whole list to get_events() because I had to rerun some of the MC events
    # using SNOMAN and so most of the runs actually have two different files
    # and otherwise the GTIDs will clash
    ev_mcs = []
    for filename in args.mc:
        ev_mcs.append(get_events([filename], merge_fits=True, mc=True))
    ev_mc = pd.concat(ev_mcs)
    ev_mc = correct_energy_bias(ev_mc)

    ev = ev.reset_index()
    ev_mc = ev_mc.reset_index()

    # remove events 200 microseconds after a muon
    ev = ev.groupby('run',group_keys=False).apply(muon_follower_cut)

    # 00-orphan cut
    ev = ev[(ev.gtid & 0xff) != 0]
    ev_mc = ev_mc[(ev_mc.gtid & 0xff) != 0]

    neutrons = ev[ev.neutron]
    neutrons_mc = ev_mc[ev_mc.neutron]

    atm = ev[ev.signal & ev.prompt & ev.atm]
    atm_mc = ev_mc[ev_mc.signal & ev_mc.prompt & ev_mc.atm]

    # Drop events without fits
    atm = atm[~np.isnan(atm.fmin)]
    atm_mc = atm_mc[~np.isnan(atm_mc.fmin)]

    atm = atm[atm.psi < 6]
    atm_mc = atm_mc[atm_mc.psi < 6]

    atm = pd.merge(atm,neutrons,left_on=['run','gtid'],right_on=['run','atm_gtid'],suffixes=('','_neutron'))
    atm_mc = pd.merge(atm_mc,neutrons_mc,left_on=['run','gtid'],right_on=['run','atm_gtid'],suffixes=('','_neutron'))

    print("neutrons with nhit > 100")
    print(atm[atm.nhit_neutron >= 100][['run','gtid_neutron','nhit_neutron','nhit_cal_neutron']])

    fig = plt.figure(1)
    plt.hist(atm.nhit_cal_neutron.values,bins=np.linspace(0,100,101),histtype='step',color='C0',label="Data")
    weights = np.tile(len(atm)/len(atm_mc),len(atm_mc))
    plt.hist(atm_mc.nhit_cal_neutron.values,bins=np.linspace(0,100,101),weights=weights,histtype='step',color='C1',label="Monte Carlo")
    plt.xlabel("Nhit")
    despine(fig,trim=True)
    plt.tight_layout()
    plot_legend(1)
    if args.save:
        plt.savefig("neutron_nhit_cal.pdf")
        plt.savefig("neutron_nhit_cal.eps")
    else:
        plt.title("Neutron Nhit Distribution")

    fig = plt.figure(2)
    dt = (atm.gtr_neutron - atm.gtr)/1e6;
    bins = np.linspace(20e-3,250,101)
    plt.hist(dt,bins=bins,histtype='step',color='C0',label="Data")
    dt = (atm_mc.gtr_neutron - atm_mc.gtr)/1e6;
    plt.hist(dt,bins=bins,weights=weights,histtype='step',color='C1',label="Monte Carlo")
    plt.xlabel(r"$\Delta$ t (ms)")
    despine(fig,trim=True)
    plt.tight_layout()
    plot_legend(1)
    if args.save:
        plt.savefig("neutron_delta_t.pdf")
        plt.savefig("neutron_delta_t.eps")
    else:
        plt.title(r"Neutron $\Delta t$ Distribution")
        plt.show()
