# -*- coding:utf-8 -*-

from __future__ import division, print_function, absolute_import

import time
import random
import logging
from collections import OrderedDict


import model
import prepro
import mecab_system_eval
import nagisa_utils as utils

from tagger import Tagger

logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)


def fit(train_file, dev_file, test_file, model_name,
        dict_file=None, emb_file=None, delimiter='\t', newline='EOS',
        layers=1, min_count=2, decay=1, epoch=10, window_size=3,
        dim_uni=32, dim_bi=16, dim_word=16, dim_ctype=8, dim_tagemb=16,
        dim_hidden=100, learning_rate=0.1, dropout_rate=0.3, seed=1234):

    """Train a joint word segmentation and sequence labeling (e.g, POS-tagging, NER) model.

    args:
        - train_file (str): Path to a train file.
        - dev_file (str): Path to a development file for early stopping.
        - test_file (str):  Path to a test file for evaluation.
        - model_name (str): Output model filename.
        - dict_file (str, optional): Path to a dictionary file.
        - emb_file (str, optional): Path to a pre-trained embedding file (word2vec format).
        - delimiter (str, optional): Separate word and tag in each line by 'delimiter'.
        - newline (str, optional):  Separate lines in the file by 'newline'.
        - layers (int, optional): RNN Layer size.
        - min_count (int, optional): Ignores all words with total frequency lower than this.
        - decay (int, optional): Learning rate decay.
        - epoch (int, optional): Epoch size.
        - window_size (int, optional): Window size of the context characters for word segmentation.
        - dim_uni (int, optional): Dimensionality of the char-unigram vectors.
        - dim_bi (int, optional): Dimensionality of the char-bigram vectors.
        - dim_word (int, optional): Dimensionality of the word vectors.
        - dim_ctype (int, optional): Dimensionality of the character-type vectors.
        - dim_tagemb (int, optional): Dimensionality of the tag vectors.
        - dim_hidden (int, optional): Dimensionality of the BiLSTM's hidden layer.
        - learning_rate (float, optional): Learning rate of SGD.
        - dropout_rate (float, optional): Dropout rate of the input vector for BiLSTMs.
        - seed (int, optional): Random seed.

    return:
        - Nothing. After finish training, however,
          save the three model files (*.vocabs, *.params, *.hp) in the current directory.

    """

    random.seed(seed)

    hp = OrderedDict({
          'LAYERS':layers,
          'THRESHOLD':min_count,
          'DECAY':decay,
          'EPOCH':epoch,
          'WINDOW_SIZE':window_size,
          'DIM_UNI':dim_uni,
          'DIM_BI':dim_bi,
          'DIM_WORD':dim_word,
          'DIM_CTYPE':dim_ctype,
          'DIM_TAGEMB':dim_tagemb,
          'DIM_HIDDEN':dim_hidden,
          'LEARNING_RATE':learning_rate,
          'DROPOUT_RATE':dropout_rate,
          'SEED': seed,

          'TRAINSET':train_file,
          'TESTSET':test_file,
          'DEVSET':dev_file,
          'DICTIONARY':dict_file,
          'EMBEDDING':emb_file,

          'HYPERPARAMS':model_name+'.hp',
          'MODEL':model_name+'.params',
          'VOCAB':model_name+'.vocabs',
          'EPOCH_MODEL':model_name+'_epoch.params'
    })


    # Preprocess
    vocabs = prepro.create_vocabs_from_trainset(trainset=hp['TRAINSET'],
                                                threshold=hp['THRESHOLD'],
                                                fn_dictionary=hp['DICTIONARY'],
                                                fn_vocabs=hp['VOCAB'],
                                                delimiter=delimiter,
                                                newline=newline)

    if emb_file is not None:
        embs, dim_word = prepro.embedding_loader(fn_embedding=hp['EMBEDDING'],
                                                 word2id=vocabs[2])
        hp['DIM_WORD'] = dim_word
    else:
        embs = None


    TrainData = prepro.from_file(filename=hp['TRAINSET'],
                                 window_size=hp['WINDOW_SIZE'],
                                 vocabs=vocabs,
                                 delimiter=delimiter,
                                 newline=newline)
    TestData  = prepro.from_file(filename=hp['TESTSET'],
                                 window_size=hp['WINDOW_SIZE'],
                                 vocabs=vocabs,
                                 delimiter=delimiter,
                                 newline=newline)
    DevData   = prepro.from_file(filename=hp['DEVSET'],
                                 window_size=hp['WINDOW_SIZE'],
                                 vocabs=vocabs,
                                 delimiter=delimiter,
                                 newline=newline)

    # Update hyper-parameters
    hp['NUM_TRAIN']         = len(TrainData.ws_data)
    hp['NUM_TEST']          = len(TestData.ws_data)
    hp['NUM_DEV']           = len(DevData.ws_data)
    hp['VOCAB_SIZE_UNI']    = len(vocabs[0])
    hp['VOCAB_SIZE_BI']     = len(vocabs[1])
    hp['VOCAB_SIZE_WORD']   = len(vocabs[2])
    hp['VOCAB_SIZE_POSTAG'] = len(vocabs[3])

    # Construct networks
    _model = model.Model(hp=hp, embs=embs)

    # Start training
    _start(hp, model=_model, train_data=TrainData, test_data=TestData, dev_data=DevData)


def _evaluation(hp, fn_model, data):
    tagger = Tagger(vocabs=hp['VOCAB'], params=fn_model, hp=hp['HYPERPARAMS'])

    def data_for_eval(words, postags):
        sent = []
        for w, p in zip(words, postags):
            p = w+"\t"+p
            if mecab_system_eval.PY_3 is True:
                w = w.encode("UTF-8")
                p = p.encode("UTF-8")
            sent.append([w, p])
        return sent

    sys_data = []
    ans_data = []
    indice   = [i for i in range(len(data.ws_data))]
    for i in indice:
        words   = data.words[i]
        pids    = data.pos_data[i][1]
        postags = [tagger.id2pos[pid] for pid in pids]
        ans_data.append(data_for_eval(words, postags))

        output      = tagger.tagging(''.join(words))
        sys_words   = output.words
        sys_postags = output.postags
        sys_data.append(data_for_eval(sys_words, sys_postags))

    r = mecab_system_eval.mecab_eval(sys_data, ans_data)
    _, _, ws_f, _, _, pos_f = mecab_system_eval.calculate_fvalues(r)
    return ws_f, pos_f


def _start(hp, model, train_data, test_data, dev_data):
    for k, v in hp.items():
        logging.info('[nagisa] {}: {}'.format(k, v))

    logs = '{:5}\t{:5}\t{:5}\t{:5}\t{:8}\t{:8}\t{:8}\t{:8}'.format(
        'Epoch', 'LR', 'Loss', 'Time_m', 'DevWS_f1',
        'DevPOS_f1', 'TestWS_f1', 'TestPOS_f1')
    logging.info(logs)


    utils.dump_data(hp, hp['HYPERPARAMS'])

    decay_counter  = 0
    best_dev_score = -1.0
    indice = [i for i in range(len(train_data.ws_data))]
    for e in range(1, hp['EPOCH']+1):
        t = time.time()
        losses = 0.
        random.shuffle(indice)
        for i in indice:
            # Word Segmentation
            X = train_data.ws_data[i][0]
            Y = train_data.ws_data[i][1]

            obs = model.encode_ws(X, train=True)
            gold_score = model.score_sentence(obs, Y)
            forward_score = model.forward(obs)
            loss = forward_score-gold_score
            # Update
            loss.backward()
            model.trainer.update()
            losses += loss.value()

            # POS-tagging
            X = train_data.pos_data[i][0]
            Y = train_data.pos_data[i][1]
            loss = model.get_POStagging_loss(X, Y)
            losses += loss.value()
            # Update
            loss.backward()
            model.trainer.update()

        model.model.save(hp['EPOCH_MODEL'])
        dev_ws_f, dev_pos_f = _evaluation(hp, fn_model=hp['EPOCH_MODEL'], data=dev_data)

        if dev_ws_f > best_dev_score:
            best_dev_score = dev_ws_f
            decay_counter = 0
            model.model.save(hp['MODEL'])
            test_ws_f, test_pos_f = _evaluation(hp, fn_model=hp['MODEL'], data=test_data)
        else:
            decay_counter += 1
            if decay_counter >= hp['DECAY']:
                model.trainer.learning_rate = model.trainer.learning_rate/2
                decay_counter = 0

        losses = losses/len(indice)
        logs = [e, model.trainer.learning_rate, losses, (time.time()-t)/60,
                dev_ws_f, dev_pos_f, test_ws_f, test_pos_f]

        logs = [log[:5] for log in map(str, logs)]

        logs = '{:5}\t{:5}\t{:5}\t{:5}\t{:8}\t{:8}\t{:8}\t{:8}'.format(
            logs[0], logs[1], logs[2], logs[3], logs[4],
            logs[5], logs[6], logs[7])

        logging.info(logs)

