# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
#     http://www.apache.org/licenses/LICENSE-2.0
#
import glob
import os.path
import re
import warnings
from collections.abc import Sequence
from pprint import pprint
from typing import Union

REQUIREMENT_ROOT = "requirements.txt"
REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt"))
REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True)
REQUIREMENT_FILES_ALL += glob.glob(os.path.join("**", "pyproject.toml"))
if os.path.isfile(REQUIREMENT_ROOT):
    REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT]


def prune_packages_in_requirements(
    packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
) -> None:
    """Remove one or more packages from the specified requirement files.

    Args:
        packages: A package name or list of package names to remove.
        req_files: A path or list of paths to requirement files to process.

    """
    if isinstance(packages, str):
        packages = [packages]
    if isinstance(req_files, str):
        req_files = [req_files]
    for req in req_files:
        _prune_packages(req, packages)


def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
    """Remove all occurrences of the given packages (by line prefix) from a requirements file.

    Args:
        req_file: Path to a requirements file.
        packages: Package names to remove. Lines starting with any of these names will be dropped.

    """
    with open(req_file) as fp:
        lines = fp.readlines()

    if isinstance(packages, str):
        packages = [packages]
    for pkg in packages:
        lines = [ln for ln in lines if not ln.startswith(pkg)]
    pprint(lines)

    with open(req_file, "w") as fp:
        fp.writelines(lines)


def _replace_min_req_in_txt(req_file: str) -> None:
    """Replace all occurrences of '>=' with '==' in a plain text requirements file.

    Args:
        req_file: Path to the requirements.txt-like file to update.

    """
    with open(req_file) as fopen:
        req = fopen.read().replace(">=", "==")
    with open(req_file, "w") as fw:
        fw.write(req)


def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None:
    """Replace all '>=' with '==' in the [project.dependencies] section of a standard pyproject.toml.

    Preserves formatting and comments using tomlkit.

    Args:
        proj_file: Path to the pyproject.toml file.

    """
    import tomlkit

    # Load and parse the existing pyproject.toml
    with open(proj_file, encoding="utf-8") as f:
        content = f.read()
    doc = tomlkit.parse(content)

    # todo: consider also replace extras in [dependency-groups] -> extras = [...]
    deps = doc.get("project", {}).get("dependencies")
    if not deps:
        return

    # Replace '>=version' with '==version' in each dependency
    for i, req in enumerate(deps):
        # Simple string value
        deps[i] = req.replace(">=", "==")

    # Dump back out, preserving layout
    with open(proj_file, "w", encoding="utf-8") as f:
        f.write(tomlkit.dumps(doc))


def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
    """Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files.

    Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.

    Args:
        req_files: A path or list of paths to requirement files to process.

    """
    if isinstance(req_files, str):
        req_files = [req_files]
    for fname in req_files:
        if fname.endswith(".txt"):
            _replace_min_req_in_txt(fname)
        elif os.path.basename(fname) == "pyproject.toml":
            _replace_min_req_in_pyproject_toml(fname)
        else:
            warnings.warn(
                "Only *.txt with plain list of requirements or standard pyproject.toml are supported."
                f"Provided '{fname}' is not supported.",
                UserWarning,
                stacklevel=2,
            )


def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None:
    """Rename a package in a plain text requirements file, preserving version specifiers and markers.

    Args:
        req_file: Path to the requirements.txt-like file to update.
        old_package: The original package name to replace.
        new_package: The new package name to use.

    """
    # load file
    with open(req_file) as fopen:
        requirements = fopen.readlines()
    # replace all occurrences
    for i, req in enumerate(requirements):
        requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req)
    # save file
    with open(req_file, "w") as fw:
        fw.writelines(requirements)


def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None:
    """Rename a package in the [project.dependencies] section of a standard pyproject.toml, preserving constraints.

    Args:
        proj_file: Path to the pyproject.toml file.
        old_package: The original package name to replace.
        new_package: The new package name to use.

    """
    import tomlkit

    # Load and parse the existing pyproject.toml
    with open(proj_file, encoding="utf-8") as f:
        content = f.read()
    doc = tomlkit.parse(content)

    # todo: consider also replace extras in [dependency-groups] -> extras = [...]
    deps = doc.get("project", {}).get("dependencies")
    if not deps:
        return

    # Replace '>=version' with '==version' in each dependency
    for i, req in enumerate(deps):
        # Simple string value
        deps[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>]|$)", new_package, req)

    # Dump back out, preserving layout
    with open(proj_file, "w", encoding="utf-8") as f:
        f.write(tomlkit.dumps(doc))


def replace_package_in_requirements(
    old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
) -> None:
    """Rename a package across multiple requirement files while keeping version constraints intact.

    Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.

    Args:
        old_package: The original package name to replace.
        new_package: The new package name to use.
        req_files: A path or list of paths to requirement files to process.

    """
    if isinstance(req_files, str):
        req_files = [req_files]
    for fname in req_files:
        if fname.endswith(".txt"):
            _replace_package_name_in_txt(fname, old_package, new_package)
        elif os.path.basename(fname) == "pyproject.toml":
            _replace_package_name_in_pyproject_toml(fname, old_package, new_package)
        else:
            warnings.warn(
                "Only *.txt with plain list of requirements or standard pyproject.toml are supported."
                f"Provided '{fname}' is not supported.",
                UserWarning,
                stacklevel=2,
            )
