"""
Download all 4 validation models from HuggingFace.
Called at container startup (not build time) so the Docker image stays slim.
Uses HF_TOKEN from environment for gated model access.
Idempotent: skips models already in the HF cache.
"""
import os
import sys
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("model_download")

HF_TOKEN = os.getenv("HF_TOKEN", "")


def hf_login():
    if not HF_TOKEN:
        log.warning("HF_TOKEN not set — gated models may fail to download")
        return
    try:
        from huggingface_hub import login
        login(token=HF_TOKEN, add_to_git_credential=False)
        log.info("HuggingFace authenticated")
    except Exception as e:
        log.warning(f"HF login failed (will try anonymous): {e}")


def download_mms_lid():
    from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification
    log.info("[1/4] MMS LID-256 (facebook/mms-lid-256)...")
    AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256")
    Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256")
    log.info("[1/4] done")


def download_voxlingua():
    import huggingface_hub
    from pathlib import Path

    log.info("[2/4] VoxLingua107 (speechbrain/lang-id-voxlingua107-ecapa)...")
    save_dir = Path("/tmp/voxlingua107")
    save_dir.mkdir(parents=True, exist_ok=True)

    for fname in ["hyperparams.yaml", "embedding_model.ckpt", "classifier.ckpt",
                   "label_encoder.txt", "mean_var_norm_emb.ckpt"]:
        if (save_dir / fname).exists():
            continue
        try:
            huggingface_hub.hf_hub_download(
                repo_id="speechbrain/lang-id-voxlingua107-ecapa",
                filename=fname, local_dir=save_dir,
            )
        except Exception:
            pass
    log.info("[2/4] done")


def download_conformer():
    from transformers import AutoModel
    log.info("[3/4] IndicConformer 600M (ai4bharat/indic-conformer-600m-multilingual)...")
    kwargs = {"trust_remote_code": True}
    if HF_TOKEN:
        kwargs["token"] = HF_TOKEN
    AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", **kwargs)
    log.info("[3/4] done")


def download_wav2vec():
    from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
    log.info("[4/4] Wav2Vec2 Large (facebook/wav2vec2-large-960h-lv60-self)...")
    Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
    Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
    log.info("[4/4] done")


def main():
    hf_login()
    download_mms_lid()
    download_voxlingua()
    download_conformer()
    download_wav2vec()
    log.info("All models ready.")


if __name__ == "__main__":
    main()
