from ddtrace.llmobs._constants import CACHE_READ_INPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import CACHE_WRITE_INPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import INPUT_TOKENS_METRIC_KEY


_MODEL_TYPE_IDENTIFIERS = (
    "foundation-model/",
    "custom-model/",
    "provisioned-model/",
    "imported-model/",
    "prompt/",
    "endpoint/",
    "inference-profile/",
    "default-prompt-router/",
)

_AI21 = "ai21"
_AMAZON = "amazon"
_ANTHROPIC = "anthropic"
_COHERE = "cohere"
_META = "meta"
_STABILITY = "stability"
_MODEL_PROVIDERS = (_AI21, _AMAZON, _ANTHROPIC, _COHERE, _META, _STABILITY)


def _fallback_provider(original_model_id: str) -> str:
    """If the original model ID contains a known model provider, return it, otherwise fallback to custom."""
    for provider in _MODEL_PROVIDERS:
        if provider in original_model_id:
            return provider
    return "custom"


def parse_model_id(model_id: str):
    """Best effort to extract and return the model provider and model name from the bedrock model ID.
    model_id can be a 1/2 period-separated string or a full AWS ARN, based on the following formats:
    1. Base model: "{model_provider}.{model_name}"
    2. Cross-region model: "{region}.{model_provider}.{model_name}"
    3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:"
        a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}"
        b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}"
        c. Provisioned model: ARN prefix + "provisioned-model/{model-id}"
        d. Imported model: ARN prefix + "imported-module/{model-id}"
        e. Prompt management: ARN prefix + "prompt/{prompt-id}"
        f. Sagemaker: ARN prefix + "endpoint/{model-id}"
        g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}"
        h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}"
    If model provider cannot be inferred from the model_id formatting, then default to "custom"
    """
    original_model_id = model_id

    if not model_id.startswith("arn:aws"):
        model_meta = model_id.split(".")
        if len(model_meta) < 2:
            return _fallback_provider(original_model_id), model_meta[0]
        return model_meta[-2], model_meta[-1]
    for identifier in _MODEL_TYPE_IDENTIFIERS:
        if identifier not in model_id:
            continue
        model_id = model_id.rsplit(identifier, 1)[-1]
        if identifier in ("foundation-model/", "custom-model/", "inference-profile/"):
            model_meta = model_id.split(".")
            if len(model_meta) < 2:
                return _fallback_provider(original_model_id), model_id
            return model_meta[-2], model_meta[-1]
        return _fallback_provider(original_model_id), model_id

    return _fallback_provider(original_model_id), "custom"


def normalize_input_tokens(usage_metrics: dict) -> None:
    """
    `input_tokens` in bedrock's response usage metadata is the number of non-cached tokens. We normalize it to mean
    the total tokens sent to the model to be consistent with other model providers.

    Args:
        usage_metrics: Dictionary containing token usage metrics that will be modified in-place
    """
    if CACHE_READ_INPUT_TOKENS_METRIC_KEY in usage_metrics or CACHE_WRITE_INPUT_TOKENS_METRIC_KEY in usage_metrics:
        input_tokens = usage_metrics.get(INPUT_TOKENS_METRIC_KEY, 0)
        cache_read_tokens = usage_metrics.get(CACHE_READ_INPUT_TOKENS_METRIC_KEY, 0)
        cache_write_tokens = usage_metrics.get(CACHE_WRITE_INPUT_TOKENS_METRIC_KEY, 0)
        usage_metrics[INPUT_TOKENS_METRIC_KEY] = input_tokens + cache_read_tokens + cache_write_tokens
