# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TypeAlias

from pydantic import Field

from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
    ChatRequestMixin,
    CompletionRequestMixin,
    EmbedRequestMixin,
    PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid

logger = init_logger(__name__)


def _get_max_total_output_tokens(
    model_config: ModelConfig,
) -> tuple[int | None, int]:
    max_total_tokens = model_config.max_model_len
    pooler_config = model_config.pooler_config

    if pooler_config is None:
        return max_total_tokens, 0

    if pooler_config.enable_chunked_processing:
        return None, 0

    max_embed_len = pooler_config.max_embed_len or max_total_tokens
    max_output_tokens = max_total_tokens - max_embed_len
    return max_total_tokens, max_output_tokens


class EmbeddingCompletionRequest(
    PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        (
            max_total_tokens,
            max_output_tokens,
        ) = _get_max_total_output_tokens(model_config)

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=max_output_tokens,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=self.add_special_tokens,
            max_total_tokens_param="max_model_len",
            max_output_tokens_param="max_model_len - max_embed_len",
        )

    def to_pooling_params(self):
        if self.normalize is not None:
            logger.warning_once(
                "`normalize` is deprecated and will be removed in v0.17. "
                "Please pass `use_activation` instead."
            )
            self.use_activation = self.normalize

        return PoolingParams(
            task="embed",
            dimensions=self.dimensions,
            use_activation=self.use_activation,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
        )


class EmbeddingChatRequest(
    PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        (
            max_total_tokens,
            max_output_tokens,
        ) = _get_max_total_output_tokens(model_config)

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=max_output_tokens,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=self.add_special_tokens,
            max_total_tokens_param="max_model_len",
            max_output_tokens_param="max_model_len - max_embed_len",
        )

    def to_pooling_params(self):
        if self.normalize is not None:
            logger.warning_once(
                "`normalize` is deprecated and will be removed in v0.17. "
                "Please pass `use_activation` instead."
            )
            self.use_activation = self.normalize

        return PoolingParams(
            task="embed",
            dimensions=self.dimensions,
            use_activation=self.use_activation,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
        )


EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest


class EmbeddingResponseData(OpenAIBaseModel):
    index: int
    object: str = "embedding"
    embedding: list[float] | str


class EmbeddingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[EmbeddingResponseData]
    usage: UsageInfo


class EmbeddingBytesResponse(OpenAIBaseModel):
    content: list[bytes]
    headers: dict[str, str] | None = None
    media_type: str = "application/octet-stream"
