import sys

from google import genai

from ddtrace import config
from ddtrace.contrib.internal.google_genai._utils import GoogleGenAIAsyncStreamHandler
from ddtrace.contrib.internal.google_genai._utils import GoogleGenAIStreamHandler
from ddtrace.contrib.internal.trace_utils import unwrap
from ddtrace.contrib.internal.trace_utils import wrap
from ddtrace.llmobs._integrations import GoogleGenAIIntegration
from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream
from ddtrace.llmobs._integrations.google_utils import extract_provider_and_model_name


config._add("google_genai", {})


def _supported_versions():
    return {"google.genai": ">=1.21.1"}


def get_version() -> str:
    return getattr(genai, "__version__", "")


def traced_generate(func, instance, args, kwargs):
    integration = genai._datadog_integration
    provider_name, model_name = extract_provider_and_model_name(kwargs=kwargs)
    with integration.trace(
        "%s.%s" % (instance.__class__.__name__, func.__name__),
        provider=provider_name,
        model=model_name,
        submit_to_llmobs=True,
    ) as span:
        resp = None
        try:
            resp = func(*args, **kwargs)
            return resp
        finally:
            integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=resp, operation="llm")


async def traced_async_generate(func, instance, args, kwargs):
    integration = genai._datadog_integration
    provider_name, model_name = extract_provider_and_model_name(kwargs=kwargs)
    with integration.trace(
        "%s.%s" % (instance.__class__.__name__, func.__name__),
        provider=provider_name,
        model=model_name,
        submit_to_llmobs=True,
    ) as span:
        resp = None
        try:
            resp = await func(*args, **kwargs)
            return resp
        finally:
            integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=resp, operation="llm")


def traced_generate_stream(func, instance, args, kwargs):
    integration = genai._datadog_integration
    provider_name, model_name = extract_provider_and_model_name(kwargs=kwargs)
    span = integration.trace(
        "%s.%s" % (instance.__class__.__name__, func.__name__),
        provider=provider_name,
        model=model_name,
        submit_to_llmobs=True,
    )
    try:
        resp = func(*args, **kwargs)
        return make_traced_stream(resp, GoogleGenAIStreamHandler(integration, span, args, kwargs))
    except Exception:
        span.set_exc_info(*sys.exc_info())
        integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=None, operation="llm")
        span.finish()
        raise


async def traced_async_generate_stream(func, instance, args, kwargs):
    integration = genai._datadog_integration
    provider_name, model_name = extract_provider_and_model_name(kwargs=kwargs)
    span = integration.trace(
        "%s.%s" % (instance.__class__.__name__, func.__name__),
        provider=provider_name,
        model=model_name,
        submit_to_llmobs=True,
    )
    try:
        resp = await func(*args, **kwargs)
        return make_traced_stream(resp, GoogleGenAIAsyncStreamHandler(integration, span, args, kwargs))
    except Exception:
        span.set_exc_info(*sys.exc_info())
        integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=None, operation="llm")
        span.finish()
        raise


def traced_embed_content(func, instance, args, kwargs):
    integration = genai._datadog_integration
    provider_name, model_name = extract_provider_and_model_name(kwargs=kwargs)
    with integration.trace(
        "%s.%s" % (instance.__class__.__name__, func.__name__),
        provider=provider_name,
        model=model_name,
        submit_to_llmobs=True,
    ) as span:
        resp = None
        try:
            resp = func(*args, **kwargs)
            return resp
        finally:
            integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=resp, operation="embedding")


async def traced_async_embed_content(func, instance, args, kwargs):
    integration = genai._datadog_integration
    provider_name, model_name = extract_provider_and_model_name(kwargs=kwargs)
    with integration.trace(
        "%s.%s" % (instance.__class__.__name__, func.__name__),
        provider=provider_name,
        model=model_name,
        submit_to_llmobs=True,
    ) as span:
        resp = None
        try:
            resp = await func(*args, **kwargs)
            return resp
        finally:
            integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=resp, operation="embedding")


def patch():
    if getattr(genai, "_datadog_patch", False):
        return

    genai._datadog_patch = True
    integration = GoogleGenAIIntegration(integration_config=config.google_genai)
    genai._datadog_integration = integration

    wrap("google.genai", "models.Models.generate_content", traced_generate)
    wrap("google.genai", "models.Models.generate_content_stream", traced_generate_stream)
    wrap("google.genai", "models.AsyncModels.generate_content", traced_async_generate)
    wrap("google.genai", "models.AsyncModels.generate_content_stream", traced_async_generate_stream)
    wrap("google.genai", "models.Models.embed_content", traced_embed_content)
    wrap("google.genai", "models.AsyncModels.embed_content", traced_async_embed_content)


def unpatch():
    if not getattr(genai, "_datadog_patch", False):
        return

    genai._datadog_patch = False

    unwrap(genai.models.Models, "generate_content")
    unwrap(genai.models.Models, "generate_content_stream")
    unwrap(genai.models.AsyncModels, "generate_content")
    unwrap(genai.models.AsyncModels, "generate_content_stream")
    unwrap(genai.models.Models, "embed_content")
    unwrap(genai.models.AsyncModels, "embed_content")

    delattr(genai, "_datadog_integration")
