"""Class for generating pronunciations from G2P models"""
from __future__ import annotations

import csv
import functools
import itertools
import logging
import multiprocessing as mp
import os
import queue
import statistics
import time
import typing
import unicodedata
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from pathlib import Path

import pynini
import pywrapfst
from kalpy.gmm.data import to_tg_interval
from praatio import textgrid
from pynini import Fst, TokenType
from pynini.lib import rewrite
from pywrapfst import SymbolTable
from sqlalchemy.orm import selectinload
from tqdm.rich import tqdm

from montreal_forced_aligner import config
from montreal_forced_aligner.abc import DatabaseMixin, KaldiFunction, TopLevelMfaWorker
from montreal_forced_aligner.corpus.text_corpus import DictionaryTextCorpusMixin, TextCorpusMixin
from montreal_forced_aligner.data import (
    GeneratedPronunciation,
    MfaArguments,
    TextgridFormats,
    WordType,
    WorkflowType,
)
from montreal_forced_aligner.db import File, Utterance, Word, bulk_update
from montreal_forced_aligner.exceptions import PyniniGenerationError
from montreal_forced_aligner.g2p.mixins import G2PTopLevelMixin
from montreal_forced_aligner.helper import comma_join, mfa_open, score_g2p
from montreal_forced_aligner.models import G2PModel
from montreal_forced_aligner.textgrid import construct_output_path
from montreal_forced_aligner.utils import run_kaldi_function

if typing.TYPE_CHECKING:
    SpeakerCharacterType = typing.Union[str, int]


__all__ = [
    "Rewriter",
    "PhonetisaurusRewriter",
    "RewriterWorker",
    "PyniniGenerator",
    "PyniniCorpusGenerator",
    "PyniniWordListGenerator",
    "PyniniValidator",
]

logger = logging.getLogger("mfa")


def threshold_lattice_to_dfa(
    lattice: pynini.Fst, threshold: float = 1.0, state_multiplier: int = 2
) -> pynini.Fst:
    """Constructs a (possibly pruned) weighted DFA of output strings.
    Given an epsilon-free lattice of output strings (such as produced by
    rewrite_lattice), attempts to determinize it, pruning non-optimal paths if
    optimal_only is true. This is valid only in a semiring with the path property.
    To prevent unexpected blowup during determinization, a state threshold is
    also used and a warning is logged if this exact threshold is reached. The
    threshold is a multiplier of the size of input lattice (by default, 4), plus
    a small constant factor. This is intended by a sensible default and is not an
    inherently meaningful value in and of itself.

    Parameters
    ----------
    lattice: :class:`~pynini.Fst`
        Epsilon-free non-deterministic finite acceptor.
    threshold: float
        Threshold for weights, 1.0 is optimal only, 0 is for all paths, greater than 1
        prunes the lattice to include paths with costs less than the optimal path's score times the threshold
    state_multiplier: int
        Max ratio for the number of states in the DFA lattice to the NFA lattice; if exceeded, a warning is logged.

    Returns
    -------
    :class:`~pynini.Fst`
        Epsilon-free deterministic finite acceptor.
    """
    weight_type = lattice.weight_type()
    weight_threshold = pynini.Weight(weight_type, threshold)
    state_threshold = 256 + state_multiplier * lattice.num_states()
    lattice = pynini.determinize(lattice, nstate=state_threshold, weight=weight_threshold)
    return lattice


def generate_scored_paths(lattice, output_token_type) -> typing.List[GeneratedPronunciation]:
    paths = lattice.paths(output_token_type=output_token_type)
    output = []
    while not paths.done():
        score = float(paths.weight())
        ostring = paths.ostring()
        output.append(GeneratedPronunciation(ostring, score))
        paths.next()
    return output


def optimal_rewrites(
    string: pynini.FstLike,
    rule: pynini.Fst,
    input_token_type: typing.Optional[TokenType] = None,
    output_token_type: typing.Optional[TokenType] = None,
    threshold: float = 1,
) -> typing.List[GeneratedPronunciation]:
    """Returns all optimal rewrites.
    Args:
    string: Input string or FST.
    rule: Input rule WFST.
    input_token_type: Optional input token type, or symbol table.
    output_token_type: Optional output token type, or symbol table.
    threshold: Threshold for weights (1 is optimal only, 0 is for all paths)
    Returns:
    A tuple of output strings.
    """
    lattice = rewrite.rewrite_lattice(string, rule, input_token_type)
    lattice = threshold_lattice_to_dfa(lattice, threshold, 4)
    return generate_scored_paths(lattice, output_token_type)


def scored_top_rewrites(
    string: pynini.FstLike,
    rule: pynini.Fst,
    nshortest: int,
    input_token_type: typing.Optional[pynini.TokenType] = None,
    output_token_type: typing.Optional[pynini.TokenType] = None,
) -> typing.List[GeneratedPronunciation]:
    """Returns the top n rewrites.

    Args:
    string: Input string or FST.
    rule: Input rule WFST.
    nshortest: The maximum number of rewrites to return.
    input_token_type: Optional input token type, or symbol table.
    output_token_type: Optional output token type, or symbol table.

    Returns:
    A list of output strings.
    """
    lattice = rewrite.rewrite_lattice(string, rule, input_token_type)
    lattice = rewrite.lattice_to_nshortest(lattice, nshortest)
    return generate_scored_paths(lattice, output_token_type)


class Rewriter:
    """
    Helper object for rewriting

    Parameters
    ----------
    fst: pynini.Fst
        G2P FST model
    input_token_type: pynini.TokenType
        Grapheme symbol table or "utf8"
    phone_symbol_table: pynini.SymbolTable
        Phone symbol table
    num_pronunciations: int
        Number of pronunciations, default to 0.  If this is 0, thresholding is used
    threshold: float
        Threshold to use for pruning rewrite lattice, defaults to 1.5, only used if num_pronunciations is 0
    """

    def __init__(
        self,
        fst: Fst,
        grapheme_symbol_table: SymbolTable,
        phone_symbol_table: SymbolTable,
        num_pronunciations: int = 0,
        threshold: float = 1,
        graphemes: typing.Set[str] = None,
        strict: bool = False,
        unicode_decomposition: bool = False,
    ):
        self.graphemes = graphemes
        self.grapheme_symbol_table = grapheme_symbol_table
        self.phone_symbol_table = phone_symbol_table
        self.strict = strict
        self.unicode_decomposition = unicode_decomposition
        if num_pronunciations > 0:
            self.rewrite = functools.partial(
                scored_top_rewrites,
                nshortest=num_pronunciations,
                rule=fst,
                input_token_type=None,
                output_token_type=self.phone_symbol_table,
            )
        else:
            self.rewrite = functools.partial(
                optimal_rewrites,
                threshold=threshold,
                rule=fst,
                input_token_type=None,
                output_token_type=self.phone_symbol_table,
            )

    def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]:
        if self.graphemes is not None:
            if self.strict and any(x not in self.graphemes for x in word):
                return None
            word = "".join([x for x in word if x in self.graphemes])
        fst = pynini.accep(word, token_type=self.grapheme_symbol_table)
        return fst

    def __call__(self, graphemes: str) -> typing.List[GeneratedPronunciation]:  # pragma: no cover
        """Call the rewrite function"""
        if self.unicode_decomposition:
            graphemes = unicodedata.normalize("NFKD", graphemes)
        if " " in graphemes:
            words = graphemes.split()
            hypotheses = []
            for w in words:
                w_fst = self.create_word_fst(w)
                if not w_fst:
                    continue
                hypotheses.append(self.rewrite(w_fst))
            hypotheses = sorted(
                (
                    GeneratedPronunciation(
                        " ".join([y.pronunciation for y in x]), sum([y.score for y in x])
                    )
                    for x in itertools.product(*hypotheses)
                ),
                key=lambda x: x.score,
            )
        else:
            fst = self.create_word_fst(graphemes)
            if not fst:
                return []
            hypotheses = self.rewrite(fst)
        return [x for x in hypotheses if x.pronunciation]


class PhonetisaurusRewriter(Rewriter):
    """
    Helper function for rewriting

    Parameters
    ----------
    fst: pynini.Fst
        G2P FST model
    grapheme_symbol_table: pynini.SymbolTable
        Grapheme symbol table
    phone_symbol_table: pynini.SymbolTable
        Phone symbol table
    num_pronunciations: int
        Number of pronunciations, default to 0.  If this is 0, thresholding is used
    threshold: float
        Threshold to use for pruning rewrite lattice, defaults to 1.5, only used if num_pronunciations is 0
    grapheme_order: int
        Maximum number of graphemes to consider single segment
    sequence_separator: str
        Separator to use between grapheme symbols
    """

    def __init__(
        self,
        fst: Fst,
        grapheme_symbol_table: SymbolTable,
        phone_symbol_table: SymbolTable,
        num_pronunciations: int = 0,
        threshold: float = 1.5,
        grapheme_order: int = 2,
        sequence_separator: str = "|",
        graphemes: typing.Set[str] = None,
        strict: bool = False,
        unicode_decomposition: bool = False,
    ):
        super().__init__(
            fst,
            grapheme_symbol_table,
            phone_symbol_table,
            num_pronunciations,
            threshold,
            graphemes,
            strict,
            unicode_decomposition,
        )
        self.sequence_separator = sequence_separator
        self.grapheme_order = grapheme_order

    def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]:
        if self.graphemes is not None:
            if self.strict and any(not self.grapheme_symbol_table.member(x) for x in word):
                return None
            word = [x for x in word if x in self.graphemes]
        if not word:
            return None
        fst = pynini.Fst()
        one = pywrapfst.Weight.one(fst.weight_type())
        max_state = 0
        for i in range(len(word)):
            start_state = fst.add_state()
            for j in range(1, self.grapheme_order + 1):
                if i + j <= len(word):
                    substring = self.sequence_separator.join(word[i : i + j])
                    ilabel = self.grapheme_symbol_table.find(substring)
                    if ilabel != pywrapfst.NO_LABEL:
                        fst.add_arc(start_state, pywrapfst.Arc(ilabel, ilabel, one, i + j))
                    if i + j >= max_state:
                        max_state = i + j
        for _ in range(fst.num_states(), max_state + 1):
            fst.add_state()
        fst.set_start(0)
        fst.set_final(len(word), one)
        fst.set_input_symbols(self.grapheme_symbol_table)
        fst.set_output_symbols(self.grapheme_symbol_table)
        return fst

    def __call__(self, graphemes: str) -> typing.List[GeneratedPronunciation]:  # pragma: no cover
        """Call the rewrite function"""
        hypotheses = super().__call__(graphemes)
        return [
            GeneratedPronunciation(x.pronunciation.replace(self.sequence_separator, " "), x.score)
            for x in hypotheses
            if x.pronunciation
        ]


class RewriterWorker(mp.Process):
    """
    Rewriter process

    Parameters
    ----------
    job_queue: :class:`~multiprocessing.Queue`
        Queue to pull words from
    return_queue: :class:`~multiprocessing.Queue`
        Queue to put pronunciations
    rewriter: :class:`~montreal_forced_aligner.g2p.generator.Rewriter`
        Function to generate pronunciations of words
    stopped: :class:`~threading.Event`
        Stop check
    """

    def __init__(
        self,
        job_queue: mp.Queue,
        return_queue: mp.Queue,
        rewriter: Rewriter,
        stopped: mp.Event,
        finished_adding: mp.Event,
    ):
        super().__init__()
        self.job_queue = job_queue
        self.return_queue = return_queue
        self.rewriter = rewriter
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.finished = mp.Event()

    def run(self) -> None:
        """Run the rewriting function"""
        while True:
            try:
                word = self.job_queue.get(timeout=0.01)
            except queue.Empty:
                if not self.finished_adding.is_set():
                    continue
                break
            if self.stopped.is_set():
                continue
            try:
                rep = self.rewriter(word)
                self.return_queue.put((word, rep))
            except rewrite.Error:
                pass
            except Exception as e:  # noqa
                self.stopped.set()
                self.return_queue.put(e)
                raise
        self.finished.set()


@dataclass
class G2PArguments(MfaArguments):
    """
    Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`

    Parameters
    ----------
    job_name: int
        Integer ID of the job
    session: :class:`sqlalchemy.orm.scoped_session` or str
        SqlAlchemy scoped session or string for database connections
    log_path: :class:`~pathlib.Path`
        Path to save logging information during the run
    rewriter: :class:`~montreal_forced_aligner.g2p.generator.Rewriter`
        Callable function for rewriter graphemes to phones
    """

    rewriter: Rewriter


class G2PFunction(KaldiFunction):
    def __init__(self, args: G2PArguments):
        super().__init__(args)
        self.rewriter = args.rewriter

    def _run(self):
        """Run the function"""

        with mfa_open(self.log_path, "w") as log_file, self.session() as session:
            query = (
                session.query(Utterance.id, Utterance.normalized_text)
                .filter(Utterance.job_id == self.job_name)
                .filter(Utterance.normalized_text != "")
            )
            for id, text in query:
                try:
                    pronunciation_text = self.rewriter(text)[0]
                    self.callback((id, pronunciation_text))
                except pynini.lib.rewrite.Error:
                    log_file.write(f"Error on generating pronunciation for {text}\n")


def clean_up_word(word: str, graphemes: typing.Set[str]) -> typing.Tuple[str, typing.Set[str]]:
    """
    Clean up word by removing graphemes not in a specified set

    Parameters
    ----------
    word : str
        Input string
    graphemes: set[str]
        Set of allowable graphemes

    Returns
    -------
    str
        Cleaned up word
    Set[str]
        Graphemes excluded
    """
    new_word = []
    missing_graphemes = set()
    for c in word:
        if c not in graphemes:
            missing_graphemes.add(c)
        else:
            new_word.append(c)
    return "".join(new_word), missing_graphemes


class OrthographyGenerator(G2PTopLevelMixin):
    """
    Abstract mixin class for generating "pronunciations" based off the orthographic word

    See Also
    --------
    :class:`~montreal_forced_aligner.g2p.mixins.G2PTopLevelMixin`
        For top level G2P generation parameters
    """

    def generate_pronunciations(
        self,
    ) -> typing.Generator[typing.Tuple[str, typing.List[typing.Tuple[str, float]]]]:
        """
        Generate pronunciations for the word set

        Yields
        -------
        str, list[tuple[str, float]]]
            Tuple of word with their scored pronunciations
        """
        for word in self.words_to_g2p:
            yield word, [(" ".join(word), 0.0)]


class PyniniGenerator(G2PTopLevelMixin):
    """
    Class for generating pronunciations from a Pynini G2P model

    Parameters
    ----------
    g2p_model_path: str
        Path to G2P model
    strict_graphemes: bool
        Flag for whether to be strict with missing graphemes and skip words containing new graphemes

    See Also
    --------
    :class:`~montreal_forced_aligner.g2p.mixins.G2PTopLevelMixin`
        For top level G2P generation parameters

    Attributes
    ----------
    g2p_model: G2PModel
        G2P model
    """

    def __init__(
        self,
        word_list: typing.List[str] = None,
        g2p_model_path: Path = None,
        strict_graphemes: bool = False,
        **kwargs,
    ):
        self.strict_graphemes = strict_graphemes
        super().__init__(**kwargs)
        self.g2p_model = G2PModel(
            g2p_model_path, root_directory=getattr(self, "workflow_directory", None)
        )
        self.output_token_type = "utf8"
        self.input_token_type = "utf8"
        self.rewriter = None
        if word_list is None:
            word_list = []
        self.word_list = word_list

    @property
    def words_to_g2p(self) -> typing.List[str]:
        """Words to produce pronunciations"""
        return self.word_list

    @property
    def data_source_identifier(self) -> str:
        """Dummy "validation" data source"""
        return "validation"

    @property
    def working_directory(self) -> typing.Optional[Path]:
        """Data directory"""
        return None

    @property
    def data_directory(self) -> Path:
        """Data directory"""
        return self.working_directory

    def setup(self):
        self.fst = pynini.Fst.read(self.g2p_model.fst_path)
        if self.g2p_model.meta["architecture"] == "phonetisaurus":
            self.output_token_type = pywrapfst.SymbolTable.read_text(self.g2p_model.sym_path)
            self.input_token_type = pywrapfst.SymbolTable.read_text(
                self.g2p_model.grapheme_sym_path
            )
            self.fst.set_input_symbols(self.input_token_type)
            self.fst.set_output_symbols(self.output_token_type)
            self.rewriter = PhonetisaurusRewriter(
                self.fst,
                self.input_token_type,
                self.output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
                grapheme_order=self.g2p_model.meta["grapheme_order"],
                graphemes=self.g2p_model.meta["graphemes"],
            )
        else:
            if self.g2p_model.sym_path is not None and os.path.exists(self.g2p_model.sym_path):
                self.output_token_type = pywrapfst.SymbolTable.read_text(self.g2p_model.sym_path)

            self.rewriter = Rewriter(
                self.fst,
                self.input_token_type,
                self.output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
                graphemes=self.g2p_model.meta["graphemes"],
            )

    def generate_pronunciations(
        self,
    ) -> typing.Generator[typing.Tuple[str, typing.List[GeneratedPronunciation]]]:
        """
        Generate pronunciations

        Yields
        -------
        str, list[tuple[str, float]]]
            Tuple of word with their scored pronunciations
        """

        num_words = len(self.words_to_g2p)
        begin = time.time()
        missing_graphemes = set()
        if self.rewriter is None:
            self.setup()
        logger.info("Generating pronunciations...")
        skipped_words = 0
        min_score = None
        max_score = None
        if not config.USE_MP or num_words < 30 or config.NUM_JOBS == 1:
            with tqdm(total=num_words, disable=config.QUIET) as pbar:
                for word in self.words_to_g2p:
                    w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
                    pbar.update(1)
                    missing_graphemes = missing_graphemes | m
                    if self.strict_graphemes and m:
                        skipped_words += 1
                        continue
                    if not w:
                        skipped_words += 1
                        continue
                    try:
                        prons = self.rewriter(w)
                    except rewrite.Error:
                        continue
                    scores = [x.score for x in prons]
                    if min_score is not None:
                        scores.append(min_score)
                    if max_score is not None:
                        scores.append(max_score)
                    min_score = min(scores)
                    max_score = min(scores)
                    yield w, prons
                logger.debug(
                    f"Skipping {skipped_words} words for containing the following graphemes: "
                    f"{comma_join(sorted(missing_graphemes))}"
                )
        else:
            stopped = mp.Event()
            finished_adding = mp.Event()
            job_queue = mp.Queue()
            return_queue = mp.Queue()
            procs = []
            for _ in range(config.NUM_JOBS):
                p = RewriterWorker(
                    job_queue,
                    return_queue,
                    self.rewriter,
                    stopped,
                    finished_adding,
                )
                procs.append(p)
                p.start()
            for word in self.words_to_g2p:
                w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
                missing_graphemes = missing_graphemes | m
                if self.strict_graphemes and m:
                    skipped_words += 1
                    continue
                if not w:
                    skipped_words += 1
                    continue
                job_queue.put(w)
            logger.debug(
                f"Skipping {skipped_words} words for containing the following graphemes: "
                f"{comma_join(sorted(missing_graphemes))}"
            )
            finished_adding.set()
            error_dict = {}
            num_words -= skipped_words
            with tqdm(total=num_words, disable=config.QUIET) as pbar:
                while True:
                    try:
                        word, result = return_queue.get(timeout=1)
                        if stopped.is_set():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.is_set():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    if isinstance(result, Exception):
                        error_dict[word] = result
                        continue
                    scores = [x.score for x in result]
                    if min_score is not None:
                        scores.append(min_score)
                    if max_score is not None:
                        scores.append(max_score)
                    min_score = min(scores)
                    max_score = min(scores)
                    yield word, result

            for p in procs:
                p.join()
            if error_dict:
                raise PyniniGenerationError(error_dict)
        logger.debug(f"Minimum score: {min_score}")
        logger.debug(f"Maximum score: {max_score}")
        logger.debug(f"Processed {num_words} in {time.time() - begin:.3f} seconds")

    def generate_dict_pronunciations(self):
        return {k: [x.pronunciation for x in v] for k, v in self.generate_pronunciations()}


class PyniniConsoleGenerator(PyniniGenerator):
    @property
    def data_directory(self) -> Path:
        return Path("-")

    @property
    def working_directory(self) -> Path:
        return config.TEMPORARY_DIRECTORY.joinpath("g2p_stdin")

    def cleanup(self) -> None:
        pass


class PyniniValidator(PyniniGenerator, TopLevelMfaWorker):
    """
    Class for running validation for G2P model training

    Parameters
    ----------
    word_list: list[str]
        List of words to generate pronunciations

    See Also
    --------
    :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
        For parameters to generate pronunciations
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def evaluation_csv_path(self) -> Path:
        """Path to working directory's CSV file"""
        return self.working_directory.joinpath("pronunciation_evaluation.csv")

    @property
    def working_directory(self) -> Path:
        """Data directory"""
        return self.output_directory.joinpath(self._current_workflow)

    def setup(self) -> None:
        """Set up the G2P validator"""
        TopLevelMfaWorker.setup(self)
        if self.initialized:
            return
        self._current_workflow = "validation"
        os.makedirs(self.working_log_directory, exist_ok=True)
        self.g2p_model.validate(self.words_to_g2p)
        PyniniGenerator.setup(self)
        self.initialized = True
        self.wer = None
        self.ler = None

    def compute_validation_errors(
        self,
        gold_values: typing.Dict[str, typing.Set[str]],
        hypothesis_values: typing.Dict[str, typing.List[str]],
    ):
        """
        Computes validation errors

        Parameters
        ----------
        gold_values: dict[str, set[str]]
            Gold pronunciations
        hypothesis_values: dict[str, list[str]]
            Hypothesis pronunciations
        """
        begin = time.time()
        # Word-level measures.
        correct = 0
        incorrect = 0
        # Label-level measures.
        total_edits = 0
        total_length = 0
        # Since the edit distance algorithm is quadratic, let's do this with
        # multiprocessing.
        logger.debug("Processing results for hypotheses")
        to_comp = []
        indices = []
        hyp_pron_count = 0
        gold_pron_count = 0
        output = []
        for word, gold_pronunciations in gold_values.items():
            if word not in hypothesis_values:
                incorrect += 1
                gold_length = statistics.mean(len(x.split()) for x in gold_pronunciations)
                total_edits += gold_length
                total_length += gold_length
                output.append(
                    {
                        "Word": word,
                        "Gold pronunciations": ", ".join(gold_pronunciations),
                        "Hypothesis pronunciations": "",
                        "Accuracy": 0,
                        "Error rate": 1.0,
                        "Length": gold_length,
                    }
                )
                continue
            hyp = hypothesis_values[word]
            if not isinstance(hyp, list):
                hyp = [hyp]
            for h in hyp:
                if h in gold_pronunciations:
                    correct += 1
                    total_length += len(h)
                    output.append(
                        {
                            "Word": word,
                            "Gold pronunciations": ", ".join(gold_pronunciations),
                            "Hypothesis pronunciations": ", ".join(hyp),
                            "Accuracy": 1,
                            "Error rate": 0.0,
                            "Length": len(h),
                        }
                    )
                    break
            else:
                incorrect += 1
                indices.append(word)
                to_comp.append((gold_pronunciations, hyp))  # Multiple hypotheses to compare
                logger.debug(
                    f"Incorrect: for the word {word}: gold is {gold_pronunciations}, hypothesized are: {hyp}"
                )
            hyp_pron_count += len(hyp)
            gold_pron_count += len(gold_pronunciations)
        logger.debug(
            f"Generated an average of {hyp_pron_count / len(hypothesis_values)} variants "
            f"The gold set had an average of {gold_pron_count / len(hypothesis_values)} variants."
        )
        with ThreadPool(config.NUM_JOBS) as pool:
            gen = pool.starmap(score_g2p, to_comp)
            for i, (edits, length) in enumerate(gen):
                word = indices[i]
                gold_pronunciations = gold_values[word]
                hyp = hypothesis_values[word]
                output.append(
                    {
                        "Word": word,
                        "Gold pronunciations": ", ".join(gold_pronunciations),
                        "Hypothesis pronunciations": ", ".join(hyp),
                        "Accuracy": 1,
                        "Error rate": edits / length,
                        "Length": length,
                    }
                )
                total_edits += edits
                total_length += length
        with mfa_open(self.evaluation_csv_path, "w") as f:
            writer = csv.DictWriter(
                f,
                fieldnames=[
                    "Word",
                    "Gold pronunciations",
                    "Hypothesis pronunciations",
                    "Accuracy",
                    "Error rate",
                    "Length",
                ],
            )
            writer.writeheader()
            for line in output:
                writer.writerow(line)
        self.wer = 100 * incorrect / (correct + incorrect)
        self.ler = 100 * total_edits / total_length
        logger.info(f"WER:\t{self.wer:.2f}")
        logger.info(f"LER:\t{self.ler:.2f}")
        logger.debug(
            f"Computation of errors for {len(gold_values)} words took {time.time() - begin:.3f} seconds"
        )

    def evaluate_g2p_model(self, gold_pronunciations: typing.Dict[str, typing.Set[str]]) -> None:
        """
        Evaluate a G2P model on the word list

        Parameters
        ----------
        gold_pronunciations: dict[str, set[str]]
            Gold pronunciations
        """
        hypotheses = {
            w: [x.pronunciation for x in prons] for w, prons in self.generate_pronunciations()
        }
        self.compute_validation_errors(gold_pronunciations, hypotheses)


class PyniniWordListGenerator(PyniniValidator, DatabaseMixin):
    """
    Top-level worker for generating pronunciations from a word list and a Pynini G2P model

    Parameters
    ----------
    word_list_path: :class:`~pathlib.Path`
        Path to word list file

    See Also
    --------
    :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
        For Pynini G2P generation parameters
    :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker`
        For top-level parameters

    Attributes
    ----------
    word_list: list[str]
        Word list to generate pronunciations
    """

    def __init__(self, word_list_path: Path, **kwargs):
        self.word_list_path = word_list_path
        super().__init__(**kwargs)

    @property
    def data_directory(self) -> Path:
        """Data directory"""
        return self.working_directory

    @property
    def data_source_identifier(self) -> str:
        """Name of the word list file"""
        return os.path.splitext(os.path.basename(self.word_list_path))[0]

    def setup(self) -> None:
        """Set up the G2P generator"""
        if self.initialized:
            return
        with mfa_open(self.word_list_path, "r") as f:
            for line in f:
                self.word_list.extend(line.strip().split())
        if not self.include_bracketed:
            self.word_list = [x for x in self.word_list if not self.check_bracketed(x)]
        super().setup()
        self.g2p_model.validate(self.words_to_g2p)
        self.initialized = True


class PyniniCorpusGenerator(PyniniGenerator, TextCorpusMixin, TopLevelMfaWorker):
    """
    Top-level worker for generating pronunciations from a corpus and a Pynini G2P model

    See Also
    --------
    :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
        For Pynini G2P generation parameters
    :class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin`
        For corpus parsing parameters
    :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker`
        For top-level parameters
    """

    def __init__(self, per_utterance: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.per_utterance = per_utterance

    def setup(self) -> None:
        """Set up the pronunciation generator"""
        if self.initialized:
            return
        self._load_corpus()
        self.initialize_jobs()
        super().setup()
        self._create_dummy_dictionary()
        self.normalize_text()
        self.create_new_current_workflow(WorkflowType.g2p)
        self.g2p_model.validate(self.words_to_g2p)
        self.initialized = True

    @property
    def working_directory(self) -> Path:
        return self.output_directory.joinpath("g2p")

    def g2p_arguments(self) -> typing.List[G2PArguments]:
        return [
            G2PArguments(
                j.id,
                getattr(self, "session" if config.USE_THREADING else "db_string", ""),
                self.working_log_directory.joinpath(f"g2p_utterances.{j.id}.log"),
                self.rewriter,
            )
            for j in self.jobs
        ]

    def export_file_pronunciations(self, output_file_path: Path):
        """
        Generate and export per-utterance G2P

        Parameters
        ----------
        output_file_path: :class:`~pathlib.Path`
            Output directory to save utterance pronunciations

        """
        output_file_path.mkdir(parents=True, exist_ok=True)
        if self.num_pronunciations != 1:
            logger.warning(
                "Number of pronunciations is hard-coded to 1 for generating per-utterance pronunciations"
            )
            self.num_pronunciations = 1
        begin = time.time()
        if self.rewriter is None:
            self.setup()
        logger.info("Generating pronunciations...")
        update_mapping = []
        for utt_id, (p) in run_kaldi_function(
            G2PFunction, self.g2p_arguments(), total_count=self.num_utterances
        ):
            update_mapping.append({"id": utt_id, "transcription_text": p.pronunciation})
        with self.session() as session:
            bulk_update(session, Utterance, update_mapping)
            session.commit()

        logger.debug(f"Processed {self.num_utterances} in {time.time() - begin:.3f} seconds")
        logger.info("Exporting files...")
        with self.session() as session:
            files = session.query(File).options(
                selectinload(File.utterances), selectinload(File.speakers)
            )
            for file in files:
                utterance_count = len(file.utterances)
                if file.sound_file is not None:
                    duration = file.sound_file.duration
                else:
                    duration = file.utterances[-1].end

                if utterance_count == 0:
                    logger.debug(f"Could not find any utterances for {file.name}")
                elif (
                    utterance_count == 1
                    and file.utterances[0].begin == 0
                    and file.utterances[0].end == duration
                ):
                    output_format = "lab"
                else:
                    output_format = TextgridFormats.SHORT_TEXTGRID
                output_path = construct_output_path(
                    file.name,
                    file.relative_path,
                    output_file_path,
                    output_format=output_format,
                )
                data = file.construct_transcription_tiers()
                if output_format == "lab":
                    for intervals in data.values():
                        with mfa_open(output_path, "w") as f:
                            f.write(intervals["transcription"][0].label)
                else:
                    tg = textgrid.Textgrid()
                    tg.minTimestamp = 0
                    tg.maxTimestamp = round(duration, 5)
                    for speaker in file.speakers:
                        speaker = speaker.name
                        intervals = data[speaker]["transcription"]
                        tier = textgrid.IntervalTier(
                            speaker,
                            [to_tg_interval(x, tg.maxTimestamp) for x in intervals],
                            minT=0,
                            maxT=round(duration, 5),
                        )

                        tg.addTier(tier)
                    tg.save(output_path, includeBlankSpaces=True, format=output_format)

    @property
    def words_to_g2p(self) -> typing.List[str]:
        """Words to produce pronunciations"""
        word_list = self.corpus_word_set
        if not self.include_bracketed:
            word_list = [x for x in word_list if not self.check_bracketed(x)]
        return word_list

    def export_pronunciations(
        self,
        output_file_path: typing.Union[str, Path],
        export_scores: bool = False,
        ensure_sorted: bool = False,
    ) -> None:
        if self.per_utterance:
            self.export_file_pronunciations(output_file_path)
        else:
            super().export_pronunciations(
                output_file_path, export_scores=export_scores, ensure_sorted=ensure_sorted
            )


class PyniniDictionaryCorpusGenerator(
    PyniniGenerator, DictionaryTextCorpusMixin, TopLevelMfaWorker
):
    """
    Top-level worker for generating pronunciations from a corpus and a Pynini G2P model

    See Also
    --------
    :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
        For Pynini G2P generation parameters
    :class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin`
        For corpus parsing parameters
    :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker`
        For top-level parameters
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._word_list = None

    @property
    def data_source_identifier(self) -> str:
        """Corpus name"""
        return os.path.basename(self.corpus_directory)

    def setup(self) -> None:
        """Set up the pronunciation generator"""
        if self.initialized:
            return
        self.load_corpus()
        super().setup()
        self.g2p_model.validate(self.words_to_g2p)
        self.initialized = True

    @property
    def words_to_g2p(self) -> typing.List[str]:
        """Words to produce pronunciations"""
        if self._word_list is None:
            with self.session() as session:
                query = (
                    session.query(Word.word)
                    .filter(Word.word_type == WordType.oov, Word.word != self.oov_word)
                    .order_by(Word.count.desc())
                )
            self._word_list = [x for x, in query]
        return self._word_list
