#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#


"""Low level iteration functions for tar archives."""

import random
import re
import tarfile
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Set, Tuple

import braceexpand

from . import filters, gopen
from .handlers import reraise_exception

trace = False
meta_prefix = "__"
meta_suffix = "__"


def base_plus_ext(path):
    """Split off all file extensions.

    Args:
        path: Path with extensions.

    Returns:
        Tuple containing the base path and all extensions.
    """
    match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
    if not match:
        return None, None
    return match.group(1), match.group(2)


def valid_sample(sample: Dict[str, Any]) -> bool:
    """Check whether a sample is valid.

    Args:
        sample: A dictionary representing a sample.

    Returns:
        Boolean indicating whether the sample is valid.
    """
    return (
        sample is not None
        and isinstance(sample, dict)
        and len(list(sample.keys())) > 0
        and not sample.get("__bad__", False)
    )


# FIXME: UNUSED
def shardlist(urls, *, shuffle=False):
    """Generate a list of URLs, possibly shuffled.

    Args:
        urls: A string or list of URLs.
        shuffle: Whether to shuffle the URLs.

    Yields:
        Dictionary containing the URL.
    """
    if isinstance(urls, str):
        urls = braceexpand.braceexpand(urls)
    else:
        urls = list(urls)
    if shuffle:
        random.shuffle(urls)
    for url in urls:
        yield dict(url=url)


def url_opener(
    data: Iterable[Dict[str, Any]],
    handler: Callable[[Exception], bool] = reraise_exception,
    **kw: Dict[str, Any],
):
    """Open URLs and yield a stream of url+stream pairs.

    Args:
        data: Iterator over dict(url=...).
        handler: Exception handler.
        **kw: Keyword arguments for gopen.gopen.

    Yields:
        A stream of url+stream pairs.
    """
    for sample in data:
        assert isinstance(sample, dict), sample
        assert "url" in sample
        url = sample["url"]
        try:
            stream = gopen.gopen(url, **kw)
            sample.update(stream=stream)
            yield sample
        except Exception as exn:
            exn.args = exn.args + (url,)
            if handler(exn):
                continue
            else:
                break


def tar_file_iterator(
    fileobj: tarfile.TarFile,
    skip_meta: Optional[str] = r"__[^/]*__($|/)",
    handler: Callable[[Exception], bool] = reraise_exception,
    select_files: Optional[Callable[[str], bool]] = None,
    rename_files: Optional[Callable[[str], str]] = None,
) -> Iterator[Dict[str, Any]]:
    """Iterate over tar file, yielding filename, content pairs for the given tar stream.

    Args:
        fileobj: The tar file stream.
        skip_meta: Regexp for keys that are skipped entirely.
        handler: Exception handler.
        select_files: Predicate for selecting files.
        rename_files: Function to rename files.

    Yields:
        A stream of samples.
    """
    stream = tarfile.open(fileobj=fileobj, mode="r|*")
    for tarinfo in stream:
        fname = tarinfo.name
        try:
            if not tarinfo.isreg():
                continue
            if fname is None:
                continue
            if "/" not in fname and fname.startswith(meta_prefix) and fname.endswith(meta_suffix):
                # skipping metadata for now
                continue
            if skip_meta is not None and re.match(skip_meta, fname):
                continue
            if rename_files:
                fname = rename_files(fname)
            if select_files is not None and not select_files(fname):
                continue
            data = stream.extractfile(tarinfo).read()
            result = dict(fname=fname, data=data)
            yield result
            stream.members = []
        except Exception as exn:
            if hasattr(exn, "args") and len(exn.args) > 0:
                exn.args = (str(exn.args[0]) + " @ " + str(fileobj),) + exn.args[1:]
            if handler(exn):
                continue
            else:
                break
    del stream


def tar_file_expander(
    data: Iterable[Dict[str, Any]],
    handler: Callable[[Exception], bool] = reraise_exception,
    select_files: Optional[Callable[[str], bool]] = None,
    rename_files: Optional[Callable[[str], str]] = None,
    eof_value: Optional[Any] = {},
) -> Iterator[Dict[str, Any]]:
    """Expand tar files.

    Args:
        data: Iterator over opened tar file streams.
        handler: Exception handler.
        select_files: Select files from tarfiles by name (permits skipping files).
        rename_files: Function to rename files.
        eof_value: Value to yield at the end of each shard.

    Yields:
        A stream of samples.
    """
    for source in data:
        url = source["url"]
        local_path = source.get("local_path")
        try:
            assert isinstance(source, dict)
            assert "stream" in source
            for sample in tar_file_iterator(
                source["stream"],
                handler=handler,
                select_files=select_files,
                rename_files=rename_files,
            ):
                assert isinstance(sample, dict) and "data" in sample and "fname" in sample
                sample["__url__"] = url
                if local_path is not None:
                    sample["__local_path__"] = local_path
                yield sample
            # we yield an EOF marker at the end of each shard so that
            # samples from different shards don't get mixed up
            if eof_value is not None:
                yield eof_value
        except Exception as exn:
            exn.args = exn.args + (source.get("stream"), source.get("url"))
            if handler(exn):
                continue
            else:
                break


def group_by_keys(
    data: Iterable[Dict[str, Any]],
    keys: Callable[[str], Tuple[str, str]] = base_plus_ext,
    lcase: bool = True,
    suffixes: Optional[Set[str]] = None,
    handler: Callable[[Exception], bool] = reraise_exception,
) -> Iterator[Dict[str, Any]]:
    """Group tarfile contents by keys and yield samples.

    Args:
        data: Iterator over tarfile contents.
        keys: Function that takes a file name and returns a key and a suffix.
        lcase: Whether to lowercase the suffix.
        suffixes: List of suffixes to keep.
        handler: Exception handler.

    Raises:
        ValueError: If there are duplicate file names in the tar file.

    Yields:
        Iterator over samples.
    """
    current_sample = None
    for filesample in data:
        try:
            assert isinstance(filesample, dict)
            if filesample == {}:
                if valid_sample(current_sample):
                    yield current_sample
                current_sample = None
                continue
            fname, value = filesample["fname"], filesample["data"]
            prefix, suffix = keys(fname)
            if trace:
                print(
                    prefix,
                    suffix,
                    current_sample.keys() if isinstance(current_sample, dict) else None,
                )
            if prefix is None:
                continue
            if lcase:
                suffix = suffix.lower()
            if current_sample is None or prefix != current_sample["__key__"]:
                if valid_sample(current_sample):
                    yield current_sample
                current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
            if suffix in current_sample:
                raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
            if suffixes is None or suffix in suffixes:
                current_sample[suffix] = value
            local_path = filesample.get("__local_path__")
            if local_path is not None:
                current_sample["__local_path__"] = local_path
        except Exception as exn:
            exn.args = exn.args + (filesample.get("stream"), filesample.get("url"))
            if handler(exn):
                continue
            else:
                break
    if valid_sample(current_sample):
        yield current_sample


def tarfile_samples(
    src: Iterable[Dict[str, Any]],
    handler: Callable[[Exception], bool] = reraise_exception,
    select_files: Optional[Callable[[str], bool]] = None,
    rename_files: Optional[Callable[[str], str]] = None,
) -> Iterable[Dict[str, Any]]:
    """Generate samples from a stream of tar files.

    Args:
        src: Stream of tar files.
        handler: Exception handler.
        select_files: Function that selects files to be included.
        rename_files: Function to rename files.

    Returns:
        Stream of samples.
    """
    streams = url_opener(src, handler=handler)
    files = tar_file_expander(streams, handler=handler, select_files=select_files, rename_files=rename_files)
    samples = group_by_keys(files, handler=handler)
    return samples


tarfile_to_samples = filters.pipelinefilter(tarfile_samples)
