"""
Conv-TasNet separation with per-chunk speaker assignment via TitaNet.
Fixes the permutation problem: for each chunk, TitaNet verifies
which separated source matches the reference speaker.
"""
import torch
import torchaudio
import numpy as np
import soundfile as sf
from pathlib import Path
from nemo.collections.asr.models import EncDecSpeakerLabelModel
from asteroid.models import ConvTasNet
import asteroid.models.base_models as _bm
import time

OUT = Path("/home/ubuntu/sam_audio_test/tse_v3_out")
OUT.mkdir(exist_ok=True)

MIXTURE = "/home/ubuntu/clip_00.mp3"
REFERENCE = "/home/ubuntu/clip00_reference_45s.wav"
CHUNK_SEC = 10
OVERLAP_SEC = 1
GAP_SEC = 0.15
SR = 16000

device = "cuda"

# Load TitaNet
print("=== Loading TitaNet ===")
spk_model = EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large")
spk_model = spk_model.eval().to(device)

# Reference embedding
ref_audio, ref_sr = torchaudio.load(REFERENCE)
if ref_sr != SR:
    ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, SR)
if ref_audio.shape[0] > 1:
    ref_audio = ref_audio.mean(dim=0, keepdim=True)
with torch.inference_mode():
    _, ref_emb = spk_model.forward(
        input_signal=ref_audio.to(device),
        input_signal_length=torch.tensor([ref_audio.shape[1]], device=device),
    )
ref_emb = ref_emb.squeeze()
print(f"  Ref embedding from {ref_audio.shape[1]/SR:.1f}s")

# Load Conv-TasNet
print("\n=== Loading Conv-TasNet 16kHz ===")
_orig_load = torch.load
def _pl(*a, **kw):
    kw["weights_only"] = False
    return _orig_load(*a, **kw)
_bm.torch.load = _pl
sep_model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_16k")
_bm.torch.load = _orig_load
sep_model = sep_model.eval().to(device)

# Load mixture
print("\n=== Loading mixture ===")
mix_audio, mix_sr = torchaudio.load(MIXTURE)
if mix_sr != SR:
    mix_audio = torchaudio.functional.resample(mix_audio, mix_sr, SR)
if mix_audio.shape[0] > 1:
    mix_audio = mix_audio.mean(dim=0, keepdim=True)
mix = mix_audio.squeeze().numpy()
total_sec = len(mix) / SR
print(f"  {total_sec:.1f}s at {SR}Hz")

# Separate with per-chunk speaker assignment
print("\n=== Separating with per-chunk speaker assignment ===")
chunk_samples = CHUNK_SEC * SR
overlap_samples = OVERLAP_SEC * SR
step_samples = chunk_samples - overlap_samples

target_stream = np.zeros(len(mix))
other_stream = np.zeros(len(mix))
weight_map = np.zeros(len(mix))
cos_sim = torch.nn.CosineSimilarity(dim=-1)

pos = 0
chunk_idx = 0
t_start = time.time()
assignments = []

while pos < len(mix):
    end = min(pos + chunk_samples, len(mix))
    chunk = mix[pos:end]
    if len(chunk) < SR:
        break

    # Separate
    chunk_t = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)
    with torch.inference_mode():
        est = sep_model(chunk_t)

    src0 = est[0, 0].cpu().numpy()[:len(chunk)]
    src1 = est[0, 1].cpu().numpy()[:len(chunk)]

    # Compute speaker embedding for each source
    scores = []
    for src in [src0, src1]:
        if np.max(np.abs(src)) < 0.001:
            scores.append(-1.0)
            continue
        src_t = torch.tensor(src, dtype=torch.float32).unsqueeze(0).to(device)
        src_len = torch.tensor([len(src)], device=device)
        with torch.inference_mode():
            _, src_emb = spk_model.forward(input_signal=src_t, input_signal_length=src_len)
        score = cos_sim(ref_emb.unsqueeze(0), src_emb.squeeze().unsqueeze(0)).item()
        scores.append(score)

    # Assign: higher similarity = target speaker
    if scores[0] >= scores[1]:
        target_src, other_src = src0, src1
        assign = 0
    else:
        target_src, other_src = src1, src0
        assign = 1

    assignments.append((chunk_idx, assign, scores[0], scores[1]))

    # Crossfade window
    fade = min(overlap_samples, len(chunk))
    window = np.ones(len(chunk))
    if pos > 0:
        window[:fade] = np.linspace(0, 1, fade)
    if end < len(mix):
        window[-fade:] = np.linspace(1, 0, fade)

    target_stream[pos:pos+len(chunk)] += target_src * window
    other_stream[pos:pos+len(chunk)] += other_src * window
    weight_map[pos:pos+len(chunk)] += window

    chunk_idx += 1
    del chunk_t, est
    torch.cuda.empty_cache()
    pos += step_samples

# Normalize
mask = weight_map > 0
target_stream[mask] /= weight_map[mask]
other_stream[mask] /= weight_map[mask]

elapsed = time.time() - t_start
print(f"  {chunk_idx} chunks in {elapsed:.1f}s")

# Print assignment stats
swaps = sum(1 for i in range(1, len(assignments)) if assignments[i][1] != assignments[i-1][1])
print(f"  Source swaps corrected: {swaps}/{len(assignments)} chunks")
avg_target_sim = np.mean([max(a[2], a[3]) for a in assignments])
avg_other_sim = np.mean([min(a[2], a[3]) for a in assignments])
print(f"  Avg target similarity: {avg_target_sim:.3f}")
print(f"  Avg other similarity:  {avg_other_sim:.3f}")

# Save
print("\n=== Saving ===")
for label, audio in [("target_speaker", target_stream), ("other_speaker", other_stream)]:
    peak = np.max(np.abs(audio))
    if peak > 0:
        audio = audio / peak * 0.9

    sf.write(str(OUT / f"{label}_full.wav"), audio, SR)

    # Compact
    threshold = 10 ** (-35 / 20)
    frame_len = int(0.02 * SR)
    hop = frame_len // 2
    nf = (len(audio) - frame_len) // hop + 1
    rms = np.array([np.sqrt(np.mean(audio[j*hop:j*hop+frame_len]**2)) for j in range(nf)])
    is_speech = rms > threshold
    min_sil = int(0.3 / (hop / SR))
    min_sp = int(0.15 / (hop / SR))
    for j in range(len(is_speech)):
        if not is_speech[j]:
            start = j
            while j < len(is_speech) and not is_speech[j]: j += 1
            if (j - start) < min_sil: is_speech[start:j] = True
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            if (j - start) < min_sp: is_speech[start:j] = False
        else: j += 1
    segs = []
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            ss = max(0, start*hop - int(0.02*SR))
            se = min(len(audio), j*hop + int(0.02*SR))
            segs.append(audio[ss:se])
        else: j += 1
    gap = np.zeros(int(GAP_SEC * SR))
    parts = []
    for i, seg in enumerate(segs):
        parts.append(seg)
        if i < len(segs) - 1: parts.append(gap)
    compact = np.concatenate(parts) if parts else np.zeros(1)
    sf.write(str(OUT / f"{label}_compact.wav"), compact, SR)
    print(f"  {label}: full={len(audio)/SR:.1f}s | compact={len(compact)/SR:.1f}s ({len(segs)} segs)")

print(f"\n=== DONE ===")
