"""
Validator Runner
================

Coordinates running multiple validators on audio segments.
Supports toggling individual validators and collecting aggregated results.
"""
import os
import json
import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

from .base import BaseValidator, ValidationResult, normalize_language_code


@dataclass
class RunnerConfig:
    """Configuration for the validation runner."""
    # Which validators to enable
    enable_indicmfa: bool = True
    enable_indic_conformer: bool = True
    
    # Global settings
    language: str = "te"
    device: str = "auto"
    
    # Output settings
    output_dir: str = "./validation_results"
    save_results: bool = True
    
    # Per-validator options
    indicmfa_options: Dict[str, Any] = field(default_factory=dict)
    indic_conformer_options: Dict[str, Any] = field(default_factory=dict)


@dataclass
class AggregatedResult:
    """Aggregated results from all validators for a single audio file."""
    audio_path: str
    reference_text: Optional[str]
    language: str
    
    # Results per validator
    results: Dict[str, ValidationResult] = field(default_factory=dict)
    
    # Timing
    total_processing_time_sec: float = 0.0
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "audio_path": self.audio_path,
            "reference_text": self.reference_text,
            "language": self.language,
            "total_processing_time_sec": self.total_processing_time_sec,
            "results": {
                name: result.to_dict() 
                for name, result in self.results.items()
            }
        }


class ValidatorRunner:
    """
    Runs multiple validators on audio files.
    
    Example:
    ```python
    runner = ValidatorRunner(
        enable_indicmfa=True,
        enable_indic_conformer=True,
        language="te"
    )
    
    result = runner.validate(
        audio_path="audio.flac",
        reference_text="reference transcription",
        language="te"
    )
    ```
    """
    
    def __init__(
        self,
        enable_indicmfa: bool = True,
        enable_indic_conformer: bool = True,
        language: str = "te",
        device: str = "auto",
        output_dir: str = "./validation_results",
        **kwargs
    ):
        """
        Initialize the validator runner.
        
        Args:
            enable_indicmfa: Enable IndicMFA validator (forced alignment)
            enable_indic_conformer: Enable IndicConformer validator (ASR)
            language: Default language code
            device: Device for ML models ("cuda", "cpu", "auto")
            output_dir: Directory to save results
        """
        self.config = RunnerConfig(
            enable_indicmfa=enable_indicmfa,
            enable_indic_conformer=enable_indic_conformer,
            language=language,
            device=device,
            output_dir=output_dir,
            **{k: v for k, v in kwargs.items() if k.endswith('_options')}
        )
        
        self.validators: Dict[str, BaseValidator] = {}
        self._initialized = False
        
    def _init_validators(self):
        """Initialize all enabled validators."""
        if self._initialized:
            return
            
        print("[Runner] Initializing validators...")
        
        # IndicMFA (Forced Alignment)
        if self.config.enable_indicmfa:
            try:
                from .indicmfa_validator import IndicMFAValidator
                self.validators["indicmfa"] = IndicMFAValidator(
                    enabled=True,
                    **self.config.indicmfa_options
                )
                print("[Runner] ✓ IndicMFA validator added")
            except ImportError as e:
                print(f"[Runner] ✗ IndicMFA unavailable: {e}")
        
        # IndicConformer (ASR)
        if self.config.enable_indic_conformer:
            try:
                from .indic_conformer_validator import IndicConformerValidator
                self.validators["indic_conformer"] = IndicConformerValidator(
                    enabled=True,
                    device=self.config.device,
                    **self.config.indic_conformer_options
                )
                print("[Runner] ✓ IndicConformer validator added")
            except ImportError as e:
                print(f"[Runner] ✗ IndicConformer unavailable: {e}")
        
        print(f"[Runner] {len(self.validators)} validators initialized")
        self._initialized = True
    
    def get_validator_info(self) -> Dict[str, Any]:
        """Get information about all validators."""
        self._init_validators()
        return {
            name: validator.get_info()
            for name, validator in self.validators.items()
        }
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: Optional[str] = None,
        validators: Optional[List[str]] = None
    ) -> AggregatedResult:
        """
        Run validation on a single audio file.
        
        Args:
            audio_path: Path to audio file
            reference_text: Reference transcription (required for MFA)
            language: Language code (defaults to config)
            validators: Specific validators to run (None = all enabled)
            
        Returns:
            AggregatedResult with results from all validators
        """
        self._init_validators()
        
        lang = normalize_language_code(language or self.config.language)
        start_time = time.time()
        
        result = AggregatedResult(
            audio_path=audio_path,
            reference_text=reference_text,
            language=lang
        )
        
        # Determine which validators to run
        validators_to_run = validators or list(self.validators.keys())
        
        for name in validators_to_run:
            if name not in self.validators:
                print(f"[Runner] Warning: Unknown validator '{name}'")
                continue
                
            validator = self.validators[name]
            
            if not validator.enabled:
                continue
                
            print(f"[Runner] Running {name}...")
            
            try:
                validation_result = validator.validate(
                    audio_path=audio_path,
                    reference_text=reference_text,
                    language=lang
                )
                result.results[name] = validation_result
                
                status = "✓" if validation_result.success else "✗"
                proc_time = validation_result.processing_time_sec or 0.0
                print(f"[Runner] {status} {name}: {proc_time:.2f}s")
                      
            except Exception as e:
                print(f"[Runner] ✗ {name} failed: {e}")
                result.results[name] = ValidationResult(
                    validator_name=name,
                    audio_path=audio_path,
                    success=False,
                    error_message=str(e)
                )
        
        result.total_processing_time_sec = time.time() - start_time
        return result
    
    def validate_batch(
        self,
        audio_paths: List[str],
        reference_texts: Optional[List[str]] = None,
        language: Optional[str] = None,
        validators: Optional[List[str]] = None,
        progress_callback: Optional[callable] = None
    ) -> List[AggregatedResult]:
        """
        Run validation on multiple audio files.
        
        Args:
            audio_paths: List of audio file paths
            reference_texts: List of reference transcriptions
            language: Language code
            validators: Specific validators to run
            progress_callback: Called with (current, total) for progress
            
        Returns:
            List of AggregatedResults
        """
        results = []
        ref_texts = reference_texts or [None] * len(audio_paths)
        
        total = len(audio_paths)
        print(f"[Runner] Processing {total} files...")
        
        for i, (audio_path, ref_text) in enumerate(zip(audio_paths, ref_texts)):
            if progress_callback:
                progress_callback(i + 1, total)
            else:
                print(f"[Runner] File {i + 1}/{total}: {Path(audio_path).name}")
            
            result = self.validate(
                audio_path=audio_path,
                reference_text=ref_text,
                language=language,
                validators=validators
            )
            results.append(result)
        
        return results
    
    def save_results(
        self,
        results: List[AggregatedResult],
        filename: Optional[str] = None
    ) -> str:
        """
        Save validation results to JSON file.
        
        Args:
            results: List of AggregatedResults
            filename: Optional filename (auto-generated if None)
            
        Returns:
            Path to saved file
        """
        os.makedirs(self.config.output_dir, exist_ok=True)
        
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"validation_results_{timestamp}.json"
        
        output_path = os.path.join(self.config.output_dir, filename)
        
        output_data = {
            "timestamp": datetime.now().isoformat(),
            "config": {
                "language": self.config.language,
                "validators": list(self.validators.keys())
            },
            "results_count": len(results),
            "results": [r.to_dict() for r in results]
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=2)
        
        print(f"[Runner] Results saved to: {output_path}")
        return output_path
    
    def cleanup(self):
        """Release resources from all validators."""
        for validator in self.validators.values():
            validator.cleanup()
        print("[Runner] Cleanup complete")


def run_all_validators(
    audio_path: str,
    reference_text: Optional[str] = None,
    language: str = "te",
    **kwargs
) -> AggregatedResult:
    """
    Convenience function to run all validators on a single file.
    
    Args:
        audio_path: Path to audio file
        reference_text: Reference transcription
        language: Language code
        **kwargs: Additional options for ValidatorRunner
        
    Returns:
        AggregatedResult
    """
    runner = ValidatorRunner(language=language, **kwargs)
    result = runner.validate(
        audio_path=audio_path,
        reference_text=reference_text,
        language=language
    )
    runner.cleanup()
    return result


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Run transcription validators")
    parser.add_argument("audio_path", help="Path to audio file")
    parser.add_argument("--reference", "-r", help="Reference transcription")
    parser.add_argument("--language", "-l", default="te", help="Language code")
    parser.add_argument("--output", "-o", default="./validation_results",
                        help="Output directory")
    parser.add_argument("--disable-mfa", action="store_true")
    parser.add_argument("--disable-conformer", action="store_true")
    
    args = parser.parse_args()
    
    runner = ValidatorRunner(
        enable_indicmfa=not args.disable_mfa,
        enable_indic_conformer=not args.disable_conformer,
        language=args.language,
        output_dir=args.output
    )
    
    result = runner.validate(
        audio_path=args.audio_path,
        reference_text=args.reference,
        language=args.language
    )
    
    # Print summary
    print(f"\n{'='*60}")
    print("Validation Summary")
    print(f"{'='*60}")
    print(f"Audio: {args.audio_path}")
    print(f"Language: {args.language}")
    print(f"Total time: {result.total_processing_time_sec:.2f}s")
    print()
    
    for name, vr in result.results.items():
        status = "✓" if vr.success else "✗"
        print(f"{status} {name}:")
        if vr.success:
            if vr.transcription:
                print(f"   Transcription: {vr.transcription[:100]}...")
            if vr.word_alignments:
                print(f"   Word alignments: {len(vr.word_alignments)} words")
            if vr.overall_confidence:
                print(f"   Confidence: {vr.overall_confidence:.3f}")
        else:
            print(f"   Error: {vr.error_message}")
    
    # Save results
    runner.save_results([result])
    runner.cleanup()
