#!/usr/bin/env python3
"""
Resize model embeddings from 16K → 59,554 for the extended Indic tokenizer.

This script:
1. Loads the original model weights
2. Resizes embedding layer + decoder head
3. Initializes new embeddings from byte-fallback compositions
4. Saves the extended model + tokenizer

Run after downloading model weights:
  python3 resize_embeddings.py

Requires: model weights at the standard HF cache location or via from_pretrained download.
"""

import json
import logging
import os
import shutil
import time

import torch

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

MODEL_NAME = "CohereLabs/cohere-transcribe-03-2026"
TOKENIZER_DIR = "/workspace/training/tokenizer_extension/merged_48k"
OUTPUT_DIR = "/workspace/training/tokenizer_extension/extended_model"
VOCAB_INFO = os.path.join(TOKENIZER_DIR, "vocab_info.json")


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.environ["HF_HOME"] = "/workspace/.hf_home"

    with open(VOCAB_INFO) as f:
        vocab_info = json.load(f)

    orig_vocab_size = vocab_info["original_vocab_size"]
    final_vocab_size = vocab_info["final_vocab_size"]
    logger.info(f"Resizing: {orig_vocab_size} → {final_vocab_size}")

    # ── Load model ──────────────────────────────────────────────────────
    from transformers import AutoModelForSpeechSeq2Seq, PreTrainedTokenizerFast

    logger.info("Loading model...")
    t0 = time.time()
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        dtype=torch.float32,
    )
    logger.info(f"Model loaded in {time.time()-t0:.1f}s")

    # Load tokenizers
    orig_tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=os.path.join(
            "/workspace/.hf_home/hub/models--CohereLabs--cohere-transcribe-03-2026/"
            "snapshots/90cf6a1e8427d6ab5e0060f53c095c245a20da4e/tokenizer.json"
        )
    )
    ext_tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=os.path.join(TOKENIZER_DIR, "tokenizer.json")
    )

    # ── Find embedding and output layers ────────────────────────────────
    embed_layer = output_layer = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Embedding) and module.num_embeddings == orig_vocab_size:
            embed_layer = (name, module)
            logger.info(f"Embedding: {name} [{module.num_embeddings} x {module.embedding_dim}]")
        if isinstance(module, torch.nn.Linear) and module.out_features == orig_vocab_size:
            output_layer = (name, module)
            logger.info(f"Output head: {name} [{module.in_features} x {module.out_features}]")

    if not embed_layer:
        raise RuntimeError("Could not find embedding layer")

    # ── Get new token strings ───────────────────────────────────────────
    # All tokens with ID >= orig_vocab_size are new
    ext_vocab = ext_tokenizer.get_vocab()
    new_tokens = {tok: idx for tok, idx in ext_vocab.items() if idx >= orig_vocab_size}
    logger.info(f"New tokens to initialize: {len(new_tokens)}")

    # ── Resize embedding ────────────────────────────────────────────────
    embed_name, embed_mod = embed_layer
    old_weight = embed_mod.weight.data.clone()
    embed_dim = embed_mod.embedding_dim

    new_embed = torch.nn.Embedding(final_vocab_size, embed_dim)
    new_embed.weight.data[:orig_vocab_size] = old_weight

    # Initialize new embeddings from byte-fallback compositions
    logger.info("Initializing new embeddings from byte-fallback compositions...")
    for token, idx in new_tokens.items():
        # Decode token back to text, encode with original tokenizer
        text = token.replace("▁", " ").strip()
        if not text:
            text = " "
        fallback_ids = orig_tokenizer.encode(text, add_special_tokens=False)
        if fallback_ids:
            # Average the byte-fallback embeddings
            new_embed.weight.data[idx] = old_weight[fallback_ids].mean(dim=0)
        else:
            new_embed.weight.data[idx] = torch.randn(embed_dim) * old_weight.std()

    _set_module(model, embed_name, new_embed)
    logger.info(f"Embedding resized: {orig_vocab_size} → {final_vocab_size}")

    # ── Preserve weight tying: output head shares embedding weights ────
    # In the original model: log_softmax.mlp.layer0.weight = transf_decoder._embedding.token_embedding.weight
    # We resize the embedding, then create a new output Linear that shares its weight.
    logger.info("Re-tying output head weights to embedding (preserving original weight sharing)...")
    if output_layer:
        out_name, out_mod = output_layer
        # Create new Linear with the right shape, then tie weight to embedding
        new_out = torch.nn.Linear(embed_dim, final_vocab_size, bias=out_mod.bias is not None)
        new_out.weight = new_embed.weight  # share the Parameter object
        if out_mod.bias is not None:
            new_out.bias.data[:orig_vocab_size] = out_mod.bias.data
            new_out.bias.data[orig_vocab_size:] = -2.0
        _set_module(model, out_name, new_out)
        # Verify tying
        tied = model.log_softmax.mlp.layer0.weight is model.transf_decoder._embedding.token_embedding.weight
        logger.info(f"Output head re-tied: weight sharing = {tied}")
    else:
        logger.warning("Could not find output head to re-tie")

    # ── Update config ───────────────────────────────────────────────────
    model.config.vocab_size = final_vocab_size
    if hasattr(model.config, "head") and isinstance(model.config.head, dict):
        model.config.head["num_classes"] = final_vocab_size

    # ── Save in bfloat16 (matching training dtype) ──────────────────────
    logger.info(f"Saving extended model to {OUTPUT_DIR} (bfloat16)...")
    model = model.to(torch.bfloat16)

    # Save state dict. Include BOTH the embedding weight AND the output head
    # weight (even though they're tied) because the model's __init__ re-ties
    # them, but from_pretrained loads weights AFTER init. If the output head
    # key is missing, it gets randomly initialized instead of being tied.
    state_dict = model.state_dict()
    # Verify both keys exist (they should share the same tensor data)
    assert "transf_decoder._embedding.token_embedding.weight" in state_dict
    if "log_softmax.mlp.layer0.weight" not in state_dict:
        # Add it explicitly — clone to avoid safetensors shared-memory error
        state_dict["log_softmax.mlp.layer0.weight"] = state_dict["transf_decoder._embedding.token_embedding.weight"].clone()
        logger.info("Added tied output head weight to state_dict (cloned)")
    else:
        # Already present but may share memory — clone to avoid safetensors error
        state_dict["log_softmax.mlp.layer0.weight"] = state_dict["log_softmax.mlp.layer0.weight"].clone()
        logger.info("Cloned tied output head weight for safetensors compatibility")

    from safetensors.torch import save_file
    save_file(state_dict, os.path.join(OUTPUT_DIR, "model.safetensors"))
    model.config.save_pretrained(OUTPUT_DIR)

    # Copy tokenizer files
    for f_name in os.listdir(TOKENIZER_DIR):
        shutil.copy2(os.path.join(TOKENIZER_DIR, f_name), OUTPUT_DIR)

    # Copy modeling/processing files needed for trust_remote_code loading
    model_dir = "/workspace/.hf_home/hub/models--CohereLabs--cohere-transcribe-03-2026/snapshots/90cf6a1e8427d6ab5e0060f53c095c245a20da4e"
    for fname in ["modeling_cohere_asr.py", "processing_cohere_asr.py",
                   "configuration_cohere_asr.py", "tokenization_cohere_asr.py",
                   "preprocessor_config.json", "processor_config.json"]:
        src = os.path.join(model_dir, fname)
        if os.path.exists(src):
            shutil.copy2(src, OUTPUT_DIR)
            logger.info(f"  Copied {fname}")

    logger.info("Done!")
    logger.info(f"\nTo use in training, update config.yaml:")
    logger.info(f'  model_name: "{OUTPUT_DIR}"')
    logger.info(f"  max_tokens: 384  # Indic sequences are now ~10x shorter, 384 for safety")


def _set_module(model, name, new_module):
    parts = name.split(".")
    parent = model
    for part in parts[:-1]:
        parent = getattr(parent, part)
    setattr(parent, parts[-1], new_module)


if __name__ == "__main__":
    main()
