import torch
import torchaudio
import torch.nn.functional as F
from pesq import pesq
from pystoi import stoi

from .other import si_sdr, pad_spec

# Settings
sr = 16000
snr = 0.5
N = 30
corrector_steps = 1


def evaluate_model(model, num_eval_files):

    clean_files = model.data_module.valid_set.clean_files
    noisy_files = model.data_module.valid_set.noisy_files
    mixture_files = model.data_module.valid_set.mixture_files
    
    # Select test files uniformly accros validation files
    total_num_files = len(clean_files)
    indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
    clean_files = list(clean_files[i] for i in indices)
    noisy_files = list(noisy_files[i] for i in indices)
    mixture_files = list(mixture_files[i] for i in indices)

    _pesq = 0
    _si_sdr = 0
    _estoi = 0
    # iterate over files
    for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files):
        # Load wavs
        x, sr_ = torchaudio.load(clean_file)
        if sr_ != sr:
            x = torchaudio.transforms.Resample(sr_, sr)(x)
        y, sr_ = torchaudio.load(noisy_file) 
        if sr_ != sr:
            y = torchaudio.transforms.Resample(sr_, sr)(y)
        m, sr_ = torchaudio.load(mixture_file) 
        if sr_ != sr:
            m = torchaudio.transforms.Resample(sr_, sr)(m)
        
        min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
        x = x[...,:min_leng]
        y = y[...,:min_leng]
        m = m[...,:min_leng]
        
        T_orig = x.size(1)   

        # Normalize per utterance
        norm_factor = y.abs().max()
        y = y / norm_factor
        m = m / norm_factor

        # Prepare DNN input
        Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
        Y = pad_spec(Y)
        
        M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0)
        M = pad_spec(M)

        y = y * norm_factor

        # print(x.shape,y.shape,m.shape,Y.shape,M.shape)
        # Reverse sampling
        sampler = model.get_pc_sampler(
            'reverse_diffusion', 'ald', Y.cuda(), M.cuda(), N=N, 
            corrector_steps=corrector_steps, snr=snr)
        sample, _ = sampler()

        sample = sample.squeeze()

   
        x_hat = model.to_audio(sample.squeeze(), T_orig)
        x_hat = x_hat * norm_factor

        x_hat = x_hat.squeeze().cpu().numpy()
        x = x.squeeze().cpu().numpy()
        y = y.squeeze().cpu().numpy()

        _si_sdr += si_sdr(x, x_hat)
        _pesq += pesq(sr, x, x_hat, 'wb') 
        _estoi += stoi(x, x_hat, sr, extended=True)
        
    return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files


def evaluate_model2(model, num_eval_files, inference_N, inference_start=0.5):

    
    N = inference_N
    reverse_start_time = inference_start
    
    clean_files = model.data_module.valid_set.clean_files
    noisy_files = model.data_module.valid_set.noisy_files
    mixture_files = model.data_module.valid_set.mixture_files
    
    # Select test files uniformly accros validation files
    total_num_files = len(clean_files)
    indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
    clean_files = list(clean_files[i] for i in indices)
    noisy_files = list(noisy_files[i] for i in indices)
    mixture_files = list(mixture_files[i] for i in indices)



    _pesq = 0
    _si_sdr = 0
    _estoi = 0
    # iterate over files
    for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files):
        # Load wavs
        x, sr_ = torchaudio.load(clean_file)
        if sr_ != sr:
            x = torchaudio.transforms.Resample(sr_, sr)(x)
        y, sr_ = torchaudio.load(noisy_file) 
        if sr_ != sr:
            y = torchaudio.transforms.Resample(sr_, sr)(y)
        m, sr_ = torchaudio.load(mixture_file) 
        if sr_ != sr:
            m = torchaudio.transforms.Resample(sr_, sr)(m)
            
        #requires only for BWE as the dataset has different length of clean and noisy files
        min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
        x = x[...,:min_leng]
        y = y[...,:min_leng]
        m = m[...,:min_leng]

        T_orig = x.size(1)   

        # Normalize per utterance
        # norm_factor = y.abs().max()
        norm_factor = m.abs().max()
        y = y / norm_factor
        x = x / norm_factor
        m = m / norm_factor

        # Prepare DNN input
        Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
        Y = pad_spec(Y)
        
        X = torch.unsqueeze(model._forward_transform(model._stft(x.cuda())), 0)
        X = pad_spec(X)
        
        M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0)
        M = pad_spec(M)
        
        
        y = y * norm_factor
        x = x * norm_factor
        
        x = x.squeeze().cpu().numpy()
        y = y.squeeze().cpu().numpy()

        total_loss = 0
        timesteps = torch.linspace(reverse_start_time, 0.03, N, device=M.device)
        #prior sampling starting from reverse_start_time 
        std = model.sde._std(reverse_start_time*torch.ones((M.shape[0],), device=M.device))
        z = torch.randn_like(M)
        X_t = M + z * std[:, None, None, None]
        
        #reverse steps by Euler Maruyama
        for i in range(len(timesteps)):
            t = timesteps[i]
            if i != len(timesteps) - 1:
                dt = t - timesteps[i+1]
            else:
                dt = timesteps[-1]
            with torch.no_grad():
                #take Euler step here
                f, g = model.sde.sde(X_t, t, M)
                vec_t = torch.ones(M.shape[0], device=M.device) * t 
                score = model.forward(X_t, vec_t, M, Y, vec_t[:,None,None,None])
                mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1})
                if i == len(timesteps) - 1: #output
                    X_t = mean_x_tm1 
                    break
                z = torch.randn_like(X) 

                X_t = mean_x_tm1 + z*g*torch.sqrt(dt)

        sample = X_t
        sample = sample.squeeze()
        x_hat = model.to_audio(sample.squeeze(), T_orig)
        x_hat = x_hat * norm_factor
        x_hat = x_hat.squeeze().cpu().numpy()
        _si_sdr += si_sdr(x, x_hat)
        _pesq += pesq(sr, x, x_hat, 'wb') 
        _estoi += stoi(x, x_hat, sr, extended=True)

    return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, total_loss/num_eval_files


def convert_to_audio(X, deemp, T_orig, model, norm_factor):
    
    sample = X

    sample = sample.squeeze()
    if len(sample.shape)==4:
        sample = sample*deemp[None, None, :, None].to(device=sample.device)
    elif len(sample.shape)==3:
        sample = sample*deemp[None, :, None].to(device=sample.device)
    else:
        sample = sample*deemp[:, None].to(device=sample.device)

    x_hat = model.to_audio(sample.squeeze(), T_orig)
    x_hat = x_hat * norm_factor

    x_hat = x_hat.squeeze().cpu().numpy()
    return x_hat