#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
candidate hypotheses.

See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""

import argparse
import random
import sys
from itertools import chain

import numpy as np
import sacrebleu
from sacrebleu import corpus_bleu as _corpus_bleu

def main():
    parser = argparse.ArgumentParser(sys.argv[0])
    parser.add_argument(
        "--sys", nargs="*", default="", metavar="FILE", help="path to system output"
    )
    parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
    parser.add_argument(
        "--output",
        default="",
        metavar="FILE",
        help="print outputs into a pretty format",
    )
    args = parser.parse_args()

    if args.sys:
        src, tgt, hypos, log_probs = load_sys(args.sys)
        print("pairwise BLEU: %.2f" % pairwise(hypos))
        if args.output:
            merge(src, tgt, hypos, log_probs, args.output)

    if args.ref:
        _, _, refs = load_ref(args.ref)
        if args.sys:
            multi_ref(refs, hypos)
        else:
            intra_ref(refs)


def dictolist(d):
    a = sorted(d.items(), key=lambda i: i[0])
    return [i[1] for i in a]


def load_sys(paths):
    src, tgt, hypos, log_probs = {}, {}, {}, {}
    for path in paths:
        with open(path) as f:
            for line in f:
                line = line.rstrip()
                # S: source
                # T: target
                # D: detokenized system output
                if line.startswith(("S-", "T-", "D-")):
                    i = int(line[line.find("-") + 1 : line.find("\t")])
                    if line.startswith("S-"):
                        src[i] = line.split("\t")[1]
                    if line.startswith("T-"):
                        tgt[i] = line.split("\t")[1]
                    if line.startswith("D-"):
                        if i not in hypos:
                            hypos[i] = []
                            log_probs[i] = []
                        hypos[i].append(line.split("\t")[2])
                        log_probs[i].append(float(line.split("\t")[1]))
    return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)


def load_ref(path):
    with open(path) as f:
        lines = f.readlines()
    src, tgt, refs = [], [], []
    i = 0
    while i < len(lines):
        if lines[i].startswith("S-"):
            src.append(lines[i].split("\t")[1].rstrip())
            i += 1
        elif lines[i].startswith("T-"):
            tgt.append(lines[i].split("\t")[1].rstrip())
            i += 1
        else:
            a = []
            while i < len(lines) and lines[i].startswith("R"):
                a.append(lines[i].split("\t")[1].rstrip())
                i += 1
            refs.append(a)
    return src, tgt, refs


def merge(src, tgt, hypos, log_probs, path):
    with open(path, "w") as f:
        for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
            f.write(s + "\n")
            f.write(t + "\n")
            f.write("\n")
            for h, lp in zip(hs, lps):
                f.write("\t%f\t%s\n" % (lp, h.strip()))
            f.write("------------------------------------------------------\n")


def corpus_bleu(sys_stream, ref_streams):
    bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
    return bleu.score


def sentence_bleu(hypothesis, reference):
    bleu = _corpus_bleu(hypothesis, reference)
    for i in range(1, 4):
        bleu.counts[i] += 1
        bleu.totals[i] += 1
    bleu = sacrebleu.BLEU.compute_bleu(
        bleu.counts,
        bleu.totals,
        bleu.sys_len,
        bleu.ref_len,
        smooth_method="exp",
    )
    return bleu.score


def pairwise(sents):
    _ref, _hypo = [], []
    for s in sents:
        for i in range(len(s)):
            for j in range(len(s)):
                if i != j:
                    _ref.append(s[i])
                    _hypo.append(s[j])
    return corpus_bleu(_hypo, [_ref])


def multi_ref(refs, hypos):
    _ref, _hypo = [], []
    ref_cnt = 0
    assert len(refs) == len(hypos)

    # count number of refs covered
    for rs, hs in zip(refs, hypos):
        a = set()
        for h in hs:
            s = [sentence_bleu(h, r) for r in rs]
            j = np.argmax(s)
            _ref.append(rs[j])
            _hypo.append(h)
            best = [k for k in range(len(rs)) if s[k] == s[j]]
            a.add(random.choice(best))
        ref_cnt += len(a)
    print("#refs covered: %.2f" % (ref_cnt / len(refs)))

    # transpose refs and hypos
    refs = list(zip(*refs))
    hypos = list(zip(*hypos))

    # compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
    k = len(hypos)
    m = len(refs)
    flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
    duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
    loo_bleus = []
    for held_out_ref in range(m):
        remaining_refs = (
            duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
        )
        assert len(remaining_refs) == m - 1
        loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
    print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))


def intra_ref(refs):
    print("ref pairwise BLEU: %.2f" % pairwise(refs))
    refs = list(zip(*refs))
    m = len(refs)
    concat_h = []
    concat_rest = [[] for j in range(m - 1)]
    for i, h in enumerate(refs):
        rest = refs[:i] + refs[i + 1 :]
        concat_h.append(h)
        for j in range(m - 1):
            concat_rest[j].extend(rest[j])
    concat_h = list(chain.from_iterable(concat_h))
    bleu = corpus_bleu(concat_h, concat_rest)
    print("multi-reference BLEU (leave-one-out): %.2f" % bleu)


if __name__ == "__main__":
    main()
