"""A simple GraphQL client for sending queries and mutations.

Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
The only substantial change is to reuse a requests.Session object.
"""

from __future__ import annotations

from typing import Any, Callable

from wandb_gql.transport.http import HTTPTransport
from wandb_graphql.execution import ExecutionResult
from wandb_graphql.language import ast
from wandb_graphql.language.printer import print_ast

from wandb._analytics import tracked_func


class GraphQLSession(HTTPTransport):
    def __init__(
        self,
        url: str,
        auth: tuple[str, str] | Callable | None = None,
        use_json: bool = False,
        timeout: int | float | None = None,
        proxies: dict[str, str] | None = None,
        **kwargs: Any,
    ) -> None:
        """Setup a session for sending GraphQL queries and mutations.

        Args:
            url (str): The GraphQL URL
            auth (tuple or callable): Auth tuple or callable for Basic/Digest/Custom HTTP Auth
            use_json (bool): Send request body as JSON instead of form-urlencoded
            timeout (int, float): Specifies a default timeout for requests (Default: None)
        """
        import requests

        super().__init__(url, **kwargs)
        self.session = requests.Session()
        if proxies:
            self.session.proxies.update(proxies)
        self.session.auth = auth
        self.default_timeout = timeout
        self.use_json = use_json

    def execute(
        self,
        document: ast.Node,
        variable_values: dict[str, Any] | None = None,
        timeout: int | float | None = None,
    ) -> ExecutionResult:
        query_str = print_ast(document)
        payload = {"query": query_str, "variables": variable_values or {}}

        data_key = "json" if self.use_json else "data"

        headers = self.headers.copy() if self.headers else {}

        # If we're tracking a calling python function, include it in the headers
        if func_info := tracked_func():
            headers.update(func_info.to_headers())

        post_args = {
            "headers": headers or None,
            "cookies": self.cookies,
            "timeout": timeout or self.default_timeout,
            data_key: payload,
        }
        request = self.session.post(self.url, **post_args)
        request.raise_for_status()

        result = request.json()
        data, errors = result.get("data"), result.get("errors")
        if data is None and errors is None:
            raise RuntimeError(f"Received non-compatible response: {result}")
        return ExecutionResult(data=data, errors=errors)
