import warnings
from functools import partial
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

from lhotse import fastcopy
from lhotse.array import Array, TemporalArray
from lhotse.audio import Recording
from lhotse.cut import Cut
from lhotse.shar.utils import to_shar_placeholder
from lhotse.shar.writers.array import ArrayTarWriter
from lhotse.shar.writers.audio import AudioTarWriter
from lhotse.shar.writers.cut import JsonlShardWriter
from lhotse.utils import Pathlike, ifnone

WriterName = Literal["wav", "flac", "mp3", "lilcom", "numpy", "jsonl"]
FieldWriterInstance = Union[AudioTarWriter, ArrayTarWriter]
FieldWriter = Type[FieldWriterInstance]


class SharWriter:
    """
    SharWriter writes cuts and their corresponding data into multiple shards,
    also recognized as the Lhotse Shar format.
    Each shard is numbered and represented as a collection of one text manifest and
    one or more binary tarfiles.
    Each tarfile contains a single type of data, e.g., recordings, features, or custom fields.

    The main idea behind Lhotse Shar format is to optimize dataloading with sequential reads,
    while keeping the data composition more flexible than e.g. WebDataset tar archives do.
    To achieve this, Lhotse Shar keeps each data type in a separate archive, along a single
    CutSet JSONL manifest.
    This way, the metadata can be investigated without iterating through the binary data.
    The format also allows iteration over a subset of fields, or extension of existing data
    with new fields.

    The user has to specify which fields should be saved, and what compression to use for each of them.
    Currently we support ``wav``, ``flac``, ``opus``, and ``mp3`` compression for ``recording`` and custom audio fields,
    and ``lilcom`` or ``numpy`` for ``features`` and custom array fields.

    Example::

        >>> cuts = CutSet(...)  # cuts have 'recording' and 'features'
        >>> with SharWriter("some_dir", shard_size=100, fields={"recording": "opus", "features": "lilcom"}) as w:
        ...     for cut in cuts:
        ...         w.write(cut)

    .. note:: Different audio backends in Lhotse may use different encoders for the same audio formats.
        It is advisable to use the same audio backend for saving and loading audio data in Shar and other formats.
        See: :class:`lhotse.audio.recording.Recording`.

    It would create a directory ``some_dir`` with files such as ``some_dir/cuts.000000.jsonl.gz``,
    ``some_dir/recording.000000.tar``, ``some_dir/features.000000.tar``,
    and then the same names but numbered with ``000001``, etc.
    The starting shard offset can be set using ``shard_offset`` parameter. The writer starts from 0 by default.

    When ``shard_size`` is set to ``None``, we will disable automatic sharding and the
    shard number suffix will be omitted from the file names.

    The option ``warn_unused_fields`` will emit a warning when cuts have some data attached to them
    (e.g., recording, features, or custom arrays) but saving it was not specified via ``fields``.

    The option ``include_cuts`` controls whether we store the cuts alongside ``fields`` (true by default).
    Turning it off is useful when extending existing dataset with new fields/feature types,
    but the original cuts do not require any modification.

    See also: :class:`~lhotse.shar.writers.tar.TarWriter`, :class:`~lhotse.shar.writers.audio.AudioTarWriter`,
        :class:`~lhotse.shar.writers.array.ArrayTarWriter`.
    """

    def __init__(
        self,
        output_dir: Pathlike,
        fields: Dict[str, str],
        shard_size: Optional[int] = 1000,
        warn_unused_fields: bool = True,
        include_cuts: bool = True,
        shard_suffix: Optional[str] = None,
        shard_offset: int = 0,
    ) -> None:
        self.output_dir = str(output_dir)
        self.shard_size = shard_size
        self.fields = fields
        self.warn_unused_fields = warn_unused_fields
        self.include_cuts = include_cuts
        if self.sharding_enabled:
            assert (
                shard_suffix is None
            ), f"shard_suffix must be None when shard_size is specified (got: '{shard_suffix}')."
            self.shard_suffix = ".%06d"
        else:
            self.shard_suffix = ifnone(shard_suffix, "")
        self.initial_shard_offset = shard_offset

        self.writers = {}
        if include_cuts:
            self.writers["cuts"] = JsonlShardWriter(
                pattern=_create_cuts_output_url(self.output_dir, self.shard_suffix),
                shard_size=self.shard_size,
                shard_offset=self.initial_shard_offset,
            )
        for field, writer_type in self.fields.items():
            make_writer_fn, ext = resolve_writer(writer_type)
            self.writers[field] = make_writer_fn(
                pattern=f"{self.output_dir}/{field}{self.shard_suffix}{ext}",
                shard_size=self.shard_size,
                shard_offset=self.initial_shard_offset,
            )

    @property
    def sharding_enabled(self) -> bool:
        return self.shard_size is not None and self.shard_size > 0

    @property
    def output_paths(self) -> Dict[str, List[str]]:
        return {k: w.output_paths for k, w in self.writers.items()}

    def __enter__(self):
        for w in self.writers.values():
            w.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def close(self):
        for w in self.writers.values():
            w.close()

    def write(self, cut: Cut) -> None:

        # Handle audio.
        if "recording" in self.fields:
            if cut.has_recording:
                data = cut.load_audio()
                recording = to_shar_placeholder(cut.recording, cut)
                cut_channels = _aslist(cut.channel)
                if recording.channel_ids != cut_channels:
                    # If recording is multi-channel but the cut refers to a subset of them,
                    # we have to update the recording manifest accordingly
                    recording.sources[0].channels = cut_channels
                    recording.channel_ids = cut_channels
                self.writers["recording"].write(
                    cut.id,
                    data,
                    cut.sampling_rate,
                    manifest=recording,
                    original_format=cut.recording.source_format,
                )
                cut = fastcopy(cut, recording=recording)
            else:
                self.writers["recording"].write_placeholder(cut.id)
        elif cut.has_recording and self.warn_unused_fields:
            warnings.warn(
                "Found cut with 'recording' field that is not specified for Shar writing."
            )

        # Handle features.
        if "features" in self.fields:
            if cut.has_features:
                data = cut.load_features()
                features = to_shar_placeholder(cut.features, cut)
                self.writers["features"].write(cut.id, data, manifest=features)
                cut = fastcopy(cut, features=features)
            else:
                self.writers["features"].write_placeholder(cut.id)
        elif cut.has_features and self.warn_unused_fields:
            warnings.warn(
                "Found cut with 'features' field that is not specified for Shar writing."
            )

        # Handle custom data and non-data attributes.
        for key in self.fields:

            # Skip fields already taken care of.
            if key in ["recording", "features"]:
                continue

            # Check if the custom attribute is available: if yes
            if cut.has_custom(key):
                val = getattr(cut, key)
                if not isinstance(val, (Array, TemporalArray, Recording)):
                    assert isinstance(
                        self.writers[key], JsonlShardWriter
                    ), f"Expected writer type 'jsonl' (got '{self.fields[key]}') for non-data field '{key}'."
                    self.writers[key].write({"cut_id": cut.id, key: val})
                else:
                    data = cut.load_custom(key)
                    placeholder_obj = to_shar_placeholder(val, cut)
                    channel_selector_key = f"{key}_channel_selector"
                    kwargs = {}
                    if isinstance(val, Recording):
                        kwargs["sampling_rate"] = val.sampling_rate
                        if cut.has_custom(channel_selector_key):
                            # override custom recording channels since the audio was loaded via cut
                            # and used the channel selector
                            placeholder_obj.sources[0].channels = cut.custom[
                                channel_selector_key
                            ]
                            placeholder_obj.channel_ids = cut.custom[
                                channel_selector_key
                            ]
                    self.writers[key].write(
                        cut.id, data, manifest=placeholder_obj, **kwargs
                    )
                    cut = fastcopy(cut, custom=cut.custom.copy())
                    cut.custom.pop(channel_selector_key, None)  # no longer needed
                    setattr(cut, key, placeholder_obj)
            else:
                self.writers[key].write_placeholder(cut.id)

        # Warn in case there is some data that wasn't requested to be saved.
        for key, val in ifnone(cut.custom, {}).items():
            if not isinstance(val, (Array, TemporalArray, Recording)):
                continue
            if key not in self.fields:
                if self.warn_unused_fields:
                    warnings.warn(
                        f"Found cut with '{key}' field that is not specified for Shar writing."
                    )
                continue

        # We will write only the relevant subset of the binary data alongside cut,
        # so we need to update the offset (start).
        cut = fastcopy(cut, start=0)

        if "cuts" in self.writers:
            self.writers["cuts"].write(cut)


def resolve_writer(name: str) -> Tuple[FieldWriter, str]:
    opts = {
        "wav": (partial(AudioTarWriter, format="wav"), ".tar"),
        "flac": (partial(AudioTarWriter, format="flac"), ".tar"),
        "mp3": (partial(AudioTarWriter, format="mp3"), ".tar"),
        "opus": (partial(AudioTarWriter, format="opus"), ".tar"),
        "original": (partial(AudioTarWriter, format="original"), ".tar"),
        "lilcom": (partial(ArrayTarWriter, compression="lilcom"), ".tar"),
        "numpy": (partial(ArrayTarWriter, compression="numpy"), ".tar"),
        "jsonl": (JsonlShardWriter, ".jsonl.gz"),
    }
    assert (
        name in opts
    ), f"Unknown field type (got: '{name}', we support only: {', '.join(opts)}"
    return opts[name]


def _create_cuts_output_url(base_output_url: str, shard_suffix: str) -> str:

    # special case where we want to ensure the CutSet actually gets gzipped
    if base_output_url.startswith("pipe:"):
        base_output_url = base_output_url.replace("pipe:", "pipe:gzip -c | ")

    return f"{base_output_url}/cuts{shard_suffix}.jsonl.gz"


def _aslist(x):
    if isinstance(x, list):
        return x
    return [x]
