# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TypeAlias

from pydantic import BaseModel, Field

from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
    ClassifyRequestMixin,
    PoolingBasicRequestMixin,
)
from vllm.entrypoints.pooling.score.utils import (
    ScoreContentPartParam,
    ScoreInput,
    ScoreInputs,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid


class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            max_output_tokens=0,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            do_lower_case=encoder_config.get("do_lower_case", False),
            max_total_tokens_param="max_model_len",
        )

    def to_pooling_params(self, task: PoolingTask = "score"):
        return PoolingParams(
            task=task,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            use_activation=self.use_activation,
        )


class ScoreDataRequest(ScoreRequestMixin):
    data_1: ScoreInputs
    data_2: ScoreInputs


class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
    queries: ScoreInputs
    documents: ScoreInputs

    @property
    def data_1(self):
        return self.queries

    @property
    def data_2(self):
        return self.documents


class ScoreQueriesItemsRequest(ScoreRequestMixin):
    queries: ScoreInputs
    items: ScoreInputs

    @property
    def data_1(self):
        return self.queries

    @property
    def data_2(self):
        return self.items


class ScoreTextRequest(ScoreRequestMixin):
    text_1: ScoreInputs
    text_2: ScoreInputs

    @property
    def data_1(self):
        return self.text_1

    @property
    def data_2(self):
        return self.text_2


ScoreRequest: TypeAlias = (
    ScoreQueriesDocumentsRequest
    | ScoreQueriesItemsRequest
    | ScoreDataRequest
    | ScoreTextRequest
)


class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
    query: ScoreInput
    documents: ScoreInputs
    top_n: int = Field(default_factory=lambda: 0)

    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            max_output_tokens=0,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            do_lower_case=encoder_config.get("do_lower_case", False),
            max_total_tokens_param="max_model_len",
        )

    def to_pooling_params(self, task: PoolingTask = "score"):
        return PoolingParams(
            task=task,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            use_activation=self.use_activation,
        )


class RerankDocument(BaseModel):
    text: str | None = None
    multi_modal: list[ScoreContentPartParam] | None = None


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    prompt_tokens: int
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
    results: list[RerankResult]


class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
    score: float


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ScoreResponseData]
    usage: UsageInfo
