from __future__ import annotations

import logging
from collections.abc import Sequence
from datetime import datetime
from typing import Any

import wandb
from wandb.sdk.integration_utils.auto_logging import Response
from wandb.sdk.lib.runid import generate_id

logger = logging.getLogger(__name__)


def subset_dict(
    original_dict: dict[str, Any], keys_subset: Sequence[str]
) -> dict[str, Any]:
    """Create a subset of a dictionary using a subset of keys.

    :param original_dict: The original dictionary.
    :param keys_subset: The subset of keys to extract.
    :return: A dictionary containing only the specified keys.
    """
    return {key: original_dict[key] for key in keys_subset if key in original_dict}


def reorder_and_convert_dict_list_to_table(
    data: list[dict[str, Any]], order: list[str]
) -> tuple[list[str], list[list[Any]]]:
    """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.

    :param data: A list of dictionaries.
    :param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
    :return: A pair of column names and corresponding values.
    """
    final_columns = []
    keys_present = set()

    # First, add all ordered keys to the final columns
    for key in order:
        if key not in keys_present:
            final_columns.append(key)
            keys_present.add(key)

    # Then, add any keys present in the dictionaries but not in the order
    for d in data:
        for key in d:
            if key not in keys_present:
                final_columns.append(key)
                keys_present.add(key)

    # Then, construct the table of values
    values = []
    for d in data:
        row = []
        for key in final_columns:
            row.append(d.get(key, None))
        values.append(row)

    return final_columns, values


def flatten_dict(
    dictionary: dict[str, Any], parent_key: str = "", sep: str = "-"
) -> dict[str, Any]:
    """Flatten a nested dictionary, joining keys using a specified separator.

    :param dictionary: The dictionary to flatten.
    :param parent_key: The base key to prepend to each key.
    :param sep: The separator to use when joining keys.
    :return: A flattened dictionary.
    """
    flattened_dict = {}
    for key, value in dictionary.items():
        new_key = f"{parent_key}{sep}{key}" if parent_key else key
        if isinstance(value, dict):
            flattened_dict.update(flatten_dict(value, new_key, sep=sep))
        else:
            flattened_dict[new_key] = value
    return flattened_dict


def collect_common_keys(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[Any]]:
    """Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.

    :param list_of_dicts: The list of dictionaries to inspect.
    :return: A dictionary with each common key and its corresponding list of values.
    """
    common_keys = set.intersection(*map(set, list_of_dicts))
    common_dict = {key: [] for key in common_keys}
    for d in list_of_dicts:
        for key in common_keys:
            common_dict[key].append(d[key])
    return common_dict


class CohereRequestResponseResolver:
    """Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""

    def __call__(
        self,
        args: Sequence[Any],
        kwargs: dict[str, Any],
        response: Response,
        start_time: float,
        time_elapsed: float,
    ) -> dict[str, Any] | None:
        """Process the response from the Cohere API and convert it to a dictionary that can be logged.

        :param args: The arguments of the original function.
        :param kwargs: The keyword arguments of the original function.
        :param response: The response from the Cohere API.
        :param start_time: The start time of the request.
        :param time_elapsed: The time elapsed for the request.
        :return: A dictionary containing the parsed response and timing information.
        """
        try:
            # Each of the different endpoints map to one specific response type
            # We want to 'type check' the response without directly importing the packages type
            # It may make more sense to pass the invoked symbol from the AutologAPI instead
            response_type = str(type(response)).split("'")[1].split(".")[-1]

            # Initialize parsed_response to None to handle the case where the response type is unsupported
            parsed_response = None
            if response_type == "Generations":
                parsed_response = self._resolve_generate_response(response)
                # TODO: Remove hard-coded default model name
                table_column_order = [
                    "start_time",
                    "query_id",
                    "model",
                    "prompt",
                    "text",
                    "token_likelihoods",
                    "likelihood",
                    "time_elapsed_(seconds)",
                    "end_time",
                ]
                default_model = "command"
            elif response_type == "Chat":
                parsed_response = self._resolve_chat_response(response)
                table_column_order = [
                    "start_time",
                    "query_id",
                    "model",
                    "conversation_id",
                    "response_id",
                    "query",
                    "text",
                    "prompt",
                    "preamble",
                    "chat_history",
                    "chatlog",
                    "time_elapsed_(seconds)",
                    "end_time",
                ]
                default_model = "command"
            elif response_type == "Classifications":
                parsed_response = self._resolve_classify_response(response)
                kwargs = self._resolve_classify_kwargs(kwargs)
                table_column_order = [
                    "start_time",
                    "query_id",
                    "model",
                    "id",
                    "input",
                    "prediction",
                    "confidence",
                    "time_elapsed_(seconds)",
                    "end_time",
                ]
                default_model = "embed-english-v2.0"
            elif response_type == "SummarizeResponse":
                parsed_response = self._resolve_summarize_response(response)
                table_column_order = [
                    "start_time",
                    "query_id",
                    "model",
                    "response_id",
                    "text",
                    "additional_command",
                    "summary",
                    "time_elapsed_(seconds)",
                    "end_time",
                    "length",
                    "format",
                ]
                default_model = "summarize-xlarge"
            elif response_type == "Reranking":
                parsed_response = self._resolve_rerank_response(response)
                table_column_order = [
                    "start_time",
                    "query_id",
                    "model",
                    "id",
                    "query",
                    "top_n",
                    # This is a nested dict key that got flattened
                    "document-text",
                    "relevance_score",
                    "index",
                    "time_elapsed_(seconds)",
                    "end_time",
                ]
                default_model = "rerank-english-v2.0"
            else:
                logger.info(f"Unsupported Cohere response object: {response}")

            return self._resolve(
                args,
                kwargs,
                parsed_response,
                start_time,
                time_elapsed,
                response_type,
                table_column_order,
                default_model,
            )
        except Exception as e:
            logger.warning(f"Failed to resolve request/response: {e}")
        return None

    # These helper functions process the response from different endpoints of the Cohere API.
    # Since the response objects for different endpoints have different structures,
    # we need different logic to process them.

    def _resolve_generate_response(self, response: Response) -> list[dict[str, Any]]:
        return_list = []
        for _response in response:
            # Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
            _response_dict = _response._visualize_helper()
            try:
                _response_dict["token_likelihoods"] = wandb.Html(
                    _response_dict["token_likelihoods"]
                )
            except (KeyError, ValueError):
                pass
            return_list.append(_response_dict)

        return return_list

    def _resolve_chat_response(self, response: Response) -> list[dict[str, Any]]:
        return [
            subset_dict(
                response.__dict__,
                [
                    "response_id",
                    "generation_id",
                    "query",
                    "text",
                    "conversation_id",
                    "prompt",
                    "chatlog",
                    "preamble",
                ],
            )
        ]

    def _resolve_classify_response(self, response: Response) -> list[dict[str, Any]]:
        # The labels key is a dict returning the scores for the classification probability for each label provided
        # We flatten this nested dict for ease of consumption in the wandb UI
        return [flatten_dict(_response.__dict__) for _response in response]

    def _resolve_classify_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
        # Example texts look strange when rendered in Wandb UI as it is a list of text and label
        # We extract each value into its own column
        example_texts = []
        example_labels = []
        for example in kwargs["examples"]:
            example_texts.append(example.text)
            example_labels.append(example.label)
        kwargs.pop("examples")
        kwargs["example_texts"] = example_texts
        kwargs["example_labels"] = example_labels
        return kwargs

    def _resolve_summarize_response(self, response: Response) -> list[dict[str, Any]]:
        return [{"response_id": response.id, "summary": response.summary}]

    def _resolve_rerank_response(self, response: Response) -> list[dict[str, Any]]:
        # The documents key contains a dict containing the content of the document which is at least "text"
        # We flatten this nested dict for ease of consumption in the wandb UI
        flattened_response_dicts = [
            flatten_dict(_response.__dict__) for _response in response
        ]
        # ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
        # As opposed to each row being one of the top_n responses
        return_dict = collect_common_keys(flattened_response_dicts)
        return_dict["id"] = response.id
        return [return_dict]

    def _resolve(
        self,
        args: Sequence[Any],
        kwargs: dict[str, Any],
        parsed_response: list[dict[str, Any]],
        start_time: float,
        time_elapsed: float,
        response_type: str,
        table_column_order: list[str],
        default_model: str,
    ) -> dict[str, Any]:
        """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.

        :param args: The arguments passed to the API client.
        :param kwargs: The keyword arguments passed to the API client.
        :param parsed_response: The parsed response from the API.
        :param start_time: The start time of the API request.
        :param time_elapsed: The time elapsed during the API request.
        :param response_type: The type of the API response.
        :param table_column_order: The desired order of columns in the resulting table.
        :param default_model: The default model to use if not specified in the response.
        :return: A dictionary containing the formatted response.
        """
        # Args[0] is the client object where we can grab specific metadata about the underlying API status
        query_id = generate_id(length=16)
        parsed_args = subset_dict(
            args[0].__dict__,
            ["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
        )

        start_time_dt = datetime.fromtimestamp(start_time)
        end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)

        timings = {
            "start_time": start_time_dt,
            "end_time": end_time_dt,
            "time_elapsed_(seconds)": time_elapsed,
        }

        packed_data = []
        for _parsed_response in parsed_response:
            _packed_dict = {
                "query_id": query_id,
                **kwargs,
                **_parsed_response,
                **timings,
                **parsed_args,
            }
            if "model" not in _packed_dict:
                _packed_dict["model"] = default_model
            packed_data.append(_packed_dict)

        columns, data = reorder_and_convert_dict_list_to_table(
            packed_data, table_column_order
        )

        request_response_table = wandb.Table(data=data, columns=columns)

        return {f"{response_type}": request_response_table}
