#!/usr/bin/env python3
# Copyright 2018  David Snyder
# Apache 2.0

# This script computes the minimum detection cost function, which is a common
# error metric used in speaker recognition.  Compared to equal error-rate,
# which assigns equal weight to false negatives and false positives, this
# error-rate is usually used to assess performance in settings where achieving
# a low false positive rate is more important than achieving a low false
# negative rate.  See the NIST 2016 Speaker Recognition Evaluation Plan at
# https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
# for more details about the metric.
from __future__ import print_function
from operator import itemgetter
import sys, argparse, os


def GetArgs():
    parser = argparse.ArgumentParser(
        description="Compute the minimum "
        "detection cost function along with the threshold at which it occurs. "
        "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
        "<trials-file> "
        "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
        "exp/scores/trials data/test/trials",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--p-target",
        type=float,
        dest="p_target",
        default=0.01,
        help="The prior probability of the target speaker in a trial.",
    )
    parser.add_argument(
        "--c-miss",
        type=float,
        dest="c_miss",
        default=1,
        help="Cost of a missed detection.  This is usually not changed.",
    )
    parser.add_argument(
        "--c-fa",
        type=float,
        dest="c_fa",
        default=1,
        help="Cost of a spurious detection.  This is usually not changed.",
    )
    parser.add_argument(
        "scores_filename",
        help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
    )
    parser.add_argument(
        "trials_filename",
        help="Input trials file, with columns of the form " "<utt1> <utt2> <target/nontarget>",
    )
    sys.stderr.write(" ".join(sys.argv) + "\n")
    args = parser.parse_args()
    args = CheckArgs(args)
    return args


def CheckArgs(args):
    if args.c_fa <= 0:
        raise Exception("--c-fa must be greater than 0")
    if args.c_miss <= 0:
        raise Exception("--c-miss must be greater than 0")
    if args.p_target <= 0 or args.p_target >= 1:
        raise Exception("--p-target must be greater than 0 and less than 1")
    return args


# Creates a list of false-negative rates, a list of false-positive rates
# and a list of decision thresholds that give those error-rates.
def ComputeErrorRates(scores, labels):

    # Sort the scores from smallest to largest, and also get the corresponding
    # indexes of the sorted scores.  We will treat the sorted scores as the
    # thresholds at which the the error-rates are evaluated.
    sorted_indexes, thresholds = zip(
        *sorted([(index, threshold) for index, threshold in enumerate(scores)], key=itemgetter(1))
    )
    labels = [labels[i] for i in sorted_indexes]
    fns = []
    tns = []

    # At the end of this loop, fns[i] is the number of errors made by
    # incorrectly rejecting scores less than thresholds[i]. And, tns[i]
    # is the total number of times that we have correctly rejected scores
    # less than thresholds[i].
    for i in range(0, len(labels)):
        if i == 0:
            fns.append(labels[i])
            tns.append(1 - labels[i])
        else:
            fns.append(fns[i - 1] + labels[i])
            tns.append(tns[i - 1] + 1 - labels[i])
    positives = sum(labels)
    negatives = len(labels) - positives

    # Now divide the false negatives by the total number of
    # positives to obtain the false negative rates across
    # all thresholds
    fnrs = [fn / float(positives) for fn in fns]

    # Divide the true negatives by the total number of
    # negatives to get the true negative rate. Subtract these
    # quantities from 1 to get the false positive rates.
    fprs = [1 - tn / float(negatives) for tn in tns]
    return fnrs, fprs, thresholds


# Computes the minimum of the detection cost function.  The comments refer to
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
    min_c_det = float("inf")
    min_c_det_threshold = thresholds[0]
    for i in range(0, len(fnrs)):
        # See Equation (2).  it is a weighted sum of false negative
        # and false positive errors.
        c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
        if c_det < min_c_det:
            min_c_det = c_det
            min_c_det_threshold = thresholds[i]
    # See Equations (3) and (4).  Now we normalize the cost.
    c_def = min(c_miss * p_target, c_fa * (1 - p_target))
    min_dcf = min_c_det / c_def
    return min_dcf, min_c_det_threshold


def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
    scores_file = open(scores_filename, "r").readlines()
    trials_file = open(trials_filename, "r").readlines()
    c_miss = c_miss
    c_fa = c_fa
    p_target = p_target

    scores = []
    labels = []

    trials = {}
    for line in trials_file:
        utt1, utt2, target = line.rstrip().split()
        trial = utt1 + " " + utt2
        trials[trial] = target

    for line in scores_file:
        utt1, utt2, score = line.rstrip().split()
        trial = utt1 + " " + utt2
        if trial in trials:
            scores.append(float(score))
            if trials[trial] == "target":
                labels.append(1)
            else:
                labels.append(0)
        else:
            raise Exception("Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename)

    fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
    mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
    return mindcf, threshold


def main():
    args = GetArgs()
    mindcf, threshold = compute_min_dcf(
        args.scores_filename, args.trials_filename, args.c_miss, args.c_fa, args.p_target
    )
    sys.stdout.write(
        "minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
        "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
    )


if __name__ == "__main__":
    main()
