#!/usr/bin/env python3
# Copyright (c) 2019, Anthony Latorre <tlatorre at uchicago>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.
"""
Script to do final dark matter search analysis. To run it just run:

    $ ./dm-search [list of data fit results] --mc [list of atmospheric MC files] --muon-mc [list of muon MC files] --steps [steps]

After running you will get a plot showing the limits for back to back dark
matter at a range of energies.
"""
from __future__ import print_function, division
import numpy as np
from scipy.stats import iqr, poisson
from matplotlib.lines import Line2D
from scipy.stats import iqr, norm, beta, percentileofscore
from scipy.special import spence
from sddm.stats import *
from sddm.dc import estimate_errors, EPSILON, truncnorm_scaled
import emcee
from sddm import printoptions
from sddm.utils import fast_cdf, correct_energy_bias
from scipy.integrate import quad
from sddm.dm import *
from sddm import SNOMAN_MASS, AV_RADIUS
import nlopt
from itertools import chain

# Likelihood Fit Parameters
# 0 - Atmospheric Neutrino Flux Scale
# 1 - Electron energy bias
# 2 - Electron energy resolution
# 3 - Muon energy bias
# 4 - Muon energy resolution
# 5 - External Muon scale
# 6 - Dark Matter Scale

# Number of events to use in the dark matter Monte Carlo histogram when fitting
# Ideally this would be as big as possible, but the bigger it is, the more time
# the fit takes.
DM_SAMPLES = 10000

DM_MASSES = {2020: np.logspace(np.log10(22),np.log10(1e3),101),
             2222: np.logspace(np.log10(318),np.log10(1e3),101)}

DISCOVERY_P_VALUE = 0.05

FIT_PARS = [
    'Atmospheric Neutrino Flux Scale',
    'Electron energy bias',
    'Electron energy resolution',
    'Muon energy bias',
    'Muon energy resolution',
    'External Muon scale',
    'Dark Matter Scale']

# Uncertainty on the energy scale
#
# - the muon energy scale and resolution terms come directly from measurements
#   on stopping muons, so those are known well.
# - for electrons, we only have Michel electrons at the low end of our energy
#   range, and therefore we don't really have any good way of constraining the
#   energy scale or resolution. However, if we assume that the ~7% energy bias
#   in the muons is from the single PE distribution (it seems likely to me that
#   that is a major part of the bias), then the energy scale should be roughly
#   the same. Since the Michel electron distributions are consistent, we leave
#   the mean value at 0, but to be conservative, we set the error to 10%.
# - The energy resolution for muons was pretty much spot on, and so we expect
#   the same from electrons. In addition, the Michel spectrum is consistent so
#   at that energy level we don't see anything which leads us to expect a major
#   difference. To be conservative, and because I don't think it really affects
#   the analysis at all, I'll leave the uncertainty here at 10% anyways.
PRIORS = [
    1.0,   # Atmospheric Neutrino Scale
    0.015, # Electron energy scale
    0.0,   # Electron energy resolution
    0.053, # Muon energy scale
    0.0,   # Muon energy resolution
    0.0,   # Muon scale
    0.0,   # Dark Matter Scale
]

PRIOR_UNCERTAINTIES = [
    0.2,   # Atmospheric Neutrino Scale
    0.03,  # Electron energy scale
    0.05,  # Electron energy resolution
    0.01,  # Muon energy scale
    0.013, # Muon energy resolution
    10.0,  # Muon scale
    np.inf,# Dark Matter Scale
]

# Lower bounds for the fit parameters
PRIORS_LOW = [
    EPSILON,
    -10,
    EPSILON,
    -10,
    EPSILON,
    0,
    0,
]

# Upper bounds for the fit parameters
PRIORS_HIGH = [
    10,
    10,
    10,
    10,
    10,
    1e9,
    1000,
]

particle_id = {20: 'e', 22: r'\mu'}

def plot_hist2(hists, bins, color=None):
    for id in (20,22,2020,2022,2222):
        if id == 20:
            plt.subplot(2,3,1)
        elif id == 22:
            plt.subplot(2,3,2)
        elif id == 2020:
            plt.subplot(2,3,4)
        elif id == 2022:
            plt.subplot(2,3,5)
        elif id == 2222:
            plt.subplot(2,3,6)

        bincenters = (bins[id][1:] + bins[id][:-1])/2
        plt.hist(bincenters, bins=bins[id], histtype='step', weights=hists[id],color=color)
        plt.gca().set_xscale("log")
        major = np.array([10,100,1000,10000])
        minor = np.unique(list(chain(*list(range(i,i*10,i) for i in major[:-1]))))
        minor = np.setdiff1d(minor,major)
        major = major[major <= bins[id][-1]]
        minor = minor[minor <= bins[id][-1]]
        plt.gca().set_xticks(major)
        plt.gca().set_xticks(minor,minor=True)
        plt.gca().set_xlim(10,10000)
        plt.xlabel("Energy (MeV)")
        plt.title('$' + ''.join([particle_id[int(''.join(x))] for x in grouper(str(id),2)]) + '$')

    if len(hists):
        plt.tight_layout()

def get_mc_hists(data,x,bins,scale=1.0,reweight=False):
    """
    Returns the expected Monte Carlo histograms for the atmospheric neutrino
    background.

    Args:
        - data: pandas dataframe of the Monte Carlo events
        - x: fit parameters
        - bins: histogram bins
        - scale: multiply histograms by an overall scale factor

    This function does two basic things:

        1. apply the energy bias and resolution corrections
        2. histogram the results

    Returns a dictionary mapping particle id combo -> histogram.
    """
    df_dict = {}
    for id in (20,22,2020,2022,2222):
        df_dict[id] = data[data.id == id]

    return get_mc_hists_fast(df_dict,x,bins,scale,reweight)

def get_mc_hists_fast(df_dict,x,bins,scale=1.0,reweight=False):
    """
    Same as get_mc_hists() but the first argument is a dictionary mapping
    particle id -> dataframe. This is much faster than selecting the events
    from the dataframe every time.
    """
    mc_hists = {}

    for id in (20,22,2020,2022,2222):
        df = df_dict[id]

        if id == 20:
            ke = df.energy1.values*(1+x[1])
            resolution = df.energy1.values*max(EPSILON,x[2])
        elif id == 2020:
            ke = df.energy1.values*(1+x[1]) + df.energy2.values*(1+x[1])
            resolution = np.sqrt((df.energy1.values*max(EPSILON,x[2]))**2 + (df.energy2.values*max(EPSILON,x[2]))**2)
        elif id == 22:
            ke = df.energy1.values*(1+x[3])
            resolution = df.energy1.values*max(EPSILON,x[4])
        elif id == 2222:
            ke = df.energy1.values*(1+x[3]) + df.energy2.values*(1+x[3])
            resolution = np.sqrt((df.energy1.values*max(EPSILON,x[4]))**2 + (df.energy2.values*max(EPSILON,x[4]))**2)
        elif id == 2022:
            ke = df.energy1.values*(1+x[1]) + df.energy2.values*(1+x[3])
            resolution = np.sqrt((df.energy1.values*max(EPSILON,x[2]))**2 + (df.energy2.values*max(EPSILON,x[4]))**2)

        if reweight:
            cdf = fast_cdf(bins[id][:,np.newaxis],ke,resolution)*df.weight.values
        else:
            cdf = fast_cdf(bins[id][:,np.newaxis],ke,resolution)

        if 'flux_weight' in df.columns:
            cdf *= df.flux_weight.values

        mc_hists[id] = np.sum(cdf[1:] - cdf[:-1],axis=-1)
        mc_hists[id] *= scale
    return mc_hists

def get_data_hists(data,bins,scale=1.0):
    """
    Returns the data histogrammed into `bins`.
    """
    data_hists = {}
    for id in (20,22,2020,2022,2222):
        data_hists[id] = np.histogram(data[data.id == id].ke.values,bins=bins[id])[0]*scale
    return data_hists

def make_nll(dm_particle_id, dm_mass, dm_energy, data, muons, mc, atmo_scale_factor, muon_scale_factor, bins, reweight=False, print_nll=False, dm_sample=None, fast=False):
    df_dict = dict(tuple(mc.groupby('id')))
    for id in (20,22,2020,2022,2222):
        if id not in df_dict:
            df_dict[id] = mc.iloc[:0]

    df_dict_muon = dict(tuple(muons.groupby('id')))
    for id in (20,22,2020,2022,2222):
        if id not in df_dict_muon:
            df_dict_muon[id] = muons.iloc[:0]

    data_hists = get_data_hists(data,bins)

    if dm_sample is None:
        dm_sample = get_dm_sample(DM_SAMPLES,dm_particle_id,dm_mass,dm_energy)

    df_dict_dm = {}
    for id in (20,22,2020,2022,2222):
        df_dict_dm[id] = dm_sample[dm_sample.id == id]
    
    if fast:
        x = np.array(PRIORS)

        fast_mc_hists = get_mc_hists_fast(df_dict,x,bins,scale=1/atmo_scale_factor,reweight=reweight)
        fast_muon_hists = get_mc_hists_fast(df_dict_muon,x,bins,scale=1/muon_scale_factor)
        fast_dm_hists = get_mc_hists_fast(df_dict_dm,x,bins,scale=1/len(dm_sample))

    def nll(x, grad=None):
        if (x < PRIORS_LOW).any() or (x > PRIORS_HIGH).any():
            return np.inf

        if fast:
            mc_hists = fast_mc_hists
            muon_hists = fast_muon_hists
            dm_hists = fast_dm_hists
        else:
            # Get the Monte Carlo histograms. We need to do this within the
            # likelihood function since we apply the energy resolution
            # parameters to the Monte Carlo.
            mc_hists = get_mc_hists_fast(df_dict,x,bins,scale=1/atmo_scale_factor,reweight=reweight)
            muon_hists = get_mc_hists_fast(df_dict_muon,x,bins,scale=1/muon_scale_factor)
            dm_hists = get_mc_hists_fast(df_dict_dm,x,bins,scale=1/len(dm_sample))

        # Calculate the negative log of the likelihood of observing the data
        # given the fit parameters

        nll = 0
        for id in data_hists:
            oi = data_hists[id]
            ei = mc_hists[id]*x[0] + muon_hists[id]*x[5] + dm_hists[id]*x[6] + EPSILON 
            N = ei.sum()
            nll -= -N - np.sum(gammaln(oi+1)) + np.sum(oi*np.log(ei))

        # Add the priors
        nll -= norm.logpdf(x[:6],PRIORS[:6],PRIOR_UNCERTAINTIES[:6]).sum()

        if print_nll:
            # Print the result
            print("nll = %.2f" % nll)

        return nll
    return nll

def do_fit(dm_particle_id,dm_mass,dm_energy,data,muon,data_mc,weights,atmo_scale_factor,muon_scale_factor,bins,steps,print_nll=False,walkers=100,thin=10,refit=True,universe=None,fast=False):
    """
    Run the fit and return the minimum along with samples from running an MCMC
    starting near the minimum.

    Args:
        - data: pandas dataframe representing the data to fit
        - muon: pandas dataframe representing the expected background from
                external muons
        - data_mc: pandas dataframe representing the expected background from
                   atmospheric neutrino events
        - weights: pandas dataframe with the GENIE weights
        - bins: an array of bins to use for the fit
        - steps: the number of MCMC steps to run

    Returns a tuple (xopt, samples) where samples is an array of shape (steps,
    number of parameters).
    """
    dm_sample = get_dm_sample(DM_SAMPLES,dm_particle_id,dm_mass,dm_energy)

    if universe is None:
        nll = make_nll(dm_particle_id,dm_mass,dm_energy,data,muon,data_mc,atmo_scale_factor,muon_scale_factor,bins,print_nll,dm_sample=dm_sample,fast=fast)

        pos = np.empty((walkers, len(PRIORS)),dtype=np.double)
        for i in range(pos.shape[0]):
            pos[i] = sample_priors()

        nwalkers, ndim = pos.shape

        # We use the KDEMove here because I think it should sample the likelihood
        # better. Because we have energy scale parameters and we are doing a binned
        # likelihood, the likelihood is discontinuous. There can also be several
        # local minima. The author of emcee recommends using the KDEMove with a lot
        # of workers to try and properly sample a multimodal distribution. In
        # addition, I've found that the autocorrelation time for the KDEMove is
        # much better than the other moves.
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lambda x: -nll(x), moves=emcee.moves.KDEMove())
        with np.errstate(invalid='ignore'):
            sampler.run_mcmc(pos, steps)

        print("Mean acceptance fraction: {0:.3f}".format(np.mean(sampler.acceptance_fraction)))

        try:
            print("autocorrelation time: ", sampler.get_autocorr_time(quiet=True))
        except Exception as e:
            print(e)

        samples = sampler.get_chain(flat=True,thin=thin)

        # Now, we use nlopt to find the best set of parameters. We start at the
        # best starting point from the MCMC and then run the SBPLX routine.
        x0 = sampler.get_chain(flat=True)[sampler.get_log_prob(flat=True).argmax()]
        opt = nlopt.opt(nlopt.LN_SBPLX, len(x0))
        opt.set_min_objective(nll)
        low = np.array(PRIORS_LOW)
        high = np.array(PRIORS_HIGH)
        if refit:
            # If we are refitting, we want to do the first fit assuming no dark
            # matter to make sure we get the best GENIE systematics for the null
            # hypothesis.
            x0[6] = low[6]
            high[6] = low[6]
        opt.set_lower_bounds(low)
        opt.set_upper_bounds(high)
        opt.set_ftol_abs(1e-10)
        opt.set_initial_step([0.01]*len(x0))
        xopt = opt.optimize(x0)

        # Get the total number of "universes" simulated in the GENIE reweight tool
        nuniverses = max(weights.keys())+1

        nlls = []
        for universe in range(nuniverses):
            data_mc_with_weights = pd.merge(data_mc,weights[universe],how='left',on=['run','unique_id'])
            data_mc_with_weights.weight = data_mc_with_weights.weight.fillna(1.0)

            nll = make_nll(dm_particle_id,dm_mass,dm_energy,data,muon,data_mc_with_weights,atmo_scale_factor,muon_scale_factor,bins,reweight=True,print_nll=print_nll,dm_sample=dm_sample,fast=fast)
            nlls.append(nll(xopt))

        universe = np.argmin(nlls)

    if refit:
        data_mc_with_weights = pd.merge(data_mc,weights[universe],how='left',on=['run','unique_id'])
        data_mc_with_weights.weight = data_mc_with_weights.weight.fillna(1.0)

        # Create a new negative log likelihood function with the weighted Monte Carlo.
        nll = make_nll(dm_particle_id,dm_mass,dm_energy,data,muon,data_mc_with_weights,atmo_scale_factor,muon_scale_factor,bins,reweight=True,print_nll=print_nll,dm_sample=dm_sample,fast=fast)

        # Now, we refit with the Monte Carlo weighted by the most likely GENIE
        # systematics.
        pos = np.empty((walkers, len(PRIORS)),dtype=np.double)
        for i in range(pos.shape[0]):
            pos[i] = sample_priors()

        nwalkers, ndim = pos.shape

        # We use the KDEMove here because I think it should sample the likelihood
        # better. Because we have energy scale parameters and we are doing a binned
        # likelihood, the likelihood is discontinuous. There can also be several
        # local minima. The author of emcee recommends using the KDEMove with a lot
        # of workers to try and properly sample a multimodal distribution. In
        # addition, I've found that the autocorrelation time for the KDEMove is
        # much better than the other moves.
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lambda x: -nll(x), moves=emcee.moves.KDEMove())
        with np.errstate(invalid='ignore'):
            sampler.run_mcmc(pos, steps)

        print("Mean acceptance fraction: {0:.3f}".format(np.mean(sampler.acceptance_fraction)))

        try:
            print("autocorrelation time: ", sampler.get_autocorr_time(quiet=True))
        except Exception as e:
            print(e)

        samples = sampler.get_chain(flat=True,thin=thin)

        # Now, we use nlopt to find the best set of parameters. We start at the
        # best starting point from the MCMC and then run the SBPLX routine.
        x0 = sampler.get_chain(flat=True)[sampler.get_log_prob(flat=True).argmax()]
        opt = nlopt.opt(nlopt.LN_SBPLX, len(x0))
        opt.set_min_objective(nll)
        low = np.array(PRIORS_LOW)
        high = np.array(PRIORS_HIGH)
        opt.set_lower_bounds(low)
        opt.set_upper_bounds(high)
        opt.set_ftol_abs(1e-10)
        opt.set_initial_step([0.01]*len(x0))
        xopt = opt.optimize(x0)

    return xopt, universe, samples

def sample_priors():
    """
    Returns a random sample of the fit parameters from the priors. For the
    first 6 parameters we use a truncated normal distribution, and for the last
    parameter we use a uniform distribution.
    """
    return np.concatenate((truncnorm_scaled(PRIORS_LOW[:6],PRIORS_HIGH[:6],PRIORS[:6],PRIOR_UNCERTAINTIES[:6]),[np.random.uniform(PRIORS_LOW[6],PRIORS_HIGH[6])]))

def get_dm_sample(n,dm_particle_id,dm_mass,dm_energy):
    """
    Returns a dataframe containing events from a dark matter particle.

    Args:

        - n: int
            number of events
        - dm_particle_id: int
            the particle id of the DM particle (2020 or 2222)
        - dm_energy: float
            The total kinetic energy of the DM particle
        - dm_resolution: float
            The fractional energy resolution of the dark matter particle, i.e.
            the actual energy resolution will be dm_energy*dm_resolution.
    """
    id1 = dm_particle_id//100
    id2 = dm_particle_id % 100
    m1 = SNOMAN_MASS[id1]
    m2 = SNOMAN_MASS[id2]
    energy1 = []
    data = np.empty(n,dtype=[('energy1',np.double),('energy2',np.double),('ke',np.double),('id1',np.int),('id2',np.int),('id',np.int)])
    for i, (v1, v2) in enumerate(islice(gen_decay(dm_mass,dm_energy,m1,m2),n)):
        E1 = v1[0]
        E2 = v2[0]
        T1 = E1 - m1
        T2 = E2 - m2
        data[i] = T1, T2, T1 + T2, id1, id2, dm_particle_id

    # FIXME: Get electron and muon resolution
    data['energy1'] += norm.rvs(scale=data['energy1']*0.05)
    data['energy2'] += norm.rvs(scale=data['energy2']*0.05)

    return pd.DataFrame(data)

def get_limits(dm_masses,data,muon,data_mc,atmo_scale_factor,muon_scale_factor,bins,steps,print_nll,walkers,thin,universe=None,fast=False):
    limits = {}
    best_fit = {}
    discovery_array = {}
    for dm_particle_id in (2020,2222):
        limits[dm_particle_id] = np.empty(len(dm_masses[dm_particle_id]))
        best_fit[dm_particle_id] = np.empty(len(dm_masses[dm_particle_id]))
        discovery_array[dm_particle_id] = np.empty(len(dm_masses[dm_particle_id]))
        for i, dm_mass in enumerate(dm_masses[dm_particle_id]):
            id1 = dm_particle_id//100
            id2 = dm_particle_id % 100
            m1 = SNOMAN_MASS[id1]
            m2 = SNOMAN_MASS[id2]
            dm_energy = dm_mass
            xopt, universe, samples = do_fit(dm_particle_id,dm_mass,dm_energy,data,muon,data_mc,weights,atmo_scale_factor,muon_scale_factor,bins,steps,print_nll,walkers,thin,refit=True,universe=universe,fast=fast)

            data_mc_with_weights = pd.merge(data_mc,weights[universe],how='left',on=['run','unique_id'])
            data_mc_with_weights.weight = data_mc_with_weights.weight.fillna(1.0)

            limit = np.percentile(samples[:,6],90)
            limits[dm_particle_id][i] = limit

            # Here, to determine if there is a discovery we make an approximate
            # calculation of the number of events which would be significant.
            #
            # We expect the likelihood to be approximately that of a Poisson
            # distribution with n background events and we are searching for a
            # signal s. n is constrained by the rest of the histograms, and so
            # we can treat is as being approximately fixed. In this case, the
            # likelihood looks approximately like:
            #
            #     P(s) = e^(-(s+n))(s+n)**i/i!
            #
            # Where i is the actual number of events. Under the null hypothesis
            # (i.e. no dark matter), we expect i to be Poisson distributed with
            # mean n. Therefore s should have the same distribution but offset
            # by n. Therefore, to determine the threshold, we simply look for
            # the threshold we expect in n and then subtract n.
            dm_kinetic_energy = dm_energy - m1 - m2

            dm_sample = get_dm_sample(DM_SAMPLES,dm_particle_id,dm_mass,dm_energy)

            # To calculate `n` we approximately want the number of events in
            # the bin which most of the dark matter events will fall. However,
            # to smoothly transition between bins, we multiply the normalized
            # dark matter histogram with the expected MC histogram and then
            # take the sum. In the case that the dark matter events all fall
            # into a single bin, this gives us that bin, but smoothly
            # interpolates between the bins.
            dm_hists = get_mc_hists(dm_sample,xopt,bins,scale=1/len(dm_sample))
            frac = dm_hists[dm_particle_id].sum()
            dm_hists[dm_particle_id] /= frac
            mc_hists = get_mc_hists(data_mc_with_weights,xopt,bins,scale=xopt[0]/atmo_scale_factor,reweight=True)
            muon_hists = get_mc_hists(muon,xopt,bins,scale=xopt[5]/muon_scale_factor)
            n = (dm_hists[dm_particle_id]*(mc_hists[dm_particle_id] + muon_hists[dm_particle_id])).sum()
            # Set our discovery threshold to the p-value we want divided by the
            # number of bins. The idea here is that the number of bins is
            # approximately equal to the number of trials so we need to
            # increase our discovery threshold to account for the look
            # elsewhere effect.
            threshold = DISCOVERY_P_VALUE/(len(bins[dm_particle_id])-1)
            discovery = poisson.ppf(1-threshold,n) + 1 - n
            # Here, we scale the discovery threshold by the fraction of the
            # dark matter hist in the histogram range. The idea is that if only
            # a small fraction of the dark matter histogram falls into the
            # histogram range, the total number of dark matter events returned
            # by the fit can be larger by this amount. I noticed this when
            # testing under the null hypothesis that the majority of the
            # "discoveries" were on the edge of the histogram.
            discovery_array[dm_particle_id][i] = discovery/frac
            best_fit[dm_particle_id][i] = xopt[6]

    return limits, best_fit, discovery_array

if __name__ == '__main__':
    import argparse
    import numpy as np
    import pandas as pd
    import sys
    import h5py
    from sddm.plot_energy import *
    from sddm.plot import *
    from sddm import setup_matplotlib
    import nlopt
    from sddm.renormalize import *

    parser = argparse.ArgumentParser("plot fit results")
    parser.add_argument("filenames", nargs='+', help="input files")
    parser.add_argument("--save", action='store_true', default=False, help="save corner plots for backgrounds")
    parser.add_argument("--mc", nargs='+', required=True, help="atmospheric MC files")
    parser.add_argument("--muon-mc", nargs='+', required=True, help="muon MC files")
    parser.add_argument("--steps", type=int, default=1000, help="number of steps in the MCMC chain")
    parser.add_argument("--pull", type=int, default=0, help="plot pull plots")
    parser.add_argument("--weights", nargs='+', required=True, help="GENIE reweight HDF5 files")
    parser.add_argument("--print-nll", action='store_true', default=False, help="print nll values")
    parser.add_argument("--walkers", type=int, default=100, help="number of walkers")
    parser.add_argument("--thin", type=int, default=10, help="number of steps to thin")
    parser.add_argument("--test", type=int, default=0, help="run tests to check discovery threshold")
    parser.add_argument("--run-list", default=None, help="run list")
    parser.add_argument("--mcpl", nargs='+', required=True, help="GENIE MCPL files")
    parser.add_argument("--run-info", required=True, help="run_info.log autosno file")
    parser.add_argument("--universe", type=int, default=None, help="GENIE universe for systematics")
    parser.add_argument("--fast", action='store_true', default=False, help="run fast version of likelihood without energy bias and resolution")
    args = parser.parse_args()

    setup_matplotlib(args.save)

    import matplotlib.pyplot as plt

    rhdr = pd.concat([read_hdf(filename, "rhdr").assign(filename=filename) for filename in args.filenames],ignore_index=True)

    if args.run_list is not None:
        run_list = np.genfromtxt(args.run_list)
        rhdr = rhdr[rhdr.run.isin(run_list)]

    # Loop over runs to prevent using too much memory
    evs = []
    for run, df in rhdr.groupby('run'):
        evs.append(get_events(df.filename.values, merge_fits=True))
    ev = pd.concat(evs).reset_index()

    livetime = 0.0
    livetime_pulse_gt = 0.0
    for _ev in evs:
        if not np.isnan(_ev.attrs['time_10_mhz']):
            livetime += _ev.attrs['time_10_mhz']
        else:
            livetime += _ev.attrs['time_pulse_gt']
        livetime_pulse_gt += _ev.attrs['time_pulse_gt']

    print("livetime            = %.2f" % livetime)
    print("livetime (pulse gt) = %.2f" % livetime_pulse_gt)

    if args.run_info:
        livetime_run_info = 0.0
        run_info = np.genfromtxt(args.run_info,usecols=range(4),dtype=(np.int,np.int,np.double,np.double))
        for run in set(ev.run.values):
            for i in range(run_info.shape[0]):
                if run_info[i][0] == run:
                    livetime_run_info += run_info[i][3]
        print("livetime (run info) = %.2f" % livetime_run_info)

    ev = correct_energy_bias(ev)

    # Note: We loop over the MC filenames here instead of just passing the
    # whole list to get_events() because I had to rerun some of the MC events
    # using SNOMAN and so most of the runs actually have two different files
    # and otherwise the GTIDs will clash
    ev_mcs = []
    for filename in args.mc:
        ev_mcs.append(get_events([filename], merge_fits=True, mc=True))
    ev_mc = pd.concat([ev_mc for ev_mc in ev_mcs if len(ev_mc) > 0]).reset_index()

    if (~rhdr.run.isin(ev_mc.run)).any():
        print_warning("Error! The following runs have no Monte Carlo: %s" % \
            np.unique(rhdr.run[~rhdr.run.isin(ev_mc.run)].values))

    muon_mc = get_events(args.muon_mc, merge_fits=True, mc=True)
    weights = pd.concat([read_hdf(filename, "weights") for filename in args.weights],ignore_index=True)

    # Add the "flux_weight" column to the ev_mc data since I stupidly simulated
    # the muon neutrino flux for the tau neutrino flux in GENIE. Doh!
    mcpl = load_mcpl_files(args.mcpl)
    ev_mc = renormalize_data(ev_mc,mcpl)

    # Merge weights with MCPL dataframe to get the unique id column in the
    # weights dataframe since that is what we use to merge with the Monte
    # Carlo.
    weights = pd.merge(weights,mcpl[['run','evn','unique_id']],on=['run','evn'],how='left')

    # There are a handful of weights which turn out to be slightly negative for
    # some reason. For example:
    #
    # run  evn  universe    weight
    # 10970   25       597 -0.000055
    # 11389   87       729 -0.021397
    # 11701  204         2 -0.000268
    # 11919  120        82 -0.002245
    # 11976  163        48 -0.000306
    # 11976  163       710 -0.000022
    # 12131   76       175 -0.000513
    # 12207   70       255 -0.002925
    # 12207   70       282 -0.014856
    # 12207   70       368 -0.030593
    # 12207   70       453 -0.019011
    # 12207   70       520 -0.020748
    # 12207   70       834 -0.028754
    # 12207   70       942 -0.020309
    # 12233  230       567 -0.000143
    # 12618  168       235 -0.000020
    # 13428  128        42 -0.083639
    # 14264   23       995 -0.017637
    # 15034   69       624 -0.000143
    # 15752  154       957 -0.006827
    weights = weights[weights.weight > 0]

    weights = dict(tuple(weights.groupby('universe')))

    ev_mc = correct_energy_bias(ev_mc)
    muon_mc = correct_energy_bias(muon_mc)

    # Set all prompt events in the MC to be muons
    muon_mc.loc[muon_mc.prompt,'muon'] = True

    # 00-orphan cut
    ev = ev[(ev.gtid & 0xff) != 0]
    ev_mc = ev_mc[(ev_mc.gtid & 0xff) != 0]
    muon_mc = muon_mc[(muon_mc.gtid & 0xff) != 0]

    # remove events 200 microseconds after a muon
    ev = ev.groupby('run',group_keys=False).apply(muon_follower_cut)

    # Get rid of events which don't have a successful fit
    ev = ev[~np.isnan(ev.fmin)]
    ev_mc = ev_mc[~np.isnan(ev_mc.fmin)]
    muon_mc = muon_mc[~np.isnan(muon_mc.fmin)]

    # require (r < av radius)
    ev = ev[ev.r < AV_RADIUS]
    ev_mc = ev_mc[ev_mc.r < AV_RADIUS]
    muon_mc = muon_mc[muon_mc.r < AV_RADIUS]

    fiducial_volume = (4/3)*np.pi*(AV_RADIUS)**3

    # require psi < 6
    ev = ev[ev.psi < 6]
    ev_mc = ev_mc[ev_mc.psi < 6]
    muon_mc = muon_mc[muon_mc.psi < 6]

    data = ev[ev.signal & ev.prompt & ~ev.atm]
    data_atm = ev[ev.signal & ev.prompt & ev.atm]

    # Right now we use the muon Monte Carlo in the fit. If you want to use the
    # actual data, you can comment the next two lines and then uncomment the
    # two after that.
    muon = muon_mc[muon_mc.muon & muon_mc.prompt & ~muon_mc.atm]
    muon_atm = muon_mc[muon_mc.muon & muon_mc.prompt & muon_mc.atm]
    #muon = ev[ev.muon & ev.prompt & ~ev.atm]
    #muon_atm = ev[ev.muon & ev.prompt & ev.atm]

    if not args.pull and not args.test:
        ev_mc = ev_mc[ev_mc.run.isin(rhdr.run)]

    data_mc = ev_mc[ev_mc.signal & ev_mc.prompt & ~ev_mc.atm]
    data_atm_mc = ev_mc[ev_mc.signal & ev_mc.prompt & ev_mc.atm]

    bins = {20:np.logspace(np.log10(20),np.log10(10e3),21),
            22:np.logspace(np.log10(20),np.log10(10e3),21)[:-5],
            2020:np.logspace(np.log10(20),np.log10(10e3),21),
            2022:np.logspace(np.log10(20),np.log10(10e3),21)[:-5],
            2222:np.logspace(np.log10(20),np.log10(10e3),21)[:-5]}

    atmo_scale_factor = 100.0
    muon_scale_factor = len(muon) + len(muon_atm)

    if args.pull:
        pull = [[] for i in range(len(FIT_PARS))]

        # Set the random seed so we get reproducible results here
        np.random.seed(0)

        for i in range(args.pull):
            xtrue = sample_priors()

            # Calculate expected number of events
            N = data_mc.flux_weight.sum()*xtrue[0]/atmo_scale_factor
            N_atm = data_atm_mc.flux_weight.sum()*xtrue[0]/atmo_scale_factor
            N_muon = len(muon)*xtrue[5]/muon_scale_factor
            N_muon_atm = len(muon_atm)*xtrue[5]/muon_scale_factor
            N_dm = xtrue[6]

            # Calculate observed number of events
            n = np.random.poisson(N)
            n_atm = np.random.poisson(N_atm)
            n_muon = np.random.poisson(N_muon)
            n_muon_atm = np.random.poisson(N_muon_atm)
            n_dm = np.random.poisson(N_dm)

            dm_particle_id = np.random.choice([2020,2222])
            dm_mass = np.random.uniform(20,10e3)
            dm_energy = dm_mass

            # Sample data from Monte Carlo
            data = pd.concat((data_mc.sample(n=n,weights='flux_weight',replace=True), muon.sample(n=n_muon,replace=True)))
            data_atm = pd.concat((data_atm_mc.sample(n=n_atm,weights='flux_weight',replace=True), muon_atm.sample(n=n_muon_atm,replace=True)))

            # Smear the energies by the additional energy resolution
            data.loc[data.id1 == 20,'energy1'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data.id1 == 20))*xtrue[2])
            data.loc[data.id1 == 22,'energy1'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data.id1 == 22))*xtrue[4])
            data.loc[data.id2 == 20,'energy2'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data.id2 == 20))*xtrue[2])
            data.loc[data.id2 == 22,'energy2'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data.id2 == 22))*xtrue[4])
            data['ke'] = data['energy1'].fillna(0) + data['energy2'].fillna(0) + data['energy3'].fillna(0)

            data_atm.loc[data_atm.id1 == 20,'energy1'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data_atm.id1 == 20))*xtrue[2])
            data_atm.loc[data_atm.id1 == 22,'energy1'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data_atm.id1 == 22))*xtrue[4])
            data_atm.loc[data_atm.id2 == 20,'energy2'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data_atm.id2 == 20))*xtrue[2])
            data_atm.loc[data_atm.id2 == 22,'energy2'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data_atm.id2 == 22))*xtrue[4])
            data_atm['ke'] = data_atm['energy1'].fillna(0) + data_atm['energy2'].fillna(0) + data_atm['energy3'].fillna(0)

            xopt, universe, samples = do_fit(dm_particle_id,dm_mass,dm_energy,data,muon,data_mc,weights,atmo_scale_factor,muon_scale_factor,bins,args.steps,args.print_nll,args.walkers,args.thin,refit=False)

            for i in range(len(FIT_PARS)):
                # The "pull plots" we make here are actually produced via a
                # procedure called "Simulation Based Calibration".
                #
                # See https://arxiv.org/abs/1804.06788.
                pull[i].append(percentileofscore(samples[:,i],xtrue[i]))

        fig = plt.figure()
        axes = []
        for i, name in enumerate(FIT_PARS):
            axes.append(plt.subplot(4,2,i+1))
            n, bins, patches = plt.hist(pull[i],bins=np.linspace(0,100,11),histtype='step')
            expected = len(pull[i])/(len(bins)-1)
            plt.axhline(expected,color='k',ls='--',alpha=0.25)
            plt.axhspan(poisson.ppf(0.005,expected), poisson.ppf(0.995,expected), facecolor='0.5', alpha=0.25)
            plt.title(name)
        for ax in axes:
            despine(ax=ax,left=True,trim=True)
            ax.get_yaxis().set_visible(False)
        plt.tight_layout()

        if args.save:
            fig.savefig("dm_search_pull_plot.pdf")
            fig.savefig("dm_search_pull_plot.eps")
        else:
            plt.show()

        sys.exit(0)

    if args.test:
        # Set the random seed so we get reproducible results here
        np.random.seed(0)

        data_mc_with_weights = pd.merge(data_mc,weights[0],how='left',on=['run','unique_id'])
        data_atm_mc_with_weights = pd.merge(data_atm_mc,weights[0],how='left',on=['run','unique_id'])

        discoveries = 0

        data_mc_with_weights.weight *= data_mc_with_weights.flux_weight
        data_atm_mc_with_weights.weight *= data_atm_mc_with_weights.flux_weight

        for i in range(args.test):
            xtrue = sample_priors()

            # Calculate expected number of events
            N = data_mc.flux_weight.sum()*xtrue[0]/atmo_scale_factor
            N_atm = data_atm_mc.flux_weight.sum()*xtrue[0]/atmo_scale_factor
            N_muon = len(muon)*xtrue[5]/muon_scale_factor
            N_muon_atm = len(muon_atm)*xtrue[5]/muon_scale_factor

            # Calculate observed number of events
            n = np.random.poisson(N)
            n_atm = np.random.poisson(N_atm)
            n_muon = np.random.poisson(N_muon)
            n_muon_atm = np.random.poisson(N_muon_atm)

            # Sample data from Monte Carlo
            data = pd.concat((data_mc_with_weights.sample(n=n,replace=True,weights='weight'), muon.sample(n=n_muon,replace=True)))
            data_atm = pd.concat((data_atm_mc_with_weights.sample(n=n_atm,replace=True,weights='weight'), muon_atm.sample(n=n_muon_atm,replace=True)))

            # Smear the energies by the additional energy resolution
            data.loc[data.id1 == 20,'energy1'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data.id1 == 20))*xtrue[2])
            data.loc[data.id1 == 22,'energy1'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data.id1 == 22))*xtrue[4])
            data.loc[data.id2 == 20,'energy2'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data.id2 == 20))*xtrue[2])
            data.loc[data.id2 == 22,'energy2'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data.id2 == 22))*xtrue[4])
            data['ke'] = data['energy1'].fillna(0) + data['energy2'].fillna(0) + data['energy3'].fillna(0)

            data_atm.loc[data_atm.id1 == 20,'energy1'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data_atm.id1 == 20))*xtrue[2])
            data_atm.loc[data_atm.id1 == 22,'energy1'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data_atm.id1 == 22))*xtrue[4])
            data_atm.loc[data_atm.id2 == 20,'energy2'] *= (1+xtrue[1]+np.random.randn(np.count_nonzero(data_atm.id2 == 20))*xtrue[2])
            data_atm.loc[data_atm.id2 == 22,'energy2'] *= (1+xtrue[3]+np.random.randn(np.count_nonzero(data_atm.id2 == 22))*xtrue[4])
            data_atm['ke'] = data_atm['energy1'].fillna(0) + data_atm['energy2'].fillna(0) + data_atm['energy3'].fillna(0)

            limits, best_fit, discovery_array = get_limits(DM_MASSES,data,muon,data_mc,atmo_scale_factor,muon_scale_factor,bins,args.steps,args.print_nll,args.walkers,args.thin,args.fast)

            for id in (2020,2222):
                if (best_fit[id] > discovery_array[id]).any():
                    discoveries += 1

        print("expected %.2f discoveries" % DISCOVERY_P_VALUE)
        print("actually got %i/%i = %.2f discoveries" % (discoveries,args.test,discoveries/args.test))

        sys.exit(0)

    limits, best_fit, discovery_array = get_limits(DM_MASSES,data,muon,data_mc,atmo_scale_factor,muon_scale_factor,bins,args.steps,args.print_nll,args.walkers,args.thin,args.universe,args.fast)

    fig = plt.figure()
    for color, dm_particle_id in zip(('C0','C1'),(2020,2222)):
        plt.plot(DM_MASSES[dm_particle_id],np.array(limits[dm_particle_id])*100**3*3600*24*365/fiducial_volume/livetime,color=color,label='$' + ''.join([particle_id[int(''.join(x))] for x in grouper(str(dm_particle_id),2)]) + '$')
    plt.gca().set_xscale("log")
    despine(fig,trim=True)
    plt.xlabel("Energy (MeV)")
    plt.ylabel("Event Rate Limit (events/$\mathrm{m}^3$/year)")
    plt.legend()
    plt.tight_layout()

    if args.save:
        plt.savefig("dm_search_limit.pdf")
        plt.savefig("dm_search_limit.eps")
    else:
        plt.suptitle("Dark Matter Limits")

    fig = plt.figure()
    for color, dm_particle_id in zip(('C0','C1'),(2020,2222)):
        plt.plot(DM_MASSES[dm_particle_id],best_fit[dm_particle_id],color=color,label='$' + ''.join([particle_id[int(''.join(x))] for x in grouper(str(dm_particle_id),2)]) + '$')
        plt.plot(DM_MASSES[dm_particle_id],discovery_array[dm_particle_id],color=color,ls='--')
    plt.gca().set_xscale("log")
    despine(fig,trim=True)
    plt.xlabel("Energy (MeV)")
    plt.ylabel("Event Rate Limit (events)")
    plt.legend()
    plt.tight_layout()

    if args.save:
        plt.savefig("dm_best_fit_with_discovery_threshold.pdf")
        plt.savefig("dm_best_fit_with_discovery_threshold.eps")
    else:
        plt.suptitle("Best Fit Dark Matter")
        plt.show()
