"""
CTC scoring: compute log-likelihood of reference text given audio log-probabilities.
This is NOT ASR — we don't decode. We ask "how likely is this text given this audio?"
"""
from __future__ import annotations

import logging

import torch

logger = logging.getLogger(__name__)


def compute_ctc_score(
    log_probs: torch.Tensor,
    target_tokens: list[int],
    blank_id: int = 0,
) -> tuple[float, float]:
    """
    Compute CTC log-likelihood of target_tokens given log_probs.

    Args:
        log_probs: [T, V] log-softmax output from CTC model (on any device)
        target_tokens: list of token IDs for the reference text
        blank_id: blank token index (usually 0)

    Returns:
        (raw_score, normalized_score) where normalized = raw / len(tokens)
        Higher is better. Typical range: raw ~ -50 to -5, norm ~ -2.0 to -0.1
    """
    T = log_probs.shape[0]
    S = len(target_tokens)

    if S == 0:
        return (0.0, 0.0)
    if T < S:
        # CTC requires T >= S (more frames than tokens)
        return (float('-inf'), float('-inf'))

    # Move to CPU for CTCLoss (it's fast and avoids device issues)
    log_probs_cpu = log_probs.detach().float().cpu()  # [T, V]

    ctc_loss_fn = torch.nn.CTCLoss(
        blank=blank_id, reduction='none', zero_infinity=True,
    )

    # CTCLoss expects: log_probs [T, N, C], targets [N, S], input_lengths [N], target_lengths [N]
    loss = ctc_loss_fn(
        log_probs_cpu.unsqueeze(1),                          # [T, 1, V]
        torch.tensor([target_tokens], dtype=torch.long),     # [1, S]
        torch.tensor([T], dtype=torch.long),                 # [1]
        torch.tensor([S], dtype=torch.long),                 # [1]
    )

    raw_score = -loss.item()
    normalized_score = raw_score / S

    return (round(raw_score, 4), round(normalized_score, 4))


def character_error_rate(reference: str, hypothesis: str) -> float:
    """
    Compute Character Error Rate between reference and hypothesis.
    Returns 0.0 (perfect match) to 1.0+ (completely wrong).
    Used as CTC score fallback when log-probabilities aren't accessible.
    """
    if not reference:
        return 0.0 if not hypothesis else 1.0
    if not hypothesis:
        return 1.0

    ref_chars = list(reference.strip())
    hyp_chars = list(hypothesis.strip())
    n = len(ref_chars)
    m = len(hyp_chars)

    # DP for edit distance
    dp = list(range(m + 1))
    for i in range(1, n + 1):
        prev = dp[0]
        dp[0] = i
        for j in range(1, m + 1):
            temp = dp[j]
            if ref_chars[i - 1] == hyp_chars[j - 1]:
                dp[j] = prev
            else:
                dp[j] = 1 + min(prev, dp[j], dp[j - 1])
            prev = temp

    return dp[m] / n if n > 0 else 0.0
