"""Public API: files."""

import io
import os
from typing import Optional

import requests
from wandb_gql import gql
from wandb_gql.client import RetryError

import wandb
from wandb import util
from wandb.apis.attrs import Attrs
from wandb.apis.normalize import normalize_exceptions
from wandb.apis.paginator import SizedPaginator
from wandb.apis.public import utils
from wandb.apis.public.api import Api
from wandb.apis.public.const import RETRY_TIMEDELTA
from wandb.sdk.lib import retry

FILE_FRAGMENT = """fragment RunFilesFragment on Run {
    files(names: $fileNames, after: $fileCursor, first: $fileLimit) {
        edges {
            node {
                id
                name
                url(upload: $upload)
                directUrl
                sizeBytes
                mimetype
                updatedAt
                md5
            }
            cursor
        }
        pageInfo {
            endCursor
            hasNextPage
        }
    }
}"""


class Files(SizedPaginator["File"]):
    """An iterable collection of `File` objects."""

    QUERY = gql(
        """
        query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
            $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
            project(name: $project, entityName: $entity) {{
                internalId
                run(name: $name) {{
                    fileCount
                    ...RunFilesFragment
                }}
            }}
        }}
        {}
        """.format(FILE_FRAGMENT)
    )

    def __init__(self, client, run, names=None, per_page=50, upload=False):
        self.run = run
        variables = {
            "project": run.project,
            "entity": run.entity,
            "name": run.id,
            "fileNames": names or [],
            "upload": upload,
        }
        super().__init__(client, variables, per_page)

    @property
    def _length(self):
        if not self.last_response:
            self._load_page()

        return self.last_response["project"]["run"]["fileCount"]

    @property
    def more(self):
        if self.last_response:
            return self.last_response["project"]["run"]["files"]["pageInfo"][
                "hasNextPage"
            ]
        else:
            return True

    @property
    def cursor(self):
        if self.last_response:
            return self.last_response["project"]["run"]["files"]["edges"][-1]["cursor"]
        else:
            return None

    def update_variables(self):
        self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})

    def convert_objects(self):
        return [
            File(self.client, r["node"], self.run)
            for r in self.last_response["project"]["run"]["files"]["edges"]
        ]

    def __repr__(self):
        return "<Files {} ({})>".format("/".join(self.run.path), len(self))


class File(Attrs):
    """File is a class associated with a file saved by wandb.

    Attributes:
        name (string): filename
        url (string): path to file
        direct_url (string): path to file in the bucket
        md5 (string): md5 of file
        mimetype (string): mimetype of file
        updated_at (string): timestamp of last update
        size (int): size of file in bytes
        path_uri (str): path to file in the bucket, currently only available for files stored in S3
    """

    def __init__(self, client, attrs, run=None):
        self.client = client
        self._attrs = attrs
        self.run = run
        self.server_supports_delete_file_with_project_id: Optional[bool] = None
        super().__init__(dict(attrs))

    @property
    def size(self):
        size_bytes = self._attrs["sizeBytes"]
        if size_bytes is not None:
            return int(size_bytes)
        return 0

    @property
    def path_uri(self) -> str:
        """
        Returns the uri path to the file in the storage bucket.
        """
        path_uri = ""
        try:
            path_uri = utils.parse_s3_url_to_s3_uri(self._attrs["directUrl"])
        except ValueError:
            wandb.termwarn("path_uri is only available for files stored in S3")
        except LookupError:
            wandb.termwarn("Unable to find direct_url of file")
        return path_uri

    @normalize_exceptions
    @retry.retriable(
        retry_timedelta=RETRY_TIMEDELTA,
        check_retry_fn=util.no_retry_auth,
        retryable_exceptions=(RetryError, requests.RequestException),
    )
    def download(
        self,
        root: str = ".",
        replace: bool = False,
        exist_ok: bool = False,
        api: Optional[Api] = None,
    ) -> io.TextIOWrapper:
        """Downloads a file previously saved by a run from the wandb server.

        Args:
            replace (boolean): If `True`, download will overwrite a local file
                if it exists. Defaults to `False`.
            root (str): Local directory to save the file.  Defaults to ".".
            exist_ok (boolean): If `True`, will not raise ValueError if file already
                exists and will not re-download unless replace=True. Defaults to `False`.
            api (Api, optional): If given, the `Api` instance used to download the file.

        Raises:
            `ValueError` if file already exists, replace=False and exist_ok=False.
        """
        if api is None:
            api = wandb.Api()

        path = os.path.join(root, self.name)
        if os.path.exists(path) and not replace:
            if exist_ok:
                return open(path)
            else:
                raise ValueError(
                    "File already exists, pass replace=True to overwrite or exist_ok=True to leave it as is and don't error."
                )

        util.download_file_from_url(path, self.url, api.api_key)
        return open(path)

    @normalize_exceptions
    def delete(self):
        project_id_mutation_fragment = ""
        project_id_variable_fragment = ""
        variable_values = {
            "files": [self.id],
        }

        # Add projectId to mutation and variables if the server supports it.
        # Otherwise, do not include projectId in mutation for older server versions which do not support it.
        if self._server_accepts_project_id_for_delete_file():
            variable_values["projectId"] = self.run._project_internal_id
            project_id_variable_fragment = ", $projectId: Int"
            project_id_mutation_fragment = "projectId: $projectId"

        mutation_string = """
            mutation deleteFiles($files: [ID!]!{}) {{
                deleteFiles(input: {{
                    files: $files
                    {}
                }}) {{
                    success
                }}
            }}
            """.format(project_id_variable_fragment, project_id_mutation_fragment)
        mutation = gql(mutation_string)

        self.client.execute(
            mutation,
            variable_values=variable_values,
        )

    def __repr__(self):
        return "<File {} ({}) {}>".format(
            self.name,
            self.mimetype,
            util.to_human_size(self.size, units=util.POW_2_BYTES),
        )

    @normalize_exceptions
    def _server_accepts_project_id_for_delete_file(self) -> bool:
        """Returns True if the server supports deleting files with a projectId.

        This check is done by utilizing GraphQL introspection in the available fields on the DeleteFiles API.
        """
        query_string = """
           query ProbeDeleteFilesProjectIdInput {
                DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") {
                    inputFields{
                        name
                    }
                }
            }
        """

        # Only perform the query once to avoid extra network calls
        if self.server_supports_delete_file_with_project_id is None:
            query = gql(query_string)
            res = self.client.execute(query)

            # If projectId is in the inputFields, the server supports deleting files with a projectId
            self.server_supports_delete_file_with_project_id = "projectId" in [
                x["name"]
                for x in (
                    res.get("DeleteFilesProjectIdInputType", {}).get(
                        "inputFields", [{}]
                    )
                )
            ]

        return self.server_supports_delete_file_with_project_id
