#required libraries
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np
from collections import deque
from gwpy.time import tconvert
from datetime import datetime
from gwpy.time import from_gps
import datetime
import time
from astropy.time import Time
import matplotlib.cbook as cbook
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
from gwpy.segments import DataQualityFlag
import json
import csv
from datetime import datetime
import decimal
from matplotlib.dates import DateFormatter, AutoDateLocator
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
import matplotlib.ticker as mticker
import math as m
import pandas as pd
import matplotlib.ticker as ticker
from datetime import datetime, timedelta
from scipy.stats import norm
import math
import argparse
from calibplot.calfetcher import CalFetcher
from lal import LIGOTimeGPS
import ast
from gwpy.timeseries import TimeSeries
from scipy.optimize import curve_fit
from scipy import stats
from scipy.stats import norm





class mango(object):
    """
    The `mango` class retrieves Pcal x/y ratio data, DARM error data (derr), and GDS reconstructed strain data from InfluxDB.

    Methods:
        get_gds_data(gps_start_time, gps_end_time)
        get_derr_data(gps_start_time, gps_end_time)
        get_locked_data(filename)
    """
    
    def __init__(self):
        """Initializes the mango class."""
        pass

    def get_gds_data(self, gps_start_time, gps_end_time):
        """
        Fetches GDS reconstructed strain data from InfluxDB.

        This method retrieves data for Pcal x/y means, x/y ratios, lock state, 
        and coherence state between the given GPS start and end times.

        Args:
            gps_start_time (int): The GPS start time for data retrieval.
            gps_end_time (int): The GPS end time for data retrieval.

        Returns:
            pd.DataFrame: A DataFrame containing the fetched GDS data.

        Raises:
            KeyError: If a key is not found in the fetched data.
        """
        pd.set_option('display.float_format', '{:.12f}'.format)
        os.environ['INFLUX_USERNAME'] = 'lhocalib' 
        os.environ['INFLUX_PASSWORD'] = 'calibrator'
        config_file = '/home/emmanuel.makelele/influxdb/ligo-calibplot/influx_config_LHO.ini'
        fetcher = CalFetcher(config_file)
        fields_pcalymean = ['time', 'coh_state', 'lock_state', 'strain_chan', 'data']
        fields_pcalxmean = ['time', 'coh_ok', 'data']
        fields_xyratio = ['time', 'data']
        meas_pcalymean = ['pcalymean']
        meas_pcalxmean = ['pcalxmean']
        meas_xyratio = ['xyratio']
        try:
            pcalxmean_data = fetcher.fetch_data(gps_start_time, gps_end_time, meas_pcalxmean, fields_pcalxmean)
            pcalymean_data = fetcher.fetch_data(gps_start_time, gps_end_time, meas_pcalymean, fields_pcalymean)
            xyratio_data = fetcher.fetch_data(gps_start_time, gps_end_time, meas_xyratio, fields_xyratio)
            pcalxmean_data = pcalxmean_data.rename(columns={'coh_ok':'coh_state_x'})
            pcalymean_data = pcalymean_data.rename(columns={'coh_state':'coh_state_y'})
            #print(pcalxmean_data)
            #print(pcalymean_data)
            #print(xyratio_data)
            tmp = pcalymean_data.merge(pcalxmean_data, on = 'time', how = 'left')
            data = tmp.merge(xyratio_data, on = 'time', how = 'left')
            print(data)
        except KeyError as ke:
            return print('KeyError:', ke)
        return data

    def get_derr_data(self, gps_start_time, gps_end_time):
        """
        Fetches DARM error (derr) data from InfluxDB.

        This method retrieves data for Pcal x/y DARM error means and x/y ratios between 
        the given GPS start and end times.

        Args:
            gps_star_time (int): The GPS start time for data retrieval.
            gps_end_time (int): The GPS end time for data retrieval.

        Returns:
            pd.DataFrame: A DataFrame containing the fetched DARM error data.

        Raises:
            KeyError: If a key is not found in the fetched data.
        """
        pd.set_option('display.float_format', '{:.12f}'.format)
        os.environ['INFLUX_USERNAME'] = 'lhocalib'
        os.environ['INFLUX_PASSWORD'] = 'calibrator'
        config_file = '/home/emmanuel.makelele/influxdb/ligo-calibplot/influx_config_LHO.ini'
        fetcher = CalFetcher(config_file)
        #fields = ['time', 'strain_channel', 'lock_state', 'data']
        #meas = ['pcalxmean_darmerr', 'pcalymean_darmerr', 'xyratio_darmerr']
        fields_pcalymean = ['time', 'coh_state', 'lock_state', 'strain_chan', 'data']
        fields_pcalxmean = ['time', 'coh_ok', 'data']
        fields_xyratio = ['time', 'data']
        meas_pcalymean = ['pcalymean_darmerr']
        meas_pcalxmean = ['pcalxmean_darmerr']
        meas_xyratio = ['xyratio_darmerr']
        try:
            pcalxmean_data = fetcher.fetch_data(gps_start_time, gps_end_time, meas_pcalxmean, fields_pcalxmean)
            pcalymean_data = fetcher.fetch_data(gps_start_time, gps_end_time, meas_pcalymean, fields_pcalymean)
            xyratio_data = fetcher.fetch_data(gps_start_time, gps_end_time, meas_xyratio, fields_xyratio)
            pcalxmean_data = pcalxmean_data.rename(columns={'coh_ok':'coh_state_x'})
            pcalymean_data = pcalymean_data.rename(columns={'coh_state':'coh_state_y'})
            #print(pcalxmean_data)
            #print(pcalymean_data)
            #print(xyratio_data)
            tmp = pcalymean_data.merge(pcalxmean_data, on = 'time', how = 'left')
            data = tmp.merge(xyratio_data, on = 'time', how = 'left')
            #print(data)
            #data = fetcher.fetch_data(gps_start_time, gps_end_time, meas, fields)
        except KeyError as ke:
            return print('KeyError:', ke)
        return data
    def get_locked_data(self, filename):
        """
        Filters data based on the 'lock_state' column.

        Args:
            filename (pd.DataFrame): A DataFrame containing lock state information.

        Returns:
        pd.DataFrame: A DataFrame filtered to include only rows where 'lock_state' does not contain 'not'.
        """
        data = filename[filename['lock_state'].str.contains('not') == False]
        data1 = data[data['coh_state_y'].str.contains('NOT_OKAY') == False]
        data2 = data1[data1['coh_state_x'].str.contains('NOT_OKAY') == False]
        return data2


def save_data(df, filename):
    """
    Save the DataFrame to a CSV file in the current working directory.

    Args:
        df (pd.DataFrame): The DataFrame to be saved.
        filename (str): The name of the file.

    Returns:
        None
    """
    file_path = os.path.join(os.getcwd(), filename)
    df.to_csv(file_path, index=False)
    print(f"Data saved to {file_path}")


# Argument parser to handle GPS start and end times
parser = argparse.ArgumentParser(description='Fetch and filter Pcal x/y ratio, DARM error, and GDS data.')
parser.add_argument('start_gps', type=int, help='The GPS start time.')
parser.add_argument('end_gps', type=int, help='The GPS end time.')

# Parse the arguments
args = parser.parse_args()

# Initialize mango instance
m = mango()

# Fetch GDS data and DARM error data
df_gds = m.get_gds_data(args.start_gps, args.end_gps)
df_derr = m.get_derr_data(args.start_gps, args.end_gps)

# Filter for locked state data
gds = m.get_locked_data(df_gds).dropna()
derr = m.get_locked_data(df_derr).dropna()
gds.to_csv('testgds.csv')
derr.to_csv('testderr.csv')
## subplot of CC_derr and CC_Strain

start_time = min(derr['time'].min(), gds['time'].min())
# Convert GPS time to UTC
gps_epoch = datetime(1980, 1, 6, 0, 0, 0)
leap_seconds = 18  # Current number of leap seconds
start_time_utc = gps_epoch + timedelta(seconds=(start_time - leap_seconds))

# Calculate days since the first GPS time
days_since_start_derr = (derr['time'] - start_time) / (24 * 3600)
days_since_start_gds = (gds['time'] - start_time) / (24 * 3600)

# Calculate mean, median, amd standard deviationd
unfmediand = derr['xyratio_darmerr'].median()
unfmediang = gds['xyratio'].median()
derr_mean = derr['xyratio_darmerr'].mean()
gds_mean = gds['xyratio'].mean()
derr_std = derr['xyratio_darmerr'].std()
gds_std = gds['xyratio'].std()

fig, axs = plt.subplots(2, 1, figsize=(9, 7), dpi=150, sharex=True)  # Create 2 subplots, sharing x-axis

# First subplot
axs[0].scatter(days_since_start_derr, derr['xyratio_darmerr'], marker='o', color='#1f77b4', alpha=0.5,label = f'$N$: {len(derr)}\n $\mu$: {derr_mean:.6f}\n $\sigma$: {derr_std:.6f}')
axs[0].axhline(unfmediand, color='#ff7f0e', linewidth=2.5, label=f'Median: {unfmediand:.7f}')
axs[0].legend(fontsize=12)
axs[0].set_ylabel('$CC_{DARM\_ERR}$', fontsize=18)
axs[0].tick_params(axis='both', which='major', labelsize=16)
# Second subplot
axs[1].scatter(days_since_start_gds, gds['xyratio'], marker='o',color='#2ca02c', alpha=0.5,label = f'$N$: {len(gds)}\n $\mu$: {gds_mean:.6f}\n $\sigma$: {gds_std:.6f}')
axs[1].axhline(unfmediang, color='#d62728', linewidth=2.5, label=f'Median: {unfmediang:.7f}')
axs[1].legend(fontsize=12)
axs[1].set_xlabel(f'Days since first GPS time ({start_time})\nUTC: {start_time_utc}', fontsize=18)
axs[1].set_ylabel('$CC_{STRAIN}$', fontsize=18)
axs[1].tick_params(axis='both', which='major', labelsize=16)

plt.tight_layout(pad=1.0)  # Adjust padding
plt.savefig('CC_ratio_locked.pdf')
plt.show()


########## Rxy computation

#divide CC_STRAIN/CC_DERR to get RXY

Rxy = gds['xyratio']/derr['xyratio_darmerr']
Rxy = Rxy.dropna()
print(len(Rxy))
print(len(Rxy.dropna()))
print(Rxy.describe())
print(gds['xyratio'].describe())
print(derr['xyratio_darmerr'].describe())
Rxy_time = (gds['time']+derr['time'])/2
Rxy_mean = np.nanmean(Rxy)
Rxy_std = np.nanstd(Rxy)
#####Plot RXY
# Calculate the first GPS time across both DataFrames
start_time1 = list(Rxy_time)[0]

# Convert GPS time to UTC
gps_epoch = datetime(1980, 1, 6, 0, 0, 0)
leap_seconds = 18  # Current number of leap seconds
start_time_utc = gps_epoch + timedelta(seconds=(start_time1 - leap_seconds))

# Calculate days since the first GPS time
days_since_start_Rxy = (Rxy_time - start_time1) / (24 * 3600)

# Calculate medians
median = Rxy.median()

fig, axs = plt.subplots(1, 1, figsize=(9, 7), dpi=150, sharex=True)  # Create 2 subplots, sharing x-axis

# First subplot
axs.scatter(days_since_start_Rxy, Rxy, marker='o', color='#1f77b4', alpha=0.5,label = f'$N$: {len(Rxy.dropna())}\n $\mu$: {Rxy_mean:.6f}\n $\sigma$: {Rxy_std:.6f}')
axs.axhline(median, color='#ff7f0e', linewidth=2.5, label=f'Median: {median:.7f}')
axs.legend(fontsize=12)

axs.set_xlabel(f'Days since first GPS time ({start_time})\nUTC: {start_time_utc}', fontsize=18)

plt.tight_layout(pad=1.0)  # Adjust padding
plt.savefig('Rxy_locked')
plt.show()



####PLOT RAW ECDF OF DATA:
# Sort the data
sorted_data = np.sort(Rxy)
# Calculate the ECDF values
y_vals = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
# Create the ECDF plot
plt.figure(figsize=(8, 6))
plt.plot(sorted_data, y_vals, marker='.', linestyle='none')
plt.xlabel('Data points')
plt.ylabel('ECDF')
plt.title('Empirical Cumulative Distribution Function (ECDF)')
plt.grid(True)
axs.set_ylabel('$R_{XY}$', fontsize=18)
axs.tick_params(axis='both', which='major', labelsize=16)
plt.savefig('ecdf.pdf')
plt.show()

#######check for zero values in Rxy. 
listofzeros = [] 
for i in Rxy:
    if i == 0:
        listofzeros.append(i)
print(len(listofzeros))
     
rxy = Rxy[(Rxy > (Rxy_mean - 0.001)) & (Rxy < (Rxy_mean + 0.001))]
######filter data until we get a smooth ecdf
res1 = stats.ecdf(rxy)

#####CDF FIT
f = lambda x,mu,sigma:norm(mu,sigma).cdf(x)
mu,sigma = curve_fit(f,sorted(rxy),res1.cdf.probabilities)[0]
data = f(sorted(rxy), mu, sigma)
##convert to lists
rxy1 = list(rxy)
data = list(data)
fig, axs = plt.subplots(1, 1, figsize=(9, 7), dpi=150, sharex=True)
axs.scatter(sorted(rxy1),res1.cdf.probabilities,alpha = 0.08,label = f'N: {len(rxy1)}')
axs.plot(sorted(rxy1),data,color = 'r', label = f'$\mu$:{mu:.5f}\n$\sigma$:{sigma:.6f}')
#plt.xlim(1.0002,1.00045)
axs.set_xlabel("$R_{XY}$")
axs.set_ylabel("Cumulative probabilities")
axs.legend()
plt.show()
plt.savefig('ecdf_fitted.pdf')
####histogram of filtered data
ax1 = rxy.hist(bins=121, alpha=0.5, color = 'blue', figsize=(9, 7), label = f'$\mu$: {np.mean(rxy):.5f}\n $\sigma$: {np.std(rxy):.5f}\n N:{len(rxy)}')
plt.xlabel('$R_{XY}$', fontsize=18)
plt.ylabel('$Frequency$')
plt.legend()
plt.savefig('Filteredhist.pdf')


# Save the locked datasets to the current working directory
save_data(gds, 'gds_locked_coherent_data.csv')
save_data(derr, 'derr_locked_coherent_data.csv')
