#!/usr/bin/env python
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.
"""
Script to plot the fit results for stopping muons and Michels. To run it just
run:

    $ ./plot-muon [list of fit results]

Currently it will plot energy distributions for external muons, stopping muons,
and michel electrons.
"""
from __future__ import print_function, division
import numpy as np
from scipy.stats import iqr, poisson
from scipy.stats import iqr, norm, beta
from sddm.stats import *
import nlopt
from sddm.dc import estimate_errors

particle_id = {20: 'e', 22: 'u'}

def print_particle_probs(data):
    n = [len(data[data.id == id]) for id in (20,22,2020,2022,2222)]

    alpha = np.ones_like(n) + n

    mode = dirichlet_mode(alpha)
    std = np.sqrt(dirichlet.var(alpha))

    for i, id in enumerate((20,22,2020,2022,2222)):
        particle_id_str = ''.join([particle_id[int(''.join(x))] for x in grouper(str(id),2)])
        print("P(%s) = %.2f +/- %.2f" % (particle_id_str,mode[i]*100,std[i]*100))

def fit_straight_line(y, yerr):
    def nll(x, grad=None):
        nll = -norm.logpdf(x,y,yerr).sum()
        return nll

    x0 = np.array([np.mean(y)])
    opt = nlopt.opt(nlopt.LN_SBPLX, len(x0))
    opt.set_min_objective(nll)
    low = np.array([-1e9])
    high = np.array([1e9])
    opt.set_lower_bounds(low)
    opt.set_upper_bounds(high)
    opt.set_ftol_abs(1e-10)
    opt.set_initial_step([0.01])

    xopt = opt.optimize(x0)
    nll_xopt = nll(xopt)

    stepsizes = estimate_errors(nll,xopt,low,high)

    return xopt[0], stepsizes[0]

if __name__ == '__main__':
    import argparse
    import numpy as np
    import pandas as pd
    import sys
    import h5py
    from sddm.plot_energy import *
    from sddm.plot import *
    from sddm import setup_matplotlib
    from sddm.utils import correct_energy_bias

    parser = argparse.ArgumentParser("plot fit results")
    parser.add_argument("filenames", nargs='+', help="input files")
    parser.add_argument("--save", action='store_true', default=False, help="save corner plots for backgrounds")
    parser.add_argument("--mc", nargs='+', required=True, help="atmospheric MC files")
    args = parser.parse_args()

    setup_matplotlib(args.save)

    import matplotlib.pyplot as plt

    # Loop over runs to prevent using too much memory
    evs = []
    rhdr = pd.concat([read_hdf(filename, "rhdr").assign(filename=filename) for filename in args.filenames],ignore_index=True)
    for run, df in rhdr.groupby('run'):
        evs.append(get_events(df.filename.values, merge_fits=True))
    ev = pd.concat(evs)
    ev = correct_energy_bias(ev)
    ev_mc = get_events(args.mc, merge_fits=True, mc=True)
    ev_mc = correct_energy_bias(ev_mc)

    # Drop events without fits
    ev = ev[~np.isnan(ev.fmin)]
    ev_mc = ev_mc[~np.isnan(ev_mc.fmin)]

    # Set all prompt events in the MC to be muons
    ev_mc.loc[ev_mc.prompt,'muon'] = True

    # First, do basic data cleaning which is done for all events.
    ev = ev[ev.signal | ev.muon]
    ev_mc = ev_mc[ev_mc.signal | ev_mc.muon]

    # 00-orphan cut
    ev = ev[(ev.gtid & 0xff) != 0]
    ev_mc = ev_mc[(ev_mc.gtid & 0xff) != 0]

    # Now, we select events tagged by the muon tag which should tag only
    # external muons. We keep the sample of muons since it's needed later to
    # identify Michel electrons and to apply the muon follower cut
    muons = ev[ev.muon]
    muons_mc = ev_mc[ev_mc.muon]

    # Try to identify Michel electrons. Currently, the event selection is based
    # on Richie's thesis. Here, we do the following:
    #
    # 1. Apply more data cleaning cuts to potential Michel electrons
    # 2. Nhit >= 100
    # 3. It must be > 800 ns and less than 20 microseconds from a prompt event
    #    or a muon
    michel = ev[ev.michel]
    michel_mc = ev_mc[ev_mc.michel]

    # remove events 200 microseconds after a muon
    ev = ev.groupby('run',group_keys=False).apply(muon_follower_cut)

    muons = muons[muons.psi < 6]
    muons_mc = muons_mc[muons_mc.psi < 6]

    handles = [Line2D([0], [0], color='C0'),
               Line2D([0], [0], color='C1')]
    labels = ('Data','Monte Carlo')

    fig = plt.figure()
    plot_hist2_data_mc(muons,muons_mc)
    despine(fig,trim=True)
    if len(muons):
        plt.tight_layout()
    fig.legend(handles,labels,loc='upper right')
    if args.save:
        plt.savefig("external_muons.pdf")
        plt.savefig("external_muons.eps")
    else:
        plt.suptitle("External Muons")

    # Plot the energy and angular distribution for external muons
    fig = plt.figure()
    plt.subplot(2,1,1)
    plt.hist(muons.ke.values, bins=np.logspace(3,7,100), histtype='step', color='C0', label="Data")
    scale = len(muons.ke.values)/len(muons_mc.ke.values)
    plt.hist(muons_mc.ke.values, weights=np.tile(scale,len(muons_mc.ke.values)), bins=np.logspace(3,7,100), histtype='step', color='C1', label="Monte Carlo")
    plt.legend()
    plt.xlabel("Energy (MeV)")
    plt.gca().set_xscale("log")
    plt.subplot(2,1,2)
    plt.hist(np.cos(muons.theta.values), bins=np.linspace(-1,1,100), histtype='step', color='C0', label="Data")
    scale = len(muons.theta.values)/len(muons_mc.theta.values)
    plt.hist(np.cos(muons_mc.theta.values), weights=np.tile(scale,len(muons_mc.ke.values)), bins=np.linspace(-1,1,100), histtype='step', color='C1', label="Monte Carlo")
    plt.legend()
    despine(fig,trim=True)
    plt.xlabel(r"$\cos(\theta)$")
    plt.tight_layout()
    if args.save:
        plt.savefig("muon_energy_cos_theta.pdf")
        plt.savefig("muon_energy_cos_theta.eps")
    else:
        plt.suptitle("Muons")

    stopping_muons = pd.merge(ev[ev.muon & ev.stopping_muon],michel,left_on=['run','gtid'],right_on=['run','muon_gtid'],suffixes=('','_michel'))
    stopping_muons_mc = pd.merge(ev_mc[ev_mc.muon & ev_mc.stopping_muon],michel_mc,left_on=['run','gtid'],right_on=['run','muon_gtid'],suffixes=('','_michel'))

    stopping_muons = stopping_muons[stopping_muons.ke < 10e3]
    stopping_muons_mc = stopping_muons_mc[stopping_muons_mc.ke < 10e3]

    stopping_muons = stopping_muons[stopping_muons.cos_theta < -0.5]
    stopping_muons_mc = stopping_muons_mc[stopping_muons_mc.cos_theta < -0.5]

    # project muon to PSUP
    stopping_muons['dx'] = stopping_muons.apply(get_dx,axis=1)
    stopping_muons_mc['dx'] = stopping_muons_mc.apply(get_dx,axis=1)
    # energy based on distance travelled
    stopping_muons['T_dx'] = dx_to_energy(stopping_muons.dx)
    stopping_muons_mc['T_dx'] = dx_to_energy(stopping_muons_mc.dx)
    stopping_muons['dT'] = stopping_muons['ke'] - stopping_muons['T_dx']
    stopping_muons_mc['dT'] = stopping_muons_mc['ke'] - stopping_muons_mc['T_dx']

    # Plot the energy and angular distribution for external muons
    fig = plt.figure()
    plt.subplot(2,1,1)
    plt.hist(stopping_muons.ke.values, bins=np.logspace(3,7,100), histtype='step', color='C0', label="Data")
    scale = len(stopping_muons.ke.values)/len(stopping_muons_mc.ke.values)
    plt.hist(stopping_muons_mc.ke.values, weights=np.tile(scale,len(stopping_muons_mc.ke.values)), bins=np.logspace(3,7,100), histtype='step', color='C1', label="Monte Carlo")
    plt.legend()
    plt.xlabel("Energy (MeV)")
    plt.gca().set_xscale("log")
    plt.subplot(2,1,2)
    plt.hist(np.cos(stopping_muons.theta.values), bins=np.linspace(-1,1,100), histtype='step', color='C0', label="Data")
    scale = len(stopping_muons.theta.values)/len(stopping_muons_mc.theta.values)
    plt.hist(np.cos(stopping_muons_mc.theta.values), weights=np.tile(scale,len(stopping_muons_mc.theta.values)), bins=np.linspace(-1,1,100), histtype='step', color='C1', label="Monte Carlo")
    plt.legend()
    despine(fig,trim=True)
    plt.xlabel(r"$\cos(\theta)$")
    plt.tight_layout()
    if args.save:
        plt.savefig("stopping_muon_energy_cos_theta.pdf")
        plt.savefig("stopping_muon_energy_cos_theta.eps")
    else:
        plt.suptitle("Stopping Muons")

    print(stopping_muons[['run','gtid','ke','T_dx','dT','gtid_michel','r_michel','ftp_r_michel','id','r']])

    print("Particle ID probability for Stopping Muons:")
    print("Data")
    print_particle_probs(stopping_muons)
    print("Monte Carlo")
    print_particle_probs(stopping_muons_mc)

    fig = plt.figure()
    plot_hist2_data_mc(stopping_muons,stopping_muons_mc)
    despine(fig,trim=True)
    if len(muons):
        plt.tight_layout()
    fig.legend(handles,labels,loc='upper right')
    if args.save:
        plt.savefig("stopping_muons.pdf")
        plt.savefig("stopping_muons.eps")
    else:
        plt.suptitle("Stopping Muons")

    fig = plt.figure()
    plt.hist((stopping_muons['ke']-stopping_muons['T_dx'])*100/stopping_muons['T_dx'], bins=np.linspace(-100,100,200), histtype='step', color='C0', label="Data")
    plt.hist((stopping_muons_mc['ke']-stopping_muons_mc['T_dx'])*100/stopping_muons_mc['T_dx'], bins=np.linspace(-100,100,200), histtype='step', color='C1', label="Monte Carlo")
    plt.legend()
    despine(fig,trim=True)
    plt.xlabel("Fractional energy difference (\%)")
    plt.title("Fractional energy difference for Stopping Muons")
    plt.tight_layout()
    if args.save:
        plt.savefig("stopping_muon_fractional_energy_difference.pdf")
        plt.savefig("stopping_muon_fractional_energy_difference.eps")
    else:
        plt.title("Stopping Muon Fractional Energy Difference")

    # 100 bins between 50 MeV and 10 GeV
    bins = np.linspace(50,2000,10)

    pd_bins = pd.cut(stopping_muons['T_dx'],bins)
    pd_bins_mc = pd.cut(stopping_muons_mc['T_dx'],bins)

    T = (bins[1:] + bins[:-1])/2

    dT = stopping_muons.groupby(pd_bins)['dT'].agg(['mean','sem','std',std_err,median,median_err,iqr_std,iqr_std_err])
    dT_mc = stopping_muons_mc.groupby(pd_bins_mc)['dT'].agg(['mean','sem','std',std_err,median,median_err,iqr_std,iqr_std_err])

    y = (dT['median']*100/T-dT_mc['median']*100/T).values
    yerr = np.sqrt((dT['median_err']*100/T)**2+(dT_mc['median_err']*100/T)**2).values
    mean, std = fit_straight_line(y,yerr)

    print("Data energy bias = %.2f +/- %.2f" % (mean, std))

    fig, (a0, a1) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [3, 1]})
    a0.errorbar(T, dT['median']*100/T, yerr=dT['median_err']*100/T, fmt='o', color='C0', label="Data")
    a0.errorbar(T, dT_mc['median']*100/T, yerr=dT_mc['median_err']*100/T, fmt='o', color='C1', label="Monte Carlo")
    despine(ax=a0,trim=True)
    a0.set_ylabel(r"Energy bias (\%)")
    a0.legend()
    a1.errorbar(T, dT['median']*100/T-dT_mc['median']*100/T, yerr=np.sqrt((dT['median_err']*100/T)**2+(dT_mc['median_err']*100/T)**2), fmt='o', color='C0')
    a1.hlines(mean,T[0],T[-1],linestyles='--',color='r')
    a1.set_ylim(0,25)
    despine(ax=a1,trim=True)
    a1.set_xlabel("Kinetic Energy (MeV)")
    a1.set_ylabel(r"Difference (\%)")
    plt.tight_layout()
    if args.save:
        plt.savefig("stopping_muon_energy_bias.pdf")
        plt.savefig("stopping_muon_energy_bias.eps")
    else:
        plt.suptitle("Stopping Muon Energy Bias")

    y = (dT['iqr_std']*100/T-dT_mc['iqr_std']*100/T).values
    yerr = np.sqrt((dT['iqr_std_err']*100/T)**2+(dT_mc['iqr_std_err']*100/T)**2).values
    mean, std = fit_straight_line(y,yerr)

    print("Data energy resolution = %.2f +/- %.2f" % (mean, std))

    fig, (a0, a1) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [3, 1]})
    a0.errorbar(T, dT['iqr_std']*100/T, yerr=dT['iqr_std_err']*100/T, fmt='o', color='C0', label="Data")
    a0.errorbar(T, dT_mc['iqr_std']*100/T, yerr=dT_mc['iqr_std_err']*100/T, fmt='o', color='C1', label="Monte Carlo")
    a0.set_ylabel(r"Energy resolution (\%)")
    despine(ax=a0,trim=True)
    a0.legend()
    a1.errorbar(T, dT['iqr_std']*100/T-dT_mc['iqr_std']*100/T, yerr=np.sqrt((dT['iqr_std_err']*100/T)**2+(dT_mc['iqr_std_err']*100/T)**2), fmt='o', color='C0')
    a1.hlines(mean,T[0],T[-1],linestyles='--',color='r')
    despine(ax=a1,trim=True)
    a1.set_xlabel("Kinetic Energy (MeV)")
    a1.set_ylabel(r"Difference (\%)")
    plt.tight_layout()
    if args.save:
        plt.savefig("stopping_muon_energy_resolution.pdf")
        plt.savefig("stopping_muon_energy_resolution.eps")
    else:
        plt.title("Stopping Muon Energy Resolution")

    # For the Michel energy plot, we only look at events where the
    # corresponding muon had less than 2500 nhit.  The reason for only looking
    # at Michel electrons from muons with less than 2500 nhit is because there
    # is significant ringing and afterpulsing after a large muon which can
    # cause the reconstruction to overestimate the energy.
    michel_low_nhit = michel[michel.muon_gtid.isin(stopping_muons.gtid.values) & (michel.muon_nhit < 2500)]
    michel_low_nhit_mc = michel_mc[michel_mc.muon_gtid.isin(stopping_muons_mc.gtid.values) & (michel_mc.muon_nhit < 2500)]

    print("Particle ID probability for Michel electrons:")
    print("Data")
    print_particle_probs(michel_low_nhit)
    print("Monte Carlo")
    print_particle_probs(michel_low_nhit_mc)

    fig = plt.figure()
    plot_hist2_data_mc(michel_low_nhit,michel_low_nhit_mc)
    despine(fig,trim=True)
    fig.legend(handles,labels,loc='upper right')
    if args.save:
        plt.savefig("michel_electrons.pdf")
        plt.savefig("michel_electrons.eps")
    else:
        plt.suptitle("Michel Electrons")

    fig = plt.figure()
    bins = np.linspace(0,100,41)
    plt.hist(michel_low_nhit.ke.values, bins=bins, histtype='step', color='C0', label="Data")
    plt.hist(michel_low_nhit_mc.ke.values, weights=np.tile(len(michel_low_nhit)/len(michel_low_nhit_mc),len(michel_low_nhit_mc.ke.values)), bins=bins, histtype='step', color='C1', label="Monte Carlo")
    hist = np.histogram(michel_low_nhit.ke.values,bins=bins)[0]
    hist_mc = np.histogram(michel_low_nhit_mc.ke.values,bins=bins)[0]
    if hist_mc.sum() > 0:
        scale = hist.sum()/hist_mc.sum()
    else:
        scale = 1.0
    p = get_multinomial_prob(hist,hist_mc,scale)
    plt.text(0.95,0.95,"p = %.2f" % p,horizontalalignment='right',verticalalignment='top',transform=plt.gca().transAxes)
    despine(fig,trim=True)
    plt.xlabel("Energy (MeV)")
    plt.tight_layout()
    plt.legend()
    if args.save:
        plt.savefig("michel_electrons_ke.pdf")
        plt.savefig("michel_electrons_ke.eps")
    else:
        plt.title("Michel Electrons")
        plt.show()
