# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import json
import os
from itertools import combinations
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import numpy as np
import pandas as pd

from nemo.collections.common.parts.preprocessing import manifest, parsers
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.utils import logging, logging_mode


class _Collection(collections.UserList):
    """List of parsed and preprocessed data."""

    OUTPUT_TYPE = None  # Single element output type.


class Text(_Collection):
    """Simple list of preprocessed text entries, result in list of tokens."""

    OUTPUT_TYPE = collections.namedtuple('TextEntity', 'tokens')

    def __init__(self, texts: List[str], parser: parsers.CharParser):
        """Instantiates text manifest and do the preprocessing step.

        Args:
            texts: List of raw texts strings.
            parser: Instance of `CharParser` to convert string to tokens.
        """

        data, output_type = [], self.OUTPUT_TYPE
        for text in texts:
            tokens = parser(text)

            if tokens is None:
                logging.warning("Fail to parse '%s' text line.", text)
                continue

            data.append(output_type(tokens))

        super().__init__(data)


class FromFileText(Text):
    """Another form of texts manifest with reading from file."""

    def __init__(self, file: str, parser: parsers.CharParser):
        """Instantiates text manifest and do the preprocessing step.

        Args:
            file: File path to read from.
            parser: Instance of `CharParser` to convert string to tokens.
        """

        texts = self.__parse_texts(file)

        super().__init__(texts, parser)

    @staticmethod
    def __parse_texts(file: str) -> List[str]:
        if not os.path.exists(file):
            raise ValueError('Provided texts file does not exists!')

        _, ext = os.path.splitext(file)
        if ext == '.csv':
            texts = pd.read_csv(file)['transcript'].tolist()
        elif ext == '.json':  # Not really a correct json.
            texts = list(item['text'] for item in manifest.item_iter(file))
        else:
            with open(file, 'r') as f:
                texts = f.readlines()

        return texts


class AudioText(_Collection):
    """List of audio-transcript text correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='AudioTextEntity',
        field_names='id audio_file duration text_tokens offset text_raw speaker orig_sr lang',
    )

    def __init__(
        self,
        ids: List[int],
        audio_files: List[str],
        durations: List[float],
        texts: List[str],
        offsets: List[str],
        speakers: List[Optional[int]],
        orig_sampling_rates: List[Optional[int]],
        token_labels: List[Optional[int]],
        langs: List[Optional[str]],
        parser: parsers.CharParser,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates audio-text manifest with filters and preprocessing.

        Args:
            ids: List of examples positions.
            audio_files: List of audio files.
            durations: List of float durations.
            texts: List of raw text transcripts.
            offsets: List of duration offsets or None.
            speakers: List of optional speakers ids.
            orig_sampling_rates: List of original sampling rates of audio files.
            langs: List of language ids, one for eadh sample, or None.
            parser: Instance of `CharParser` to convert string to tokens.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        output_type = self.OUTPUT_TYPE
        all_has_duration = True
        data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
        if index_by_file_id:
            self.mapping = {}

        for id_, audio_file, duration, offset, text, speaker, orig_sr, token_labels, lang in zip(
            ids, audio_files, durations, offsets, texts, speakers, orig_sampling_rates, token_labels, langs
        ):
            if duration is None:
                all_has_duration = False
            # Duration filters.
            if duration is not None and min_duration is not None and duration < min_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if duration is not None and max_duration is not None and duration > max_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if token_labels is not None:
                text_tokens = token_labels
            else:
                if text != '':
                    if hasattr(parser, "is_aggregate") and parser.is_aggregate and isinstance(text, str):
                        if lang is not None:
                            text_tokens = parser(text, lang)
                        # for future use if want to add language bypass to audio_to_text classes
                        # elif hasattr(parser, "lang") and parser.lang is not None:
                        #    text_tokens = parser(text, parser.lang)
                        else:
                            raise ValueError("lang required in manifest when using aggregate tokenizers")
                    else:
                        text_tokens = parser(text)
                else:
                    text_tokens = []

                if text_tokens is None:
                    duration_filtered += duration
                    num_filtered += 1
                    continue

            total_duration += duration if duration is not None else 0.0

            data.append(output_type(id_, audio_file, duration, text_tokens, offset, text, speaker, orig_sr, lang))
            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                if file_id not in self.mapping:
                    self.mapping[file_id] = []
                self.mapping[file_id].append(len(data) - 1)

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
        logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)
        if not all_has_duration:
            logging.info("Not all audios have duration information, the total number of hours is inaccurate.")
        super().__init__(data)


class VideoText(_Collection):
    """List of video-transcript text correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='AudioTextEntity',
        field_names='id video_file duration text_tokens offset text_raw speaker orig_sr lang',
    )

    def __init__(
        self,
        ids: List[int],
        video_files: List[str],
        durations: List[float],
        texts: List[str],
        offsets: List[str],
        speakers: List[Optional[int]],
        orig_sampling_rates: List[Optional[int]],
        token_labels: List[Optional[int]],
        langs: List[Optional[str]],
        parser: parsers.CharParser,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates video-text manifest with filters and preprocessing.

        Args:
            ids: List of examples positions.
            video_files: List of video files.
            durations: List of float durations.
            texts: List of raw text transcripts.
            offsets: List of duration offsets or None.
            speakers: List of optional speakers ids.
            orig_sampling_rates: List of original sampling rates of audio files.
            langs: List of language ids, one for eadh sample, or None.
            parser: Instance of `CharParser` to convert string to tokens.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        output_type = self.OUTPUT_TYPE
        data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
        if index_by_file_id:
            self.mapping = {}

        for id_, video_file, duration, offset, text, speaker, orig_sr, token_labels, lang in zip(
            ids, video_files, durations, offsets, texts, speakers, orig_sampling_rates, token_labels, langs
        ):
            # Duration filters.
            if min_duration is not None and duration < min_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if max_duration is not None and duration > max_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if token_labels is not None:
                text_tokens = token_labels
            else:
                if text != '':
                    if hasattr(parser, "is_aggregate") and parser.is_aggregate and isinstance(text, str):
                        if lang is not None:
                            text_tokens = parser(text, lang)
                        else:
                            raise ValueError("lang required in manifest when using aggregate tokenizers")
                    else:
                        text_tokens = parser(text)
                else:
                    text_tokens = []

                if text_tokens is None:
                    duration_filtered += duration
                    num_filtered += 1
                    continue

            total_duration += duration

            data.append(output_type(id_, video_file, duration, text_tokens, offset, text, speaker, orig_sr, lang))
            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(video_file))
                if file_id not in self.mapping:
                    self.mapping[file_id] = []
                self.mapping[file_id].append(len(data) - 1)

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
        logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

        super().__init__(data)


class ASRAudioText(AudioText):
    """`AudioText` collector from asr structured json files."""

    def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[Callable] = None, *args, **kwargs):
        """Parse lists of audio files, durations and transcripts texts.

        Args:
            manifests_files: Either single string file or list of such -
                manifests to yield items from.
            *args: Args to pass to `AudioText` constructor.
            **kwargs: Kwargs to pass to `AudioText` constructor.
        """

        (
            ids,
            audio_files,
            durations,
            texts,
            offsets,
        ) = (
            [],
            [],
            [],
            [],
            [],
        )

        speakers, orig_srs, token_labels, langs = [], [], [], []
        for item in manifest.item_iter(manifests_files, parse_func=parse_func):
            ids.append(item['id'])
            audio_files.append(item['audio_file'])
            durations.append(item['duration'])
            texts.append(item['text'])
            offsets.append(item['offset'])
            speakers.append(item['speaker'])
            orig_srs.append(item['orig_sr'])
            token_labels.append(item['token_labels'])
            langs.append(item['lang'])
        super().__init__(
            ids, audio_files, durations, texts, offsets, speakers, orig_srs, token_labels, langs, *args, **kwargs
        )


class SpeechLLMAudioTextEntity(object):
    """Class for SpeechLLM dataloader instance."""

    def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None:
        """Initialize the AudioTextEntity for a SpeechLLM dataloader instance."""
        self.id = sid
        self.audio_file = audio_file
        self.duration = duration
        self.context = context
        self.answer = answer
        self.offset = offset
        self.speaker = speaker
        self.orig_sr = orig_sr
        self.lang = lang


class ASRVideoText(VideoText):
    """`VideoText` collector from cv structured json files."""

    def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs):
        """Parse lists of video files, durations and transcripts texts.

        Args:
            manifests_files: Either single string file or list of such -
                manifests to yield items from.
            *args: Args to pass to `VideoText` constructor.
            **kwargs: Kwargs to pass to `VideoText` constructor.
        """

        (
            ids,
            video_files,
            durations,
            texts,
            offsets,
        ) = (
            [],
            [],
            [],
            [],
            [],
        )
        speakers, orig_srs, token_labels, langs = [], [], [], []
        for item in manifest.item_iter(manifests_files):
            ids.append(item['id'])
            video_files.append(item['video_file'])
            durations.append(item['duration'])
            texts.append(item['text'])
            offsets.append(item['offset'])
            speakers.append(item['speaker'])
            orig_srs.append(item['orig_sr'])
            token_labels.append(item['token_labels'])
            langs.append(item['lang'])
        super().__init__(
            ids, video_files, durations, texts, offsets, speakers, orig_srs, token_labels, langs, *args, **kwargs
        )


class SpeechLLMAudioText(object):
    """List of audio-transcript text correspondence with preprocessing.

    All of the audio, duration, context, answer are optional.
    If answer is not present, text is treated as the answer.
    """

    def __init__(
        self,
        ids: List[int],
        audio_files: List[str],
        durations: List[float],
        context_list: List[str],
        answers: List[str],
        offsets: List[str],
        speakers: List[Optional[int]],
        orig_sampling_rates: List[Optional[int]],
        langs: List[Optional[str]],
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
        max_num_samples: Optional[int] = None,
    ):
        """Instantiates audio-context-answer manifest with filters and preprocessing.

        Args:
            ids: List of examples positions.
            audio_files: List of audio files.
            durations: List of float durations.
            context_list: List of raw text transcripts.
            answers: List of raw text transcripts.
            offsets: List of duration offsets or None.
            speakers: List of optional speakers ids.
            orig_sampling_rates: List of original sampling rates of audio files.
            langs: List of language ids, one for eadh sample, or None.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
        if index_by_file_id:
            self.mapping = {}

        for id_, audio_file, duration, offset, context, answer, speaker, orig_sr, lang in zip(
            ids, audio_files, durations, offsets, context_list, answers, speakers, orig_sampling_rates, langs
        ):
            # Duration filters.
            if duration is not None:
                curr_min_dur = min(duration) if isinstance(duration, list) else duration
                curr_max_dur = max(duration) if isinstance(duration, list) else duration
                curr_sum_dur = sum(duration) if isinstance(duration, list) else duration
                if min_duration is not None and curr_min_dur < min_duration:
                    duration_filtered += curr_sum_dur
                    num_filtered += 1
                    continue

                if max_duration is not None and curr_max_dur > max_duration:
                    duration_filtered += curr_sum_dur
                    num_filtered += 1
                    continue
                total_duration += curr_sum_dur

            if answer is None:
                duration_filtered += curr_sum_dur
                num_filtered += 1
                continue

            data.append(
                SpeechLLMAudioTextEntity(id_, audio_file, duration, context, answer, offset, speaker, orig_sr, lang)
            )
            if index_by_file_id and audio_file is not None:
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                if file_id not in self.mapping:
                    self.mapping[file_id] = []
                self.mapping[file_id].append(len(data) - 1)

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if max_num_samples is not None and not index_by_file_id:
            if max_num_samples <= len(data):
                logging.info(f"Subsampling dataset from {len(data)} to {max_num_samples} samples")
                data = data[:max_num_samples]
            else:
                logging.info(f"Oversampling dataset from {len(data)} to {max_num_samples} samples")
                data = data * (max_num_samples // len(data))
                res_num = max_num_samples % len(data)
                res_data = [data[idx] for idx in np.random.choice(len(data), res_num, replace=False)]
                data.extend(res_data)
        elif max_num_samples is not None and index_by_file_id:
            logging.warning("Tried to subsample dataset by max_num_samples, but cannot since index_by_file_id is set.")

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
        logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

        self.data = data

    def __getitem__(self, idx):
        if idx < 0 or idx > len(self.data):
            raise ValueError(f"index out of range [0,{len(self.data)}), got {idx} instead")
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class SpeechLLMAudioTextCollection(SpeechLLMAudioText):
    """`SpeechLLMAudioText` collector from SpeechLLM json files.

    This collector also keeps backward compatibility with SpeechLLMAudioText.
    """

    def __init__(
        self,
        manifests_files: Union[str, List[str]],
        context_file: Optional[Union[List[str], str]] = None,
        context_key: str = "context",
        answer_key: str = "answer",
        *args,
        **kwargs,
    ):
        """Parse lists of audio files, durations and transcripts texts.

        Args:
            manifests_files: Either single string file or list of such -
                manifests to yield items from.
            *args: Args to pass to `AudioText` constructor.
            **kwargs: Kwargs to pass to `AudioText` constructor.
        """
        self.context_key = context_key
        self.answer_key = answer_key

        (
            ids,
            audio_files,
            durations,
            context_list,
            answers,
            offsets,
        ) = (
            [],
            [],
            [],
            [],
            [],
            [],
        )
        speakers, orig_srs, langs = (
            [],
            [],
            [],
        )
        if context_file is not None:
            question_file_list = context_file.split(",") if isinstance(context_file, str) else context_file
            self.context_list = []
            for filepath in question_file_list:
                with open(filepath, 'r') as f:
                    for line in f.readlines():
                        line = line.strip()
                        if line:
                            self.context_list.append(line)
            logging.info(f"Use random text context from {context_file} for {manifests_files}")
        else:
            self.context_list = None

        for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item):
            ids.append(item['id'])
            audio_files.append(item['audio_file'])
            durations.append(item['duration'])
            context_list.append(item['context'])
            answers.append(item['answer'])
            offsets.append(item['offset'])
            speakers.append(item['speaker'])
            orig_srs.append(item['orig_sr'])
            langs.append(item['lang'])
        super().__init__(
            ids, audio_files, durations, context_list, answers, offsets, speakers, orig_srs, langs, *args, **kwargs
        )

    def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
        item = json.loads(line)

        # Audio file
        if 'audio_filename' in item:
            item['audio_file'] = item.pop('audio_filename')
        elif 'audio_filepath' in item:
            item['audio_file'] = item.pop('audio_filepath')
        elif 'audio_file' not in item:
            item['audio_file'] = None

        # If the audio path is a relative path and does not exist,
        # try to attach the parent directory of manifest to the audio path.
        # Revert to the original path if the new path still doesn't exist.
        # Assume that the audio path is like "wavs/xxxxxx.wav".
        if item['audio_file'] is not None:
            item['audio_file'] = manifest.get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file)

        # Duration.
        if 'duration' not in item:
            item['duration'] = None

        # Answer.
        if self.answer_key in item:
            item['answer'] = item.pop(self.answer_key)
        elif 'text' in item:
            # compatability with ASR manifests that uses 'text' as answer key
            item['answer'] = item.pop('text')
        elif 'text_filepath' in item:
            with open(item.pop('text_filepath'), 'r') as f:
                item['answer'] = f.read()
        else:
            item['answer'] = "na"

        # context.
        if self.context_key in item:
            item['context'] = item.pop(self.context_key)
        elif 'context_filepath' in item:
            with open(item.pop('context_filepath'), 'r') as f:
                item['context'] = f.read()
        elif self.context_list is not None:
            context = np.random.choice(self.context_list).strip()
            item['context'] = context
        elif 'question' in item:
            # compatability with old manifests that uses 'question' as context key
            logging.warning(
                f"Neither `{self.context_key}` is found nor"
                f"`context_file` is set, but found `question` in item: {item}",
                mode=logging_mode.ONCE,
            )
            item['context'] = item.pop('question')
        else:
            # default context if nothing is found
            item['context'] = "what does this audio mean"

        item = dict(
            audio_file=item['audio_file'],
            duration=item['duration'],
            context=str(item['context']),
            answer=str(item['answer']),
            offset=item.get('offset', None),
            speaker=item.get('speaker', None),
            orig_sr=item.get('orig_sample_rate', None),
            lang=item.get('lang', None),
        )
        return item


class SpeechLabel(_Collection):
    """List of audio-label correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='SpeechLabelEntity',
        field_names='audio_file duration label offset',
    )

    def __init__(
        self,
        audio_files: List[str],
        durations: List[float],
        labels: List[Union[int, str]],
        offsets: List[Optional[float]],
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates audio-label manifest with filters and preprocessing.

        Args:
            audio_files: List of audio files.
            durations: List of float durations.
            labels: List of labels.
            offsets: List of offsets or None.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        if index_by_file_id:
            self.mapping = {}
        output_type = self.OUTPUT_TYPE
        data, duration_filtered = [], 0.0
        total_duration = 0.0
        duration_undefined = True

        for audio_file, duration, command, offset in zip(audio_files, durations, labels, offsets):
            # Duration filters.
            if duration is not None and min_duration is not None and duration < min_duration:
                duration_filtered += duration
                continue

            if duration is not None and max_duration is not None and duration > max_duration:
                duration_filtered += duration
                continue

            data.append(output_type(audio_file, duration, command, offset))

            if duration is not None:
                total_duration += duration
                duration_undefined = False

            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                self.mapping[file_id] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        if duration_undefined:
            logging.info(f"Dataset loaded with {len(data)} items. The durations were not provided.")
        else:
            logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.")
            logging.info(
                f"Dataset successfully loaded with {len(data)} items "
                f"and total duration provided from manifest is {total_duration / 3600: .2f} hours."
            )

        self.uniq_labels = sorted(set(map(lambda x: x.label, data)))
        logging.info("# {} files loaded accounting to # {} labels".format(len(data), len(self.uniq_labels)))

        super().__init__(data)


class ASRSpeechLabel(SpeechLabel):
    """`SpeechLabel` collector from structured json files."""

    def __init__(
        self,
        manifests_files: Union[str, List[str]],
        is_regression_task=False,
        cal_labels_occurrence=False,
        delimiter=None,
        *args,
        **kwargs,
    ):
        """Parse lists of audio files, durations and transcripts texts.

        Args:
            manifests_files: Either single string file or list of such -
                manifests to yield items from.
            is_regression_task: It's a regression task.
            cal_labels_occurrence: whether to calculate occurence of labels.
            delimiter: separator for labels strings.
            *args: Args to pass to `SpeechLabel` constructor.
            **kwargs: Kwargs to pass to `SpeechLabel` constructor.
        """
        audio_files, durations, labels, offsets = [], [], [], []
        all_labels = []
        for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item):
            audio_files.append(item['audio_file'])
            durations.append(item['duration'])
            if not is_regression_task:
                label = item['label']
                label_list = label.split() if not delimiter else label.split(delimiter)
            else:
                label = float(item['label'])
                label_list = [label]

            labels.append(label)
            offsets.append(item['offset'])
            all_labels.extend(label_list)
        if cal_labels_occurrence:
            self.labels_occurrence = collections.Counter(all_labels)

        super().__init__(audio_files, durations, labels, offsets, *args, **kwargs)

    def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
        item = json.loads(line)

        # Audio file
        if 'audio_filename' in item:
            item['audio_file'] = item.pop('audio_filename')
        elif 'audio_filepath' in item:
            item['audio_file'] = item.pop('audio_filepath')
        else:
            raise ValueError(f"Manifest file has invalid json line structure: {line} without proper audio file key.")
        item['audio_file'] = manifest.get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file)

        # Duration.
        if 'duration' not in item:
            raise ValueError(f"Manifest file has invalid json line structure: {line} without proper duration key.")

        # Label.
        if 'command' in item:
            item['label'] = item.pop('command')
        elif 'target' in item:
            item['label'] = item.pop('target')
        elif 'label' in item:
            pass
        else:
            raise ValueError(f"Manifest file has invalid json line structure: {line} without proper label key.")

        item = dict(
            audio_file=item['audio_file'],
            duration=item['duration'],
            label=item['label'],
            offset=item.get('offset', None),
        )

        return item


class FeatureSequenceLabel(_Collection):
    """List of feature sequence of label correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='FeatureSequenceLabelEntity',
        field_names='feature_file seq_label',
    )

    def __init__(
        self,
        feature_files: List[str],
        seq_labels: List[str],
        max_number: Optional[int] = None,
        index_by_file_id: bool = False,
    ):
        """Instantiates feature-SequenceLabel manifest with filters and preprocessing.

        Args:
            feature_files: List of feature files.
            seq_labels: List of sequences of labels.
            max_number: Maximum number of samples to collect.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        output_type = self.OUTPUT_TYPE
        data, num_filtered = (
            [],
            0.0,
        )
        self.uniq_labels = set()

        if index_by_file_id:
            self.mapping = {}

        for feature_file, seq_label in zip(feature_files, seq_labels):

            label_tokens, uniq_labels_in_seq = self.relative_speaker_parser(seq_label)

            data.append(output_type(feature_file, label_tokens))
            self.uniq_labels |= uniq_labels_in_seq

            if label_tokens is None:
                num_filtered += 1
                continue

            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(feature_file))
                self.mapping[feature_file] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        logging.info(f"# {len(data)} files loaded including # {len(self.uniq_labels)} unique labels")
        super().__init__(data)

    def relative_speaker_parser(self, seq_label):
        """Convert sequence of speaker labels to relative labels.
        Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2]
        In this seq of label , if label do not appear before, assign new relative labels len(pos);
        else reuse previous assigned relative labels.

        Args:
            seq_label (str): A string of a sequence of labels.

        Return:
            relative_seq_label (List) : A list of relative sequence of labels
            unique_labels_in_seq (Set): A set of unique labels in the sequence
        """
        seq = seq_label.split()
        conversion_dict = dict()
        relative_seq_label = []

        for seg in seq:
            if seg in conversion_dict:
                converted = conversion_dict[seg]
            else:
                converted = len(conversion_dict)
                conversion_dict[seg] = converted

            relative_seq_label.append(converted)

        unique_labels_in_seq = set(conversion_dict.keys())
        return relative_seq_label, unique_labels_in_seq


class ASRFeatureSequenceLabel(FeatureSequenceLabel):
    """`FeatureSequenceLabel` collector from asr structured json files."""

    def __init__(
        self,
        manifests_files: Union[str, List[str]],
        max_number: Optional[int] = None,
        index_by_file_id: bool = False,
    ):
        """Parse lists of feature files and sequences of labels.

        Args:
            manifests_files:
                Either single string file or list of such manifests to yield items from.
            max_number:
                Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor.
            index_by_file_id:
                If True, saves a mapping from filename base (ID) to index in data;
                pass to `FeatureSequenceLabel` constructor.
        """

        feature_files, seq_labels = [], []
        for item in manifest.item_iter(manifests_files, parse_func=self._parse_item):
            feature_files.append(item['feature_file'])
            seq_labels.append(item['seq_label'])

        super().__init__(feature_files, seq_labels, max_number, index_by_file_id)

    def _parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
        item = json.loads(line)

        # Feature file
        if 'feature_filename' in item:
            item['feature_file'] = item.pop('feature_filename')
        elif 'feature_filepath' in item:
            item['feature_file'] = item.pop('feature_filepath')
        else:
            raise ValueError(
                f"Manifest file has invalid json line " f"structure: {line} without proper feature file key."
            )
        item['feature_file'] = os.path.expanduser(item['feature_file'])

        # Seq of Label.
        if 'seq_label' in item:
            item['seq_label'] = item.pop('seq_label')
        else:
            raise ValueError(
                f"Manifest file has invalid json line " f"structure: {line} without proper seq_label key."
            )

        item = dict(
            feature_file=item['feature_file'],
            seq_label=item['seq_label'],
        )

        return item


class DiarizationLabel(_Collection):
    """List of diarization audio-label correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='DiarizationLabelEntity',
        field_names='audio_file duration rttm_file offset target_spks sess_spk_dict clus_spk_digits rttm_spk_digits',
    )

    def __init__(
        self,
        audio_files: List[str],
        durations: List[float],
        rttm_files: List[str],
        offsets: List[float],
        target_spks_list: List[tuple],
        sess_spk_dicts: List[Dict],
        clus_spk_list: List[tuple],
        rttm_spk_list: List[tuple],
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates audio-label manifest with filters and preprocessing.

        Args:
            audio_files:
                List of audio file paths.
            durations:
                List of float durations.
            rttm_files:
                List of RTTM files (Groundtruth diarization annotation file).
            offsets:
                List of offsets or None.
            target_spks (tuple):
                List of tuples containing the two indices of targeted speakers for evaluation.
                Example: [[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], [(0, 1), (1, 2), (0, 2)], ...]
            sess_spk_dict (Dict):
                List of Mapping dictionaries between RTTM speakers and speaker labels in the clustering result.
            clus_spk_digits (tuple):
                List of Tuple containing all the speaker indices from the clustering result.
                Example: [(0, 1, 2, 3), (0, 1, 2), ...]
            rttm_spkr_digits (tuple):
                List of tuple containing all the speaker indices in the RTTM file.
                Example: (0, 1, 2), (0, 1), ...]
            max_number: Maximum number of samples to collect
            do_sort_by_duration: True if sort samples list by duration
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        if index_by_file_id:
            self.mapping = {}
        output_type = self.OUTPUT_TYPE
        data, duration_filtered = [], 0.0

        zipped_items = zip(
            audio_files, durations, rttm_files, offsets, target_spks_list, sess_spk_dicts, clus_spk_list, rttm_spk_list
        )
        for (
            audio_file,
            duration,
            rttm_file,
            offset,
            target_spks,
            sess_spk_dict,
            clus_spk_digits,
            rttm_spk_digits,
        ) in zipped_items:

            if duration is None:
                duration = 0

            data.append(
                output_type(
                    audio_file,
                    duration,
                    rttm_file,
                    offset,
                    target_spks,
                    sess_spk_dict,
                    clus_spk_digits,
                    rttm_spk_digits,
                )
            )

            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                self.mapping[file_id] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info(
            "Filtered duration for loading collection is %f.",
            duration_filtered,
        )
        logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips")

        super().__init__(data)


class DiarizationSpeechLabel(DiarizationLabel):
    """`DiarizationLabel` diarization data sample collector from structured json files."""

    def __init__(
        self,
        manifests_files: Union[str, List[str]],
        emb_dict: Dict,
        clus_label_dict: Dict,
        round_digits: int = 2,
        seq_eval_mode=False,
        pairwise_infer=False,
        *args,
        **kwargs,
    ):
        """
        Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization
        model infers only two speakers, speaker pairs are generated from the total number of speakers in
        the session.

        Args:
            manifest_filepath (str):
                Path to input manifest JSON files.
            emb_dict (Dict):
                Dictionary containing cluster-average embeddings and speaker mapping information.
            clus_label_dict (Dict):
                Segment-level speaker labels from clustering results.
            round_digit (int):
                Number of digits to round.
            seq_eval_mode (bool):
                If True, F1 score will be calculated for each speaker pair during inference mode.
            pairwise_infer (bool):
                If True, this dataset class operates in inference mode. In inference mode, a set of
                speakers in the input audio is split into multiple pairs of speakers and speaker tuples
                (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to
                merge the individual results.
            *args: Args to pass to `SpeechLabel` constructor.
            **kwargs: Kwargs to pass to `SpeechLabel` constructor.
        """
        self.round_digits = round_digits
        self.emb_dict = emb_dict
        self.clus_label_dict = clus_label_dict
        self.seq_eval_mode = seq_eval_mode
        self.pairwise_infer = pairwise_infer
        audio_files, durations, rttm_files, offsets, target_spks_list, sess_spk_dicts, clus_spk_list, rttm_spk_list = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )

        for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm):
            # Inference mode
            if self.pairwise_infer:
                clus_speaker_digits = sorted(list(set([x[2] for x in clus_label_dict[item['uniq_id']]])))
                if item['rttm_file']:
                    base_scale_index = max(self.emb_dict.keys())
                    _sess_spk_dict = self.emb_dict[base_scale_index][item['uniq_id']]['mapping']
                    sess_spk_dict = {int(v.split('_')[-1]): k for k, v in _sess_spk_dict.items()}
                    rttm_speaker_digits = [int(v.split('_')[1]) for k, v in _sess_spk_dict.items()]
                    if self.seq_eval_mode:
                        clus_speaker_digits = rttm_speaker_digits
                else:
                    sess_spk_dict = None
                    rttm_speaker_digits = None

            # Training mode
            else:
                rttm_labels = []
                with open(item['rttm_file'], 'r') as f:
                    for line in f.readlines():
                        start, end, speaker = self.split_rttm_line(line, decimals=3)
                        rttm_labels.append('{} {} {}'.format(start, end, speaker))
                speaker_set = set()
                for rttm_line in rttm_labels:
                    spk_str = rttm_line.split()[-1]
                    speaker_set.add(spk_str)
                speaker_list = sorted(list(speaker_set))
                sess_spk_dict = {key: val for key, val in enumerate(speaker_list)}
                target_spks = tuple(sess_spk_dict.keys())
                clus_speaker_digits = target_spks
                rttm_speaker_digits = target_spks

            if len(clus_speaker_digits) <= 2:
                spk_comb_list = [(0, 1)]
            else:
                spk_comb_list = [x for x in combinations(clus_speaker_digits, 2)]

            for target_spks in spk_comb_list:
                audio_files.append(item['audio_file'])
                durations.append(item['duration'])
                rttm_files.append(item['rttm_file'])
                offsets.append(item['offset'])
                target_spks_list.append(target_spks)
                sess_spk_dicts.append(sess_spk_dict)
                clus_spk_list.append(clus_speaker_digits)
                rttm_spk_list.append(rttm_speaker_digits)

        super().__init__(
            audio_files,
            durations,
            rttm_files,
            offsets,
            target_spks_list,
            sess_spk_dicts,
            clus_spk_list,
            rttm_spk_list,
            *args,
            **kwargs,
        )

    def split_rttm_line(self, rttm_line: str, decimals: int = 3):
        """
        Convert a line in RTTM file to speaker label, start and end timestamps.

        An example line of `rttm_line`:
            SPEAKER abc_dev_0123 1 146.903 1.860 <NA> <NA> speaker543 <NA> <NA>

        The above example RTTM line contains the following information:
            session name: abc_dev_0123
            segment start time: 146.903
            segment duration: 1.860
            speaker label: speaker543

        Args:
            rttm_line (str):
                A line in RTTM formatted file containing offset and duration of each segment.
            decimals (int):
                Number of digits to be rounded.

        Returns:
            start (float):
                Start timestamp in floating point number.
            end (float):
                End timestamp in floating point number.
            speaker (str):
                speaker string in RTTM lines.
        """
        rttm = rttm_line.strip().split()
        start = round(float(rttm[3]), decimals)
        end = round(float(rttm[4]), decimals) + round(float(rttm[3]), decimals)
        speaker = rttm[7]
        return start, end, speaker

    def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]:
        """Parse each rttm file and save it to in Dict format"""
        item = json.loads(line)
        if 'audio_filename' in item:
            item['audio_file'] = item.pop('audio_filename')
        elif 'audio_filepath' in item:
            item['audio_file'] = item.pop('audio_filepath')
        else:
            raise ValueError(
                f"Manifest file has invalid json line " f"structure: {line} without proper audio file key."
            )
        item['audio_file'] = os.path.expanduser(item['audio_file'])
        item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0]
        if 'duration' not in item:
            raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.")
        item = dict(
            audio_file=item['audio_file'],
            uniq_id=item['uniq_id'],
            duration=item['duration'],
            rttm_file=item['rttm_filepath'],
            offset=item.get('offset', None),
        )
        return item


class EndtoEndDiarizationLabel(_Collection):
    """List of end-to-end diarization audio-label correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='DiarizationLabelEntity',
        field_names='audio_file uniq_id duration rttm_file offset',
    )

    def __init__(
        self,
        audio_files: List[str],
        uniq_ids: List[str],
        durations: List[float],
        rttm_files: List[str],
        offsets: List[float],
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """
        Instantiates audio-label manifest with filters and preprocessing.

        This method initializes the EndtoEndDiarizationLabel object by processing the input data
        and applying optional filters and sorting.

        Args:
            audio_files (List[str]): List of audio file paths.
            uniq_ids (List[str]): List of unique identifiers for each audio file.
            durations (List[float]): List of float durations for each audio file.
            rttm_files (List[str]): List of RTTM path strings (Groundtruth diarization annotation file).
            offsets (List[float]): List of offsets or None for each audio file.
            max_number (Optional[int]): Maximum number of samples to collect. Defaults to None.
            do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False.
            index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data.
                                     Defaults to False.

        """
        if index_by_file_id:
            self.mapping = {}
        output_type = self.OUTPUT_TYPE
        data, duration_filtered = [], 0.0

        zipped_items = zip(audio_files, uniq_ids, durations, rttm_files, offsets)
        for (
            audio_file,
            uniq_id,
            duration,
            rttm_file,
            offset,
        ) in zipped_items:

            if duration is None:
                duration = 0

            data.append(
                output_type(
                    audio_file,
                    uniq_id,
                    duration,
                    rttm_file,
                    offset,
                )
            )

            if index_by_file_id:
                if isinstance(audio_file, list):
                    if len(audio_file) == 0:
                        raise ValueError(f"Empty audio file list: {audio_file}")
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                self.mapping[file_id] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info(
            "Filtered duration for loading collection is %f.",
            duration_filtered,
        )
        logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips")

        super().__init__(data)


class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel):
    """End-to-end speaker diarization data sample collector from structured json files."""

    def __init__(
        self,
        manifests_files: Union[str, List[str]],
        round_digits: int = 2,
        *args,
        **kwargs,
    ):
        """
        Parse lists of audio files, durations, RTTM (Diarization annotation) files.
        Since diarization model infers only two speakers, speaker pairs are generated
        from the total number of speakers in the session.

        Args:
            manifest_filepath (str):
                Path to input manifest json files.
            round_digit (int):
                Number of digits to be rounded.
            *args: Args to pass to `SpeechLabel` constructor.
            **kwargs: Kwargs to pass to `SpeechLabel` constructor.
        """
        self.round_digits = round_digits
        audio_files, uniq_ids, durations, rttm_files, offsets = (
            [],
            [],
            [],
            [],
            [],
        )

        for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm):
            # Training mode
            audio_files.append(item['audio_file'])
            uniq_ids.append(item['uniq_id'])
            durations.append(item['duration'])
            rttm_files.append(item['rttm_file'])
            offsets.append(item['offset'])

        super().__init__(
            audio_files,
            uniq_ids,
            durations,
            rttm_files,
            offsets,
            *args,
            **kwargs,
        )

    def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]:
        """Parse each rttm file and save it to in Dict format"""
        item = json.loads(line)

        if 'offset' not in item or item['offset'] is None:
            item['offset'] = 0

        # If the name `audio_file` is not present in the manifest file, replace it.
        if 'audio_file' in item:
            pass
        if 'audio_filename' in item:
            item['audio_file'] = item.pop('audio_filename')
        elif 'audio_filepath' in item:
            item['audio_file'] = item.pop('audio_filepath')
        else:
            raise ValueError(
                f"Manifest file has invalid json line " f"structure: {line} without proper audio file key."
            )

        # Audio file handling depending on the types
        if isinstance(item['audio_file'], list):
            audio_file_list = []
            for single_audio_file in item['audio_file']:
                audio_file_list.append(get_full_path(audio_file=single_audio_file, manifest_file=manifest_file))
            item['audio_file'] = audio_file_list
        elif isinstance(item['audio_file'], str):
            item['audio_file'] = get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file)
            if not os.path.exists(item['audio_file']):
                raise FileNotFoundError(f"Audio file not found: {item['audio_file']}")
        else:
            raise ValueError(
                f"Manifest file has invalid json line "
                f"structure: {line} without proper audio file value: {item['audio_file']}."
            )

        # If the name `rttm_file` is not present in the manifest file, replace it or assign None.
        if 'rttm_file' in item:
            pass
        elif 'rttm_filename' in item:
            item['rttm_file'] = item.pop('rttm_filename')
        elif 'rttm_filepath' in item:
            item['rttm_file'] = item.pop('rttm_filepath')
        else:
            item['rttm_file'] = None

        # If item['rttm_file'] is not None and the RTTM file exists, get the full path
        if item['rttm_file'] is not None:
            item['rttm_file'] = get_full_path(audio_file=item['rttm_file'], manifest_file=manifest_file)
            if not os.path.exists(item['rttm_file']):
                raise FileNotFoundError(f"RTTM file not found: {item['rttm_file']}")

        # Handling `uniq_id` string
        if 'uniq_id' not in item:
            item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0]

        if not isinstance(item['uniq_id'], str):
            raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.")

        if 'duration' not in item:
            raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.")

        item = dict(
            audio_file=item['audio_file'],
            uniq_id=item['uniq_id'],
            duration=item['duration'],
            rttm_file=item['rttm_file'],
            offset=item.get('offset', 0),
        )
        return item


class Audio(_Collection):
    """Prepare a list of all audio items, filtered by duration."""

    OUTPUT_TYPE = collections.namedtuple(typename='Audio', field_names='audio_files duration offset text')

    def __init__(
        self,
        audio_files_list: List[Dict[str, str]],
        duration_list: List[float],
        offset_list: List[float],
        text_list: List[str],
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
    ):
        """Instantiantes an list of audio files.

        Args:
            audio_files_list: list of dictionaries with mapping from audio_key to audio_filepath
            duration_list: list of durations of input files
            offset_list: list of offsets
            text_list: list of texts
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration.
        """

        output_type = self.OUTPUT_TYPE
        data, total_duration = [], 0.0
        num_filtered, duration_filtered = 0, 0.0

        for audio_files, duration, offset, text in zip(audio_files_list, duration_list, offset_list, text_list):
            # Duration filters
            if min_duration is not None and duration < min_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if max_duration is not None and duration > max_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            total_duration += duration
            data.append(output_type(audio_files, duration, offset, text))

            # Max number of entities filter
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            data.sort(key=lambda entity: entity.duration)

        logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
        logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

        super().__init__(data)


class AudioCollection(Audio):
    """List of audio files from a manifest file."""

    def __init__(
        self,
        manifest_files: Union[str, List[str]],
        audio_to_manifest_key: Dict[str, str],
        *args,
        **kwargs,
    ):
        """Instantiates a list of audio files loaded from a manifest file.

        Args:
            manifest_files: path to a single manifest file or a list of paths
            audio_to_manifest_key: dictionary mapping audio signals to keys of the manifest
        """
        # Support for comma-separated manifests
        if type(manifest_files) == str:
            manifest_files = manifest_files.split(',')

        for audio_key, manifest_key in audio_to_manifest_key.items():
            # Support for comma-separated keys
            if type(manifest_key) == str and ',' in manifest_key:
                audio_to_manifest_key[audio_key] = manifest_key.split(',')

        # Keys from manifest which contain audio
        self.audio_to_manifest_key = audio_to_manifest_key

        # Initialize data
        audio_files_list, duration_list, offset_list, text_list = [], [], [], []

        # Parse manifest files
        for item in manifest.item_iter(manifest_files, parse_func=self.__parse_item):
            audio_files_list.append(item['audio_files'])
            duration_list.append(item['duration'])
            offset_list.append(item['offset'])
            text_list.append(item['text'])

        super().__init__(audio_files_list, duration_list, offset_list, text_list, *args, **kwargs)

    def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
        """Parse a single line from a manifest file.

        Args:
            line: a string representing a line from a manifest file in JSON format
            manifest_file: path to the manifest file. Used to resolve relative paths.

        Returns:
            Dictionary with audio_files, duration, and offset.
        """

        # Local utility function
        def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]):
            """Get item[key] if key is string, or a list
            of strings by combining item[key[0]], item[key[1]], etc.
            """
            # Prepare audio file(s)
            if manifest_key is None:
                # Support for inference, when a target key is None
                audio_file = None
            elif isinstance(manifest_key, str):
                # Load files from a single manifest key
                audio_file = item[manifest_key]
            elif isinstance(manifest_key, Iterable):
                # Load files from multiple manifest keys
                audio_file = []
                for key in manifest_key:
                    item_key = item[key]
                    if isinstance(item_key, str):
                        audio_file.append(item_key)
                    elif isinstance(item_key, list):
                        audio_file += item_key
                    else:
                        raise ValueError(f'Unexpected type {type(item_key)} of item for key {key}: {item_key}')
            else:
                raise ValueError(f'Unexpected type {type(manifest_key)} of manifest_key: {manifest_key}')

            return audio_file

        # Convert JSON line to a dictionary
        item = json.loads(line)

        # Handle all audio files
        audio_files = {}
        for audio_key, manifest_key in self.audio_to_manifest_key.items():

            audio_file = get_audio_file(item, manifest_key)

            # Get full path to audio file(s)
            if isinstance(audio_file, str):
                # This dictionary entry points to a single file
                audio_files[audio_key] = manifest.get_full_path(audio_file, manifest_file)
            elif isinstance(audio_file, Iterable):
                # This dictionary entry points to multiple files
                # Get the files and keep the list structure for this key
                audio_files[audio_key] = [manifest.get_full_path(f, manifest_file) for f in audio_file]
            elif audio_file is None and audio_key.startswith('target'):
                # For inference, we don't need the target
                audio_files[audio_key] = None
            else:
                raise ValueError(f'Unexpected type {type(audio_file)} of audio_file: {audio_file}')
        item['audio_files'] = audio_files

        # Handle duration
        if 'duration' not in item:
            raise ValueError(f'Duration not available in line: {line}. Manifest file: {manifest_file}')

        # Handle offset
        if 'offset' not in item:
            item['offset'] = 0.0

        # Handle text
        if 'text' not in item:
            item['text'] = None

        return dict(
            audio_files=item['audio_files'], duration=item['duration'], offset=item['offset'], text=item['text']
        )


class FeatureLabel(_Collection):
    """List of feature sequence and their label correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='FeatureLabelEntity',
        field_names='feature_file label duration',
    )

    def __init__(
        self,
        feature_files: List[str],
        labels: List[str],
        durations: List[float],
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates feature-SequenceLabel manifest with filters and preprocessing.

        Args:
            feature_files: List of feature files.
            labels: List of labels.
            max_number: Maximum number of samples to collect.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        output_type = self.OUTPUT_TYPE
        data = []
        duration_filtered = 0.0
        total_duration = 0.0
        self.uniq_labels = set()

        if index_by_file_id:
            self.mapping = {}

        for feature_file, label, duration in zip(feature_files, labels, durations):
            # Duration filters.
            if min_duration is not None and duration < min_duration:
                duration_filtered += duration
                continue

            if max_duration is not None and duration > max_duration:
                duration_filtered += duration
                continue

            data.append(output_type(feature_file, label, duration))
            self.uniq_labels |= set(label)
            total_duration += duration

            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(feature_file))
                self.mapping[file_id] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info(f"Filtered duration for loading collection is {duration_filtered / 2600:.2f} hours.")
        logging.info(f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.")
        logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels)))
        super().__init__(data)


class ASRFeatureLabel(FeatureLabel):
    """`FeatureLabel` collector from asr structured json files."""

    def __init__(
        self,
        manifests_files: Union[str, List[str]],
        is_regression_task: bool = False,
        cal_labels_occurrence: bool = False,
        delimiter: Optional[str] = None,
        *args,
        **kwargs,
    ):
        """Parse lists of feature files and sequences of labels.

        Args:
            manifests_files: Either single string file or list of such -
                manifests to yield items from.
            max_number:  Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data;
                              pass to `FeatureSequenceLabel` constructor.
        """

        feature_files, labels, durations = [], [], []
        all_labels = []
        for item in manifest.item_iter(manifests_files, parse_func=self._parse_item):
            feature_files.append(item['feature_file'])
            durations.append(item['duration'])

            if not is_regression_task:
                label = item['label']
                label_list = label.split() if not delimiter else label.split(delimiter)
            else:
                label = float(item['label'])
                label_list = [label]

            labels.append(label)
            all_labels.extend(label_list)
        if cal_labels_occurrence:
            self.labels_occurrence = collections.Counter(all_labels)

        super().__init__(feature_files, labels, durations, *args, **kwargs)

    def _parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
        item = json.loads(line)

        # Feature file
        if 'feature_filename' in item:
            item['feature_file'] = item.pop('feature_filename')
        elif 'feature_filepath' in item:
            item['feature_file'] = item.pop('feature_filepath')
        elif 'feature_file' not in item:
            raise ValueError(
                f"Manifest file has invalid json line " f"structure: {line} without proper 'feature_file' key."
            )
        item['feature_file'] = manifest.get_full_path(audio_file=item['feature_file'], manifest_file=manifest_file)

        # Label.
        if 'label' in item:
            item['label'] = item.pop('label')
        else:
            raise ValueError(f"Manifest file has invalid json line structure: {line} without proper 'label' key.")

        item = dict(feature_file=item['feature_file'], label=item['label'], duration=item['duration'])

        return item


class FeatureText(_Collection):
    """List of audio-transcript text correspondence with preprocessing."""

    OUTPUT_TYPE = collections.namedtuple(
        typename='FeatureTextEntity',
        field_names='id feature_file rttm_file duration text_tokens offset text_raw speaker orig_sr lang',
    )

    def __init__(
        self,
        ids: List[int],
        feature_files: List[str],
        rttm_files: List[str],
        durations: List[float],
        texts: List[str],
        offsets: List[str],
        speakers: List[Optional[int]],
        orig_sampling_rates: List[Optional[int]],
        token_labels: List[Optional[int]],
        langs: List[Optional[str]],
        parser: parsers.CharParser,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates feature-text manifest with filters and preprocessing.

        Args:
            ids: List of examples positions.
            feature_files: List of audio feature files.
            rttm_files: List of audio rttm files.
            durations: List of float durations.
            texts: List of raw text transcripts.
            offsets: List of duration offsets or None.
            speakers: List of optional speakers ids.
            orig_sampling_rates: List of original sampling rates of audio files.
            langs: List of language ids, one for eadh sample, or None.
            parser: Instance of `CharParser` to convert string to tokens.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        output_type = self.OUTPUT_TYPE
        data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
        if index_by_file_id:
            self.mapping = {}

        for id_, feat_file, rttm_file, duration, offset, text, speaker, orig_sr, token_labels, lang in zip(
            ids,
            feature_files,
            rttm_files,
            durations,
            offsets,
            texts,
            speakers,
            orig_sampling_rates,
            token_labels,
            langs,
        ):
            # Duration filters.
            if min_duration is not None and duration < min_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if max_duration is not None and duration > max_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if token_labels is not None:
                text_tokens = token_labels
            else:
                if text != '':
                    if hasattr(parser, "is_aggregate") and parser.is_aggregate and isinstance(text, str):
                        if lang is not None:
                            text_tokens = parser(text, lang)
                        else:
                            raise ValueError("lang required in manifest when using aggregate tokenizers")
                    else:
                        text_tokens = parser(text)
                else:
                    text_tokens = []

                if text_tokens is None:
                    duration_filtered += duration
                    num_filtered += 1
                    continue

            total_duration += duration

            data.append(
                output_type(id_, feat_file, rttm_file, duration, text_tokens, offset, text, speaker, orig_sr, lang)
            )
            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(feat_file))
                if file_id not in self.mapping:
                    self.mapping[file_id] = []
                self.mapping[file_id].append(len(data) - 1)

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
        logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

        super().__init__(data)


class ASRFeatureText(FeatureText):
    """`FeatureText` collector from asr structured json files."""

    def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs):
        """Parse lists of audio files, durations and transcripts texts.

        Args:
            manifests_files: Either single string file or list of such -
                manifests to yield items from.
            *args: Args to pass to `AudioText` constructor.
            **kwargs: Kwargs to pass to `AudioText` constructor.
        """

        (
            ids,
            feature_files,
            rttm_files,
            durations,
            texts,
            offsets,
        ) = (
            [],
            [],
            [],
            [],
            [],
            [],
        )
        speakers, orig_srs, token_labels, langs = [], [], [], []
        for item in manifest.item_iter(manifests_files):
            ids.append(item['id'])
            feature_files.append(item['feature_file'])
            rttm_files.append(item['rttm_file'])
            durations.append(item['duration'])
            texts.append(item['text'])
            offsets.append(item['offset'])
            speakers.append(item['speaker'])
            orig_srs.append(item['orig_sr'])
            token_labels.append(item['token_labels'])
            langs.append(item['lang'])

        super().__init__(
            ids,
            feature_files,
            rttm_files,
            durations,
            texts,
            offsets,
            speakers,
            orig_srs,
            token_labels,
            langs,
            *args,
            **kwargs,
        )
