import asyncio
import inspect
import os
import time
from pathlib import Path
from typing import (
    Any,
    Awaitable,
    Callable,
    Dict,
    List,
    NamedTuple,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

# Use pyarrow for cloud storage access
import pyarrow.fs as pa_fs
from pydantic import Field, field_validator

from ray.llm._internal.common.base_pydantic import BaseModelExtended
from ray.llm._internal.common.observability.logging import get_logger

T = TypeVar("T")

logger = get_logger(__name__)


def is_remote_path(path: str) -> bool:
    """Check if the path is a remote path.

    Args:
        path: The path to check.

    Returns:
        True if the path is a remote path, False otherwise.
    """
    return path.startswith("s3://") or path.startswith("gs://")


class ExtraFiles(BaseModelExtended):
    bucket_uri: str
    destination_path: str


class CloudMirrorConfig(BaseModelExtended):
    """Unified mirror config for cloud storage (S3 or GCS).

    Args:
        bucket_uri: URI of the bucket (s3:// or gs://)
        extra_files: Additional files to download
    """

    bucket_uri: Optional[str] = None
    extra_files: List[ExtraFiles] = Field(default_factory=list)

    @field_validator("bucket_uri")
    @classmethod
    def check_uri_format(cls, value):
        if value is None:
            return value

        if not is_remote_path(value):
            raise ValueError(
                f'Got invalid value "{value}" for bucket_uri. '
                'Expected a URI that starts with "s3://" or "gs://".'
            )
        return value

    @property
    def storage_type(self) -> str:
        """Returns the storage type ('s3' or 'gcs') based on the URI prefix."""
        if self.bucket_uri is None:
            return None
        elif self.bucket_uri.startswith("s3://"):
            return "s3"
        elif self.bucket_uri.startswith("gs://"):
            return "gcs"
        return None


class LoraMirrorConfig(BaseModelExtended):
    lora_model_id: str
    bucket_uri: str
    max_total_tokens: Optional[int]
    sync_args: Optional[List[str]] = None

    @field_validator("bucket_uri")
    @classmethod
    def check_uri_format(cls, value):
        if value is None:
            return value

        if not is_remote_path(value):
            raise ValueError(
                f'Got invalid value "{value}" for bucket_uri. '
                'Expected a URI that starts with "s3://" or "gs://".'
            )
        return value

    @property
    def _bucket_name_and_path(self) -> str:
        for prefix in ["s3://", "gs://"]:
            if self.bucket_uri.startswith(prefix):
                return self.bucket_uri[len(prefix) :]
        return self.bucket_uri

    @property
    def bucket_name(self) -> str:
        return self._bucket_name_and_path.split("/")[0]

    @property
    def bucket_path(self) -> str:
        return "/".join(self._bucket_name_and_path.split("/")[1:])


class CloudFileSystem:
    """A unified interface for cloud file system operations using PyArrow.

    This class provides a simple interface for common operations on cloud storage
    systems (S3, GCS) using PyArrow's filesystem interface.
    """

    @staticmethod
    def get_fs_and_path(object_uri: str) -> Tuple[pa_fs.FileSystem, str]:
        """Get the appropriate filesystem and path from a URI.

        Args:
            object_uri: URI of the file (s3:// or gs://)
                If URI contains 'anonymous@', anonymous access is used.
                Example: s3://anonymous@bucket/path

        Returns:
            Tuple of (filesystem, path)
        """
        anonymous = False
        # Check for anonymous access pattern
        # e.g. s3://anonymous@bucket/path
        if "@" in object_uri:
            parts = object_uri.split("@", 1)
            # Check if the first part ends with "anonymous"
            if parts[0].endswith("anonymous"):
                anonymous = True
                # Remove the anonymous@ part, keeping the scheme
                scheme = parts[0].split("://")[0]
                object_uri = f"{scheme}://{parts[1]}"

        if object_uri.startswith("s3://"):
            fs = pa_fs.S3FileSystem(anonymous=anonymous)
            path = object_uri[5:]  # Remove "s3://"
        elif object_uri.startswith("gs://"):
            fs = pa_fs.GcsFileSystem(anonymous=anonymous)
            path = object_uri[5:]  # Remove "gs://"
        else:
            raise ValueError(f"Unsupported URI scheme: {object_uri}")

        return fs, path

    @staticmethod
    def get_file(
        object_uri: str, decode_as_utf_8: bool = True
    ) -> Optional[Union[str, bytes]]:
        """Download a file from cloud storage into memory.

        Args:
            object_uri: URI of the file (s3:// or gs://)
            decode_as_utf_8: If True, decode the file as UTF-8

        Returns:
            File contents as string or bytes, or None if file doesn't exist
        """
        try:
            fs, path = CloudFileSystem.get_fs_and_path(object_uri)

            # Check if file exists
            if not fs.get_file_info(path).type == pa_fs.FileType.File:
                logger.info(f"URI {object_uri} does not exist.")
                return None

            # Read file
            with fs.open_input_file(path) as f:
                body = f.read()

            if decode_as_utf_8:
                body = body.decode("utf-8")
            return body
        except Exception as e:
            logger.info(f"Error reading {object_uri}: {e}")
            return None

    @staticmethod
    def list_subfolders(folder_uri: str) -> List[str]:
        """List the immediate subfolders in a cloud directory.

        Args:
            folder_uri: URI of the directory (s3:// or gs://)

        Returns:
            List of subfolder names (without trailing slashes)
        """
        # Ensure that the folder_uri has a trailing slash.
        folder_uri = f"{folder_uri.rstrip('/')}/"

        try:
            fs, path = CloudFileSystem.get_fs_and_path(folder_uri)

            # List directory contents
            file_infos = fs.get_file_info(pa_fs.FileSelector(path, recursive=False))

            # Filter for directories and extract subfolder names
            subfolders = []
            for file_info in file_infos:
                if file_info.type == pa_fs.FileType.Directory:
                    # Extract just the subfolder name without the full path
                    subfolder = os.path.basename(file_info.path.rstrip("/"))
                    subfolders.append(subfolder)

            return subfolders
        except Exception as e:
            logger.info(f"Error listing subfolders in {folder_uri}: {e}")
            return []

    @staticmethod
    def download_files(
        path: str,
        bucket_uri: str,
        substrings_to_include: Optional[List[str]] = None,
    ) -> None:
        """Download files from cloud storage to a local directory.

        Args:
            path: Local directory where files will be downloaded
            bucket_uri: URI of cloud directory
            substrings_to_include: Only include files containing these substrings
        """
        try:
            fs, source_path = CloudFileSystem.get_fs_and_path(bucket_uri)

            # Ensure the destination directory exists
            os.makedirs(path, exist_ok=True)

            # List all files in the bucket
            file_selector = pa_fs.FileSelector(source_path, recursive=True)
            file_infos = fs.get_file_info(file_selector)

            # Download each file
            for file_info in file_infos:
                if file_info.type != pa_fs.FileType.File:
                    continue

                # Get relative path from source prefix
                rel_path = file_info.path[len(source_path) :].lstrip("/")

                # Check if file matches substring filters
                if substrings_to_include:
                    if not any(
                        substring in rel_path for substring in substrings_to_include
                    ):
                        continue

                # Create destination directory if needed
                if "/" in rel_path:
                    dest_dir = os.path.join(path, os.path.dirname(rel_path))
                    os.makedirs(dest_dir, exist_ok=True)

                # Download the file
                dest_path = os.path.join(path, rel_path)
                with fs.open_input_file(file_info.path) as source_file:
                    with open(dest_path, "wb") as dest_file:
                        dest_file.write(source_file.read())

        except Exception as e:
            logger.exception(f"Error downloading files from {bucket_uri}: {e}")
            raise

    @staticmethod
    def download_model(
        destination_path: str, bucket_uri: str, tokenizer_only: bool
    ) -> None:
        """Download a model from cloud storage.

        This downloads a model in the format expected by the HuggingFace transformers
        library.

        Args:
            destination_path: Path where the model will be stored
            bucket_uri: URI of the cloud directory containing the model
            tokenizer_only: If True, only download tokenizer-related files
        """
        try:
            fs, source_path = CloudFileSystem.get_fs_and_path(bucket_uri)

            # Check for hash file
            hash_path = os.path.join(source_path, "hash")
            hash_info = fs.get_file_info(hash_path)

            if hash_info.type == pa_fs.FileType.File:
                # Download and read hash file
                with fs.open_input_file(hash_path) as f:
                    f_hash = f.read().decode("utf-8").strip()
                logger.info(
                    f"Detected hash file in bucket {bucket_uri}. "
                    f"Using {f_hash} as the hash."
                )
            else:
                f_hash = "0000000000000000000000000000000000000000"
                logger.info(
                    f"Hash file does not exist in bucket {bucket_uri}. "
                    f"Using {f_hash} as the hash."
                )

            # Write hash to refs/main
            main_dir = os.path.join(destination_path, "refs")
            os.makedirs(main_dir, exist_ok=True)
            with open(os.path.join(main_dir, "main"), "w") as f:
                f.write(f_hash)

            # Create destination directory
            destination_dir = os.path.join(destination_path, "snapshots", f_hash)
            os.makedirs(destination_dir, exist_ok=True)

            logger.info(f'Downloading model files to directory "{destination_dir}".')

            # Download files
            tokenizer_file_substrings = (
                ["tokenizer", "config.json"] if tokenizer_only else []
            )
            CloudFileSystem.download_files(
                path=destination_dir,
                bucket_uri=bucket_uri,
                substrings_to_include=tokenizer_file_substrings,
            )

        except Exception as e:
            logger.exception(f"Error downloading model from {bucket_uri}: {e}")
            raise

    @staticmethod
    def upload_files(
        local_path: str,
        bucket_uri: str,
    ) -> None:
        """Upload files to cloud storage.

        Args:
            local_path: The local path of the files to upload.
            bucket_uri: The bucket uri to upload the files to, must start with `s3://` or `gs://`.
        """
        try:
            fs, dest_path = CloudFileSystem.get_fs_and_path(bucket_uri)

            pa_fs.copy_files(
                source=local_path,
                destination=dest_path,
                source_filesystem=pa_fs.LocalFileSystem(),
                destination_filesystem=fs,
            )
        except Exception as e:
            logger.exception(f"Error uploading files to {bucket_uri}: {e}")
            raise

    @staticmethod
    def upload_model(
        local_path: str,
        bucket_uri: str,
    ) -> None:
        """Upload a model to cloud storage.

        Args:
            local_path: The local path of the model.
            bucket_uri: The bucket uri to upload the model to, must start with `s3://` or `gs://`.
        """
        try:
            # If refs/main exists, upload as hash, and treat snapshots/<hash> as the model.
            # Otherwise, this is a custom model, we do not assume folder hierarchy.
            refs_main = Path(local_path, "refs", "main")
            if refs_main.exists():
                model_path = os.path.join(
                    local_path, "snapshots", refs_main.read_text().strip()
                )
                CloudFileSystem.upload_files(
                    local_path=model_path, bucket_uri=bucket_uri
                )
                CloudFileSystem.upload_files(
                    local_path=str(refs_main),
                    bucket_uri=os.path.join(bucket_uri, "hash"),
                )
            else:
                CloudFileSystem.upload_files(
                    local_path=local_path, bucket_uri=bucket_uri
                )
            logger.info(f"Uploaded model files to {bucket_uri}.")
        except Exception as e:
            logger.exception(f"Error uploading model to {bucket_uri}: {e}")
            raise


class _CacheEntry(NamedTuple):
    value: Any
    expire_time: Optional[float]


class CloudObjectCache:
    """A cache that works with both sync and async fetch functions.

    The purpose of this data structure is to cache the result of a function call
    usually used to fetch a value from a cloud object store.

    The idea is this:
    - Cloud operations are expensive
    - In LoRA specifically, we would fetch remote storage to download the model weights
    at each request.
    - If the same model is requested many times, we don't want to inflate the time to first token.
    - We control the cache via not only the least recently used eviction policy, but also
    by expiring cache entries after a certain time.
    - If the object is missing, we cache the missing status for a small duration while if
    the object exists, we cache the object for a longer duration.
    """

    def __init__(
        self,
        max_size: int,
        fetch_fn: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]],
        missing_expire_seconds: Optional[int] = None,
        exists_expire_seconds: Optional[int] = None,
        missing_object_value: Any = object(),
    ):
        """Initialize the cache.

        Args:
            max_size: Maximum number of items to store in cache
            fetch_fn: Function to fetch values (can be sync or async)
            missing_expire_seconds: How long to cache missing objects (None for no expiration)
            exists_expire_seconds: How long to cache existing objects (None for no expiration)
        """
        self._cache: Dict[str, _CacheEntry] = {}
        self._max_size = max_size
        self._fetch_fn = fetch_fn
        self._missing_expire_seconds = missing_expire_seconds
        self._exists_expire_seconds = exists_expire_seconds
        self._is_async = inspect.iscoroutinefunction(fetch_fn) or (
            callable(fetch_fn) and inspect.iscoroutinefunction(fetch_fn.__call__)
        )
        self._missing_object_value = missing_object_value
        # Lock for thread-safe cache access
        self._lock = asyncio.Lock()

    async def aget(self, key: str) -> Any:
        """Async get value from cache or fetch it if needed."""

        if not self._is_async:
            raise ValueError("Cannot use async get() with sync fetch function")

        async with self._lock:
            value, should_fetch = self._check_cache(key)
            if not should_fetch:
                return value

            # Fetch new value
            value = await self._fetch_fn(key)
            self._update_cache(key, value)
            return value

    def get(self, key: str) -> Any:
        """Sync get value from cache or fetch it if needed."""
        if self._is_async:
            raise ValueError("Cannot use sync get() with async fetch function")

        # For sync access, we use a simple check-then-act pattern
        # This is safe because sync functions are not used in async context
        value, should_fetch = self._check_cache(key)
        if not should_fetch:
            return value

        # Fetch new value
        value = self._fetch_fn(key)
        self._update_cache(key, value)
        return value

    def _check_cache(self, key: str) -> tuple[Any, bool]:
        """Check if key exists in cache and is valid.

        Returns:
            Tuple of (value, should_fetch)
            where should_fetch is True if we need to fetch a new value
        """
        now = time.monotonic()

        if key in self._cache:
            value, expire_time = self._cache[key]
            if expire_time is None or now < expire_time:
                return value, False

        return None, True

    def _update_cache(self, key: str, value: Any) -> None:
        """Update cache with new value."""
        now = time.monotonic()

        # Calculate expiration
        expire_time = None
        if (
            self._missing_expire_seconds is not None
            or self._exists_expire_seconds is not None
        ):
            if value is self._missing_object_value:
                expire_time = (
                    now + self._missing_expire_seconds
                    if self._missing_expire_seconds
                    else None
                )
            else:
                expire_time = (
                    now + self._exists_expire_seconds
                    if self._exists_expire_seconds
                    else None
                )

        # Enforce size limit by removing oldest entry if needed
        # This is an O(n) operation but it's fine since the cache size is usually small.
        if len(self._cache) >= self._max_size:
            oldest_key = min(
                self._cache, key=lambda k: self._cache[k].expire_time or float("inf")
            )
            del self._cache[oldest_key]

        self._cache[key] = _CacheEntry(value, expire_time)

    def __len__(self) -> int:
        return len(self._cache)


class CloudModelAccessor:
    """Unified accessor for models stored in cloud storage (S3 or GCS).

    Args:
        model_id: The model id to download or upload.
        mirror_config: The mirror config for the model.
    """

    def __init__(self, model_id: str, mirror_config: CloudMirrorConfig):
        self.model_id = model_id
        self.mirror_config = mirror_config

    def _get_lock_path(self, suffix: str = "") -> Path:
        return Path(
            "~", f"{self.model_id.replace('/', '--')}{suffix}.lock"
        ).expanduser()

    def _get_model_path(self) -> Path:
        if Path(self.model_id).exists():
            return Path(self.model_id)
        # Delayed import to avoid circular dependencies
        from transformers.utils.hub import TRANSFORMERS_CACHE

        return Path(
            TRANSFORMERS_CACHE, f"models--{self.model_id.replace('/', '--')}"
        ).expanduser()


def remote_object_cache(
    max_size: int,
    missing_expire_seconds: Optional[int] = None,
    exists_expire_seconds: Optional[int] = None,
    missing_object_value: Any = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
    """A decorator that provides async caching using CloudObjectCache.

    This is a direct replacement for the remote_object_cache/cachetools combination,
    using CloudObjectCache internally to maintain cache state.

    Args:
        max_size: Maximum number of items to store in cache
        missing_expire_seconds: How long to cache missing objects
        exists_expire_seconds: How long to cache existing objects
        missing_object_value: Value to use for missing objects
    """

    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        # Create a single cache instance for this function
        cache = CloudObjectCache(
            max_size=max_size,
            fetch_fn=func,
            missing_expire_seconds=missing_expire_seconds,
            exists_expire_seconds=exists_expire_seconds,
            missing_object_value=missing_object_value,
        )

        async def wrapper(*args, **kwargs):
            # Extract the key from either first positional arg or object_uri kwarg
            key = args[0] if args else kwargs.get("object_uri")
            return await cache.aget(key)

        return wrapper

    return decorator
