#!/usr/bin/env python
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.
"""
Script to plot the fit results for stopping muons and Michels. To run it just
run:

    $ ./plot-muon [list of fit results]

Currently it will plot energy distributions for external muons, stopping muons,
and michel electrons.
"""
from __future__ import print_function, division
import numpy as np
from scipy.stats import iqr, poisson
from scipy.stats import iqr, norm, beta
from sddm.stats import *
import emcee
from sddm.dc import estimate_errors, EPSILON
import nlopt
from sddm import printoptions

particle_id = {20: 'e', 22: 'u'}

# Absolute tolerance for the minimizer.
# Since we're minimizing the negative log likelihood, we really only care about
# the value of the minimum to within ~0.05 (10% of the one sigma shift).
# However, I have noticed before that setting a higher tolerance can sometimes
# cause the fit to get stuck in a local minima, so we set it here to a very
# small value.
FTOL_ABS = 1e-10

def print_particle_probs(data):
    n = [len(data[data.id == id]) for id in (20,22,2020,2022,2222)]

    alpha = np.ones_like(n) + n

    mode = dirichlet_mode(alpha)
    std = np.sqrt(dirichlet.var(alpha))

    for i, id in enumerate((20,22,2020,2022,2222)):
        particle_id_str = ''.join([particle_id[int(''.join(x))] for x in grouper(str(id),2)])
        print("P(%s) = %.2f +/- %.2f" % (particle_id_str,mode[i]*100,std[i]*100))

def make_nll(data, mc, bins, print_nll=False):
    data_hist = np.histogram(data.ke.values,bins=bins)[0]

    def nll(x, grad=None):
        if any(x[i] < 0 for i in (0,2)):
            return np.inf

        # Get the Monte Carlo histograms. We need to do this within the
        # likelihood function since we apply the energy resolution parameters
        # to the Monte Carlo.
        cdf = norm.cdf(bins[:,np.newaxis],mc.ke.values*(1+x[1]),mc.ke.values*x[2])
        mc_hist = np.sum(cdf[1:] - cdf[:-1],axis=-1)*x[0]

        # Calculate the negative log of the likelihood of observing the data
        # given the fit parameters
        oi = data_hist
        ei = mc_hist + EPSILON
        N = ei.sum()
        nll = N + np.sum(gammaln(oi+1)) - np.sum(oi*np.log(ei))

        if print_nll:
            # Print the result
            print("nll = %.2f" % nll)

        return nll
    return nll

if __name__ == '__main__':
    import argparse
    import numpy as np
    import pandas as pd
    import sys
    import h5py
    from sddm.plot_energy import *
    from sddm.plot import *
    from sddm import setup_matplotlib
    from sddm.utils import correct_energy_bias

    parser = argparse.ArgumentParser("plot fit results")
    parser.add_argument("filenames", nargs='+', help="input files")
    parser.add_argument("--save", action='store_true', default=False, help="save corner plots for backgrounds")
    parser.add_argument("--mc", nargs='+', required=True, help="atmospheric MC files")
    parser.add_argument("--print-nll", action='store_true', default=False, help="print nll values")
    parser.add_argument("--steps", type=int, default=1000, help="number of steps in the MCMC chain")
    parser.add_argument("--walkers", type=int, default=100, help="number of walkers")
    args = parser.parse_args()

    setup_matplotlib(args.save)

    import matplotlib.pyplot as plt

    # Loop over runs to prevent using too much memory
    evs = []
    rhdr = pd.concat([read_hdf(filename, "rhdr").assign(filename=filename) for filename in args.filenames],ignore_index=True)
    for run, df in rhdr.groupby('run'):
        evs.append(get_events(df.filename.values, merge_fits=True))
    ev = pd.concat(evs)
    ev = correct_energy_bias(ev)

    # Note: We loop over the MC filenames here instead of just passing the
    # whole list to get_events() because I had to rerun some of the MC events
    # using SNOMAN and so most of the runs actually have two different files
    # and otherwise the GTIDs will clash
    ev_mcs = []
    for filename in args.mc:
        ev_mcs.append(get_events([filename], merge_fits=True, mc=True))
    ev_mc = pd.concat(ev_mcs)
    ev_mc = correct_energy_bias(ev_mc)

    # Drop events without fits
    ev = ev[~np.isnan(ev.fmin)]
    ev_mc = ev_mc[~np.isnan(ev_mc.fmin)]

    ev = ev.reset_index()
    ev_mc = ev_mc.reset_index()

    # First, do basic data cleaning which is done for all events.
    ev = ev[ev.signal | ev.muon]
    ev_mc = ev_mc[ev_mc.signal | ev_mc.muon]

    # 00-orphan cut
    ev = ev[(ev.gtid & 0xff) != 0]
    ev_mc = ev_mc[(ev_mc.gtid & 0xff) != 0]

    # Now, we select events tagged by the muon tag which should tag only
    # external muons. We keep the sample of muons since it's needed later to
    # identify Michel electrons and to apply the muon follower cut
    muons = ev[ev.muon]
    muons_mc = ev_mc[ev_mc.muon]

    # Try to identify Michel electrons. Currently, the event selection is based
    # on Richie's thesis. Here, we do the following:
    #
    # 1. Apply more data cleaning cuts to potential Michel electrons
    # 2. Nhit >= 100
    # 3. It must be > 800 ns and less than 20 microseconds from a prompt event
    #    or a muon
    michel = ev[ev.michel]
    michel_mc = ev_mc[ev_mc.michel]

    # remove events 200 microseconds after a muon
    ev = ev.groupby('run',group_keys=False).apply(muon_follower_cut)

    ev = ev[ev.psi < 6]
    ev_mc = ev_mc[ev_mc.psi < 6]

    handles = [Line2D([0], [0], color='C0'),
               Line2D([0], [0], color='C1')]
    labels = ('Data','Monte Carlo')

    # For the Michel energy plot, we only look at events where the
    # corresponding muon had less than 2500 nhit.  The reason for only looking
    # at Michel electrons from muons with less than 2500 nhit is because there
    # is significant ringing and afterpulsing after a large muon which can
    # cause the reconstruction to overestimate the energy.
    michel_low_nhit = michel[michel.muon_gtid.isin(ev.gtid.values) & (michel.muon_nhit < 2500)]
    michel_low_nhit_mc = michel_mc[michel_mc.muon_gtid.isin(ev_mc.gtid.values) & (michel_mc.muon_nhit < 2500)]

    bins = np.linspace(0,100,41)

    nll = make_nll(michel_low_nhit,michel_low_nhit_mc,bins,args.print_nll)

    x0 = np.array([len(michel_low_nhit)/len(michel_low_nhit_mc),0.0,EPSILON])
    opt = nlopt.opt(nlopt.LN_SBPLX, len(x0))
    opt.set_min_objective(nll)
    low = np.array([EPSILON,-1,EPSILON])
    high = np.array([1e9,1.0,1e9])
    opt.set_lower_bounds(low)
    opt.set_upper_bounds(high)
    opt.set_ftol_abs(FTOL_ABS)
    opt.set_initial_step([0.01]*len(x0))

    xopt = opt.optimize(x0)
    print("xopt = ", xopt)
    nll_xopt = nll(xopt)
    print("nll(xopt) = ", nll(xopt))

    stepsizes = estimate_errors(nll,xopt,low,high)

    stepsizes[1] = 0.1

    with printoptions(precision=3, suppress=True):
        print("Errors: ", stepsizes)

    pos = np.empty((args.walkers, len(x0)),dtype=np.double)
    for i in range(pos.shape[0]):
        pos[i] = xopt + np.random.randn(len(x0))*stepsizes
        pos[i,:] = np.clip(pos[i,:],low,high)

    nwalkers, ndim = pos.shape

    sampler = emcee.EnsembleSampler(nwalkers, ndim, lambda x: -nll(x), moves=emcee.moves.KDEMove())
    with np.errstate(invalid='ignore'):
        sampler.run_mcmc(pos, args.steps)

    print("Mean acceptance fraction: {0:.3f}".format(np.mean(sampler.acceptance_fraction)))

    try:
        print("autocorrelation time: ", sampler.get_autocorr_time(quiet=True))
    except Exception as e:
        print(e)

    log_prob = sampler.get_log_prob(flat=True)
    samples = sampler.get_chain(flat=True)
    flat_samples = samples.reshape((-1,len(x0)))

    plt.figure()
    plt.subplot(2,2,1)
    plt.hist(flat_samples[:,0],bins=100,histtype='step')
    plt.xlabel("Normalization")
    despine(ax=plt.gca(),left=True,trim=True)
    plt.gca().get_yaxis().set_visible(False)
    plt.subplot(2,2,2)
    plt.hist(flat_samples[:,1],bins=100,histtype='step')
    plt.xlabel("Energy Bias")
    despine(ax=plt.gca(),left=True,trim=True)
    plt.gca().get_yaxis().set_visible(False)
    plt.subplot(2,2,3)
    plt.hist(flat_samples[:,2],bins=100,histtype='step')
    plt.xlabel("Energy Resolution")
    despine(ax=plt.gca(),left=True,trim=True)
    plt.gca().get_yaxis().set_visible(False)
    if args.save:
        plt.savefig("michel_electron_fit_posterior.pdf")
        plt.savefig("michel_electron_fit_posterior.eps")
    else:
        plt.suptitle("Fit Parameters")

    mode_bias = samples[log_prob.argmax()][1]
    mean_bias, error_bias = np.mean(samples[:,1]), np.std(samples[:,1])
    mode_resolution = samples[log_prob.argmax()][2]
    mean_resolution, error_resolution = np.mean(samples[:,2]), np.std(samples[:,2])

    print("Energy bias = %.2g +/- %.2g" % (mode_bias,error_bias))
    print("Energy resolution = %.2g +/- %.2g" % (mode_resolution,error_resolution))

    print("Particle ID probability for Michel electrons:")
    print("Data")
    print_particle_probs(michel_low_nhit)
    print("Monte Carlo")
    print_particle_probs(michel_low_nhit_mc)

    fig = plt.figure()
    plot_hist2_data_mc(michel_low_nhit,michel_low_nhit_mc)
    despine(fig,trim=True)
    fig.legend(handles,labels,loc='upper right')
    if args.save:
        plt.savefig("michel_electrons.pdf")
        plt.savefig("michel_electrons.eps")
    else:
        plt.suptitle("Michel Electrons")

    fig = plt.figure()
    plt.hist(michel_low_nhit.ke.values, bins=bins, histtype='step', color='C0', label="Data")
    plt.hist(michel_low_nhit_mc.ke.values, weights=np.tile(len(michel_low_nhit)/len(michel_low_nhit_mc),len(michel_low_nhit_mc.ke.values)), bins=bins, histtype='step', color='C1', label="Monte Carlo")
    hist = np.histogram(michel_low_nhit.ke.values,bins=bins)[0]
    hist_mc = np.histogram(michel_low_nhit_mc.ke.values,bins=bins)[0]
    if hist_mc.sum() > 0:
        norm = hist.sum()/hist_mc.sum()
    else:
        norm = 1.0
    p = get_multinomial_prob(hist,hist_mc,norm)
    plt.text(0.95,0.95,"p = %.2f" % p,horizontalalignment='right',verticalalignment='top',transform=plt.gca().transAxes)
    despine(fig,trim=True)
    plt.xlabel("Energy (MeV)")
    plt.tight_layout()
    plt.legend()
    if args.save:
        plt.savefig("michel_electrons_ke.pdf")
        plt.savefig("michel_electrons_ke.eps")
    else:
        plt.title("Michel Electrons")
        plt.show()
