"""
Pipeline orchestrator per worker:
claim video -> download tar -> polish audio -> batch-cycle through segments -> pack results -> upload -> mark done.
"""
from __future__ import annotations

import asyncio
import logging
import tempfile
import time
from pathlib import Path
from typing import Optional

from .audio_polish import polish_all_segments, PolishedSegment
from .batch_cycle import BatchCycleEngine, BatchResult
from .config import WORKER_BATCH_SIZE, BATCH_INTERVAL_SECONDS, EnvConfig
from .db import MockDB, PostgresDB, VideoTask, WorkerStats
from .providers.base import BaseProvider, TranscriptionRequest
from .r2_client import R2Client, ExtractedVideo

logger = logging.getLogger(__name__)


class Pipeline:

    def __init__(
        self,
        config: EnvConfig,
        db: MockDB | PostgresDB,
        r2: R2Client,
        primary_provider: BaseProvider,
        fallback_provider: Optional[BaseProvider],
        worker_id: str,
        stats: WorkerStats,
    ):
        self.config = config
        self.db = db
        self.r2 = r2
        self.primary = primary_provider
        self.fallback = fallback_provider
        self.worker_id = worker_id
        self.stats = stats

    async def process_video(self, task: VideoTask) -> bool:
        """Process a single video end-to-end. Returns True on success."""
        video_id = task.video_id
        language = task.language

        # Reuse prefetch dir if available, otherwise create fresh
        if task.prefetch_dir and task.prefetch_dir.exists():
            work_dir = task.prefetch_dir
            logger.info(f"[{video_id}] Using prefetched download at {work_dir}")
        else:
            work_dir = Path(tempfile.mkdtemp(prefix=f"worker_{self.worker_id}_"))

        try:
            logger.info(f"[{video_id}] Starting pipeline, language={language}")
            self.stats.current_video_id = video_id

            loop = asyncio.get_running_loop()

            # STEP 1: Download tar (skip if prefetched tar already exists)
            tar_path = work_dir / f"{video_id}.tar"
            if tar_path.exists():
                logger.info(f"[{video_id}] Tar already downloaded (prefetch), extracting...")
            else:
                logger.info(f"[{video_id}] Downloading tar from R2...")
                tar_path = await loop.run_in_executor(
                    None, self.r2.download_tar, video_id, work_dir)
            extracted = await loop.run_in_executor(
                None, self.r2.extract_tar, tar_path, video_id)
            logger.info(f"[{video_id}] Extracted {len(extracted.segment_paths)} segments")

            if not extracted.segment_paths:
                logger.warning(f"[{video_id}] No segments found, marking done")
                await self.db.mark_video_done(video_id)
                return True

            # Use language from metadata if available, fall back to task language
            language = extracted.language or language

            # STEP 2: Polish audio segments (parallel threads inside, run in executor to not block)
            logger.info(f"[{video_id}] Polishing {len(extracted.segment_paths)} segments...")
            polished = await loop.run_in_executor(
                None, polish_all_segments, extracted.segment_paths)
            valid_segments = [p for p in polished if not p.trim_meta.discarded]
            discarded = len(polished) - len(valid_segments)
            logger.info(f"[{video_id}] Polished: {len(valid_segments)} valid, {discarded} discarded")

            if not valid_segments:
                logger.warning(f"[{video_id}] All segments discarded after polishing")
                await self.db.mark_video_done(video_id)
                return True

            self.stats.segments_remaining = len(valid_segments)

            # STEP 3: Batch cycle through segments
            batch_engine = BatchCycleEngine(
                primary_provider=self.primary,
                fallback_provider=self.fallback,
                worker_id=self.worker_id,
                video_id=video_id,
            )

            all_transcription_jsons: dict[str, dict] = {}
            total_batches = (len(valid_segments) + WORKER_BATCH_SIZE - 1) // WORKER_BATCH_SIZE

            for batch_idx in range(total_batches):
                batch_start_time = time.monotonic()
                start = batch_idx * WORKER_BATCH_SIZE
                end = min(start + WORKER_BATCH_SIZE, len(valid_segments))
                batch_segments = valid_segments[start:end]

                # Build requests
                requests = []
                audio_durations = {}
                trim_metas = {}

                for seg in batch_segments:
                    seg_id = seg.trim_meta.original_file
                    if seg.trim_meta.was_split:
                        seg_id = f"{seg.trim_meta.original_file}_split{seg.trim_meta.split_index}"

                    requests.append(TranscriptionRequest(
                        segment_id=seg_id,
                        audio_base64=seg.base64_audio,
                        language_code=language,
                        original_file=seg.trim_meta.original_file,
                    ))
                    audio_durations[seg_id] = seg.trim_meta.final_duration_ms / 1000
                    trim_metas[seg_id] = {
                        "abrupt_start": seg.trim_meta.abrupt_start,
                        "abrupt_end": seg.trim_meta.abrupt_end,
                        "was_split": seg.trim_meta.was_split,
                    }

                logger.info(f"[{video_id}] Batch {batch_idx+1}/{total_batches}: {len(requests)} segments")
                batch_result = await batch_engine.run_batch(
                    requests=requests,
                    expected_language=language,
                    audio_durations=audio_durations,
                    trim_metas=trim_metas,
                )

                # Collect transcription JSONs for tar packing
                for resp in batch_result.responses:
                    if resp.transcription_data:
                        all_transcription_jsons[resp.segment_id] = resp.transcription_data

                # Insert results to DB
                if batch_result.transcription_records:
                    await self.db.insert_results(batch_result.transcription_records)
                if batch_result.flag_records:
                    await self.db.insert_flags(batch_result.flag_records)

                # Update stats
                self.stats.segments_sent += batch_result.segments_sent
                self.stats.segments_completed += batch_result.segments_returned
                self.stats.segments_failed += batch_result.segments_error
                self.stats.segments_429 += batch_result.segments_429
                self.stats.cache_hits += batch_result.cache_hits
                self.stats.batches_completed += 1
                self.stats.segments_remaining = len(valid_segments) - end
                self.stats.avg_batch_latency_ms = batch_result.avg_latency_ms

                # Accumulate token totals for cost tracking
                for r in batch_result.responses:
                    if r.token_usage.total_tokens > 0:
                        self.stats.total_input_tokens += r.token_usage.input_tokens
                        self.stats.total_output_tokens += r.token_usage.output_tokens
                        self.stats.total_cached_tokens += r.token_usage.cached_tokens

                # Compute active RPM/TPM
                batch_elapsed_s = (time.monotonic() - batch_start_time)
                if batch_elapsed_s > 0:
                    self.stats.active_rpm = (batch_result.segments_sent / batch_elapsed_s) * 60
                    total_tokens = sum(
                        r.token_usage.total_tokens for r in batch_result.responses
                        if r.token_usage.total_tokens > 0
                    )
                    self.stats.active_tpm = (total_tokens / batch_elapsed_s) * 60

                # Rate limit: wait for next batch interval if we finished early
                elapsed = time.monotonic() - batch_start_time
                remaining_wait = BATCH_INTERVAL_SECONDS - elapsed
                if remaining_wait > 0 and batch_idx < total_batches - 1:
                    logger.debug(f"[{video_id}] Waiting {remaining_wait:.1f}s for next batch")
                    await asyncio.sleep(remaining_wait)

            # STEP 4: Pack results and upload (run in thread to not block event loop)
            logger.info(f"[{video_id}] Packing results tar...")
            path_index = {p.name: p for p in extracted.segment_paths}
            valid_paths = [path_index[s.trim_meta.original_file]
                           for s in valid_segments if s.trim_meta.original_file in path_index]

            loop = asyncio.get_running_loop()
            result_tar = await loop.run_in_executor(
                None,
                self.r2.pack_results_tar,
                video_id, extracted.work_dir, valid_paths, all_transcription_jsons,
                {**extracted.metadata, "transcription_summary": {
                    "total_segments": len(polished),
                    "valid_segments": len(valid_segments),
                    "discarded_segments": discarded,
                    "transcribed_segments": len(all_transcription_jsons),
                    "worker_id": self.worker_id,
                }},
            )

            logger.info(f"[{video_id}] Uploading result tar to R2...")
            await loop.run_in_executor(None, self.r2.upload_tar, result_tar, video_id)

            # STEP 5: Mark done
            await self.db.mark_video_done(video_id)
            self.stats.current_video_id = None
            logger.info(f"[{video_id}] Pipeline complete!")
            return True

        except Exception as e:
            logger.error(f"[{video_id}] Pipeline failed: {e}", exc_info=True)
            await self.db.mark_video_failed(video_id, str(e))
            await self.db.set_worker_error(self.worker_id, str(e))
            self.stats.errors.append(str(e))
            return False

        finally:
            self.r2.cleanup(work_dir)
