import numpy as np
import matplotlib.pyplot as plt

# Function to calculate y, given x and 2 model parameters
def func(x, norm, temp):
  y = norm / x**5 / (np.exp(1.435e8 / (x * temp)) - 1)
  return y

# Data points, hard-wired into the code
x = np.array([1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000,
              11000, 12000, 13000, 14000, 15000], dtype=np.float)
y = np.array([0.00, 1.62, 13.43, 25.37, 28.97, 26.91, 22.81, 18.54, 14.82,
              11.79, 9.39, 7.52, 6.06, 4.92, 4.03])

# Range to do grid-search and sampling rate
norm_min = 1.0e22
norm_max = 1.5e22
norm_num = 300
temp_min = 4500
temp_max = 7500
temp_num = 300
norm_list = np.linspace(norm_min, norm_max, norm_num)
temp_list = np.linspace(temp_min, temp_max, temp_num)

# Table to keep chi^2 value for each set of model parameters
chisq_table = np.zeros((norm_num, temp_num))
for i in xrange(norm_num):
	for j in xrange(norm_num):
		norm_now = norm_list[i]
		temp_now = temp_list[j]
		y_now = func(x, norm_now, temp_now)
		chisq_now = np.sum((y - y_now)**2)
		chisq_table[i,j] = np.log(chisq_now)

# Find model parameters corresponding to lowest chi^2 value
ind = np.unravel_index(np.argmin(chisq_table), chisq_table.shape)
ind_norm = ind[0]
ind_temp = ind[1]
norm_optimal = norm_list[ind_norm]
temp_optimal = temp_list[ind_temp]
print norm_optimal, temp_optimal

# Plot the chi^2 surface
plt.imshow(chisq_table.T, origin='lower', interpolation='nearest')
plt.xlabel("Normalization")
plt.ylabel("Temperature")
plt.xticks([0, ind_norm, norm_num-1], [norm_min, norm_optimal, norm_max])
plt.yticks([0, ind_temp, temp_num-1], [temp_min, temp_optimal, temp_max])
plt.show()
