# Copyright 2020 Hirofumi Inaguma
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Conformer common arguments."""


def add_arguments_rnn_encoder_common(group):
    """Define common arguments for RNN encoder."""
    group.add_argument(
        "--etype",
        default="blstmp",
        type=str,
        choices=[
            "lstm",
            "blstm",
            "lstmp",
            "blstmp",
            "vgglstmp",
            "vggblstmp",
            "vgglstm",
            "vggblstm",
            "gru",
            "bgru",
            "grup",
            "bgrup",
            "vgggrup",
            "vggbgrup",
            "vgggru",
            "vggbgru",
        ],
        help="Type of encoder network architecture",
    )
    group.add_argument(
        "--elayers",
        default=4,
        type=int,
        help="Number of encoder layers",
    )
    group.add_argument(
        "--eunits",
        "-u",
        default=300,
        type=int,
        help="Number of encoder hidden units",
    )
    group.add_argument("--eprojs", default=320, type=int, help="Number of encoder projection units")
    group.add_argument(
        "--subsample",
        default="1",
        type=str,
        help="Subsample input frames x_y_z means "
        "subsample every x frame at 1st layer, "
        "every y frame at 2nd layer etc.",
    )
    return group


def add_arguments_rnn_decoder_common(group):
    """Define common arguments for RNN decoder."""
    group.add_argument(
        "--dtype",
        default="lstm",
        type=str,
        choices=["lstm", "gru"],
        help="Type of decoder network architecture",
    )
    group.add_argument("--dlayers", default=1, type=int, help="Number of decoder layers")
    group.add_argument("--dunits", default=320, type=int, help="Number of decoder hidden units")
    group.add_argument(
        "--dropout-rate-decoder",
        default=0.0,
        type=float,
        help="Dropout rate for the decoder",
    )
    group.add_argument(
        "--sampling-probability",
        default=0.0,
        type=float,
        help="Ratio of predicted labels fed back to decoder",
    )
    group.add_argument(
        "--lsm-type",
        const="",
        default="",
        type=str,
        nargs="?",
        choices=["", "unigram"],
        help="Apply label smoothing with a specified distribution type",
    )
    return group


def add_arguments_rnn_attention_common(group):
    """Define common arguments for RNN attention."""
    group.add_argument(
        "--atype",
        default="dot",
        type=str,
        choices=[
            "noatt",
            "dot",
            "add",
            "location",
            "coverage",
            "coverage_location",
            "location2d",
            "location_recurrent",
            "multi_head_dot",
            "multi_head_add",
            "multi_head_loc",
            "multi_head_multi_res_loc",
        ],
        help="Type of attention architecture",
    )
    group.add_argument(
        "--adim",
        default=320,
        type=int,
        help="Number of attention transformation dimensions",
    )
    group.add_argument("--awin", default=5, type=int, help="Window size for location2d attention")
    group.add_argument(
        "--aheads",
        default=4,
        type=int,
        help="Number of heads for multi head attention",
    )
    group.add_argument(
        "--aconv-chans",
        default=-1,
        type=int,
        help="Number of attention convolution channels \
                       (negative value indicates no location-aware attention)",
    )
    group.add_argument(
        "--aconv-filts",
        default=100,
        type=int,
        help="Number of attention convolution filters \
                       (negative value indicates no location-aware attention)",
    )
    group.add_argument(
        "--dropout-rate",
        default=0.0,
        type=float,
        help="Dropout rate for the encoder",
    )
    return group
