"""
Transcription functions
-----------------------

"""
from __future__ import annotations

import dataclasses
import logging
import os
import queue
import threading
import typing
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict

import numpy as np
import sqlalchemy
from _kalpy.fstext import ConstFst, VectorFst
from _kalpy.lat import CompactLatticeWriter
from _kalpy.lm import ConstArpaLm
from _kalpy.util import BaseFloatMatrixWriter, Int32VectorWriter, ReadKaldiObject
from kalpy.data import KaldiMapping, MatrixArchive, Segment
from kalpy.decoder.decode_graph import DecodeGraphCompiler
from kalpy.feat.data import FeatureArchive
from kalpy.feat.fmllr import FmllrComputer
from kalpy.fstext.lexicon import LexiconCompiler
from kalpy.gmm.data import LatticeArchive
from kalpy.gmm.decode import GmmDecoder, GmmRescorer
from kalpy.lm.rescore import LmRescorer
from kalpy.utils import generate_write_specifier
from sqlalchemy.orm import joinedload, subqueryload

from montreal_forced_aligner import config
from montreal_forced_aligner.abc import KaldiFunction, MetaDict
from montreal_forced_aligner.data import Language, MfaArguments, PhoneType
from montreal_forced_aligner.db import File, Job, Phone, SoundFile, Speaker, Utterance
from montreal_forced_aligner.diarization.multiprocessing import UtteranceFileLoader
from montreal_forced_aligner.tokenization.simple import SimpleTokenizer
from montreal_forced_aligner.transcription.models import MfaFasterWhisperPipeline, load_model
from montreal_forced_aligner.utils import thread_logger
from montreal_forced_aligner.vad.models import MfaVAD

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend")
        torch_logger.setLevel(logging.ERROR)
        torch_logger = logging.getLogger("speechbrain.utils.train_logger")
        torch_logger.setLevel(logging.ERROR)
        transformers_logger = logging.getLogger("transformers.modeling_utils")
        transformers_logger.setLevel(logging.ERROR)
        transformers_logger = logging.getLogger(
            "speechbrain.lobes.models.huggingface_transformers.huggingface"
        )
        transformers_logger.setLevel(logging.ERROR)
        import torch

        CUDA_AVAILABLE = torch.cuda.is_available()
        try:
            from speechbrain.pretrained import EncoderASR, WhisperASR
        except (ImportError, ModuleNotFoundError):  # speechbrain 1.0
            from speechbrain.inference.ASR import EncoderASR, WhisperASR
    FOUND_SPEECHBRAIN = True
except (ImportError, OSError):
    FOUND_SPEECHBRAIN = False
    CUDA_AVAILABLE = False
    WhisperASR = None
    EncoderASR = None


__all__ = [
    "FmllrRescoreFunction",
    "FinalFmllrFunction",
    "InitialFmllrFunction",
    "CarpaLmRescoreFunction",
    "DecodeFunction",
    "LmRescoreFunction",
    "CreateHclgFunction",
    "WhisperASR",
    "EncoderASR",
    "SpeechbrainAsrArguments",
    "SpeechbrainAsrCudaArguments",
    "WhisperArguments",
    "WhisperCudaArguments",
    "SpeechbrainAsrFunction",
    "WhisperAsrFunction",
    "FOUND_SPEECHBRAIN",
    "CUDA_AVAILABLE",
]

logger = logging.getLogger("mfa")


@dataclass
class CreateHclgArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Current working directory
    small_arpa_path: :class:`~pathlib.Path`
        Path to small ARPA file
    medium_arpa_path: :class:`~pathlib.Path`
        Path to medium ARPA file
    big_arpa_path: :class:`~pathlib.Path`
        Path to big ARPA file
    model_path: :class:`~pathlib.Path`
        Acoustic model path
    hclg_options: dict[str, Any]
        HCLG options
    """

    lexicon_compiler: LexiconCompiler
    working_directory: Path
    small_arpa_path: Path
    medium_arpa_path: Path
    big_arpa_path: Path
    model_path: Path
    tree_path: Path
    hclg_options: MetaDict


@dataclass
class SpeechbrainAsrArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Current working directory
    """

    working_directory: Path
    architecture: str
    language: Language
    tokenizer: typing.Optional[SimpleTokenizer]


@dataclass
class SpeechbrainAsrCudaArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Current working directory
    """

    working_directory: Path
    model: typing.Union[EncoderASR, WhisperASR]
    tokenizer: typing.Optional[SimpleTokenizer]


@dataclass
class WhisperArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Current working directory
    """

    working_directory: Path
    architecture: str
    language: Language
    tokenizer: typing.Optional[SimpleTokenizer]
    cuda: bool
    export_directory: typing.Optional[Path]


@dataclass
class WhisperCudaArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Current working directory
    """

    working_directory: Path
    model: MfaFasterWhisperPipeline
    tokenizer: typing.Optional[SimpleTokenizer]
    cuda: bool
    export_directory: typing.Optional[Path]


@dataclass
class DecodeArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
    model_path: :class:`~pathlib.Path`
        Path to model file
    decode_options: dict[str, Any]
        Decoding options
    hclg_paths: dict[int, Path]
        Per dictionary HCLG.fst paths
    """

    working_directory: Path
    model_path: Path
    decode_options: MetaDict
    hclg_paths: Dict[int, Path]


@dataclass
class DecodePhoneArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.validation.corpus_validator.DecodePhoneFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
    dictionaries: list[int]
        List of dictionary ids
    feature_strings: dict[int, str]
        Mapping of dictionaries to feature generation strings
    decode_options: dict[str, Any]
        Decoding options
    model_path: :class:`~pathlib.Path`
        Path to model file
    lat_paths: dict[int, Path]
        Per dictionary lattice paths
    phone_symbol_path: :class:`~pathlib.Path`
        Phone symbol table paths
    hclg_path: :class:`~pathlib.Path`
        HCLG.fst paths
    """

    working_directory: Path
    model_path: Path
    hclg_path: Path
    decode_options: MetaDict


@dataclass
class LmRescoreArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
    dictionaries: list[int]
        List of dictionary ids
    lm_rescore_options: dict[str, Any]
        Rescoring options
    lat_paths: dict[int, Path]
        Per dictionary lattice paths
    rescored_lat_paths: dict[int, Path]
        Per dictionary rescored lattice paths
    old_g_paths: dict[int, Path]
        Mapping of dictionaries to small G.fst paths
    new_g_paths: dict[int, Path]
        Mapping of dictionaries to medium G.fst paths
    """

    working_directory: Path
    lm_rescore_options: MetaDict
    old_g_paths: Dict[int, Path]
    new_g_paths: Dict[int, Path]


@dataclass
class CarpaLmRescoreArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
    dictionaries: list[int]
        List of dictionary ids
    lat_paths: dict[int, Path]
        Per dictionary lattice paths
    rescored_lat_paths: dict[int, Path]
        Per dictionary rescored lattice paths
    old_g_paths: dict[int, Path]
        Mapping of dictionaries to medium G.fst paths
    new_g_paths: dict[int, Path]
        Mapping of dictionaries to G.carpa paths
    """

    working_directory: Path
    lm_rescore_options: MetaDict
    old_g_paths: Dict[int, Path]
    new_g_paths: Dict[int, Path]


@dataclass
class InitialFmllrArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
    dictionaries: list[int]
        List of dictionary ids
    feature_strings: dict[int, str]
        Mapping of dictionaries to feature generation strings
    model_path: :class:`~pathlib.Path`
        Path to model file
    fmllr_options: dict[str, Any]
        fMLLR options
    pre_trans_paths: dict[int, Path]
        Per dictionary pre-fMLLR lattice paths
    lat_paths: dict[int, Path]
        Per dictionary lattice paths
    spk2utt_paths: dict[int, Path]
        Per dictionary speaker to utterance mapping paths
    """

    working_directory: Path
    ali_model_path: Path
    model_path: Path
    fmllr_options: MetaDict


@dataclass
class FinalFmllrArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
    model_path: :class:`~pathlib.Path`
        Path to model file
    fmllr_options: dict[str, Any]
        fMLLR options
    """

    working_directory: Path
    ali_model_path: Path
    model_path: Path
    fmllr_options: MetaDict


@dataclass
class FmllrRescoreArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    working_directory: :class:`~pathlib.Path`
        Working directory
        Path to model file
    rescore_options: dict[str, Any]
        Rescoring options
    """

    working_directory: Path
    model_path: Path
    rescore_options: MetaDict


class CreateHclgFunction(KaldiFunction):
    """
    Create HCLG.fst file

    See Also
    --------
    :meth:`.Transcriber.create_hclgs`
        Main function that calls this function in parallel
    :meth:`.Transcriber.create_hclgs_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`add-self-loops`
        Relevant Kaldi binary
    :openfst_src:`fstconvert`
        Relevant OpenFst binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgArguments`
        Arguments for the function
    """

    def __init__(self, args: CreateHclgArguments):
        super().__init__(args)
        self.lexicon_compiler = args.lexicon_compiler
        self.working_directory = args.working_directory
        self.small_arpa_path = args.small_arpa_path
        self.medium_arpa_path = args.medium_arpa_path
        self.big_arpa_path = args.big_arpa_path
        self.model_path = args.model_path
        self.tree_path = args.tree_path
        self.hclg_options = args.hclg_options

    def _run(self) -> None:
        """Run the function"""
        with thread_logger("kalpy.decode_graph", self.log_path, job_name=self.job_name):
            hclg_path = self.working_directory.joinpath(f"HCLG.{self.job_name}.fst")
            small_g_path = self.working_directory.joinpath(f"G_small.{self.job_name}.fst")
            medium_g_path = self.working_directory.joinpath(f"G_med.{self.job_name}.fst")
            carpa_path = self.working_directory.joinpath(f"G.{self.job_name}.carpa")
            small_compiler = DecodeGraphCompiler(
                self.model_path, self.tree_path, self.lexicon_compiler, **self.hclg_options
            )
            small_compiler.export_hclg(self.small_arpa_path, hclg_path)
            small_compiler.export_g(small_g_path)
            del small_compiler
            medium_compiler = DecodeGraphCompiler(
                self.model_path, self.tree_path, self.lexicon_compiler, **self.hclg_options
            )
            medium_compiler.compile_g_fst(self.medium_arpa_path)
            medium_compiler.export_g(medium_g_path)
            del medium_compiler
            carpa_compiler = DecodeGraphCompiler(
                self.model_path, self.tree_path, self.lexicon_compiler, **self.hclg_options
            )
            carpa_compiler.compile_g_carpa(self.big_arpa_path, carpa_path)
            del carpa_compiler
            if hclg_path.exists():
                self.callback((True, hclg_path))
            else:
                self.callback((False, hclg_path))


class DecodeFunction(KaldiFunction):
    """
    Multiprocessing function for performing decoding

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_utterances`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.decode_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`gmm-latgen-faster`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments`
        Arguments for the function
    """

    def __init__(self, args: DecodeArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.hclg_paths = args.hclg_paths
        self.decode_options = args.decode_options
        self.model_path = args.model_path

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.decode", self.log_path, job_name=self.job_name
        ) as decode_logger:
            job: Job = (
                session.query(Job)
                .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
                .filter(Job.id == self.job_name)
                .first()
            )
            silence_phones = [
                x
                for x, in session.query(Phone.mapping_id).filter(
                    Phone.phone_type.in_([PhoneType.silence])
                )
            ]

            for d in job.dictionaries:
                decode_logger.debug(f"Decoding for dictionary {d.name} ({d.id})")
                decode_logger.debug(f"Decoding with model: {self.model_path}")

                feature_archive = job.construct_feature_archive(self.working_directory, d.name)

                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                alignment_file_name = job.construct_path(
                    self.working_directory, "ali", "ark", d.name
                )
                words_path = job.construct_path(self.working_directory, "words", "ark", d.name)
                hclg_fst = ConstFst.Read(str(self.hclg_paths[d.name]))
                boost_silence = self.decode_options.pop("boost_silence", 1.0)
                decoder = GmmDecoder(self.model_path, hclg_fst, **self.decode_options)
                if boost_silence != 1.0:
                    decoder.boost_silence(boost_silence, silence_phones)
                decoder.export_lattices(
                    lat_path,
                    feature_archive,
                    word_file_name=words_path,
                    alignment_file_name=alignment_file_name,
                    callback=self.callback,
                )


class SpeechbrainAsrFunction(KaldiFunction):
    """
    Multiprocessing function for performing decoding

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_utterances`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.decode_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`gmm-latgen-faster`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments`
        Arguments for the function
    """

    def __init__(self, args: typing.Union[SpeechbrainAsrArguments, SpeechbrainAsrCudaArguments]):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.cuda = isinstance(args, SpeechbrainAsrCudaArguments)
        self.model = None
        self.tokenizer = args.tokenizer
        if self.cuda:
            self.model = args.model
        else:
            self.model = (
                f"speechbrain/asr-{args.architecture}-commonvoice-14-{args.language.iso_code}"
            )

    def _run(self) -> None:
        """Run the function"""
        run_opts = None
        if self.cuda:
            run_opts = {"device": "cuda"}
        model = self.model
        if isinstance(model, str):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                if "wav2vec2" in model:
                    # Download models if needed
                    model = EncoderASR.from_hparams(
                        source=model,
                        savedir=os.path.join(
                            config.TEMPORARY_DIRECTORY,
                            "models",
                            "EncoderASR",
                            model,
                        ),
                        huggingface_cache_dir=os.path.join(
                            config.TEMPORARY_DIRECTORY, "models", "hf_cache"
                        ),
                        run_opts=run_opts,
                    )
                else:
                    # Download models if needed
                    model = WhisperASR.from_hparams(
                        source=model,
                        savedir=os.path.join(
                            config.TEMPORARY_DIRECTORY,
                            "models",
                            "WhisperASR",
                            model,
                        ),
                        huggingface_cache_dir=os.path.join(
                            config.TEMPORARY_DIRECTORY, "models", "hf_cache"
                        ),
                        run_opts=run_opts,
                    )
        return_q = queue.Queue(2)
        finished_adding = threading.Event()
        stopped = threading.Event()
        loader = UtteranceFileLoader(
            self.job_name,
            self.session,
            return_q,
            stopped,
            finished_adding,
            model=model,
            for_xvector=False,
        )
        loader.start()
        exception = None
        current_index = 0
        while True:
            try:
                batch = return_q.get(timeout=1)
            except queue.Empty:
                if finished_adding.is_set():
                    break
                continue
            if stopped.is_set():
                continue
            if isinstance(batch, Exception):
                exception = batch
                stopped.set()
                continue

            audio, lens = batch.signal
            predicted_words, predicted_tokens = model.transcribe_batch(audio, lens)
            for i, u_id in enumerate(batch.utterance_id):
                text = predicted_words[i]
                if self.tokenizer is not None:
                    text = self.tokenizer(text)[0]
                self.callback((int(u_id), text))
            del predicted_words
            del predicted_tokens
            del audio
            del lens
            current_index += 1
            if current_index > 10:
                torch.cuda.empty_cache()
                current_index = 0

        loader.join()
        if exception:
            raise exception


@dataclasses.dataclass
class WhisperSegmentationData:
    utterance_id: int
    segments: typing.List[typing.Dict[str, float]]
    export_path: Path
    speaker_name: str
    utterance_begin: float
    utterance_end: float
    file_id: float
    file_duration: float


class WhisperUtteranceLoader(threading.Thread):
    """
    Helper process for loading utterance waveforms in parallel with embedding extraction

    Parameters
    ----------
    job_name: int
        Job identifier
    session: sqlalchemy.orm.scoped_session
        Session
    return_q: :class:`~queue.Queue`
        Queue to put waveforms
    stopped: :class:`~threading.Event`
        Check for whether the process to exit gracefully
    finished_adding: :class:`~threading.Event`
        Check for whether the worker has processed all utterances
    """

    def __init__(
        self,
        job_name: typing.Optional[int],
        session: sqlalchemy.orm.scoped_session,
        return_q: queue.Queue,
        stopped: threading.Event,
        finished_adding: threading.Event,
        model: MfaFasterWhisperPipeline,
        export_directory: Path = None,
    ):
        super().__init__()
        self.job_name = job_name
        self.session = session
        self.return_q = return_q
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.model = model
        self.export_directory = export_directory

    def run(self) -> None:
        """
        Run the waveform loading job
        """

        with self.session() as session:
            try:
                utterances = (
                    session.query(
                        Utterance.id,
                        SoundFile.sound_file_path,
                        Utterance.begin,
                        Utterance.end,
                        Utterance.channel,
                        File.relative_path,
                        File.name,
                        Speaker.name,
                        File.id,
                        SoundFile.duration,
                    )
                    .join(Utterance.speaker)
                    .join(Utterance.file)
                    .join(File.sound_file)
                )
                if self.job_name is not None:
                    utterances = utterances.filter(Utterance.job_id == self.job_name)
                utterances = utterances.order_by(Utterance.file_id, Utterance.begin)
                if not utterances.count():
                    self.finished_adding.set()
                    return
                for u in utterances:
                    if self.stopped.is_set():
                        break
                    segment = Segment(u[1], u[2], u[3], u[4])
                    export_path = None
                    if self.export_directory is not None:
                        export_path = self.export_directory.joinpath(u[5], u[6])
                        if any(export_path.with_suffix(x).exists() for x in [".lab", ".TextGrid"]):
                            continue
                    audio = segment.load_audio().astype(np.float32)
                    segments = self.model.vad_model.segment_for_whisper(
                        audio, **self.model._vad_params
                    )
                    return_data = WhisperSegmentationData(
                        u[0], segments, export_path, u[7], u[2], u[3], u[-2], u[-1]
                    )
                    self.return_q.put(return_data)
            except Exception as e:
                self.return_q.put(e)
            finally:
                self.finished_adding.set()


class WhisperUtteranceVAD(threading.Thread):
    """
    Helper process for loading utterance waveforms in parallel with embedding extraction

    Parameters
    ----------
    job_name: int
        Job identifier
    session: sqlalchemy.orm.scoped_session
        Session
    return_q: :class:`~queue.Queue`
        Queue to put waveforms
    stopped: :class:`~threading.Event`
        Check for whether the process to exit gracefully
    finished_adding: :class:`~threading.Event`
        Check for whether the worker has processed all utterances
    """

    def __init__(
        self,
        job_name: int,
        job_q: queue.Queue,
        return_q: queue.Queue,
        stopped: threading.Event,
        finished_adding: threading.Event,
        model: MfaFasterWhisperPipeline,
        export_directory: Path = None,
    ):
        super().__init__()
        self.job_name = job_name
        self.job_q = job_q
        self.return_q = return_q
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.model = model
        self.export_directory = export_directory

    def run(self) -> None:
        """
        Run the waveform loading job
        """

        while True:
            try:
                batch = self.job_q.get(timeout=1)
            except queue.Empty:
                if self.finished_adding.is_set():
                    break
                continue
            if self.stopped.is_set():
                continue
            if isinstance(batch, Exception):
                exception = batch
                self.return_q.put(exception)
                self.stopped.set()
                continue
            try:
                utterance_id, audio, export_path, speaker_name, begin, end = batch
                segments = self.model.vad_model.segment_for_whisper(
                    audio, **self.model._vad_params
                )
                self.return_q.put((utterance_id, segments, export_path, speaker_name, begin, end))
            except Exception as e:
                self.return_q.put(e)


class WhisperAsrFunction(KaldiFunction):
    """
    Multiprocessing function for performing decoding

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_utterances`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.decode_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`gmm-latgen-faster`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments`
        Arguments for the function
    """

    def __init__(self, args: typing.Union[WhisperArguments, WhisperCudaArguments]):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.working_directory = args.working_directory
        self.cuda = args.cuda and CUDA_AVAILABLE
        self.architecture = None
        self.model = None
        self.language = None
        self.export_directory = args.export_directory
        if isinstance(args, WhisperCudaArguments):
            self.model = args.model
        else:
            self.language = args.language
            self.architecture = args.architecture
        self.tokenizer = args.tokenizer

    def _run(self) -> None:
        """Run the function"""
        from praatio import textgrid

        model = self.model
        if model is None:
            language = None
            if self.language is not Language.unknown:
                language = self.language.iso_code
            run_opts = None
            if self.cuda:
                run_opts = {"device": "cuda"}
            vad_model = MfaVAD.from_hparams(
                source="speechbrain/vad-crdnn-libriparty",
                savedir=os.path.join(config.TEMPORARY_DIRECTORY, "models", "VAD"),
                run_opts=run_opts,
            )
            model = load_model(
                self.architecture,
                device="cuda" if self.cuda else "cpu",
                language=language,
                vad_model=vad_model,
                vad_options=None,
                compute_type="float16" if self.cuda else "int8",
                download_root=os.path.join(
                    config.TEMPORARY_DIRECTORY,
                    "models",
                    "Whisper",
                ),
                threads=config.NUM_JOBS,
            )
            if self.cuda:
                model.to("cuda")
        return_q = queue.Queue(100)
        finished_adding = threading.Event()
        stopped = threading.Event()
        loader = WhisperUtteranceLoader(
            self.job_name if not self.cuda else None,
            self.session,
            return_q,
            stopped,
            finished_adding,
            model,
            export_directory=self.export_directory,
        )
        loader.start()
        exception = None
        current_file_id = None
        current_data = None
        export_data = {}
        while True:
            try:
                vad_result: WhisperSegmentationData = return_q.get(timeout=1)
            except queue.Empty:
                if finished_adding.is_set():
                    break
                continue
            if stopped.is_set():
                continue
            if isinstance(vad_result, Exception):
                exception = vad_result
                stopped.set()
                continue
            try:
                result = model.transcribe(
                    vad_result.segments,
                    [vad_result.utterance_id] * len(vad_result.segments),
                    batch_size=config.NUM_JOBS,
                )
                if current_file_id is None:
                    current_file_id = vad_result.file_id
                    current_data = vad_result

                if current_data.export_path is not None and current_file_id != vad_result.file_id:
                    current_data.export_path.parent.mkdir(parents=True, exist_ok=True)
                    tg = textgrid.Textgrid()
                    tg.minTimestamp = 0
                    tg.maxTimestamp = current_data.file_duration
                    for speaker_name, intervals in export_data.items():
                        tier = textgrid.IntervalTier(
                            speaker_name,
                            intervals,
                            minT=0,
                            maxT=current_data.file_duration,
                        )

                        tg.addTier(tier)
                    tg.save(
                        str(current_data.export_path.with_suffix(".TextGrid")),
                        includeBlankSpaces=True,
                        format="short_textgrid",
                    )
                    export_data = {}
                    current_file_id = vad_result.file_id
                    current_data = vad_result
                if vad_result.speaker_name not in export_data:
                    export_data[vad_result.speaker_name] = []
                for utterance_id, segments in result.items():
                    texts = []
                    for seg in segments:
                        seg["text"] = seg["text"].strip()
                        if self.tokenizer is not None:
                            seg["text"] = self.tokenizer(seg["text"])[0]
                        texts.append(seg["text"])
                        b = round(vad_result.utterance_begin + seg["start"], 3)
                        e = round(vad_result.utterance_begin + seg["end"], 3)
                        if (
                            export_data[vad_result.speaker_name]
                            and b < export_data[vad_result.speaker_name][-1].end
                        ):
                            b = export_data[vad_result.speaker_name][-1].end
                        export_data[vad_result.speaker_name].append(
                            textgrid.constants.Interval(
                                b,
                                e,
                                seg["text"],
                            )
                        )
                    self.callback((utterance_id, segments))
            except Exception as e:
                exception = e
                stopped.set()

        loader.join()
        if exception:
            raise exception


class LmRescoreFunction(KaldiFunction):
    """
    Multiprocessing function rescore lattices by replacing the small G.fst with the medium G.fst

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_utterances`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.lm_rescore_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`lattice-lmrescore-pruned`
        Relevant Kaldi binary
    :openfst_src:`fstproject`
        Relevant OpenFst binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreArguments`
        Arguments for the function
    """

    def __init__(self, args: LmRescoreArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.old_g_paths = args.old_g_paths
        self.new_g_paths = args.new_g_paths
        self.lm_rescore_options = args.lm_rescore_options

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.lm", self.log_path, job_name=self.job_name
        ):
            job: Job = (
                session.query(Job)
                .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
                .filter(Job.id == self.job_name)
                .first()
            )
            for d in job.dictionaries:
                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                tmp_lat_path = job.construct_path(self.working_directory, "lat.tmp", "ark", d.name)
                os.rename(lat_path, tmp_lat_path)
                old_g_path = self.old_g_paths[d.name]
                new_g_path = self.new_g_paths[d.name]
                olg_g = VectorFst.Read(str(old_g_path))
                new_lm = VectorFst.Read(str(new_g_path))
                rescorer = LmRescorer(olg_g, **self.lm_rescore_options)
                lattice_archive = LatticeArchive(tmp_lat_path, determinized=True)
                rescorer.export_lattices(lat_path, lattice_archive, new_lm, callback=self.callback)
                lattice_archive.close()
                os.remove(tmp_lat_path)


class CarpaLmRescoreFunction(KaldiFunction):
    """
    Multiprocessing function to rescore lattices by replacing medium G.fst with large G.carpa

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_utterances`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.carpa_lm_rescore_arguments`
        Job method for generating arguments for this function
    :openfst_src:`fstproject`
        Relevant OpenFst binary
    :kaldi_src:`lattice-lmrescore`
        Relevant Kaldi binary
    :kaldi_src:`lattice-lmrescore-const-arpa`
        Relevant Kaldi binary

    Parameters
    ----------
    args: CarpaLmRescoreArguments
        Arguments
    """

    def __init__(self, args: CarpaLmRescoreArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.lm_rescore_options = args.lm_rescore_options
        self.old_g_paths = args.old_g_paths
        self.new_g_paths = args.new_g_paths

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.lm", self.log_path, job_name=self.job_name
        ):
            job: Job = (
                session.query(Job)
                .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
                .filter(Job.id == self.job_name)
                .first()
            )
            for d in job.dictionaries:
                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                tmp_lat_path = job.construct_path(self.working_directory, "lat.tmp", "ark", d.name)
                os.rename(lat_path, tmp_lat_path)
                old_g_path = self.old_g_paths[d.name]
                new_g_path = self.new_g_paths[d.name]
                olg_g = VectorFst.Read(str(old_g_path))
                new_lm = ConstArpaLm()
                ReadKaldiObject(str(new_g_path), new_lm)
                rescorer = LmRescorer(olg_g, **self.lm_rescore_options)
                lattice_archive = LatticeArchive(tmp_lat_path, determinized=True)
                rescorer.export_lattices(lat_path, lattice_archive, new_lm, callback=self.callback)
                lattice_archive.close()
                os.remove(tmp_lat_path)


class InitialFmllrFunction(KaldiFunction):
    """
    Multiprocessing function for running initial fMLLR calculation

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_fmllr`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.initial_fmllr_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`lattice-to-post`
        Relevant Kaldi binary
    :kaldi_src:`weight-silence-post`
        Relevant Kaldi binary
    :kaldi_src:`gmm-post-to-gpost`
        Relevant Kaldi binary
    :kaldi_src:`gmm-est-fmllr-gpost`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrArguments`
        Arguments for the function
    """

    def __init__(self, args: InitialFmllrArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.ali_model_path = args.ali_model_path
        self.model_path = args.model_path
        self.fmllr_options = args.fmllr_options

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.fmllr", self.log_path, job_name=self.job_name
        ) as fmllr_logger:
            fmllr_logger.debug(f"Using acoustic model: {self.model_path}\n")
            job: typing.Optional[Job] = session.get(
                Job, self.job_name, options=[joinedload(Job.dictionaries), joinedload(Job.corpus)]
            )
            lda_mat_path = self.working_directory.joinpath("lda.mat")
            if not lda_mat_path.exists():
                lda_mat_path = None
            for d in job.dictionaries:
                feat_path = job.construct_path(
                    job.corpus.current_subset_directory, "feats", "scp", dictionary_id=d.name
                )
                utt2spk_path = job.construct_path(
                    job.corpus.current_subset_directory, "utt2spk", "scp", dictionary_id=d.name
                )
                spk2utt_path = job.construct_path(
                    job.corpus.current_subset_directory, "spk2utt", "scp", dictionary_id=d.name
                )
                utt2spk = KaldiMapping()
                utt2spk.load(utt2spk_path)
                spk2utt = KaldiMapping(list_mapping=True)
                spk2utt.load(spk2utt_path)
                feature_archive = FeatureArchive(
                    feat_path,
                    utt2spk=utt2spk,
                    lda_mat_file_name=lda_mat_path,
                    deltas=True,
                )
                silence_phones = [
                    x
                    for x, in session.query(Phone.mapping_id).filter(
                        Phone.phone_type.in_([PhoneType.silence, PhoneType.oov])
                    )
                ]
                computer = FmllrComputer(
                    self.ali_model_path,
                    self.model_path,
                    silence_phones,
                    spk2utt=spk2utt,
                    **self.fmllr_options,
                )
                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                fmllr_logger.debug(f"Processing {lat_path} with features from {feat_path}")
                lattice_archive = LatticeArchive(lat_path, determinized=False)
                temp_trans_path = job.construct_path(
                    self.working_directory, "trans", "ark", d.name
                )
                computer.export_transforms(
                    temp_trans_path,
                    feature_archive,
                    lattice_archive,
                    callback=self.callback,
                )
                feature_archive.close()
                lattice_archive.close()
                del feature_archive
                del lattice_archive
                del computer
                trans_archive = MatrixArchive(temp_trans_path)
                write_specifier = generate_write_specifier(
                    job.construct_path(
                        job.corpus.current_subset_directory, "trans", "ark", dictionary_id=d.name
                    ),
                    write_scp=True,
                )
                writer = BaseFloatMatrixWriter(write_specifier)
                for speaker, trans in trans_archive:
                    writer.Write(str(speaker), trans)
                writer.Close()
                trans_archive.close()
                del trans_archive
                os.remove(temp_trans_path)


class FinalFmllrFunction(KaldiFunction):

    """
    Multiprocessing function for running final fMLLR estimation

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_fmllr`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.final_fmllr_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`lattice-determinize-pruned`
        Relevant Kaldi binary
    :kaldi_src:`lattice-to-post`
        Relevant Kaldi binary
    :kaldi_src:`weight-silence-post`
        Relevant Kaldi binary
    :kaldi_src:`gmm-est-fmllr`
        Relevant Kaldi binary
    :kaldi_src:`compose-transforms`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrArguments`
        Arguments for the function
    """

    def __init__(self, args: FinalFmllrArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.ali_model_path = args.ali_model_path
        self.model_path = args.model_path
        self.fmllr_options = args.fmllr_options

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.fmllr", self.log_path, job_name=self.job_name
        ) as fmllr_logger:
            fmllr_logger.debug(f"Using acoustic model: {self.model_path}\n")
            job: typing.Optional[Job] = session.get(
                Job, self.job_name, options=[joinedload(Job.dictionaries), joinedload(Job.corpus)]
            )
            lda_mat_path = self.working_directory.joinpath("lda.mat")
            if not lda_mat_path.exists():
                lda_mat_path = None
            for d in job.dictionaries:
                feat_path = job.construct_path(
                    job.corpus.current_subset_directory, "feats", "scp", dictionary_id=d.name
                )
                fmllr_trans_path = job.construct_path(
                    job.corpus.current_subset_directory, "trans", "scp", dictionary_id=d.name
                )
                previous_transform_archive = None
                if not fmllr_trans_path.exists():
                    fmllr_logger.debug("Computing transforms from scratch")
                    fmllr_trans_path = None
                else:
                    fmllr_logger.debug(f"Updating previous transforms {fmllr_trans_path}")
                    previous_transform_archive = MatrixArchive(fmllr_trans_path)
                utt2spk_path = job.construct_path(
                    job.corpus.current_subset_directory, "utt2spk", "scp", dictionary_id=d.name
                )
                spk2utt_path = job.construct_path(
                    job.corpus.current_subset_directory, "spk2utt", "scp", dictionary_id=d.name
                )
                utt2spk = KaldiMapping()
                utt2spk.load(utt2spk_path)
                spk2utt = KaldiMapping(list_mapping=True)
                spk2utt.load(spk2utt_path)
                feature_archive = FeatureArchive(
                    feat_path,
                    utt2spk=utt2spk,
                    lda_mat_file_name=lda_mat_path,
                    transform_file_name=fmllr_trans_path,
                    deltas=True,
                )
                silence_phones = [
                    x
                    for x, in session.query(Phone.mapping_id).filter(
                        Phone.phone_type.in_([PhoneType.silence, PhoneType.oov])
                    )
                ]
                computer = FmllrComputer(
                    self.ali_model_path,
                    self.model_path,
                    silence_phones,
                    spk2utt=spk2utt,
                    **self.fmllr_options,
                )
                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                fmllr_logger.debug(f"Processing {lat_path} with features from {feat_path}")
                lattice_archive = LatticeArchive(lat_path, determinized=False)
                temp_trans_path = job.construct_path(
                    self.working_directory, "trans", "ark", d.name
                )
                computer.export_transforms(
                    temp_trans_path,
                    feature_archive,
                    lattice_archive,
                    previous_transform_archive=previous_transform_archive,
                    callback=self.callback,
                )
                feature_archive.close()
                del previous_transform_archive
                del feature_archive
                del lattice_archive
                del computer
                if fmllr_trans_path is not None:
                    os.remove(fmllr_trans_path)
                    os.remove(fmllr_trans_path.with_suffix(".ark"))
                trans_archive = MatrixArchive(temp_trans_path)
                write_specifier = generate_write_specifier(
                    job.construct_path(
                        job.corpus.current_subset_directory, "trans", "ark", dictionary_id=d.name
                    ),
                    write_scp=True,
                )
                writer = BaseFloatMatrixWriter(write_specifier)
                for speaker, trans in trans_archive:
                    writer.Write(str(speaker), trans)
                writer.Close()
                del trans_archive
                os.remove(temp_trans_path)


class FmllrRescoreFunction(KaldiFunction):
    """
    Multiprocessing function to rescore lattices following fMLLR estimation

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_fmllr`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.fmllr_rescore_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`gmm-rescore-lattice`
        Relevant Kaldi binary
    :kaldi_src:`lattice-determinize-pruned`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreArguments`
        Arguments for the function
    """

    def __init__(self, args: FmllrRescoreArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.model_path = args.model_path
        self.rescore_options = args.rescore_options

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.decode", self.log_path, job_name=self.job_name
        ) as decode_logger:
            job: Job = (
                session.query(Job)
                .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
                .filter(Job.id == self.job_name)
                .first()
            )
            rescorer = GmmRescorer(self.model_path, **self.rescore_options)
            for d in job.dictionaries:
                decode_logger.debug(f"Aligning for dictionary {d.name} ({d.id})")
                decode_logger.debug(f"Aligning with model: {self.model_path}")
                fst_path = job.construct_path(self.working_directory, "fsts", "ark", d.name)
                decode_logger.debug(f"Training graph archive: {fst_path}")

                fmllr_path = job.construct_path(
                    job.corpus.current_subset_directory, "trans", "scp", d.name
                )
                if not fmllr_path.exists():
                    fmllr_path = None
                lda_mat_path = self.working_directory.joinpath("lda.mat")
                if not lda_mat_path.exists():
                    lda_mat_path = None
                feat_path = job.construct_path(
                    job.corpus.current_subset_directory, "feats", "scp", dictionary_id=d.name
                )
                utt2spk_path = job.construct_path(
                    job.corpus.current_subset_directory, "utt2spk", "scp", d.name
                )
                utt2spk = KaldiMapping()
                utt2spk.load(utt2spk_path)
                decode_logger.debug(f"Feature path: {feat_path}")
                decode_logger.debug(f"LDA transform path: {lda_mat_path}")
                decode_logger.debug(f"Speaker transform path: {fmllr_path}")
                decode_logger.debug(f"utt2spk path: {utt2spk_path}")
                feature_archive = FeatureArchive(
                    feat_path,
                    utt2spk=utt2spk,
                    lda_mat_file_name=lda_mat_path,
                    transform_file_name=fmllr_path,
                    deltas=True,
                )
                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                tmp_lat_path = job.construct_path(self.working_directory, "lat.tmp", "ark", d.name)
                os.rename(lat_path, tmp_lat_path)
                lattice_archive = LatticeArchive(tmp_lat_path, determinized=True)
                rescorer.export_lattices(
                    lat_path, lattice_archive, feature_archive, callback=self.callback
                )
                lattice_archive.close()
                os.remove(tmp_lat_path)


@dataclass
class PerSpeakerDecodeArguments(MfaArguments):
    """Arguments for :class:`~montreal_forced_aligner.validation.corpus_validator.PerSpeakerDecodeFunction`"""

    working_directory: Path
    model_path: Path
    tree_path: Path
    decode_options: MetaDict
    order: int
    method: str


class PerSpeakerDecodeFunction(KaldiFunction):
    """
    Multiprocessing function to test utterance transcriptions with utterance and speaker ngram models

    See Also
    --------
    :kaldi_src:`compile-train-graphs-fsts`
        Relevant Kaldi binary
    :kaldi_src:`gmm-latgen-faster`
        Relevant Kaldi binary
    :kaldi_src:`lattice-oracle`
        Relevant Kaldi binary
    :openfst_src:`farcompilestrings`
        Relevant OpenFst binary
    :ngram_src:`ngramcount`
        Relevant OpenGrm-Ngram binary
    :ngram_src:`ngrammake`
        Relevant OpenGrm-Ngram binary
    :ngram_src:`ngramshrink`
        Relevant OpenGrm-Ngram binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.validation.corpus_validator.PerSpeakerDecodeArguments`
        Arguments for the function
    """

    def __init__(self, args: PerSpeakerDecodeArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.model_path = args.model_path
        self.decode_options = args.decode_options
        self.tree_path = args.tree_path
        self.order = args.order
        self.method = args.method
        self.word_symbols_paths = {}

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.decode", self.log_path, job_name=self.job_name
        ) as decode_logger:
            job: Job = (
                session.query(Job)
                .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
                .filter(Job.id == self.job_name)
                .first()
            )
            silence_phones = [
                x
                for x, in session.query(Phone.mapping_id).filter(
                    Phone.phone_type.in_([PhoneType.silence])
                )
            ]

            for d in job.dictionaries:
                decode_logger.debug(f"Decoding for dictionary {d.name} ({d.id})")
                decode_logger.debug(f"Decoding with model: {self.model_path}")

                fmllr_path = job.construct_path(
                    job.corpus.current_subset_directory, "trans", "scp", d.name
                )
                if not fmllr_path.exists():
                    fmllr_path = None
                lda_mat_path = self.working_directory.joinpath("lda.mat")
                if not lda_mat_path.exists():
                    lda_mat_path = None
                feat_path = job.construct_path(
                    job.corpus.current_subset_directory, "feats", "scp", dictionary_id=d.name
                )
                utt2spk_path = job.construct_path(
                    job.corpus.current_subset_directory, "utt2spk", "scp", d.name
                )
                utt2spk = KaldiMapping()
                utt2spk.load(utt2spk_path)
                decode_logger.debug(f"Feature path: {feat_path}")
                decode_logger.debug(f"LDA transform path: {lda_mat_path}")
                decode_logger.debug(f"Speaker transform path: {fmllr_path}")
                decode_logger.debug(f"utt2spk path: {utt2spk_path}")
                feature_archive = FeatureArchive(
                    feat_path,
                    utt2spk=utt2spk,
                    lda_mat_file_name=lda_mat_path,
                    transform_file_name=fmllr_path,
                    deltas=True,
                )

                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                alignment_file_name = job.construct_path(
                    self.working_directory, "ali", "ark", d.name
                )
                words_path = job.construct_path(self.working_directory, "words", "ark", d.name)
                boost_silence = self.decode_options.pop("boost_silence", 1.0)

                current_speaker = None
                write_specifier = generate_write_specifier(lat_path, write_scp=False)
                alignment_writer = None
                if alignment_file_name:
                    alignment_write_specifier = generate_write_specifier(
                        alignment_file_name, write_scp=False
                    )
                    alignment_writer = Int32VectorWriter(alignment_write_specifier)
                word_writer = None
                if words_path:
                    word_write_specifier = generate_write_specifier(words_path, write_scp=False)
                    word_writer = Int32VectorWriter(word_write_specifier)
                writer = CompactLatticeWriter(write_specifier)
                for utt_id, speaker_id in (
                    session.query(Utterance.kaldi_id, Utterance.speaker_id)
                    .filter(Utterance.job_id == job.id)
                    .order_by(Utterance.kaldi_id)
                ):
                    if speaker_id != current_speaker:
                        lm_path = os.path.join(d.temp_directory, f"{speaker_id}.fst")
                        hclg_fst = ConstFst.Read(str(lm_path))
                        decoder = GmmDecoder(self.model_path, hclg_fst, **self.decode_options)
                        if boost_silence != 1.0:
                            decoder.boost_silence(boost_silence, silence_phones)
                    for transcription in decoder.decode_utterances(feature_archive):
                        if transcription is None:
                            continue
                        utt_id = int(transcription.utterance_id.split("-")[-1])
                        self.callback((utt_id, transcription.likelihood))
                        writer.Write(str(transcription.utterance_id), transcription.lattice)
                        if alignment_writer is not None:
                            alignment_writer.Write(
                                str(transcription.utterance_id), transcription.alignment
                            )
                        if word_writer is not None:
                            word_writer.Write(str(transcription.utterance_id), transcription.words)
                writer.Close()
                if alignment_writer is not None:
                    alignment_writer.Close()
                if word_writer is not None:
                    word_writer.Close()


class DecodePhoneFunction(KaldiFunction):
    """
    Multiprocessing function for performing decoding

    See Also
    --------
    :meth:`.TranscriberMixin.transcribe_utterances`
        Main function that calls this function in parallel
    :meth:`.TranscriberMixin.decode_arguments`
        Job method for generating arguments for this function
    :kaldi_src:`gmm-latgen-faster`
        Relevant Kaldi binary

    Parameters
    ----------
    args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments`
        Arguments for the function
    """

    def __init__(self, args: DecodePhoneArguments):
        super().__init__(args)
        self.working_directory = args.working_directory
        self.hclg_path = args.hclg_path
        self.decode_options = args.decode_options
        self.model_path = args.model_path

    def _run(self) -> None:
        """Run the function"""
        with self.session() as session, thread_logger(
            "kalpy.decode", self.log_path, job_name=self.job_name
        ) as decode_logger:
            job: Job = (
                session.query(Job)
                .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
                .filter(Job.id == self.job_name)
                .first()
            )
            silence_phones = [
                x
                for x, in session.query(Phone.mapping_id).filter(
                    Phone.phone_type.in_([PhoneType.silence])
                )
            ]
            phones = session.query(Phone.mapping_id, Phone.phone)
            reversed_phone_mapping = {}
            for p_id, phone in phones:
                reversed_phone_mapping[p_id] = phone
            hclg_fst = ConstFst.Read(str(self.hclg_path))
            for d in job.dictionaries:
                decode_logger.debug(f"Decoding for dictionary {d.name} ({d.id})")
                decode_logger.debug(f"Decoding with model: {self.model_path}")
                feature_archive = job.construct_feature_archive(self.working_directory, d.name)
                lat_path = job.construct_path(self.working_directory, "lat", "ark", d.name)
                alignment_file_name = job.construct_path(
                    self.working_directory, "ali", "ark", d.name
                )
                words_path = job.construct_path(self.working_directory, "words", "ark", d.name)

                boost_silence = self.decode_options.pop("boost_silence", 1.0)
                decoder = GmmDecoder(self.model_path, hclg_fst, **self.decode_options)
                if boost_silence != 1.0:
                    decoder.boost_silence(boost_silence, silence_phones)
                decoder.export_lattices(
                    lat_path,
                    feature_archive,
                    word_file_name=words_path,
                    alignment_file_name=alignment_file_name,
                    callback=self.callback,
                )
