"""Artifact saver."""

from __future__ import annotations

import concurrent.futures
import json
import os
import tempfile
from typing import TYPE_CHECKING, Awaitable, Sequence

import wandb
import wandb.filesync.step_prepare
from wandb import util
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
from wandb.sdk.lib.paths import URIStr

if TYPE_CHECKING:
    from typing import Protocol

    from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
    from wandb.sdk.internal.file_pusher import FilePusher
    from wandb.sdk.internal.internal_api import Api as InternalApi
    from wandb.sdk.internal.progress import ProgressFn

    class SaveFn(Protocol):
        def __call__(
            self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
        ) -> bool:
            pass

    class SaveFnAsync(Protocol):
        def __call__(
            self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
        ) -> Awaitable[bool]:
            pass


class ArtifactSaver:
    _server_artifact: dict | None  # TODO better define this dict

    def __init__(
        self,
        api: InternalApi,
        digest: str,
        manifest_json: dict,
        file_pusher: FilePusher,
        is_user_created: bool = False,
    ) -> None:
        self._api = api
        self._file_pusher = file_pusher
        self._digest = digest
        self._manifest = ArtifactManifest.from_manifest_json(manifest_json)
        self._manifest.storage_policy._api = self._api
        self._is_user_created = is_user_created
        self._server_artifact = None

    def save(
        self,
        entity: str,
        project: str,
        type: str,
        name: str,
        client_id: str,
        sequence_client_id: str,
        distributed_id: str | None = None,
        finalize: bool = True,
        metadata: dict | None = None,
        ttl_duration_seconds: int | None = None,
        description: str | None = None,
        aliases: Sequence[str] | None = None,
        tags: Sequence[str] | None = None,
        use_after_commit: bool = False,
        incremental: bool = False,
        history_step: int | None = None,
        base_id: str | None = None,
    ) -> dict | None:
        return self._save_internal(
            entity,
            project,
            type,
            name,
            client_id,
            sequence_client_id,
            distributed_id,
            finalize,
            metadata,
            ttl_duration_seconds,
            description,
            aliases,
            tags,
            use_after_commit,
            incremental,
            history_step,
            base_id,
        )

    def _save_internal(
        self,
        entity: str,
        project: str,
        type: str,
        name: str,
        client_id: str,
        sequence_client_id: str,
        distributed_id: str | None = None,
        finalize: bool = True,
        metadata: dict | None = None,
        ttl_duration_seconds: int | None = None,
        description: str | None = None,
        aliases: Sequence[str] | None = None,
        tags: Sequence[str] | None = None,
        use_after_commit: bool = False,
        incremental: bool = False,
        history_step: int | None = None,
        base_id: str | None = None,
    ) -> dict | None:
        alias_specs = []
        for alias in aliases or []:
            alias_specs.append({"artifactCollectionName": name, "alias": alias})

        tag_specs = [{"tagName": tag} for tag in tags or []]

        """Returns the server artifact."""
        self._server_artifact, latest = self._api.create_artifact(
            type,
            name,
            self._digest,
            metadata=metadata,
            ttl_duration_seconds=ttl_duration_seconds,
            aliases=alias_specs,
            tags=tag_specs,
            description=description,
            is_user_created=self._is_user_created,
            distributed_id=distributed_id,
            client_id=client_id,
            sequence_client_id=sequence_client_id,
            history_step=history_step,
        )

        assert self._server_artifact is not None  # mypy optionality unwrapper
        artifact_id = self._server_artifact["id"]
        if base_id is None and latest:
            base_id = latest["id"]
        if self._server_artifact["state"] == "COMMITTED":
            if use_after_commit:
                self._api.use_artifact(
                    artifact_id,
                    artifact_entity_name=entity,
                    artifact_project_name=project,
                )
            return self._server_artifact
        if (
            self._server_artifact["state"] != "PENDING"
            # For old servers, see https://github.com/wandb/wandb/pull/6190
            and self._server_artifact["state"] != "DELETED"
        ):
            raise Exception(
                'Unknown artifact state "{}"'.format(self._server_artifact["state"])
            )

        manifest_type = "FULL"
        manifest_filename = "wandb_manifest.json"
        if incremental:
            manifest_type = "INCREMENTAL"
            manifest_filename = "wandb_manifest.incremental.json"
        elif distributed_id:
            manifest_type = "PATCH"
            manifest_filename = "wandb_manifest.patch.json"
        artifact_manifest_id, _ = self._api.create_artifact_manifest(
            manifest_filename,
            "",
            artifact_id,
            base_artifact_id=base_id,
            include_upload=False,
            type=manifest_type,
        )

        step_prepare = wandb.filesync.step_prepare.StepPrepare(
            self._api, 0.1, 0.01, 1000
        )  # TODO: params
        step_prepare.start()

        # Upload Artifact "L1" files, the actual artifact contents
        self._file_pusher.store_manifest_files(
            self._manifest,
            artifact_id,
            lambda entry, progress_callback: self._manifest.storage_policy.store_file(
                artifact_id,
                artifact_manifest_id,
                entry,
                step_prepare,
                progress_callback=progress_callback,
            ),
        )

        def before_commit() -> None:
            self._resolve_client_id_manifest_references()
            with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as fp:
                path = os.path.abspath(fp.name)
                json.dump(self._manifest.to_manifest_json(), fp, indent=4)
            digest = md5_file_b64(path)
            if distributed_id or incremental:
                # If we're in the distributed flow, we want to update the
                # patch manifest we created with our finalized digest.
                _, resp = self._api.update_artifact_manifest(
                    artifact_manifest_id,
                    digest=digest,
                )
            else:
                # In the regular flow, we can recreate the full manifest with the
                # updated digest.
                #
                # NOTE: We do this for backwards compatibility with older backends
                # that don't support the 'updateArtifactManifest' API.
                _, resp = self._api.create_artifact_manifest(
                    manifest_filename,
                    digest,
                    artifact_id,
                    base_artifact_id=base_id,
                )

            # We're duplicating the file upload logic a little, which isn't great.
            upload_url = resp["uploadUrl"]
            upload_headers = resp["uploadHeaders"]
            extra_headers = {}
            for upload_header in upload_headers:
                key, val = upload_header.split(":", 1)
                extra_headers[key] = val
            with open(path, "rb") as fp2:
                self._api.upload_file_retry(
                    upload_url,
                    fp2,
                    extra_headers=extra_headers,
                )

        commit_result: concurrent.futures.Future[None] = concurrent.futures.Future()

        # Queue the commit. It will only happen after all file uploads finish.
        self._file_pusher.commit_artifact(
            artifact_id,
            finalize=finalize,
            before_commit=before_commit,
            result_future=commit_result,
        )

        # Block until all artifact files are uploaded and the
        # artifact is committed.
        try:
            commit_result.result()
        finally:
            step_prepare.shutdown()

        if finalize and use_after_commit:
            self._api.use_artifact(
                artifact_id,
                artifact_entity_name=entity,
                artifact_project_name=project,
            )

        return self._server_artifact

    def _resolve_client_id_manifest_references(self) -> None:
        for entry_path in self._manifest.entries:
            entry = self._manifest.entries[entry_path]
            if entry.ref is not None:
                if entry.ref.startswith("wandb-client-artifact:"):
                    client_id = util.host_from_path(entry.ref)
                    artifact_file_path = util.uri_from_path(entry.ref)
                    artifact_id = self._api._resolve_client_id(client_id)
                    if artifact_id is None:
                        raise RuntimeError(f"Could not resolve client id {client_id}")
                    entry.ref = URIStr(
                        f"wandb-artifact://{b64_to_hex_id(B64MD5(artifact_id))}/{artifact_file_path}"
                    )
