import numpy as np
import random
from scipy.stats import rv_continuous
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


'''
Spin magnitude:
spins: gaussian at .7, width .1
'''
'''
num_events=6000
spin_mean=.7
spin_std_dev=.1
y_range = 1
grad=300

xbins = np.linspace(0,y_range,grad*(y_range))

class Spin_mag_pdf(rv_continuous):
    def _pdf(self, x):
        return (1.0/(2*np.pi*spin_std_dev**2)**.5)*np.exp(-(x-spin_mean)**2/(2*spin_std_dev**2))

#random distribution
spin_mag_pdf = Spin_mag_pdf(name="Spin mag distribution", a=0,b=1)
x=np.array([np.random.rand(num_events)])
y_rand=spin_mag_pdf.ppf(x)[0]

#actual distribution
y_act= (1.0/(2*np.pi*spin_std_dev**2)**.5)*np.exp(-(xbins-spin_mean)**2/(2*spin_std_dev**2))

#FIGURE 1 - LINEAR X AXIS
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
plt.hist(y_rand,normed=True,bins=xbins,label="probability")
plt.plot(xbins,y_act,label="expected function")
fig1.suptitle("Expected Spin Magnitude Distribution for BH's")
plt.ylabel("p")
plt.xlabel("spin")

# curve fitting
n, bins = np.histogram(y_rand,normed=True,bins=xbins)
bin_centers = bins[:-1] + 0.5 * (bins[1:] - bins[:-1])
print len(bin_centers)
print len(n)

def f(x, mean, std_dev,a):
    return a*np.exp(-(x-mean)**2/(2.*std_dev**2))
popt, pcov = curve_fit(f, bin_centers, n, p0 = [.5, .01,1])
print popt

text = "gaussian, mean="+str(popt[0])+",std_dev="+str(popt[1])+" vs. mean=.7,std_dev=.1"

plt.figure(1)
plt.plot(bin_centers,f(bin_centers, *popt),label="fit function")
ax1.set_title(text,fontsize=9)
plt.legend(loc='upper left')

plt.show()
'''

'''
m_1,m_2
First generate M w/ Salpeter. Then generate m_1, m_2 from symmetric mass ratio = q = m_1/m_2 (q>=1)
eta = mu/M = (m_1*m_2)/(m_1+m_2)**2, [0,.25], half-gaussian, peak at .25, width of .05
'''

# FIRST GENERATE THE TOTAL MASS

num_events=20000

min_bh_mass = 2.5 #default: .5, realistically ~2.5

y_range_mass = 100
grad_mass=3
xbins_mass = np.linspace(1,y_range_mass,grad_mass*(y_range_mass))

y_range_eta = .25
grad_eta=100
xbins_eta = np.linspace(0,y_range_eta,grad_eta*(y_range_eta))

class Mass_pdf(rv_continuous):
    def _pdf(self, x):
        return 1.35*((2*min_bh_mass)**(1.35))*(x**(-2.35))
#random distribution
mass_pdf = Mass_pdf(name="Mass distribution", a=2*min_bh_mass)
x=np.array([np.random.rand(num_events)])
y_rand_tot=mass_pdf.ppf(x)[0] #y_rand_tot is the list of random masses

# NOW GENERATE ETA
eta_std_dev=.05
eta_mean=.25

class Eta_pdf(rv_continuous):
    def _pdf(self, x):
        return (2.0/(2*np.pi*eta_std_dev**2)**.5)*np.exp(-(x-eta_mean)**2/(2*eta_std_dev**2))
#random distribution
eta_pdf = Eta_pdf(name="Eta distribution", a=0,b=.25)
x=np.array([np.random.rand(num_events)])
y_rand_eta=eta_pdf.ppf(x)[0] #y_rand_eta is the list of random etas
#actual distribution
y_act= (2.0/(2*np.pi*eta_std_dev**2)**.5)*np.exp(-(xbins_eta-eta_mean)**2/(2*eta_std_dev**2))
#FIGURE 1 - LINEAR X AXIS
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
plt.hist(y_rand_eta,normed=True,bins=xbins_eta,label="probability")
plt.plot(xbins_eta,y_act,label="expected function")
fig1.suptitle("Expected Eta Distribution for BH's")
plt.ylabel("p")
plt.xlabel("eta")
# curve fitting
n, bins = np.histogram(y_rand_eta,normed=True,bins=xbins_eta)
bin_centers = bins[:-1] + 0.5 * (bins[1:] - bins[:-1])
def f(x, mean, std_dev,a):
    return a*np.exp(-(x-mean)**2/(2.*std_dev**2))
popt, pcov = curve_fit(f, bin_centers, n, p0 = [.2, .01,1])
print popt
text = "gaussian, mean="+str(popt[0])+",std_dev="+str(popt[1])+" vs. mean=.25,std_dev=.05"
plt.figure(1)
plt.plot(bin_centers,f(bin_centers, *popt),label="fit function")
ax1.set_title(text,fontsize=9)
plt.legend(loc='upper left')


# CALCULATE M_1 AND M_2 AND PLOT

m_1 = .5*y_rand_tot*(1+(1-4*y_rand_eta))
m_2 = .5*y_rand_tot*(1-(1-4*y_rand_eta))
y_range = 100000
xbins = np.linspace(1,y_range,grad_mass*(y_range))
fig2 = plt.figure(2)
plt.scatter(m_1,m_2,color="b",label="m_1 vs m_2",s=2)
plt.plot(xbins,xbins,color="g",label="m_1=m_2")
fig2.suptitle("m_1 vs m_2 for BH's")
plt.ylabel("m_2 (M_solar)")
plt.xlabel("m_1 (M_solar)")
plt.legend(loc='upper left')

plt.show()
