from scipy import interpolate
from scipy.fft import fft, ifft
from scipy.constants import c, epsilon_0
import numpy as np, math
from matplotlib import pyplot as plt

lambda_0 = 800 * 10**(-9)
omega_0 = 2*np.pi*c / lambda_0
delta_lambda = 50.0 * 10**(-9)
delta_Tau = (1.47 * 10**(-3) * (lambda_0*10**9)**2 / (delta_lambda*10**9)) * 10**(-15) #
delta_omega_PFT = (4*np.log(2) / delta_Tau) # equivalent (equal) to: ((4*ln(2)) / (2*pi)) * (2*pi/delta_Tau) = 0.441 * (2*pi/delta_Tau)
F = 2.0 # focal length in meters, as from Liu 2017

def G(omegas):
    # return ( np.sqrt(np.pi/(2*np.log(2))) * tau * np.exp( -(tau**2/(8*np.log(2))) * (omegas-omega_0)**2 ) ) # why is this here?
    return np.exp( -(delta_Tau**2/(8*np.log(2))) * (omegas-omega_0)**2 )

xsi = 0.1 * (1.0 * 10**(-15)) / (1.0 * 10**(-3))
def phase_c(xs, omegas):
    return ( xsi * np.reshape(xs, (xs.shape[0], 1)) * np.reshape((omegas-omega_0), (1, omegas.shape[0])) )

E0 = np.sqrt( (0.2*10**4 * 8 * np.sqrt(np.log(2))) / (delta_Tau*np.sqrt(np.pi)*c*epsilon_0/2.0) )  * np.sqrt(2*np.pi*np.log(2)) / delta_omega_PFT

def f(xi, omega): # the prefactors from Eq. (5) of Li et al. (2017) (the ones pre-multiplying the Fraunhoffer integral)
    first = omega * np.exp(1j * (omega/c) * F) / (1j * 2*np.pi*c*F) # only function of omega. first is shape (omega.shape[0], )
    omega = np.reshape(omega, (1, omega.shape[0]))
    xi = np.reshape(xi, (xi.shape[0], 1))
    second = np.exp(1j * (omega/c) * xi**2 / (2*F)) # second is shape (xi.shape[0], omega.shape[0]).
    return (first * second) # returned is shape (xi.shape[0], omega.shape[0])

x0 = 0.0
delta_x = 196.0 # obtained from N=10, N_x=8*10^3, xi_max=10um, F=2m
xmin_PFT = x0 - delta_x #
xmax_PFT = x0 + delta_x # 
num_xs_PFT = 8 * 10**3
xs_PFT = np.linspace(xmin_PFT, xmax_PFT, num_xs_PFT)
sampling_spacing_xs_PFT = np.true_divide( (xmax_PFT-xmin_PFT), num_xs_PFT)

num_omegas_focus = 5 * 10**2
maximum_time = 100.0 * 10**(-15) 
N = math.ceil( (np.pi*num_omegas_focus)/(2*delta_omega_PFT*maximum_time) ) - 1
omega_max_focus = omega_0 + N*delta_omega_PFT
omega_min_focus = omega_0 - N*delta_omega_PFT
omegas_focus = np.linspace(omega_min_focus, omega_max_focus, num_omegas_focus) # shape (num_omegas_focus, )
sampling_spacing_omegas_focus = np.true_divide((omega_max_focus-omega_min_focus) , num_omegas_focus)

Es_x_omega = np.multiply(   (E0 * G(omegas_focus)) ,
                                (np.exp(1j*phase_c(xs_PFT, omegas_focus)))   # phase_c uses xsi, the PFT coefficient
# Es_x_omega holds across columns (vertically downwards) the x-dependence and across rows (horizontally) the omega-dependence
# Es_x_omega is shape (num_xs_PFT, num_omegas_focus)

Bprime_data_real = np.empty((Es_x_omega.shape[0], Es_x_omega.shape[1])) # this can be rewritten in a more Pythonic way
Bprime_data_imag = np.empty((Es_x_omega.shape[0], Es_x_omega.shape[1]))
for i in range(Es_x_omega.shape[1]): # for all the columns (all omegas)
    # Perform FFT wrt x (so go from x to Kappa (a scaled spatial frequency))
    intermediate = fft(Es_x_omega[:, i])
    Bprime_data_real[:, i] = np.real(intermediate) * sampling_spacing_xs_PFT # multiplication by \Delta, see my docu above
    Bprime_data_imag[:, i] = np.imag(intermediate) * sampling_spacing_xs_PFT # multiplication by \Delta, see my docu above
    if i % 10000 == 0:
        print("We have done fft number {}".format(i) + " out of {}".format(Es_x_omega.shape[1]) + "ffts")
# Bprime is function of (Kappa, omega): across rows the omega dependence, across columns the Kappa dependence.
# Get the Kappas:
returned_freqs = np.fft.fftfreq(num_xs_PFT, sampling_spacing_xs_PFT) # shape (num_xs_PFT, )
Kappas_ugly = 2*np.pi * returned_freqs # shape (num_xs_PFT, ), but unordered in terms of the magnitude of the values! see https://numpy.org/doc/stable/reference/generated/numpy.fft.fftfreq.html
Kappas_pretty = 2*np.pi * np.fft.fftshift(returned_freqs)
indices = (Kappas_ugly == Kappas_pretty[:, None]).argmax(1) # shape (num_xs_PFT, )
indices = indices.reshape((indices.shape[0], 1)) # needed for adapting Dani Mesejo's answer: he reordered based on 1D slices laid horizontally, here I reorder based on 1D slices laid vertically.
# see my notebook for visuals, 22-23 Nov 2021
hold_real = Bprime_data_real.shape[1]
hold_imag = Bprime_data_imag.shape[1]

Bprime_data_real_pretty = np.take_along_axis(Bprime_data_real, np.tile(indices, (1, hold_real)), axis=0) # adapted from Dani Mesejo's answer
Bprime_data_imag_pretty = np.take_along_axis(Bprime_data_imag, np.tile(indices, (1, hold_imag)), axis=0) # adapted from Dani Mesejo's answer
print(Bprime_data_real_pretty.shape) # shape (num_xs_PFT, num_omegas_focus), similarly for Bprime_data_imag_pretty

Bprime_real = interpolate.RectBivariateSpline(Kappas_pretty, omegas_focus, Bprime_data_real_pretty) # this CTOR creates an object faster (which can also be queried faster)
Bprime_imag = interpolate.RectBivariateSpline(Kappas_pretty, omegas_focus, Bprime_data_imag_pretty) # than interpolate.interp2d() does.
print("We have the interpolators!")

# Prepare for the aim: plot E versus time (horizontal axis) and xi (vertical axis). 
xi_min = -5.0 * 10**(-6) # um
xi_max = 5.0 * 10**(-6) # um
num_xis = 5000
xis = np.linspace(xi_min, xi_max, num_xis) 
print("We are preparing now!")

Es_Kappa_omega_without_prefactor = np.empty((xis.shape[0], omegas_focus.shape[0]), dtype=complex)
for j in range(Es_Kappa_omega_without_prefactor.shape[0]): # for each row
    for i in range(Es_Kappa_omega_without_prefactor.shape[1]): # for each column
        Es_Kappa_omega_without_prefactor[j, i] = Bprime_real(omegas_focus[i]*xis[j] /(c*F), omegas_focus[i]) + 1j*Bprime_imag(omegas_focus[i]*xis[j] /(c*F), omegas_focus[i])
        if ((i + j*Es_Kappa_omega_without_prefactor.shape[1]) % 30000 == 0):
            print("We have done iter number {}".format(i + j*Es_Kappa_omega_without_prefactor.shape[1]) 
                                    + " out of {}".format(Es_Kappa_omega_without_prefactor.shape[0] * Es_Kappa_omega_without_prefactor.shape[1]) + " iterations in querying the interpolators")

Es_Kappa_omega = np.multiply(      f(xis, omegas_focus),   # f(xis, omegas_focus) is shape (xis.shape[0], omegas_focus.shape[0])
                                Es_Kappa_omega_without_prefactor  # Es_Kappa_omega_without_prefactor is shape (xis.shape[0], omegas_focus.shape[0])
                            ) # the obtained variable is shape (xis.shape[0], omegas_focus.shape[0])

# Do IFT of Es_Kappa_omega w.r.t. omega to go from FD (omega) to TD (time t).
Es_Kappa_time = np.empty_like(Es_Kappa_omega, dtype=complex) # shape (xis.shape[0], omegas_focus.shape[0])
# Along columns (so vertically) the xi dependence, along rows (horizontally), the omega dependence
for i in range(Es_Kappa_omega.shape[0]): # for each row (for each xi)
    Es_Kappa_time[i, :] = ifft(Es_Kappa_omega[i, :]) * (sampling_spacing_omegas_focus/(2*np.pi)) * num_omegas_focus # 1st multiplication is by Delta, 2nd multiplication is by N
    if i % 10000 == 0:
        print("We have done ifft number {}".format(i) + " out of a total of {}".format(Es_Kappa_omega.shape[0]) + " iffts")

returned_times_ugly = np.fft.fftfreq(num_omegas_focus, d=(sampling_spacing_omegas_focus/(2*np.pi))) # shape (num_omegas_focus, )
returned_times_pretty = np.fft.fftshift(returned_times_ugly) # order the returned "frequencies" (here "frequencies" = times because it's IFT (so from FD to TD))
indices = (returned_times_ugly == returned_times_pretty[:, None]).argmax(1)
Es_Kappa_time = np.take_along_axis(Es_Kappa_time, np.tile(indices, (Es_Kappa_time.shape[0], 1)), axis=1) # this is purely Dani Mesejo's answer

returned_times_pretty_mesh, xis_mesh = np.meshgrid(returned_times_pretty, xis)
fig, ax = plt.subplots()
c = ax.pcolormesh(returned_times_pretty_mesh, xis_mesh, np.real(Es_Kappa_time), cmap='viridis')
fig.colorbar(c, ax=ax, label=r'$[V/m]$')
ax.set_xlabel("t [s]")
ax.set_ylabel("xi [m]")

fig, ax = plt.subplots()
ax.imshow(np.multiply(np.square(np.real(Es_Kappa_time)), (c*epsilon_0)), cmap='viridis')
ax.set_xlabel("t [s]")
ax.set_ylabel("xi [m]")


c = ax.pcolormesh(returned_times_pretty_mesh, xis_mesh, (c*epsilon_0)*np.real(Es_Kappa_time)**2, cmap='viridis')
fig.colorbar(c, ax=ax, label=r'$[V/m]$')


    160 fig, ax = plt.subplots()
--> 161 ax.imshow(np.multiply(np.real(Es_Kappa_time)**2, (c*epsilon_0)), cmap='viridis')






