"""tensorboard watcher."""

import glob
import logging
import os
import queue
import socket
import sys
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import wandb
from wandb import util
from wandb.plot import CustomChart
from wandb.sdk.lib import filesystem

from . import run as internal_run

if TYPE_CHECKING:
    from queue import PriorityQueue

    from tensorboard.backend.event_processing.event_file_loader import EventFileLoader
    from tensorboard.compat.proto.event_pb2 import ProtoEvent

    from wandb.proto.wandb_internal_pb2 import RunRecord
    from wandb.sdk.lib.filesystem import FilesDict

    from ..interface.interface_queue import InterfaceQueue
    from .settings_static import SettingsStatic

    HistoryDict = Dict[str, Any]

# Give some time for tensorboard data to be flushed
SHUTDOWN_DELAY = 5
ERROR_DELAY = 5
REMOTE_FILE_TOKEN = "://"
logger = logging.getLogger(__name__)


def _link_and_save_file(
    path: str, base_path: str, interface: "InterfaceQueue", settings: "SettingsStatic"
) -> None:
    # TODO(jhr): should this logic be merged with Run.save()
    files_dir = settings.files_dir
    file_name = os.path.relpath(path, base_path)
    abs_path = os.path.abspath(path)
    wandb_path = os.path.join(files_dir, file_name)
    filesystem.mkdir_exists_ok(os.path.dirname(wandb_path))
    # We overwrite existing symlinks because namespaces can change in Tensorboard
    if os.path.islink(wandb_path) and abs_path != os.readlink(wandb_path):
        os.remove(wandb_path)
        os.symlink(abs_path, wandb_path)
    elif not os.path.exists(wandb_path):
        os.symlink(abs_path, wandb_path)
    # TODO(jhr): need to figure out policy, live/throttled?
    interface.publish_files(
        dict(files=[(filesystem.GlobStr(glob.escape(file_name)), "live")])
    )


def is_tfevents_file_created_by(
    path: str, hostname: Optional[str], start_time: Optional[float]
) -> bool:
    """Check if a path is a tfevents file.

    Optionally checks that it was created by [hostname] after [start_time].

    tensorboard tfevents filename format:
        https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81
    tensorflow tfevents filename format:
        https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68
    """
    if not path:
        raise ValueError("Path must be a nonempty string")
    basename = os.path.basename(path)
    if basename.endswith((".profile-empty", ".sagemaker-uploaded")):
        return False
    fname_components = basename.split(".")
    try:
        tfevents_idx = fname_components.index("tfevents")
    except ValueError:
        return False
    # check the hostname, which may have dots
    if hostname is not None:
        for i, part in enumerate(hostname.split(".")):
            try:
                fname_component_part = fname_components[tfevents_idx + 2 + i]
            except IndexError:
                return False
            if part != fname_component_part:
                return False
    if start_time is not None:
        try:
            created_time = int(fname_components[tfevents_idx + 1])
        except (ValueError, IndexError):
            return False
        # Ensure that the file is newer then our start time, and that it was
        # created from the same hostname.
        # TODO: we should also check the PID (also contained in the tfevents
        #     filename). Can we assume that our parent pid is the user process
        #     that wrote these files?
        if created_time < int(start_time):
            return False
    return True


class TBWatcher:
    _logdirs: "Dict[str, TBDirWatcher]"
    _watcher_queue: "PriorityQueue"

    def __init__(
        self,
        settings: "SettingsStatic",
        run_proto: "RunRecord",
        interface: "InterfaceQueue",
        force: bool = False,
    ) -> None:
        self._logdirs = {}
        self._consumer: Optional[TBEventConsumer] = None
        self._settings = settings
        self._interface = interface
        self._run_proto = run_proto
        self._force = force
        # TODO(jhr): do we need locking in this queue?
        self._watcher_queue = queue.PriorityQueue()
        wandb.tensorboard.reset_state()  # type: ignore

    def _calculate_namespace(self, logdir: str, rootdir: str) -> Optional[str]:
        namespace: Optional[str]
        dirs = list(self._logdirs) + [logdir]

        if os.path.isfile(logdir):
            filename = os.path.basename(logdir)
        else:
            filename = ""

        if rootdir == "":
            rootdir = util.to_forward_slash_path(
                os.path.dirname(os.path.commonprefix(dirs))
            )
            # Tensorboard loads all tfevents files in a directory and prepends
            # their values with the path. Passing namespace to log allows us
            # to nest the values in wandb
            # Note that we strip '/' instead of os.sep, because elsewhere we've
            # converted paths to forward slash.
            namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")

            # TODO: revisit this heuristic, it exists because we don't know the
            # root log directory until more than one tfevents file is written to
            if len(dirs) == 1 and namespace not in ["train", "validation"]:
                namespace = None
        else:
            namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")

        return namespace

    def add(self, logdir: str, save: bool, root_dir: str) -> None:
        logdir = util.to_forward_slash_path(logdir)
        root_dir = util.to_forward_slash_path(root_dir)
        if logdir in self._logdirs:
            return
        namespace = self._calculate_namespace(logdir, root_dir)
        # TODO(jhr): implement the deferred tbdirwatcher to find namespace

        if not self._consumer:
            self._consumer = TBEventConsumer(
                self, self._watcher_queue, self._run_proto, self._settings
            )
            self._consumer.start()

        tbdir_watcher = TBDirWatcher(
            self, logdir, save, namespace, self._watcher_queue, self._force
        )
        self._logdirs[logdir] = tbdir_watcher
        tbdir_watcher.start()

    def finish(self) -> None:
        for tbdirwatcher in self._logdirs.values():
            tbdirwatcher.shutdown()
        for tbdirwatcher in self._logdirs.values():
            tbdirwatcher.finish()
        if self._consumer:
            self._consumer.finish()


class TBDirWatcher:
    def __init__(
        self,
        tbwatcher: "TBWatcher",
        logdir: str,
        save: bool,
        namespace: Optional[str],
        queue: "PriorityQueue",
        force: bool = False,
    ) -> None:
        self.directory_watcher = util.get_module(
            "tensorboard.backend.event_processing.directory_watcher",
            required="Please install tensorboard package",
        )
        # self.event_file_loader = util.get_module(
        #     "tensorboard.backend.event_processing.event_file_loader",
        #     required="Please install tensorboard package",
        # )
        self.tf_compat = util.get_module(
            "tensorboard.compat", required="Please install tensorboard package"
        )
        self._tbwatcher = tbwatcher
        self._generator = self.directory_watcher.DirectoryWatcher(
            logdir, self._loader(save, namespace), self._is_our_tfevents_file
        )
        self._thread = threading.Thread(target=self._thread_except_body)
        self._first_event_timestamp = None
        self._shutdown = threading.Event()
        self._queue = queue
        self._file_version = None
        self._namespace = namespace
        self._logdir = logdir
        self._hostname = socket.gethostname()
        self._force = force
        self._process_events_lock = threading.Lock()

    def start(self) -> None:
        self._thread.start()

    def _is_our_tfevents_file(self, path: str) -> bool:
        """Check if a path has been modified since launch and contains tfevents."""
        if not path:
            raise ValueError("Path must be a nonempty string")
        path = self.tf_compat.tf.compat.as_str_any(path)
        if self._force:
            return is_tfevents_file_created_by(path, None, None)
        else:
            return is_tfevents_file_created_by(
                path, self._hostname, self._tbwatcher._settings.x_start_time
            )

    def _loader(
        self, save: bool = True, namespace: Optional[str] = None
    ) -> "EventFileLoader":
        """Incredibly hacky class generator to optionally save / prefix tfevent files."""
        _loader_interface = self._tbwatcher._interface
        _loader_settings = self._tbwatcher._settings
        try:
            from tensorboard.backend.event_processing import event_file_loader
        except ImportError:
            raise Exception("Please install tensorboard package")

        class EventFileLoader(event_file_loader.EventFileLoader):
            def __init__(self, file_path: str) -> None:
                super().__init__(file_path)
                if save:
                    if REMOTE_FILE_TOKEN in file_path:
                        logger.warning(
                            "Not persisting remote tfevent file: %s", file_path
                        )
                    else:
                        # TODO: save plugins?
                        logdir = os.path.dirname(file_path)
                        parts = list(os.path.split(logdir))
                        if namespace and parts[-1] == namespace:
                            parts.pop()
                            logdir = os.path.join(*parts)
                        _link_and_save_file(
                            path=file_path,
                            base_path=logdir,
                            interface=_loader_interface,
                            settings=_loader_settings,
                        )

        return EventFileLoader

    def _process_events(self, shutdown_call: bool = False) -> None:
        try:
            with self._process_events_lock:
                for event in self._generator.Load():
                    self.process_event(event)
        except (
            self.directory_watcher.DirectoryDeletedError,
            StopIteration,
            RuntimeError,
            OSError,
        ) as e:
            # When listing s3 the directory may not yet exist, or could be empty
            logger.debug("Encountered tensorboard directory watcher error: %s", e)
            if not self._shutdown.is_set() and not shutdown_call:
                time.sleep(ERROR_DELAY)

    def _thread_except_body(self) -> None:
        try:
            self._thread_body()
        except Exception:
            logger.exception("generic exception in TBDirWatcher thread")
            raise

    def _thread_body(self) -> None:
        """Check for new events every second."""
        shutdown_time: Optional[float] = None
        while True:
            self._process_events()
            if self._shutdown.is_set():
                now = time.time()
                if not shutdown_time:
                    shutdown_time = now + SHUTDOWN_DELAY
                elif now > shutdown_time:
                    break
            time.sleep(1)

    def process_event(self, event: "ProtoEvent") -> None:
        # print("\nEVENT:::", self._logdir, self._namespace, event, "\n")
        if self._first_event_timestamp is None:
            self._first_event_timestamp = event.wall_time

        if event.HasField("file_version"):
            self._file_version = event.file_version

        if event.HasField("summary"):
            self._queue.put(Event(event, self._namespace))

    def shutdown(self) -> None:
        self._process_events(shutdown_call=True)
        self._shutdown.set()

    def finish(self) -> None:
        self.shutdown()
        self._thread.join()


class Event:
    """An event wrapper to enable priority queueing."""

    def __init__(self, event: "ProtoEvent", namespace: Optional[str]):
        self.event = event
        self.namespace = namespace
        self.created_at = time.time()

    def __lt__(self, other: "Event") -> bool:
        if self.event.wall_time < other.event.wall_time:
            return True
        return False


class TBEventConsumer:
    """Consume tfevents from a priority queue.

    There should always only be one of these per run_manager.  We wait for 10 seconds of
    queued events to reduce the chance of multiple tfevent files triggering out of order
    steps.
    """

    def __init__(
        self,
        tbwatcher: TBWatcher,
        queue: "PriorityQueue",
        run_proto: "RunRecord",
        settings: "SettingsStatic",
        delay: int = 10,
    ) -> None:
        self._tbwatcher = tbwatcher
        self._queue = queue
        self._thread = threading.Thread(target=self._thread_except_body)
        self._shutdown = threading.Event()
        self.tb_history = TBHistory()
        self._delay = delay

        # This is a bit of a hack to get file saving to work as it does in the user
        # process. Since we don't have a real run object, we have to define the
        # datatypes callback ourselves.
        def datatypes_cb(fname: filesystem.GlobStr) -> None:
            files: FilesDict = dict(files=[(fname, "now")])
            self._tbwatcher._interface.publish_files(files)

        # this is only used for logging artifacts
        self._internal_run = internal_run.InternalRun(run_proto, settings, datatypes_cb)
        self._internal_run._set_internal_run_interface(self._tbwatcher._interface)

    def start(self) -> None:
        self._start_time = time.time()
        self._thread.start()

    def finish(self) -> None:
        self._delay = 0
        self._shutdown.set()
        self._thread.join()
        while not self._queue.empty():
            event = self._queue.get(True, 1)
            if event:
                self._handle_event(event, history=self.tb_history)
                items = self.tb_history._get_and_reset()
                for item in items:
                    self._save_row(
                        item,
                    )

    def _thread_except_body(self) -> None:
        try:
            self._thread_body()
        except Exception:
            logger.exception("generic exception in TBEventConsumer thread")
            raise

    def _thread_body(self) -> None:
        while True:
            try:
                event = self._queue.get(True, 1)
                # Wait self._delay seconds from consumer start before logging events
                if (
                    time.time() < self._start_time + self._delay
                    and not self._shutdown.is_set()
                ):
                    self._queue.put(event)
                    time.sleep(0.1)
                    continue
            except queue.Empty:
                event = None
                if self._shutdown.is_set():
                    break
            if event:
                self._handle_event(event, history=self.tb_history)
                items = self.tb_history._get_and_reset()
                for item in items:
                    self._save_row(
                        item,
                    )
        # flush uncommitted data
        self.tb_history._flush()
        items = self.tb_history._get_and_reset()
        for item in items:
            self._save_row(item)

    def _handle_event(
        self, event: "ProtoEvent", history: Optional["TBHistory"] = None
    ) -> None:
        wandb.tensorboard._log(  # type: ignore
            event.event,
            step=event.event.step,
            namespace=event.namespace,
            history=history,
        )

    def _save_row(self, row: "HistoryDict") -> None:
        chart_keys = set()
        for k, v in row.items():
            if isinstance(v, CustomChart):
                chart_keys.add(k)
                v.set_key(k)
                self._tbwatcher._interface.publish_config(
                    key=v.spec.config_key,
                    val=v.spec.config_value,
                )

        for k in chart_keys:
            chart = row.pop(k)
            if isinstance(chart, CustomChart):
                row[chart.spec.table_key] = chart.table

        self._tbwatcher._interface.publish_history(
            self._internal_run,
            row,
            publish_step=False,
        )


class TBHistory:
    _data: "HistoryDict"
    _added: "List[HistoryDict]"

    def __init__(self) -> None:
        self._step = 0
        self._step_size = 0
        self._data = dict()
        self._added = []

    def _flush(self) -> None:
        if not self._data:
            return
        # A single tensorboard step may have too much data
        # we just drop the largest keys in the step if it does.
        # TODO: we could flush the data across multiple steps
        if self._step_size > util.MAX_LINE_BYTES:
            metrics = [(k, sys.getsizeof(v)) for k, v in self._data.items()]
            metrics.sort(key=lambda t: t[1], reverse=True)
            bad = 0
            dropped_keys = []
            for k, v in metrics:
                # TODO: (cvp) Added a buffer of 100KiB, this feels rather brittle.
                if self._step_size - bad < util.MAX_LINE_BYTES - 100000:
                    break
                else:
                    bad += v
                    dropped_keys.append(k)
                    del self._data[k]
            wandb.termwarn(
                f"Step {self._step} exceeds max data limit, dropping {len(dropped_keys)} of the largest keys:"
            )
            print("\t" + ("\n\t".join(dropped_keys)))  # noqa: T201
        self._data["_step"] = self._step
        self._added.append(self._data)
        self._step += 1
        self._step_size = 0

    def add(self, d: "HistoryDict") -> None:
        self._flush()
        self._data = dict()
        self._data.update(self._track_history_dict(d))

    def _track_history_dict(self, d: "HistoryDict") -> "HistoryDict":
        e = {}
        for k in d.keys():
            e[k] = d[k]
            self._step_size += sys.getsizeof(e[k])
        return e

    def _row_update(self, d: "HistoryDict") -> None:
        self._data.update(self._track_history_dict(d))

    def _get_and_reset(self) -> "List[HistoryDict]":
        added = self._added[:]
        self._added = []
        return added
