import pandas as pd
import numpy as np
import lal as lal
import lalsimulation as lalsim
clight = 2.99792458e8                # m/s
G = 6.67259e-11                      # m^3/kg/s^2 
MSol = 1.989e30                      # kg
### Function Definitions

def read_bbh_file(filename, num_lines):
    df = pd.read_csv(filename, nrows=num_lines)
    all_param_dicts = {}
    for i in range(num_lines):
        params = {}
        params['redshift']= np.random.random()
        params['m1_det']=df["m1"][i]/(1.+params['redshift'])
        params['m2_det']=df["m2"][i]/(1.+params['redshift'])
        params['m1_source']=df["m1"][i]
        params['m2_source']=df["m2"][i]
        params['s1x']= df["spin mag 1"][i]*np.sin(df["spin polar 1"][i])*np.cos(df["spin azimuth 1"][i])
        params['s1y']= df["spin mag 1"][i]*np.sin(df["spin polar 1"][i])*np.sin(df["spin azimuth 1"][i])
        params['s1z']= df["spin mag 1"][i]*np.cos(df["spin polar 1"][i])
        params['s2x']= df["spin mag 2"][i]*np.sin(df["spin polar 2"][i])*np.cos(df["spin azimuth 2"][i])
        params['s2y']= df["spin mag 2"][i]*np.sin(df["spin polar 2"][i])*np.sin(df["spin azimuth 2"][i])
        params['s2z']= df["spin mag 2"][i]*np.cos(df["spin polar 2"][i])
        params['luminosity_distance']=df["distance"][i]
        params['iota']=df["inclination"][i]
        params['chi1L']= df["chi1L"][i]
        params['chi2L']= df["chi2L"][i]
        params['chip']= df["chip"][i]
        params['alpha']= df["alpha"][i]
        params['psi']= df["psi"][i]        
        params['ra']= df["ra"][i]
        params['dec']=df["declination"][i]
        params['psi']= df["coalesence phase"][i]
        params['phiRef']= df["phase"][i]
        all_param_dicts[str(i)] = params
    return all_param_dicts

def read_file(filename,columns):
    file_dict={}
    f = open(filename,"r") 
    f_read=f.readlines()
    for line in f_read[1:]:
        split_line=line.split()
        for i in range(len(columns)):
            try:
                file_dict[columns[i]].append(float(split_line[i]))
            except:
                file_dict[columns[i]]=[float(split_line[i])]
    for i in range(len(columns)):
        file_dict[columns[i]]=np.array(file_dict[columns[i]])
    return file_dict

def htilde_of_f(fmin, fmax, deltaF, m1, m2, chi1L, chi2L, chip, thetaJ, alpha, dist, phi_c):

    fref = 20.

    dist = 1e6*lal.lal.PC_SI*dist
    m1 = m1*lal.MSUN_SI
    m2 = m2*lal.MSUN_SI

    H = lalsim.SimIMRPhenomP(chi1L, chi2L, chip, thetaJ, m1, m2, dist, alpha, phi_c, deltaF, fmin, fmax, fref, 1, None)

    hplus = H[0].data.data * np.exp(1j*np.pi*phi_c)
    hcross = H[1].data.data * np.exp(1j*np.pi*phi_c)

    return hplus, hcross

def compute_waveform_at_detector_L(fmin, fmax, deltaF, m1, m2, chi1L, chi2L, chip, thetaJ, alpha, dist, phi_c,ra,dec):
    '''
    http://software.ligo.org/docs/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_p_8c.html
    
    m1, m2: in units of M_sol
    dist: in units of MPC
    chi1L: Component of dimensionless spin 1 along Lhat; s1z for now
    chi2L: Component of dimensionless spin 2 along Lhat; s2z for now
    chip: Effective spin in the orbital plane; psi for now  
    thetaJ: Angle between J0 and line of sight (z-direction); iota for now
    alpha: Initial value of alpha angle (azimuthal precession angle); 0 for now
    phi_c: Orbital phase at the peak of the underlying non precessing model (rad); phiRef for now
    
    '''
    hp,hc = htilde_of_f(fmin,fmax,deltaF,m1,m2,chi1L,chi2L,chip,thetaJ,alpha,dist,phi_c)
    
    fseries = np.linspace(0, fmax, int(fmax/deltaF) + 1)

    # for GW150914, should probably chnage
    epoch = 1167559934.62
    epoch_GPS = lal.lal.LIGOTimeGPS(epoch)
    gmst = lal.GreenwichMeanSiderealTime(epoch_GPS)
    IFO_cached = lal.CachedDetectors[lal.LALDetectorIndexLLODIFF]
    timedelay_L = lal.TimeDelayFromEarthCenter(IFO_cached.location, ra, dec, epoch_GPS) 
    F_plus_L, F_cross_L = lal.ComputeDetAMResponse(IFO_cached.response, ra, dec, 2., gmst)
    htilde_L = F_plus_L * hp + F_cross_L * hc
    # the +1 gets rid of wraparound behavior    
    htilde_L *= np.exp(1j*np.pi*2*fseries*(timedelay_L+1))                         
    return htilde_L


def apply_modulation(waveform_fft,freqs,mtot,lamda_amp,lamda_freq):
    '''
    modifies waveform based off of the exponential arg, 2 lambda factors
    '''
    exp_arg = ((2.*G*mtot*MSol*np.pi*freqs)**(2./3))/(clight**2)
    mult_factor=np.exp((lamda_amp+1.j*lamda_freq)*exp_arg)
    mod_waveform_fft=waveform_fft*mult_factor
    '''
    plt.figure()
    plt.plot(freqs,waveform_fft)
    plt.xlabel("freqs")
    plt.ylabel("waveform fft")
    plt.show()
    
    plt.figure()
    plt.plot(freqs,exp_arg)
    plt.xlabel("freqs")
    plt.ylabel("exp_arg")
    plt.show()   
    print mult_factor
    plt.figure()
    plt.plot(freqs,np.real(mult_factor),label='real')
    plt.xlabel("freqs")
    plt.ylabel("mult factor")
    plt.legend(loc="upper right")
    plt.show() 
                       
    plt.figure()
    plt.plot(freqs,np.imag(mult_factor),label="imag")
    plt.xlabel("freqs")
    plt.ylabel("mult factor")
    plt.legend(loc="upper right")
    plt.show() 
    
    plt.figure()
    plt.plot(freqs,waveform_fft)
    plt.plot(freqs,mod_waveform_fft)
    plt.xlabel("freqs")
    plt.ylabel("waveform fft")
    plt.show() 
    '''

    return mod_waveform_fft


def generate_mod_waveform_fft_at_detector(fmin,fmax,deltaF,m1,m2,chi1L,chi2L,chip,thetaJ,alpha,dist,phi_c,
                                          lamda_amp,lamda_freq,ra,dec):
    '''
    makes waveform, THEN modulates, THEN sends to detector
    '''
    # generates the raw waveform at source
    hp,hc = htilde_of_f(fmin,fmax,deltaF,m1,m2,chi1L,chi2L,chip,thetaJ,alpha,dist,phi_c)
    
    fseries = np.linspace(0, fmax, int(fmax/deltaF) + 1)
    
    #modulates the waveform
    hp_mod = apply_modulation(hp,fseries,(m1+m2),lamda_amp,lamda_freq)
    hc_mod = apply_modulation(hc,fseries,(m1+m2),lamda_amp,lamda_freq)
    
    #generates the response at the detector
    epoch = 1167559934.62
    epoch_GPS = lal.lal.LIGOTimeGPS(epoch)
    gmst = lal.GreenwichMeanSiderealTime(epoch_GPS)
    IFO_cached = lal.CachedDetectors[lal.LALDetectorIndexLLODIFF]
    timedelay_L = lal.TimeDelayFromEarthCenter(IFO_cached.location, ra, dec, epoch_GPS) 
    F_plus_L, F_cross_L = lal.ComputeDetAMResponse(IFO_cached.response, ra, dec, 2., gmst)
    htilde_mod_L = F_plus_L * hp_mod + F_cross_L * hc_mod
    # the +1 gets rid of wraparound behavior    
    htilde_mod_L *= np.exp(1j*np.pi*2*fseries*(timedelay_L+1))  
    
    return htilde_mod_L

def shift_time(time,waveform,deltaF):
    '''
    places peak of waveform at t=0
    '''
    fs = deltaF*waveform.size

    tphase = np.absolute(np.unwrap(np.angle(waveform)))
    fGW = np.gradient(tphase)*fs/(2.*np.pi)
    return time-time[np.where(fGW==max(fGW))[0][0]]



