# ! /usr/bin/python
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from math import ceil
from typing import Any, Dict, List, Optional, Union

import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.utils.data import DataLoader

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import (
    ASRBPEMixin,
    ASRModuleMixin,
    ASRTranscriptionMixin,
    TranscribeConfig,
    TranscriptionReturnType,
)
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig
from nemo.collections.asr.parts.utils.slu_utils import SequenceGenerator, SequenceGeneratorConfig, get_seq_mask
from nemo.collections.common.losses import SmoothedNLLLoss
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
from nemo.utils import logging, model_utils

__all__ = ["SLUIntentSlotBPEModel"]


class SLUIntentSlotBPEModel(ASRModel, ExportableEncDecModel, ASRModuleMixin, ASRBPEMixin, ASRTranscriptionMixin):
    """Model for end-to-end speech intent classification and slot filling, which is formulated as a speech-to-sequence task"""

    def __init__(self, cfg: DictConfig, trainer=None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        if 'tokenizer' not in cfg:
            raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !")

        # Setup the tokenizer
        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.preprocessor = self.from_config_dict(self.cfg.preprocessor)
        self.encoder = self.from_config_dict(self.cfg.encoder)
        self.decoder = self.from_config_dict(self.cfg.decoder)

        if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = self.from_config_dict(self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup optional Optimization flags
        self.setup_optimization_flags()

        # Adapter modules setup (from ASRAdapterModelMixin)
        self.setup_adapters()

        self.vocabulary = self.tokenizer.tokenizer.get_vocab()
        vocab_size = len(self.vocabulary)

        # Create embedding layer
        self.cfg.embedding["vocab_size"] = vocab_size
        self.embedding = self.from_config_dict(self.cfg.embedding)

        # Create token classifier
        self.cfg.classifier["num_classes"] = vocab_size
        self.classifier = self.from_config_dict(self.cfg.classifier)

        self.loss = SmoothedNLLLoss(label_smoothing=self.cfg.loss.label_smoothing)

        self.sequence_generator = SequenceGenerator(
            cfg=self.cfg.sequence_generator,
            embedding=self.embedding,
            decoder=self.decoder,
            log_softmax=self.classifier,
            tokenizer=self.tokenizer,
        )

        # Setup decoding objects
        decoding_cfg = self.cfg.get('decoding', None)

        # In case decoding config not found, use default config
        if decoding_cfg is None:
            decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig)
            with open_dict(self.cfg):
                self.cfg.decoding = decoding_cfg

        self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer)

        # Setup metric with decoding strategy
        self.wer = WER(
            decoding=self.decoding,
            use_cer=self._cfg.get('use_cer', False),
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
            fold_consecutive=False,
        )

    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        if hasattr(self.preprocessor, '_sample_rate'):
            input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
        else:
            input_signal_eltype = AudioSignal()
        return {
            "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
            "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
            "target_semantics": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
            "target_semantics_length": NeuralType(tuple('B'), LengthsType(), optional=True),
            "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
            "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
            "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType(), optional=True),
            "lengths": NeuralType(tuple('B'), LengthsType(), optional=True),
            "greedy_predictions": NeuralType(('B', 'T'), LabelsType(), optional=True),
        }

    def set_decoding_strategy(self, cfg: SequenceGeneratorConfig):
        cfg.max_sequence_length = self.sequence_generator.generator.max_seq_length
        self.sequence_generator = SequenceGenerator(cfg, self.embedding, self.decoder, self.classifier, self.tokenizer)

    @typecheck()
    def forward(
        self,
        input_signal=None,
        input_signal_length=None,
        target_semantics=None,
        target_semantics_length=None,
        processed_signal=None,
        processed_signal_length=None,
    ):
        """
        Forward pass of the model.

        Params:
            input_signal: Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents
            timesteps, with 1 second of audio represented as `self.sample_rate` number of floating point values.

            input_signal_length: Vector of length B, that contains the individual lengths of the audio sequences.

            target_semantics: Tensor that represents a batch of semantic tokens, of shape [B, L].

            target_semantics_length: Vector of length B, that contains the individual lengths of the semantic sequences.

            processed_signal: Tensor that represents a batch of processed audio signals, of shape (B, D, T) that has
            undergone processing via some DALI preprocessor.

            processed_signal_length: Vector of length B, that contains the individual lengths of the processed audio
            sequences.

        Returns:
            A tuple of 3 elements -
            1) The log probabilities tensor of shape [B, T, D].
            2) The lengths of the output sequence after decoder, of shape [B].
            3) The token predictions of the model of shape [B, T].
        """
        has_input_signal = input_signal is not None and input_signal_length is not None
        has_processed_signal = processed_signal is not None and processed_signal_length is not None
        if (has_input_signal ^ has_processed_signal) == False:
            raise ValueError(
                f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
                " with ``processed_signal`` and ``processed_signal_len`` arguments."
            )

        if not has_processed_signal:
            processed_signal, processed_signal_length = self.preprocessor(
                input_signal=input_signal,
                length=input_signal_length,
            )

        if self.spec_augmentation is not None and self.training:
            processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

        encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
        encoded = encoded.transpose(1, 2)  # BxDxT -> BxTxD
        encoded_mask = get_seq_mask(encoded, encoded_len)

        if target_semantics is None:  # in inference-only mode
            predictions = self.sequence_generator(encoded, encoded_mask)
            return None, None, predictions

        bos_semantics_tokens = target_semantics[:, :-1]
        bos_semantics = self.embedding(bos_semantics_tokens)
        bos_semantics_mask = get_seq_mask(bos_semantics, target_semantics_length - 1)

        decoded = self.decoder(
            encoder_states=encoded,
            encoder_mask=encoded_mask,
            decoder_states=bos_semantics,
            decoder_mask=bos_semantics_mask,
        )
        log_probs = self.classifier(decoded)

        predictions = log_probs.argmax(dim=-1, keepdim=False)

        pred_len = self.sequence_generator.get_seq_length(predictions)
        return log_probs, pred_len, predictions

    # PTL-specific methods
    def training_step(self, batch, batch_nb):
        if len(batch) == 4:
            signal, signal_len, semantics, semantics_len = batch
        else:
            signal, signal_len, semantics, semantics_len, sample_id = batch

        log_probs, pred_len, predictions = self.forward(
            input_signal=signal,
            input_signal_length=signal_len,
            target_semantics=semantics,
            target_semantics_length=semantics_len,
        )

        eos_semantics = semantics[:, 1:]
        eos_semantics_len = semantics_len - 1  # subtract 1 for eos tokens

        loss_value = self.loss(log_probs=log_probs, labels=eos_semantics, lengths=eos_semantics_len)

        tensorboard_logs = {'train_loss': loss_value.item()}
        if len(self._optimizer.param_groups) == 1:
            tensorboard_logs['learning_rate'] = self._optimizer.param_groups[0]['lr']
        else:
            for i, group in enumerate(self._optimizer.param_groups):
                tensorboard_logs[f'learning_rate_g{i}'] = group['lr']

        if hasattr(self, '_trainer') and self._trainer is not None:
            log_every_n_steps = self._trainer.log_every_n_steps
        else:
            log_every_n_steps = 1

        if (batch_nb + 1) % log_every_n_steps == 0:
            self.wer.update(
                predictions=predictions,
                targets=eos_semantics,
                predictions_lengths=pred_len,
                targets_lengths=eos_semantics_len,
            )
            wer, _, _ = self.wer.compute()
            self.wer.reset()
            tensorboard_logs.update({'training_batch_wer': wer})

        return {'loss': loss_value, 'log': tensorboard_logs}

    def predict(
        self, input_signal, input_signal_length, processed_signal=None, processed_signal_length=None, dataloader_idx=0
    ) -> List[str]:
        has_input_signal = input_signal is not None and input_signal_length is not None
        has_processed_signal = processed_signal is not None and processed_signal_length is not None
        if (has_input_signal ^ has_processed_signal) == False:
            raise ValueError(
                f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
                " with ``processed_signal`` and ``processed_signal_len`` arguments."
            )

        if not has_processed_signal:
            processed_signal, processed_signal_length = self.preprocessor(
                input_signal=input_signal,
                length=input_signal_length,
            )

        if self.spec_augmentation is not None and self.training:
            processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

        encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
        encoded = encoded.transpose(1, 2)  # BxDxT -> BxTxD
        encoded_mask = get_seq_mask(encoded, encoded_len)

        pred_tokens = self.sequence_generator(encoded, encoded_mask)
        predictions = self.sequence_generator.decode_semantics_from_tokens(pred_tokens)
        return predictions

    def validation_pass(self, batch, batch_idx, dataloader_idx=0):
        if len(batch) == 4:
            signal, signal_len, semantics, semantics_len = batch
        else:
            signal, signal_len, semantics, semantics_len, sample_id = batch

        if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
            log_probs, pred_len, predictions = self.forward(
                processed_signal=signal,
                processed_signal_length=signal_len,
                target_semantics=semantics,
                target_semantics_length=semantics_len,
            )
        else:
            log_probs, pred_len, predictions = self.forward(
                input_signal=signal,
                input_signal_length=signal_len,
                target_semantics=semantics,
                target_semantics_length=semantics_len,
            )

        eos_semantics = semantics[:, 1:]
        eos_semantics_len = semantics_len - 1  # subtract 1 for bos&eos tokens

        loss_value = self.loss(log_probs=log_probs, labels=eos_semantics, lengths=eos_semantics_len)

        self.wer.update(
            predictions=predictions,
            targets=eos_semantics,
            predictions_lengths=pred_len,
            targets_lengths=eos_semantics_len,
        )
        wer, wer_num, wer_denom = self.wer.compute()
        self.wer.reset()

        return {
            'val_loss': loss_value,
            'val_wer_num': wer_num,
            'val_wer_denom': wer_denom,
            'val_wer': wer,
        }

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
        if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
            self.validation_step_outputs[dataloader_idx].append(metrics)
        else:
            self.validation_step_outputs.append(metrics)
        return metrics

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
        test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
        if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
            self.test_step_outputs[dataloader_idx].append(test_logs)
        else:
            self.test_step_outputs.append(test_logs)
        return test_logs

    def test_dataloader(self):
        if self._test_dl is None:
            # None dataloader no longer supported in PTL2.0
            self._test_dl = []

        return self._test_dl

    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        shuffle = config['shuffle']
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        if config.get('use_dali', False):
            device_id = self.local_rank if device == 'gpu' else None
            dataset = audio_to_text_dataset.get_dali_bpe_dataset(
                config=config,
                tokenizer=self.tokenizer,
                shuffle=shuffle,
                device_id=device_id,
                global_rank=self.global_rank,
                world_size=self.world_size,
                preprocessor_cfg=self._cfg.preprocessor,
            )
            return dataset

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
                'manifest_filepath' in config and config['manifest_filepath'] is None
            ):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
            dataset = audio_to_text_dataset.get_tarred_dataset(
                config=config,
                tokenizer=self.tokenizer,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
                augmentor=augmentor,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config['manifest_filepath'] is None:
                logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
                return None

            dataset = audio_to_text_dataset.get_bpe_dataset(
                config=config, tokenizer=self.tokenizer, augmentor=augmentor
            )
        if hasattr(dataset, 'collate_fn'):
            collate_fn = dataset.collate_fn
        elif hasattr(dataset.datasets[0], 'collate_fn'):
            # support datasets that are lists of entries
            collate_fn = dataset.datasets[0].collate_fn
        else:
            # support datasets that are lists of lists
            collate_fn = dataset.datasets[0].datasets[0].collate_fn

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )

    def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
        """
        Sets up the training data loader via a Dict-like object.

        Args:
            train_data_config: A config that contains the information regarding construction
                of an ASR Training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
        """
        if 'shuffle' not in train_data_config:
            train_data_config['shuffle'] = True

        # preserve config
        self._update_dataset_config(dataset_name='train', config=train_data_config)

        self._train_dl = self._setup_dataloader_from_config(config=train_data_config)

        # Need to set this because if using an IterableDataset, the length of the dataloader is the total number
        # of samples rather than the number of batches, and this messes up the tqdm progress bar.
        # So we set the number of steps manually (to the correct number) to fix this.
        if (
            self._train_dl is not None
            and hasattr(self._train_dl, 'dataset')
            and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset)
        ):
            # We also need to check if limit_train_batches is already set.
            # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
            # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
            if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float):
                self._trainer.limit_train_batches = int(
                    self._trainer.limit_train_batches
                    * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
                )
            elif self._trainer is None:
                logging.warning(
                    "Model Trainer was not set before constructing the dataset, incorrect number of "
                    "training batches will be used. Please set the trainer and rebuild the dataset."
                )

    def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
        """
        Sets up the validation data loader via a Dict-like object.

        Args:
            val_data_config: A config that contains the information regarding construction
                of an ASR Training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
        """
        if 'shuffle' not in val_data_config:
            val_data_config['shuffle'] = False

        # preserve config
        self._update_dataset_config(dataset_name='validation', config=val_data_config)

        self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
        """
        Sets up the test data loader via a Dict-like object.

        Args:
            test_data_config: A config that contains the information regarding construction
                of an ASR Training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
        """
        if 'shuffle' not in test_data_config:
            test_data_config['shuffle'] = False

        # preserve config
        self._update_dataset_config(dataset_name='test', config=test_data_config)

        self._test_dl = self._setup_dataloader_from_config(config=test_data_config)

    def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a temporary data loader which wraps the provided audio file.

        Args:
            config: A python dictionary which contains the following keys:
            paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \
                Recommended length per file is between 5 and 25 seconds.
            batch_size: (int) batch size to use during inference. \
                Bigger will result in better throughput performance but would use more memory.
            temp_dir: (str) A temporary directory where the audio manifest is temporarily
                stored.
            num_workers: (int) number of workers. Depends of the batch_size and machine. \
                0 - only the main process will load batches, 1 - one worker (not main process)

        Returns:
            A pytorch DataLoader for the given audio file(s).
        """

        if 'manifest_filepath' in config:
            manifest_filepath = config['manifest_filepath']
            batch_size = config['batch_size']
        else:
            manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
            batch_size = min(config['batch_size'], len(config['paths2audio_files']))

        dl_config = {
            'manifest_filepath': manifest_filepath,
            'sample_rate': self.preprocessor._sample_rate,
            'batch_size': batch_size,
            'shuffle': False,
            'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
            'pin_memory': True,
            'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
        }

        temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
        return temporary_datalayer

    @torch.no_grad()
    def transcribe(
        self,
        audio: Union[List[str], DataLoader],
        batch_size: int = 4,
        return_hypotheses: bool = False,
        num_workers: int = 0,
        verbose: bool = True,
    ) -> TranscriptionReturnType:
        """
        Uses greedy decoding to transcribe audio files into SLU semantics. 
        Use this method for debugging and prototyping.

        Args:
            audio: (a single or list) of paths to audio files or a np.ndarray audio array.
                Can also be a dataloader object that provides values that can be consumed by the model.
                Recommended length per file is between 5 and 25 seconds. \
                But it is possible to pass a few hours long file if enough GPU memory is available.
            batch_size: (int) batch size to use during inference.
                Bigger will result in better throughput performance but would use more memory.
            return_hypotheses: (bool) Either return hypotheses or text
                With hypotheses can do some postprocessing like getting timestamp or rescoring
            num_workers: (int) number of workers for DataLoader
            verbose: (bool) whether to display tqdm progress bar

        Returns:
            A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
        """
        return super().transcribe(
            audio=audio,
            batch_size=batch_size,
            return_hypotheses=return_hypotheses,
            num_workers=num_workers,
            verbose=verbose,
        )

    """ Transcription related methods """

    def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
        predictions = self.predict(input_signal=batch[0], input_signal_length=batch[1])
        output = {'predictions': predictions}
        return output

    def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> List[str]:
        hypotheses = outputs.pop('predictions')
        return hypotheses

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        results = []

        model = PretrainedModelInfo(
            pretrained_model_name="slu_conformer_transformer_large_slurp",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:slu_conformer_transformer_large_slurp",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/slu_conformer_transformer_large_slurp/versions/1.13.0/files/slu_conformer_transformer_large_slurp.nemo",
        )
        results.append(model)

    @property
    def wer(self):
        return self._wer

    @wer.setter
    def wer(self, wer):
        self._wer = wer
