"""Parallel processing tools to help speed up certain tasks like data
preprocessing.

Authors
 * Sylvain de Langen 2023
"""

import itertools
import multiprocessing
from collections import deque
from concurrent.futures import Executor, ProcessPoolExecutor
from threading import Condition
from typing import Any, Callable, Iterable, Optional

from tqdm.auto import tqdm


def _chunk_process_wrapper(fn, chunk):
    return list(map(fn, chunk))


class CancelFuturesOnExit:
    """Context manager that .cancel()s all elements of a list upon exit.
    This is used to abort futures faster when raising an exception."""

    def __init__(self, future_list):
        self.future_list = future_list

    def __enter__(self):
        pass

    def __exit__(self, _type, _value, _traceback):
        for future in self.future_list:
            future.cancel()


class _ParallelMapper:
    """Internal class for `parallel_map`, arguments match the constructor's."""

    def __init__(
        self,
        fn: Callable[[Any], Any],
        source: Iterable[Any],
        process_count: int,
        chunk_size: int,
        queue_size: int,
        executor: Optional[Executor],
        progress_bar: bool,
        progress_bar_kwargs: dict,
    ):
        self.future_chunks = deque()
        self.cv = Condition()
        self.just_finished_count = 0
        """Number of jobs that were just done processing, guarded by
        `self.cv`."""
        self.remote_exception = None
        """Set by a worker when it encounters an exception, guarded by
        `self.cv`."""

        self.fn = fn
        self.source = source
        self.process_count = process_count
        self.chunk_size = chunk_size
        self.queue_size = queue_size
        self.executor = executor

        self.known_len = len(source) if hasattr(source, "__len__") else None
        self.source_it = iter(source)
        self.depleted_source = False

        if progress_bar:
            tqdm_final_kwargs = {"total": self.known_len}
            tqdm_final_kwargs.update(progress_bar_kwargs)
            self.pbar = tqdm(**tqdm_final_kwargs)
        else:
            self.pbar = None

    def run(self):
        """Spins up an executor (if none were provided), then yields all
        processed chunks in order."""
        with CancelFuturesOnExit(self.future_chunks):
            if self.executor is not None:
                # just use the executor we were provided
                yield from self._map_all()
            else:
                # start and shut down a process pool executor -- ok for
                # long-running tasks
                with ProcessPoolExecutor(
                    max_workers=self.process_count
                ) as pool:
                    self.executor = pool
                    yield from self._map_all()

    def _bump_processed_count(self, future):
        """Notifies the main thread of the finished job, bumping the number of
        jobs it should requeue. Updates the progress bar based on the returned
        chunk length.

        Arguments
        ---------
        future: concurrent.futures.Future
            A future holding a processed chunk (of type `list`).

        Returns
        -------
        None
        """
        if future.cancelled():
            # the scheduler wants us to stop or something else happened, give up
            return

        future_exception = future.exception()

        # wake up dispatcher thread to refill the queue
        with self.cv:
            if future_exception is not None:
                # signal to the main thread that it should raise
                self.remote_exception = future_exception

            self.just_finished_count += 1
            self.cv.notify()

        if future_exception is None:
            # update progress bar with the length of the output as the progress
            # bar is over element count, not chunk count.
            if self.pbar is not None:
                self.pbar.update(len(future.result()))

    def _enqueue_job(self):
        """Pulls a chunk from the source iterable and submits it to the
        pool; must be run from the main thread.

        Returns
        -------
        `True` if any job was submitted (that is, if there was any chunk
        left to process), `False` otherwise.
        """
        # immediately deplete the input stream of chunk_size elems (or less)
        chunk = list(itertools.islice(self.source_it, self.chunk_size))

        # empty chunk? then we finished iterating over the input stream
        if len(chunk) == 0:
            self.depleted_source = True
            return False

        future = self.executor.submit(_chunk_process_wrapper, self.fn, chunk)
        future.add_done_callback(self._bump_processed_count)
        self.future_chunks.append(future)

        return True

    def _map_all(self):
        """Performs all the parallel mapping logic.

        Yields
        ------
        The items from source processed by fn
        """

        # initial queue fill
        for _ in range(self.queue_size):
            if not self._enqueue_job():
                break

        # consume & requeue logic
        while (not self.depleted_source) or (len(self.future_chunks) != 0):
            with self.cv:
                # if `cv.notify` was called by a worker _after_ the `with cv`
                # block last iteration, then `just_finished_count` would be
                # incremented, but this `cv.wait` would not wake up -- skip it.
                while self.just_finished_count == 0:
                    # wait to be woken up by a worker thread, which could mean:
                    # - that a chunk was processed: try to yield any
                    # - that a call failed with an exception: raise it
                    # - nothing; it could be a spurious CV wakeup: keep looping
                    self.cv.wait()

                if self.remote_exception is not None:
                    raise self.remote_exception

                # store the amount to requeue, avoiding data races
                to_queue_count = self.just_finished_count
                self.just_finished_count = 0

            # try to enqueue as many jobs as there were just finished.
            # when the input is finished, the queue will not be refilled.
            for _ in range(to_queue_count):
                if not self._enqueue_job():
                    break

            # yield from left to right as long as there is enough ready
            # e.g. | done | done | !done | done | !done | !done
            # would yield from the first two. we might deplete the entire queue
            # at that point, the `depleted_source` loop check is needed as such.
            while len(self.future_chunks) != 0 and self.future_chunks[0].done():
                yield from self.future_chunks.popleft().result()

        if self.pbar is not None:
            self.pbar.close()


def parallel_map(
    fn: Callable[[Any], Any],
    source: Iterable[Any],
    process_count: int = multiprocessing.cpu_count(),
    chunk_size: int = 8,
    queue_size: int = 128,
    executor: Optional[Executor] = None,
    progress_bar: bool = True,
    progress_bar_kwargs: dict = {"smoothing": 0.02},
):
    """Maps iterable items with a function, processing chunks of items in
    parallel with multiple processes and displaying progress with tqdm.

    Processed elements will always be returned in the original, correct order.
    Unlike `ProcessPoolExecutor.map`, elements are produced AND consumed lazily.

    Arguments
    ---------
    fn: Callable
        The function that is called for every element in the source list.
        The output is an iterator over the source list after fn(elem) is called.

    source: Iterable
        Iterator whose elements are passed through the mapping function.

    process_count: int
        The number of processes to spawn. Ignored if a custom executor is
        provided.
        For CPU-bound tasks, it is generally not useful to exceed logical core
        count.
        For IO-bound tasks, it may make sense to as to limit the amount of time
        spent in iowait.

    chunk_size: int
        How many elements are fed to the worker processes at once. A value of 8
        is generally fine. Low values may increase overhead and reduce CPU
        occupancy.

    queue_size: int
        Number of chunks to be waited for on the main process at a time.
        Low values increase the chance of the queue being starved, forcing
        workers to idle.
        Very high values may cause high memory usage, especially if the source
        iterable yields large objects.

    executor: Optional[Executor]
        Allows providing an existing executor (preferably a
        ProcessPoolExecutor). If None (the default), a process pool will be
        spawned for this mapping task and will be shut down after.

    progress_bar: bool
        Whether to show a tqdm progress bar.

    progress_bar_kwargs: dict
        A dict of keyword arguments that is forwarded to tqdm when
        `progress_bar == True`. Allows overriding the defaults or e.g.
        specifying `total` when it cannot be inferred from the source iterable.

    Yields
    ------
    The items from source processed by fn
    """
    mapper = _ParallelMapper(
        fn,
        source,
        process_count,
        chunk_size,
        queue_size,
        executor,
        progress_bar,
        progress_bar_kwargs,
    )
    yield from mapper.run()
