# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities import rank_zero_only

# from sacrebleu import corpus_bleu


class LogEpochTimeCallback(Callback):
    """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log"""

    @rank_zero_only
    def on_train_epoch_start(self, trainer, pl_module):
        self.epoch_start = time.time()

    @rank_zero_only
    def on_train_epoch_end(self, trainer, pl_module):
        curr_time = time.time()
        duration = curr_time - self.epoch_start
        trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)


# class MachineTranslationLogEvalCallback(Callback):
#     def _on_eval_end(self, trainer, pl_module, mode):
#         counts = np.array(self._non_pad_tokens)
#         eval_loss = np.sum(np.array(self._losses) * counts) / np.sum(counts)
#         sacre_bleu = corpus_bleu(self._translations, [self._ground_truths], tokenize="13a")
#         print(f"{mode} results for process with global rank {pl_module.global_rank}".upper())
#         for i in range(pl_module.num_examples[mode]):
#             print('\u0332'.join(f"EXAMPLE {i}:"))  # Underline output
#             sent_id = np.random.randint(len(self._translations))
#             print(f"Ground truth: {self._ground_truths[sent_id]}\n")
#             print(f"Translation: {self._translations[sent_id]}\n")
#             print()
#         print("-" * 50)
#         print(f"loss: {eval_loss:.3f}")
#         print(f"SacreBLEU: {sacre_bleu}")
#         print("-" * 50)

#     @rank_zero_only
#     def on_test_end(self, trainer, pl_module):
#         self._on_eval_end(trainer, pl_module, "test")

#     @rank_zero_only
#     def on_validation_end(self, trainer, pl_module):
#         self._on_eval_end(trainer, pl_module, "val")

#     @rank_zero_only
#     def on_sanity_check_end(self, trainer, pl_module):
#         self._on_eval_end(trainer, pl_module, "val")

#     def _on_eval_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx, mode):
#         self._translations.extend(outputs['translations'])
#         self._ground_truths.extend(outputs['ground_truths'])
#         self._non_pad_tokens.append(outputs['num_non_pad_tokens'])
#         self._losses.append(outputs[f'{mode}_loss'])

#     @rank_zero_only
#     def on_test_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx):
#         self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'test')

#     @rank_zero_only
#     def on_validation_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx):
#         self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'val')

#     def _on_eval_start(self, trainer, pl_module):
#         self._translations = []
#         self._ground_truths = []
#         self._losses = []
#         self._non_pad_tokens = []

#     @rank_zero_only
#     def on_test_start(self, trainer, pl_module):
#         self._on_eval_start(trainer, pl_module)

#     @rank_zero_only
#     def on_validation_start(self, trainer, pl_module):
#         self._on_eval_start(trainer, pl_module)

#     @rank_zero_only
#     def on_sanity_check_start(self, trainer, pl_module):
#         self._on_eval_start(trainer, pl_module)
