"""
Base classes for transcription validators.
All validators inherit from BaseValidator and produce ValidationResult objects.
"""
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
from enum import Enum


class ValidatorStatus(Enum):
    """Status of a validator."""
    NOT_INITIALIZED = "not_initialized"
    READY = "ready"
    ERROR = "error"
    DISABLED = "disabled"


@dataclass
class WordAlignment:
    """Represents alignment information for a single word."""
    word: str
    start_time: float  # seconds
    end_time: float    # seconds
    confidence: Optional[float] = None  # 0-1 if available
    
    @property
    def duration(self) -> float:
        return self.end_time - self.start_time


@dataclass
class ValidationResult:
    """Result from a single validator for a single audio segment."""
    validator_name: str
    audio_path: str
    
    # Status
    success: bool
    error_message: Optional[str] = None
    
    # Transcription output (if ASR-based validator)
    transcription: Optional[str] = None
    
    # Alignment data (if alignment-based validator)
    word_alignments: List[WordAlignment] = field(default_factory=list)
    
    # Confidence/scores
    overall_confidence: Optional[float] = None
    alignment_score: Optional[float] = None  # Average alignment confidence
    
    # Timing stats
    processing_time_sec: Optional[float] = None
    audio_duration_sec: Optional[float] = None
    
    # Raw output for debugging
    raw_output: Optional[Dict[str, Any]] = None
    
    @property
    def words_per_second(self) -> Optional[float]:
        """Calculate speaking rate if transcription available."""
        if self.transcription and self.audio_duration_sec:
            word_count = len(self.transcription.split())
            if self.audio_duration_sec > 0:
                return word_count / self.audio_duration_sec
        return None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            "validator_name": self.validator_name,
            "audio_path": self.audio_path,
            "success": self.success,
            "error_message": self.error_message,
            "transcription": self.transcription,
            "word_alignments": [
                {
                    "word": wa.word,
                    "start_time": wa.start_time,
                    "end_time": wa.end_time,
                    "confidence": wa.confidence,
                    "duration": wa.duration
                }
                for wa in self.word_alignments
            ],
            "overall_confidence": self.overall_confidence,
            "alignment_score": self.alignment_score,
            "processing_time_sec": self.processing_time_sec,
            "audio_duration_sec": self.audio_duration_sec,
            "words_per_second": self.words_per_second,
        }


class BaseValidator(ABC):
    """
    Abstract base class for all transcription validators.
    
    Each validator should:
    1. Initialize model/resources in __init__ or setup()
    2. Implement validate() to process a single audio file
    3. Return ValidationResult with relevant metrics
    """
    
    name: str = "base_validator"
    description: str = "Base validator class"
    
    def __init__(self, enabled: bool = True, **kwargs):
        """
        Initialize the validator.
        
        Args:
            enabled: Whether this validator is active
            **kwargs: Validator-specific configuration
        """
        self.enabled = enabled
        self.status = ValidatorStatus.NOT_INITIALIZED
        self.config = kwargs
        self._model = None
        
    @abstractmethod
    def setup(self) -> bool:
        """
        Initialize models and resources.
        Called lazily before first validation.
        
        Returns:
            True if setup successful, False otherwise
        """
        pass
    
    @abstractmethod
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"  # ISO 639-1 code
    ) -> ValidationResult:
        """
        Validate/process a single audio file.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference transcription for alignment
            language: Language code (te=Telugu, hi=Hindi, kn=Kannada, etc.)
            
        Returns:
            ValidationResult with transcription/alignment data
        """
        pass
    
    def validate_batch(
        self,
        audio_paths: List[str],
        reference_texts: Optional[List[str]] = None,
        language: str = "te"
    ) -> List[ValidationResult]:
        """
        Validate multiple audio files.
        Default implementation calls validate() for each file.
        Override for batch-optimized processing.
        """
        results = []
        ref_texts = reference_texts or [None] * len(audio_paths)
        
        for audio_path, ref_text in zip(audio_paths, ref_texts):
            result = self.validate(audio_path, ref_text, language)
            results.append(result)
            
        return results
    
    def ensure_setup(self) -> bool:
        """Ensure validator is set up before use."""
        if self.status == ValidatorStatus.NOT_INITIALIZED:
            try:
                success = self.setup()
                self.status = ValidatorStatus.READY if success else ValidatorStatus.ERROR
            except Exception as e:
                self.status = ValidatorStatus.ERROR
                print(f"[{self.name}] Setup failed: {e}")
                return False
        return self.status == ValidatorStatus.READY
    
    def cleanup(self):
        """Release resources. Override if needed."""
        self._model = None
        
    def get_info(self) -> Dict[str, Any]:
        """Get validator information."""
        return {
            "name": self.name,
            "description": self.description,
            "enabled": self.enabled,
            "status": self.status.value,
            "config": self.config
        }


# Language code mapping for convenience
LANGUAGE_CODES = {
    "telugu": "te",
    "hindi": "hi", 
    "kannada": "kn",
    "tamil": "ta",
    "malayalam": "ml",
    "bengali": "bn",
    "gujarati": "gu",
    "marathi": "mr",
    "punjabi": "pa",
    "odia": "or",
    "english": "en",
    # Also accept ISO codes directly
    "te": "te",
    "hi": "hi",
    "kn": "kn",
    "ta": "ta",
    "ml": "ml",
    "bn": "bn",
    "gu": "gu",
    "mr": "mr",
    "pa": "pa",
    "or": "or",
    "en": "en",
}


def normalize_language_code(language: str) -> str:
    """Convert language name to ISO 639-1 code."""
    return LANGUAGE_CODES.get(language.lower(), language.lower())
