"""Centralized logging configuration for Maya server."""

import os
import sys
import json
import logging
import warnings
import traceback
from datetime import datetime, timezone
from loguru import logger

from utils import constants


LOG_LEVEL_PREFIXES = {
    "DEBUG": "DEBUG",
    "INFO": "INFO",
    "WARNING": "WARN",
    "ERROR": "ERROR",
    "CRITICAL": "CRITICAL",
    "EXCEPTION": "ERROR",
}


def log_filter(record):
    # Check message content filters first, even for ERROR level logs
    msg = record["message"]
    if "Unable to send message" in msg and "before joining" in msg:
        return False  # Normal race condition - message sent before Daily join completes

    # Also check exception message if present
    if record.get("exception"):
        exc_info = record["exception"]
        if exc_info and exc_info.value:
            exc_msg = str(exc_info.value)
            if "Unable to send message" in exc_msg and "before joining" in exc_msg:
                return False

    # Daily SDK noise (filter BEFORE error passthrough — these are harmless at any level)
    if "Failed to send LogLine" in msg:
        return False
    if "ResponseCanceled" in msg or "ResponseCancelled" in msg:
        return False
    if "Metrics failed to get snapshot" in msg:
        return False
    if "daily_core" in msg:
        return False
    if "CallManagerEvent" in msg:
        return False
    if "MediasoupManager" in msg:
        return False

    # Allow all other ERROR/CRITICAL/EXCEPTION logs
    if record["level"].name in ("ERROR", "CRITICAL", "EXCEPTION"):
        return True

    name = record.get("name", "")
    if name and name.startswith(("botocore", "boto3", "urllib3", "s3transfer")):
        return False
    if name and name.startswith("openai"):
        return False
    if name and name.startswith("httpx"):
        return False

    if "Empty audio frame received for STT service" in msg:
        return False
    if "Ignoring not RTVI message" in msg:
        return False
    if "User stopped speaking but no new aggregation received" in msg:
        return False
    if "Unclosed client session" in msg or "Unclosed connector" in msg:
        return False
    if "Task was destroyed but it is pending" in msg:
        return False
    if "Loading JSON file:" in msg:
        return False
    if "Changing event name" in msg:
        return False
    if "Request options:" in msg:
        return False
    return True


CONSOLE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}"


def datadog_sink(message):
    record = message.record
    level = record["level"].name
    
    timestamp = record["time"].astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
    
    log_entry = {
        "timestamp": timestamp,
        "level": level,
        "logger": record["name"],
        "function": record["function"],
        "line": record["line"],
        "message": record["message"],
        "service": "maya-pipecat",
        "env": os.getenv("DD_ENV", os.getenv("ENVIRONMENT", "development")),
        "version": os.getenv("DD_VERSION", getattr(constants, 'APP_VERSION', "1.0.0")),
    }
    
    # Add all extra fields from logger.bind()
    for key, value in record["extra"].items():
        if value and key not in ("_maya_configured",):  # Skip internal keys
            log_entry[key] = value
    
    if level in ("ERROR", "CRITICAL", "EXCEPTION"):
        log_entry["status"] = "error"
        log_entry["error.kind"] = level
        
        if record.get("exception"):
            exc_info = record["exception"]
            if exc_info:
                log_entry["error.type"] = exc_info.type.__name__ if exc_info.type else "Unknown"
                log_entry["error.message"] = str(exc_info.value) if exc_info.value else ""
                log_entry["error.stack"] = "".join(traceback.format_exception(exc_info.type, exc_info.value, exc_info.traceback))
    else:
        log_entry["status"] = "info"
    
    return json.dumps(log_entry, ensure_ascii=False) + "\n"


class DatadogFileSink:
    def __init__(self, filepath):
        self.filepath = filepath
        self._file = None
    
    def _ensure_file(self):
        if self._file is None:
            os.makedirs(os.path.dirname(self.filepath), exist_ok=True)
            self._file = open(self.filepath, "a", encoding="utf-8")
    
    def write(self, message):
        json_line = datadog_sink(message)
        self._ensure_file()
        self._file.write(json_line)
        self._file.flush()
    
    def close(self):
        if self._file:
            self._file.close()


class DatadogStdoutSink:
    """Writes JSON-formatted logs to stdout for DD DaemonSet collection.
    DD auto-parses level, message, service from the JSON structure."""
    
    def write(self, message):
        json_line = datadog_sink(message)
        sys.stdout.write(json_line)
        sys.stdout.flush()


class InterceptHandler(logging.Handler):
    def emit(self, record):
        try:
            level = logger.level(record.levelname).name
        except ValueError:
            level = record.levelno
        
        frame, depth = sys._getframe(6), 6
        while frame and frame.f_code.co_filename == logging.__file__:
            frame = frame.f_back
            depth += 1
        
        logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())


def log_startup_banner():
    banner = """
╔══════════════════════════════════════════════════════════════════╗
║                     MAYA PIPECAT SERVER                          ║
║                   Starting Application...                        ║
╚══════════════════════════════════════════════════════════════════╝
"""
    print(banner)
    
    logger.info("=" * 60)
    logger.info("MAYA PIPECAT SERVER - STARTUP INITIATED")
    logger.info("=" * 60)
    logger.info(f"STARTUP_CONFIG: env={os.getenv('ENVIRONMENT', 'development')} host={constants.SERVER_HOST} port={constants.SERVER_PORT}")
    logger.info(f"STARTUP_CONFIG: log_dir={constants.LOG_DIR} log_file={constants.LOG_FILE_PATH}")
    
    dd_enabled = "yes" if os.getenv("DD_API_KEY") else "no"
    logger.info(f"STARTUP_CONFIG: datadog_enabled={dd_enabled}")
    
    turso_configured = "yes" if constants.TURSO_DATABASE_URL else "no"
    logger.info(f"STARTUP_CONFIG: turso_configured={turso_configured}")
    
    logger.info("=" * 60)


def log_startup_complete(startup_time_ms: float = None):
    logger.info("=" * 60)
    logger.info("MAYA PIPECAT SERVER - STARTUP COMPLETE")
    if startup_time_ms:
        logger.info(f"STARTUP_TIME: {startup_time_ms:.2f}ms")
    logger.info("Server is ready to accept connections")
    logger.info("=" * 60)


def log_shutdown():
    logger.info("=" * 60)
    logger.info("MAYA PIPECAT SERVER - SHUTDOWN INITIATED")
    logger.info("=" * 60)


def configure_logging():
    if hasattr(logger, '_maya_configured'):
        return
    
    os.makedirs(constants.LOG_DIR, exist_ok=True)
    logger.remove()
    
    is_production = bool(os.getenv("DD_AGENT_HOST"))
    stdout_level = os.getenv("LOG_LEVEL", "INFO" if is_production else "DEBUG")
    
    if is_production:
        # Production: JSON to stdout (DD DaemonSet auto-parses level, message, service, etc.)
        json_stdout_sink = DatadogStdoutSink()
        logger.add(
            json_stdout_sink.write,
            level=stdout_level,
            filter=log_filter,
        )
    else:
        # Local dev: human-readable plain text to stdout
        logger.add(
            sys.stdout,
            level="DEBUG",
            format=CONSOLE_FORMAT,
            filter=log_filter,
            colorize=False,
            enqueue=False,
            backtrace=True,
            diagnose=True,
        )
    
    # Always write plain text to file for local debugging
    logger.add(
        constants.LOG_FILE_PATH,
        level="DEBUG",
        format=CONSOLE_FORMAT,
        filter=log_filter,
        rotation=constants.LOG_ROTATION_SIZE,
        retention=constants.LOG_RETENTION,
        compression="zip",
        enqueue=False,
        backtrace=True,
        diagnose=True,
    )
    
    warnings.filterwarnings("default")
    logging.captureWarnings(True)
    
    logging.basicConfig(handlers=[InterceptHandler()], level=logging.DEBUG, force=True)
    
    for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "pipecat", "aiohttp"]:
        logging.getLogger(logger_name).handlers = [InterceptHandler()]
        logging.getLogger(logger_name).propagate = False
    
    logger._maya_configured = True
    
    log_startup_banner()


def _configure_datadog_logging():
    try:
        datadog_path = "/var/log/maya-pipecat/datadog.log"
        os.makedirs("/var/log/maya-pipecat", exist_ok=True)
        sink = DatadogFileSink(datadog_path)
        logger.add(
            sink.write,
            level="INFO",
            filter=log_filter,
        )
        logger.info(f"DATADOG_LOGGING: enabled=true path={datadog_path}")
    except PermissionError:
        datadog_log_dir = os.path.join(constants.LOG_DIR, "datadog")
        datadog_path = os.path.join(datadog_log_dir, "datadog.log")
        os.makedirs(datadog_log_dir, exist_ok=True)
        sink = DatadogFileSink(datadog_path)
        logger.add(
            sink.write,
            level="INFO",
            filter=log_filter,
        )
        logger.info(f"DATADOG_LOGGING: enabled=true path={datadog_path} (fallback)")


def get_logger(name: str = None, session_id: str = None, device_id: str = None):
    bound_logger = logger
    if name:
        bound_logger = bound_logger.bind(name=name)
    if session_id:
        bound_logger = bound_logger.bind(session_id=session_id)
    if device_id:
        bound_logger = bound_logger.bind(device_id=device_id)
    return bound_logger


def log_error(message: str, error: Exception = None, **context):
    error_type = type(error).__name__ if error else "Unknown"
    error_msg = str(error) if error else ""
    
    full_message = f"ERROR: {message}"
    if error:
        full_message += f" | error_type={error_type} error_message={error_msg}"
    
    for key, value in context.items():
        full_message += f" {key}={value}"
    
    if error:
        logger.opt(exception=error).error(full_message)
    else:
        logger.error(full_message)


def log_session_event(event: str, session_id: str, device_id: str = None, **extra):
    parts = [f"SESSION_{event.upper()}: session_id={session_id}"]
    if device_id:
        parts.append(f"device_id={device_id}")
    for key, value in extra.items():
        parts.append(f"{key}={value}")
    logger.bind(session_id=session_id, device_id=device_id or "").info(" ".join(parts))


def log_api_request(endpoint: str, method: str, device_id: str = None, **extra):
    parts = [f"API_REQUEST: endpoint={endpoint} method={method}"]
    if device_id:
        parts.append(f"device_id={device_id}")
    for key, value in extra.items():
        parts.append(f"{key}={value}")
    logger.info(" ".join(parts))


def log_api_response(endpoint: str, status_code: int, duration_ms: float = None, **extra):
    parts = [f"API_RESPONSE: endpoint={endpoint} status={status_code}"]
    if duration_ms:
        parts.append(f"duration_ms={duration_ms:.2f}")
    for key, value in extra.items():
        parts.append(f"{key}={value}")
    logger.info(" ".join(parts))


def log_db_operation(operation: str, table: str = None, success: bool = True, duration_ms: float = None, **extra):
    status = "success" if success else "failed"
    parts = [f"DB_{operation.upper()}: status={status}"]
    if table:
        parts.append(f"table={table}")
    if duration_ms:
        parts.append(f"duration_ms={duration_ms:.2f}")
    for key, value in extra.items():
        parts.append(f"{key}={value}")
    
    if success:
        logger.debug(" ".join(parts))
    else:
        logger.error(" ".join(parts))


def log_external_call(service: str, operation: str, success: bool = True, duration_ms: float = None, **extra):
    status = "success" if success else "failed"
    parts = [f"EXTERNAL_{service.upper()}: operation={operation} status={status}"]
    if duration_ms:
        parts.append(f"duration_ms={duration_ms:.2f}")
    for key, value in extra.items():
        parts.append(f"{key}={value}")
    
    if success:
        logger.info(" ".join(parts))
    else:
        logger.error(" ".join(parts))
