import json
import os
from datetime import datetime, timedelta
from pathlib import Path

from wandb.errors import AuthenticationError

DEFAULT_WANDB_CREDENTIALS_FILE = Path(
    os.path.expanduser("~/.config/wandb/credentials.json")
)

_expires_at_fmt = "%Y-%m-%d %H:%M:%S"


def access_token(base_url: str, token_file: Path, credentials_file: Path) -> str:
    """Retrieve an access token from the credentials file.

    If no access token exists, create a new one by exchanging the identity
    token from the token file, and save it to the credentials file.

    Args:
        base_url (str): The base URL of the server
        token_file (pathlib.Path): The path to the file containing the
        identity token
        credentials_file (pathlib.Path): The path to file used to save
        temporary access tokens

    Returns:
        str: The access token
    """
    if not credentials_file.exists():
        _write_credentials_file(base_url, token_file, credentials_file)

    data = _fetch_credentials(base_url, token_file, credentials_file)
    return data["access_token"]


def _write_credentials_file(base_url: str, token_file: Path, credentials_file: Path):
    """Obtain an access token from the server and write it to the credentials file.

    Args:
        base_url (str): The base URL of the server
        token_file (pathlib.Path): The path to the file containing the
        identity token
        credentials_file (pathlib.Path): The path to file used to save
        temporary access tokens
    """
    credentials = _create_access_token(base_url, token_file)
    data = {"credentials": {base_url: credentials}}
    with open(credentials_file, "w") as file:
        json.dump(data, file, indent=4)

        # Set file permissions to be read/write by the owner only
        os.chmod(credentials_file, 0o600)


def _fetch_credentials(base_url: str, token_file: Path, credentials_file: Path) -> dict:
    """Fetch the access token from the credentials file.

    If the access token has expired, fetch a new one from the server and save it
    to the credentials file.

    Args:
        base_url (str): The base URL of the server
        token_file (pathlib.Path): The path to the file containing the
        identity token
        credentials_file (pathlib.Path): The path to file used to save
        temporary access tokens

    Returns:
        dict: The credentials including the access token.
    """
    creds = {}
    with open(credentials_file) as file:
        data = json.load(file)
        if "credentials" not in data:
            data["credentials"] = {}
        if base_url in data["credentials"]:
            creds = data["credentials"][base_url]

    expires_at = datetime.utcnow()
    if "expires_at" in creds:
        expires_at = datetime.strptime(creds["expires_at"], _expires_at_fmt)

    if expires_at <= datetime.utcnow():
        creds = _create_access_token(base_url, token_file)
        with open(credentials_file, "w") as file:
            data["credentials"][base_url] = creds
            json.dump(data, file, indent=4)

    return creds


def _create_access_token(base_url: str, token_file: Path) -> dict:
    """Exchange an identity token for an access token from the server.

    Args:
        base_url (str): The base URL of the server.
        token_file (pathlib.Path): The path to the file containing the
        identity token

    Returns:
        dict: The access token and its expiration.

    Raises:
        FileNotFoundError: If the token file is not found.
        OSError: If there is an issue reading the token file.
        AuthenticationError: If the server fails to provide an access token.
    """
    import requests

    try:
        with open(token_file) as file:
            token = file.read().strip()
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Identity token file not found: {token_file}") from e
    except OSError as e:
        raise OSError(
            f"Failed to read the identity token from file: {token_file}"
        ) from e

    url = f"{base_url}/oidc/token"
    data = {
        "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
        "assertion": token,
    }
    headers = {"Content-Type": "application/x-www-form-urlencoded"}

    response = requests.post(url, data=data, headers=headers)

    if response.status_code != 200:
        raise AuthenticationError(
            f"Failed to retrieve access token: {response.status_code}, {response.text}"
        )

    resp_json = response.json()
    expires_at = datetime.utcnow() + timedelta(seconds=float(resp_json["expires_in"]))
    resp_json["expires_at"] = expires_at.strftime(_expires_at_fmt)
    del resp_json["expires_in"]

    return resp_json
