#!/usr/bin/env python3
"""
Detailed Pipeline Profiler for Veena3 TTS.

Instruments every stage of the TTS pipeline with microsecond precision:
- Prompt building
- vLLM prefill (prompt encoding)
- vLLM decode (per-token generation)
- Token parsing (BiCodec extraction)
- BiCodec audio decode (GPU)
- Crossfade / audio post-processing
- WAV header creation

For streaming: tracks token-by-token timing to find bottleneck.
For concurrency: measures contention points under parallel load.

Usage:
    # Single request breakdown
    python scripts/profile_pipeline.py

    # Concurrency profiling (1, 5, 10, 20 parallel requests)
    python scripts/profile_pipeline.py --concurrency 1,5,10,20

    # Long text profiling
    python scripts/profile_pipeline.py --text long

    # Streaming profiling
    python scripts/profile_pipeline.py --stream
"""

from __future__ import annotations

import argparse
import asyncio
import logging
import os
import re
import sys
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

# === Path setup (same as local_server.py) ===
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
EXTERNAL_DIR = os.path.join(REPO_ROOT, "external")
for path in [os.path.join(EXTERNAL_DIR, "sparktts"), os.path.join(EXTERNAL_DIR, "AP-BWE"), REPO_ROOT]:
    if path not in sys.path:
        sys.path.insert(0, path)

os.environ["AUTH_BYPASS_MODE"] = "true"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("profiler")

# Test texts
TEXTS = {
    "short": "Hello, this is a quick test.",
    "medium": (
        "The quick brown fox jumps over the lazy dog. "
        "This sentence tests basic text normalization and generation."
    ),
    "long": (
        "In the heart of a bustling city, where towering skyscrapers cast long shadows "
        "over crowded streets, there lived an old bookkeeper named Margaret. She had spent "
        "forty years cataloging the stories of others, yet never found time to write her own. "
        "One rainy Tuesday afternoon, she discovered a leather-bound journal tucked behind "
        "a shelf of dusty encyclopedias. Its pages were blank, waiting."
    ),
}


# === Timing Utilities ===

@dataclass
class StageTimer:
    """Accumulated timing for a named pipeline stage."""
    name: str
    calls: int = 0
    total_ms: float = 0.0
    min_ms: float = float("inf")
    max_ms: float = 0.0
    samples: List[float] = field(default_factory=list)

    def record(self, ms: float):
        self.calls += 1
        self.total_ms += ms
        self.min_ms = min(self.min_ms, ms)
        self.max_ms = max(self.max_ms, ms)
        self.samples.append(ms)

    @property
    def avg_ms(self) -> float:
        return self.total_ms / self.calls if self.calls > 0 else 0

    def report_line(self) -> str:
        if self.calls == 0:
            return f"  {self.name:<35} (not called)"
        if self.calls == 1:
            return f"  {self.name:<35} {self.total_ms:>8.2f}ms"
        return (
            f"  {self.name:<35} {self.total_ms:>8.2f}ms total  "
            f"avg={self.avg_ms:.2f}ms  min={self.min_ms:.2f}ms  "
            f"max={self.max_ms:.2f}ms  calls={self.calls}"
        )


@dataclass
class RequestProfile:
    """Complete timing profile for one TTS request."""
    request_id: str
    text: str
    speaker: str
    stream: bool
    stages: Dict[str, StageTimer] = field(default_factory=dict)

    # Metadata filled during profiling
    total_ms: float = 0.0
    prompt_tokens: int = 0
    generated_tokens: int = 0
    semantic_tokens: int = 0
    global_tokens: int = 0
    audio_bytes: int = 0
    audio_seconds: float = 0.0

    # Streaming-specific
    ttfb_ms: float = 0.0
    chunks_emitted: int = 0
    token_generation_rates: List[float] = field(default_factory=list)  # tokens/sec samples
    decode_latencies: List[float] = field(default_factory=list)  # ms per BiCodec decode

    def stage(self, name: str) -> StageTimer:
        if name not in self.stages:
            self.stages[name] = StageTimer(name=name)
        return self.stages[name]

    def report(self) -> str:
        lines = [
            f"\n{'=' * 80}",
            f"  REQUEST PROFILE: {self.request_id}",
            f"  text={len(self.text)} chars, speaker={self.speaker}, stream={self.stream}",
            f"{'=' * 80}",
            f"  TOTAL: {self.total_ms:.1f}ms",
            f"  Generated: {self.generated_tokens} tokens ({self.semantic_tokens} semantic, {self.global_tokens} global)",
            f"  Audio: {self.audio_seconds:.2f}s ({self.audio_bytes} bytes)",
        ]
        if self.stream:
            lines.append(f"  TTFB: {self.ttfb_ms:.1f}ms, Chunks: {self.chunks_emitted}")
        if self.generated_tokens > 0 and self.total_ms > 0:
            tps = self.generated_tokens / (self.total_ms / 1000)
            lines.append(f"  Token gen rate: {tps:.1f} tok/s")
        if self.audio_seconds > 0:
            rtf = (self.total_ms / 1000) / self.audio_seconds
            lines.append(f"  RTF: {rtf:.3f}")

        lines.append(f"\n  {'STAGE BREAKDOWN':}")
        lines.append(f"  {'-' * 76}")

        # Order stages by total time (descending)
        sorted_stages = sorted(self.stages.values(), key=lambda s: s.total_ms, reverse=True)
        for s in sorted_stages:
            lines.append(s.report_line())

        # Token generation rate over time (streaming)
        if self.token_generation_rates:
            lines.append(f"\n  TOKEN GEN RATE (tok/s over time):")
            # Show first 5, middle, last 5
            rates = self.token_generation_rates
            if len(rates) <= 12:
                lines.append(f"    {' → '.join(f'{r:.0f}' for r in rates)}")
            else:
                first = ' → '.join(f'{r:.0f}' for r in rates[:5])
                last = ' → '.join(f'{r:.0f}' for r in rates[-5:])
                lines.append(f"    {first} → ... → {last}")
            import statistics
            lines.append(f"    avg={statistics.mean(rates):.0f}  min={min(rates):.0f}  max={max(rates):.0f}")

        # BiCodec decode latencies
        if self.decode_latencies:
            lines.append(f"\n  BICODEC DECODE LATENCIES (ms per call):")
            import statistics
            lines.append(
                f"    avg={statistics.mean(self.decode_latencies):.2f}  "
                f"min={min(self.decode_latencies):.2f}  "
                f"max={max(self.decode_latencies):.2f}  "
                f"calls={len(self.decode_latencies)}"
            )

        lines.append(f"{'=' * 80}")
        return "\n".join(lines)


# === Non-Streaming Profiler ===

async def profile_non_streaming(
    runtime,
    text: str,
    speaker: str = "lipakshi",
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
) -> RequestProfile:
    """Profile a single non-streaming request with stage-level timing."""
    from vllm import SamplingParams
    from veena3modal.core.constants import TRAINING_STOP_TOKEN_IDS, AUDIO_SAMPLE_RATE
    from veena3modal.audio.utils import add_wav_header

    profile = RequestProfile(
        request_id=f"prof-{uuid.uuid4().hex[:8]}",
        text=text,
        speaker=speaker,
        stream=False,
    )

    t_total_start = time.perf_counter()

    # === Stage 1: Prompt Building ===
    t0 = time.perf_counter()
    prompt = runtime.prompt_builder.build_prefix(speaker, text)
    profile.stage("1_prompt_build").record((time.perf_counter() - t0) * 1000)
    profile.prompt_tokens = len(runtime.model.tokenizer.encode(prompt, add_special_tokens=False))

    # === Stage 2: Sampling Config ===
    t0 = time.perf_counter()
    sampling_params = SamplingParams(
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        max_tokens=max_tokens,
        repetition_penalty=repetition_penalty,
        stop=TRAINING_STOP_TOKEN_IDS,
        skip_special_tokens=False,
        seed=seed,
    )
    profile.stage("2_sampling_config").record((time.perf_counter() - t0) * 1000)

    # === Stage 3: vLLM Token Generation ===
    request_id = f"prof-{uuid.uuid4().hex[:8]}"
    t_gen_start = time.perf_counter()
    t_first_output = None
    iteration_count = 0
    prev_token_count = 0

    results_generator = runtime.model.engine.generate(
        prompt=prompt,
        sampling_params=sampling_params,
        request_id=request_id,
    )

    final_output = None
    async for request_output in results_generator:
        iteration_count += 1
        now = time.perf_counter()

        if t_first_output is None:
            t_first_output = now
            prefill_ms = (t_first_output - t_gen_start) * 1000
            profile.stage("3a_vllm_prefill").record(prefill_ms)

        # Track per-iteration token rate
        current_count = len(request_output.outputs[0].token_ids)
        new_tokens = current_count - prev_token_count
        if new_tokens > 0 and iteration_count > 1:
            # Estimate instantaneous tokens/sec
            # (note: vLLM may batch multiple tokens per yield)
            dt = now - t_gen_start
            if dt > 0:
                profile.token_generation_rates.append(current_count / dt)
        prev_token_count = current_count
        final_output = request_output

    t_gen_end = time.perf_counter()
    total_gen_ms = (t_gen_end - t_gen_start) * 1000
    decode_only_ms = total_gen_ms - (profile.stages.get("3a_vllm_prefill", StageTimer("")).total_ms)
    profile.stage("3b_vllm_decode").record(decode_only_ms)
    profile.stage("3_vllm_total").record(total_gen_ms)
    profile.generated_tokens = len(final_output.outputs[0].token_ids) if final_output else 0

    if not final_output:
        profile.total_ms = (time.perf_counter() - t_total_start) * 1000
        return profile

    generated_text = final_output.outputs[0].text

    # === Stage 4: Token Extraction (regex) ===
    t0 = time.perf_counter()
    semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", generated_text)
    semantic_ids = [int(t) for t in semantic_matches]
    global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", generated_text)
    global_ids = [int(t) for t in global_matches]
    profile.stage("4_token_extraction").record((time.perf_counter() - t0) * 1000)
    profile.semantic_tokens = len(semantic_ids)
    profile.global_tokens = len(global_ids)

    if not semantic_ids or not global_ids:
        profile.total_ms = (time.perf_counter() - t_total_start) * 1000
        return profile

    # === Stage 5: Token Validation ===
    t0 = time.perf_counter()
    runtime.bicodec_decoder.validate_tokens(semantic_ids, global_ids)
    profile.stage("5_token_validation").record((time.perf_counter() - t0) * 1000)

    # === Stage 6: BiCodec Decode (GPU) ===
    t0 = time.perf_counter()
    audio_bytes = runtime.bicodec_decoder.decode_to_bytes(semantic_ids, global_ids)
    decode_ms = (time.perf_counter() - t0) * 1000
    profile.stage("6_bicodec_decode").record(decode_ms)
    profile.decode_latencies.append(decode_ms)

    if audio_bytes is None:
        profile.total_ms = (time.perf_counter() - t_total_start) * 1000
        return profile

    # === Stage 7: WAV Header ===
    t0 = time.perf_counter()
    wav_bytes = add_wav_header(audio_bytes, sample_rate=AUDIO_SAMPLE_RATE)
    profile.stage("7_wav_header").record((time.perf_counter() - t0) * 1000)

    profile.audio_bytes = len(wav_bytes)
    profile.audio_seconds = (len(audio_bytes)) / (AUDIO_SAMPLE_RATE * 2)  # int16 = 2 bytes/sample
    profile.total_ms = (time.perf_counter() - t_total_start) * 1000

    return profile


# === Streaming Profiler ===

async def profile_streaming(
    runtime,
    text: str,
    speaker: str = "lipakshi",
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
) -> RequestProfile:
    """Profile a single streaming request with per-token and per-decode timing."""
    from vllm import SamplingParams
    from veena3modal.core.constants import CODE_END_TOKEN_ID, DEFAULT_MIN_TOKENS
    from veena3modal.core.token_utils import BiCodecTokenParser
    from veena3modal.audio.crossfade import crossfade_bytes_int16

    profile = RequestProfile(
        request_id=f"prof-stream-{uuid.uuid4().hex[:8]}",
        text=text,
        speaker=speaker,
        stream=True,
    )

    t_total_start = time.perf_counter()

    # === Stage 1: Prompt Building ===
    t0 = time.perf_counter()
    prompt = runtime.prompt_builder.build_prefix(speaker, text)
    profile.stage("1_prompt_build").record((time.perf_counter() - t0) * 1000)

    # === Stage 2: Parser Init ===
    t0 = time.perf_counter()
    tokenizer = runtime.model.tokenizer
    token_parser = BiCodecTokenParser(tokenizer)
    profile.stage("2_parser_init").record((time.perf_counter() - t0) * 1000)

    # === Stage 3: Sampling Config ===
    t0 = time.perf_counter()
    sampling_params = SamplingParams(
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        max_tokens=max_tokens,
        min_tokens=DEFAULT_MIN_TOKENS,
        repetition_penalty=repetition_penalty,
        stop_token_ids=[CODE_END_TOKEN_ID],
        seed=seed,
    )
    profile.stage("3_sampling_config").record((time.perf_counter() - t0) * 1000)

    # Buffers
    semantic_buffer: List[int] = []
    global_buffer: List[int] = []
    processed_token_count = 0

    EXPECTED_GLOBAL_COUNT = 32
    MIN_SEMANTIC_FIRST = 16
    DECODE_INTERVAL = 24

    request_id = f"prof-stream-{uuid.uuid4().hex[:8]}"

    # === Stage 4: vLLM Generation (streaming) ===
    t_gen_start = time.perf_counter()
    t_first_output = None
    t_globals_complete = None
    t_first_decode = None
    last_decode_count = 0
    total_samples_emitted = 0
    previous_chunk_tail: Optional[bytes] = None
    iteration_count = 0

    results_generator = runtime.model.engine.generate(
        prompt=prompt,
        sampling_params=sampling_params,
        request_id=request_id,
    )

    # Per-token timing: time between consecutive yields from vLLM
    t_last_yield = t_gen_start
    inter_token_times: List[float] = []

    async for request_output in results_generator:
        now = time.perf_counter()
        iteration_count += 1

        # Track inter-yield timing (approximation of per-token time)
        inter_ms = (now - t_last_yield) * 1000
        inter_token_times.append(inter_ms)
        t_last_yield = now

        if t_first_output is None:
            t_first_output = now
            prefill_ms = (t_first_output - t_gen_start) * 1000
            profile.stage("4a_vllm_prefill").record(prefill_ms)

        generated_ids = request_output.outputs[0].token_ids

        # === Token Parsing ===
        t_parse = time.perf_counter()
        new_token_ids = generated_ids[processed_token_count:]
        processed_token_count = len(generated_ids)
        token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
        profile.stage("4b_token_parsing").record((time.perf_counter() - t_parse) * 1000)

        # Track globals completion
        if t_globals_complete is None and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            t_globals_complete = now
            globals_ms = (t_globals_complete - t_gen_start) * 1000
            profile.stage("4c_global_token_gen").record(globals_ms)

        # === BiCodec Decode (when enough tokens) ===
        if len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            sem_count = len(semantic_buffer)
            if sem_count >= MIN_SEMANTIC_FIRST and sem_count - last_decode_count >= DECODE_INTERVAL:
                last_decode_count = sem_count

                t_decode = time.perf_counter()
                audio_bytes = runtime.bicodec_decoder.decode_streaming(
                    semantic_ids=semantic_buffer,
                    global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],
                    use_sliding_window=False,
                    trim_warmup=False,
                )
                decode_ms = (time.perf_counter() - t_decode) * 1000
                profile.stage("5_bicodec_decode").record(decode_ms)
                profile.decode_latencies.append(decode_ms)

                if audio_bytes:
                    total_decoded_samples = len(audio_bytes) // 2
                    if total_decoded_samples > total_samples_emitted:
                        new_bytes = audio_bytes[total_samples_emitted * 2:]

                        t_cf = time.perf_counter()
                        to_emit, previous_chunk_tail = crossfade_bytes_int16(
                            previous_chunk_tail, new_bytes,
                            sample_rate_hz=16000, crossfade_ms=50,
                        )
                        profile.stage("6_crossfade").record((time.perf_counter() - t_cf) * 1000)

                        emitted_samples = len(to_emit) // 2
                        total_samples_emitted += emitted_samples

                        if to_emit:
                            profile.chunks_emitted += 1
                            if t_first_decode is None:
                                t_first_decode = time.perf_counter()
                                profile.ttfb_ms = (t_first_decode - t_total_start) * 1000

    t_gen_end = time.perf_counter()
    profile.stage("4_vllm_total").record((t_gen_end - t_gen_start) * 1000)

    # Final decode
    if len(semantic_buffer) > last_decode_count and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
        t_decode = time.perf_counter()
        audio_bytes = runtime.bicodec_decoder.decode_streaming(
            semantic_ids=semantic_buffer,
            global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],
            use_sliding_window=False, trim_warmup=False,
        )
        decode_ms = (time.perf_counter() - t_decode) * 1000
        profile.stage("5_bicodec_decode").record(decode_ms)
        profile.decode_latencies.append(decode_ms)

        if audio_bytes:
            total_decoded_samples = len(audio_bytes) // 2
            if total_decoded_samples > total_samples_emitted:
                new_bytes = audio_bytes[total_samples_emitted * 2:]
                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                    previous_chunk_tail, new_bytes, sample_rate_hz=16000, crossfade_ms=50,
                )
                total_samples_emitted += len(to_emit) // 2
                if to_emit:
                    profile.chunks_emitted += 1

    if previous_chunk_tail:
        total_samples_emitted += len(previous_chunk_tail) // 2
        profile.chunks_emitted += 1

    profile.generated_tokens = processed_token_count
    profile.semantic_tokens = len(semantic_buffer)
    profile.global_tokens = len(global_buffer)
    profile.audio_seconds = total_samples_emitted / 16000
    profile.audio_bytes = total_samples_emitted * 2
    profile.total_ms = (time.perf_counter() - t_total_start) * 1000

    # Compute token generation rates from inter-token times
    # Skip first (prefill) and compute tok/s for decode phase
    if len(inter_token_times) > 2:
        decode_times = inter_token_times[1:]  # Skip prefill
        for dt in decode_times:
            if dt > 0:
                profile.token_generation_rates.append(1000.0 / dt)  # ms → tok/s

    # Record inter-token timing stats
    if inter_token_times:
        import statistics
        decode_times = inter_token_times[1:] if len(inter_token_times) > 1 else inter_token_times
        if decode_times:
            profile.stage("4d_inter_token_avg").record(statistics.mean(decode_times))
            profile.stage("4e_inter_token_p99").record(
                sorted(decode_times)[int(len(decode_times) * 0.99)]
            )

    return profile


# === Concurrency Profiler ===

async def profile_concurrent(
    runtime,
    num_requests: int,
    text: str,
    speaker: str = "lipakshi",
    stream: bool = False,
) -> List[RequestProfile]:
    """Profile multiple concurrent requests."""

    speakers = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika"]

    async def run_one(i: int) -> RequestProfile:
        spk = speakers[i % len(speakers)]
        if stream:
            return await profile_streaming(runtime, text, spk)
        else:
            return await profile_non_streaming(runtime, text, spk)

    profiles = await asyncio.gather(*[run_one(i) for i in range(num_requests)])
    return list(profiles)


def print_concurrency_summary(profiles: List[RequestProfile], concurrency: int, stream: bool):
    """Print aggregated summary for concurrent profiles."""
    import statistics

    mode = "streaming" if stream else "non-streaming"
    totals = [p.total_ms for p in profiles]
    wall_time = max(totals)

    print(f"\n{'#' * 80}")
    print(f"  CONCURRENCY={concurrency} ({mode}), {len(profiles)} requests")
    print(f"{'#' * 80}")
    print(f"  Wall time: {wall_time:.0f}ms")
    print(f"  Throughput: {len(profiles) / (wall_time / 1000):.2f} req/s")
    print(f"  Latency: avg={statistics.mean(totals):.0f}ms  "
          f"p50={sorted(totals)[len(totals)//2]:.0f}ms  "
          f"p99={sorted(totals)[int(len(totals)*0.99)]:.0f}ms")

    if stream:
        ttfbs = [p.ttfb_ms for p in profiles if p.ttfb_ms > 0]
        if ttfbs:
            print(f"  TTFB: avg={statistics.mean(ttfbs):.0f}ms  "
                  f"p50={sorted(ttfbs)[len(ttfbs)//2]:.0f}ms  "
                  f"p95={sorted(ttfbs)[int(len(ttfbs)*0.95)]:.0f}ms")

    # Aggregate stage timing
    stage_totals: Dict[str, List[float]] = {}
    for p in profiles:
        for name, s in p.stages.items():
            if name not in stage_totals:
                stage_totals[name] = []
            stage_totals[name].append(s.total_ms)

    print(f"\n  STAGE BREAKDOWN (avg across {len(profiles)} requests):")
    print(f"  {'-' * 76}")
    sorted_stages = sorted(stage_totals.items(), key=lambda x: -statistics.mean(x[1]))
    for name, times in sorted_stages:
        avg = statistics.mean(times)
        mx = max(times)
        mn = min(times)
        pct = (avg / statistics.mean(totals)) * 100 if totals else 0
        print(f"  {name:<35} avg={avg:>8.2f}ms  min={mn:>8.2f}ms  max={mx:>8.2f}ms  ({pct:.1f}%)")

    # BiCodec decode contention
    all_decode_lats = []
    for p in profiles:
        all_decode_lats.extend(p.decode_latencies)
    if all_decode_lats:
        print(f"\n  BICODEC DECODE (all calls across {len(profiles)} requests):")
        print(f"    total calls: {len(all_decode_lats)}")
        print(f"    avg={statistics.mean(all_decode_lats):.2f}ms  "
              f"p50={sorted(all_decode_lats)[len(all_decode_lats)//2]:.2f}ms  "
              f"p99={sorted(all_decode_lats)[int(len(all_decode_lats)*0.99)]:.2f}ms  "
              f"max={max(all_decode_lats):.2f}ms")

    # Token generation rates
    all_rates = []
    for p in profiles:
        all_rates.extend(p.token_generation_rates)
    if all_rates:
        print(f"\n  TOKEN GEN RATE (all samples):")
        print(f"    avg={statistics.mean(all_rates):.0f} tok/s  "
              f"min={min(all_rates):.0f}  max={max(all_rates):.0f}")


# === GPU Memory Snapshot ===

def gpu_snapshot() -> str:
    try:
        import torch
        if not torch.cuda.is_available():
            return "No GPU"
        alloc = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        return f"GPU: {alloc:.1f}GB alloc, {reserved:.1f}GB reserved / {total:.1f}GB total"
    except Exception:
        return "GPU info unavailable"


# === Main ===

async def main_async(args):
    # Initialize runtime
    logger.info("Initializing TTS runtime for profiling...")
    logger.info(gpu_snapshot())

    from veena3modal.local_server import resolve_model_paths, DEFAULT_LOCAL_MODEL_DIR

    model_path = args.model_path
    llm_path, bicodec_path = resolve_model_paths(model_path)

    from veena3modal.services.tts_runtime import initialize_runtime, get_runtime, is_initialized

    if not is_initialized():
        initialize_runtime(
            model_path=llm_path,
            bicodec_path=bicodec_path,
            device="cuda",
            # gpu_memory_utilization uses VLLM_CONFIG default (0.25 after Tier 1 optimization)
        )

    runtime = get_runtime()
    logger.info(f"Runtime ready: {runtime.model_version}")
    logger.info(gpu_snapshot())

    text = TEXTS.get(args.text, TEXTS["short"])
    concurrency_levels = [int(x) for x in args.concurrency.split(",")]

    # === Single Request Profiles ===
    if not args.concurrency_only:
        logger.info(f"\n{'=' * 80}")
        logger.info(f"SINGLE REQUEST PROFILING (text={args.text}, {len(text)} chars)")
        logger.info(f"{'=' * 80}")

        # Non-streaming
        if not args.stream:
            logger.info("\nProfiling non-streaming...")
            p = await profile_non_streaming(runtime, text, "lipakshi")
            print(p.report())

        # Streaming
        if not args.no_stream:
            logger.info("\nProfiling streaming...")
            p = await profile_streaming(runtime, text, "lipakshi")
            print(p.report())

    # === Concurrency Profiles ===
    for conc in concurrency_levels:
        modes = []
        if not args.stream:
            modes.append(False)
        if not args.no_stream:
            modes.append(True)

        for stream_mode in modes:
            logger.info(f"\nProfiling {conc} concurrent {'streaming' if stream_mode else 'non-streaming'} requests...")
            profiles = await profile_concurrent(runtime, conc, text, stream=stream_mode)
            print_concurrency_summary(profiles, conc, stream_mode)

    logger.info(f"\n{gpu_snapshot()}")
    logger.info("Profiling complete.")


def main():
    parser = argparse.ArgumentParser(description="Veena3 TTS Pipeline Profiler")
    parser.add_argument("--model-path", default=os.path.join(REPO_ROOT, "models", "spark_tts_4speaker"))
    parser.add_argument("--text", default="short", choices=["short", "medium", "long"])
    parser.add_argument("--concurrency", default="1,5,10,20", help="Concurrency levels (comma-separated)")
    parser.add_argument("--stream", action="store_true", help="Stream mode only")
    parser.add_argument("--no-stream", action="store_true", help="Non-stream mode only")
    parser.add_argument("--concurrency-only", action="store_true", help="Skip single-request profiles")
    args = parser.parse_args()
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
