o
    }oi4l                     @   s   d Z ddlZddlmZmZ ddlZddlmZ ddlm	Z	 ddl
mZ ddlmZmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZmZ ddl m!Z! ddl"m#Z# dgZ$G dd deZ%dS )z
This file contains code artifacts adapted from the original implementation:
https://github.com/google-research/google-research/blob/master/schema_guided_dst/baseline/train_and_predict.py
    N)ListOptional)Trainer)
DictConfig)
DataLoader)DialogueSGDBERTDatasetDialogueSGDDataProcessor)evaluateget_in_domain_services)write_predictions_to_file)SGDDialogueStateLoss)NLPModel)
SGDDecoder
SGDEncoder)tensor2list)PretrainedModelInfo	typecheck)logging)deprecated_warning
SGDQAModelc                	       s|  e Zd ZdZedd Zd4dedef fddZe	 d	d
 Z
dd Zd5deej dededefddZd5deej dededefddZdeej fddZd5dee defddZd5dee defddZdee dedejjjdefdd Zd!d" Zd#ed$efd%d&Zd4d'ee fd(d)Zd4d*ee fd+d,Zd4d-ee fd.d/Z dededefd0d1Z!e"dee# fd2d3Z$  Z%S )6r   a  
    Dialogue State Tracking Model SGD-QA (https://arxiv.org/abs/2105.08049)

    The SGD-QA model is a fast multi-pass schema-guided state-tracking model, that is trained on the Google schema-guided state tracking dataset (https://arxiv.org/abs/1909.05855).
    The model takes dialogue as input and outputs the dialogue state, which includes slot-value pairs.
    The model consists of two components: a neural natural language understanding model (NLU), and a rule-based state tracker.
    The NLU takes in a dialogue turn and different schema (entity) information options and outputs their match score. The state tracker takes the highest rated entities and composes
    the dialogue state across turns.
    c                 C   s   | j S N)decoderself r   d/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/nlp/models/dialogue/sgdqa_model.pyoutput_module6   s   zSGDQAModel.output_moduleNcfgtrainerc                    sZ   t d d| _t j||d t| jjj| jj	j
d| _	t| jjjd| _tdd| _d S )Nr   F)r   r   )hidden_sizedropout)embedding_dimmean)	reduction)r   data_preparedsuper__init__r   
bert_modelconfigr   _cfgencoderr    r   r   r   loss)r   r   r   	__class__r   r   r&   :   s   zSGDQAModel.__init__c                 C   s^   | j |||d}t|tr|d }| j|d\}}| j|||d\}}}}	}
}||||	|
|fS )N	input_idstoken_type_idsattention_maskr   )hidden_states)encoded_utterancetoken_embeddingsutterance_mask)r'   
isinstancetupler*   r   )r   r/   r1   r0   r4   r3   logit_intent_statuslogit_req_slot_statuslogit_cat_slot_statuslogit_cat_slot_value_statuslogit_noncat_slot_statuslogit_spansr   r   r   forwardD   s.   
zSGDQAModel.forwardc                 C   s   |\}}}}}}}	}
}}}}}}}| |||d\}}}}}}| j ||||	||
||||||||d}| jjd d }| d| | jd|dd ||dS )	Nr.   r8   intent_statusr9   requested_slot_statusr:   categorical_slot_statusr;   categorical_slot_value_statusr<   noncategorical_slot_statusr=   noncategorical_slot_value_startnoncategorical_slot_value_end	task_maskr   lr
train_lossT)prog_bar)r+   rH   )r+   
_optimizerparam_groupslog)r   batch	batch_idxexample_id_num
service_idutterance_idsr0   r1   r@   rA   rB   rC   rD   rE   rF   start_char_idxend_char_idxrG   r8   r9   r:   r;   r<   r=   r+   rH   r   r   r   training_step`   s^   	zSGDQAModel.training_stepr   rN   rO   dataloader_idxreturnc                 C   sv   | j |d\}}| d| t| jjtkr+t| jjdkr+| j| d|d|i n
| jd|d|i d|d|iS )z
        Called at every validation step to aggregate and postprocess outputs on each GPU
        Args:
            batch: input batch at validation step
            batch_idx: batch index
            dataloader_idx: dataloader index
        rN   val_loss   tensors)	eval_step_helperrM   typer   val_dataloaderslistlenvalidation_step_outputsappendr   rN   rO   rV   r+   r[   r   r   r   validation_step   s    zSGDQAModel.validation_stepc                 C   sj   | j |d\}}t| jjtkr%t| jjdkr%| j| d|d|i n
| jd|d|i d|d|iS )z
        Called at every test step to aggregate and postprocess outputs on each GPU
        Args:
            batch: input batch at test step
            batch_idx: batch index
            dataloader_idx: dataloader index
        rX   rZ   	test_lossr[   )r\   r]   r   test_dataloadersr_   r`   test_step_outputsrb   rc   r   r   r   	test_step   s
    zSGDQAModel.test_stepc           ;      C   sp  |\}}}}}}}}	}
}}}}}}| |||d\}}}}}}| j ||||||	||
||||||d}g }g }g }g }g }g }g }g }g } g }!| jjr| jjdkr| jj}"t|"D ]R}#|t| |t| |t| |t| |t| |t| |t| |t| | t| |!t| qWtj	|| tj	|| tj	|| tj	|| tj	|| tj	|| tj	|| tj	|| tj	| | tj	|!| n2|| || || || || || || || | | |!| t
|}t
|}t
|}t
|}t
|}t
|}t
|}t
|}t
| }t
|!}tj |}tj |}$tjjdd|}%tj|dd}&tj|%ddd }'tj |}(tjjdd|})tj|dd}*tj|)ddd }+tjjdd},|,|}-tj|-dd\}.}/|/ \}0}1tj|.ddtj|/dd }2tj|1|2 d	ddd}3tj|1|2 d	ddd}4|3|4k|0dd}5t|5tj|2 |2 |2jd
|2}2tj|2d|1d dd}6tj|2d|1d ddd }7t|6|1}8t|6|1}9||||$|&|'|(|*|+|7|8|9||d}:||:fS )a.  
        Helper called at every validation/test step to aggregate and postprocess outputs on each GPU
        Args:
            batch: input batch at step
        Returns:
            loss: averaged batch loss
            tensors: collection of aggregated output tensors across all GPU workers
        r.   r?   rZ   )dim)axisr      )device)rm   dtype)rP   rQ   r@   req_slot_statuscat_slot_statuscat_slot_status_pcat_slot_value_statusnoncat_slot_statusnoncat_slot_status_pnoncat_slot_pnoncat_slot_startnoncat_slot_endnoncat_alignment_startnoncat_alignment_end)r+   r   num_devices
world_sizerangerb   torch
empty_likedistributed
all_gathercatnnSigmoidSoftmaxargmaxmaxunbindsize	unsqueezearange
get_deviceviewrepeatwherezerosrn   floor_dividefmod);r   rN   rP   rQ   rR   r0   r1   r@   rA   rB   rC   rD   rE   rF   rS   rT   rG   r8   r9   r:   r;   r<   r=   r+   all_example_id_numall_service_idall_logit_intent_statusall_logit_req_slot_statusall_logit_cat_slot_statusall_logit_cat_slot_value_statusall_logit_noncat_slot_statusall_logit_spansall_start_char_idxall_end_char_idxr{   indro   cat_slot_status_distrp   rq   rr   noncat_slot_status_distrs   rt   softmaxscoresstart_scores
end_scores
batch_sizemax_num_tokenstotal_scores	start_idxend_idxinvalid_index_maskmax_span_index
max_span_pspan_start_indexspan_end_indexr[   r   r   r   r\      s  	



















zSGDQAModel.eval_step_helperoutputsc           	      C      t dd |D  }| j| dd }| j| }| j|||d}| D ]\}}| j| d| |dd q&| jd	|ddd
 dS )z
        Called at the end of validation to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each validation step
            dataloader_idx: dataloader index
        c                 S      g | ]}|d  qS )rY   r   .0xr   r   r   
<listcomp>d      z9SGDQAModel.multi_validation_epoch_end.<locals>.<listcomp>Nri   r   split
dataloader_Trank_zero_onlyrY   rJ   r   )r}   stackr"   _validation_names_validation_dlmulti_eval_epoch_end_helperitemsrM   	r   r   rV   avg_lossr   r   metricskvr   r   r   multi_validation_epoch_end]     
z%SGDQAModel.multi_validation_epoch_endc           	      C   r   )z
        Called at the end of test to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each test step
            dataloader_idx: dataloader index
        c                 S   r   )re   r   r   r   r   r   r   u  r   z3SGDQAModel.multi_test_epoch_end.<locals>.<listcomp>Nri   r   r   Tr   re   r   )r}   r   r"   _test_names_test_dlr   r   rM   r   r   r   r   multi_test_epoch_endn  r   zSGDQAModel.multi_test_epoch_endr   r   c              	   C   s  dt dtdtjdt fdd}dtdtfd	d
}tdd |D }tdd |D }tdd |D }tdd |D }	tdd |D }
tdd |D }tdd |D }tdd |D }tdd |D }tdd |D }tdd |D }tdd |D }tdd |D }tdd |D }| jjj}||j	||}i }z| j
jdur| j
jnd}W n   d}Y | j
jdkrgtj|dd|| jj	j}tj|dd t| jj	j|| jj	j}i }||d < ||d!< ||d"< |	|d#< |
|d$< ||d%< ||d&< ||d'< ||d(< ||d)< ||d*< ||d+< ||d,< ||d-< ttj| jj	j|d.| jd/}|||jd }t|||| jj| jj	jd0|d1 t|| jj	j||| jj	j| jj	jd2}|S )3a0  
        Helper called at the end of evaluation to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each step
            split: data split
            dataloader: dataloader
        Returns:
            metrics: metrics collection
        r   ids_to_service_names_dictrP   rW   c                    s     fdd}t t|t|S )z
            Constructs string representation of example ID
            Args:
                split: evaluation data split
                ids_to_service_names_dict: id to service name mapping
                example_id_num: tensor example id
            c              
      s.   | \}}}}}}}d ||| | |||S )Nz{}-{}_{:05d}-{:02d}-{}-{}-{}-{})format)	ex_id_numdialog_id_1dialog_id_2turn_idrQ   model_task_idslot_intent_idvalue_idr   r   r   r   format_turn_id  s   zZSGDQAModel.multi_eval_epoch_end_helper.<locals>.get_str_example_id.<locals>.format_turn_id)r_   mapr   )r   r   rP   r   r   r   r   get_str_example_id  s   	zBSGDQAModel.multi_eval_epoch_end_helper.<locals>.get_str_example_idpredictionsr   c                 S   sz   dd t |D }|  D ]-\}}|dkrt||}t |D ]}|dkr.|| || |< q|| d|| |< qq|S )a   
            Combines predicted values to a single example.
            Args:
                predictions: predictions ordered by keys then batch
                batch_size: batch size
            Returns:
                examples_preds: predictions ordered by batch then key
            c                 S   s   g | ]}i qS r   r   )r   r   r   r   r   r     s    zbSGDQAModel.multi_eval_epoch_end_helper.<locals>.combine_predictions_in_example.<locals>.<listcomp>
example_idri   )r|   r   r}   chunkr   )r   r   examples_predsr   r   ir   r   r   combine_predictions_in_example  s   	zNSGDQAModel.multi_eval_epoch_end_helper.<locals>.combine_predictions_in_examplec                 S      g | ]}|d  d qS )r[   rP   r   r   r   r   r   r         z:SGDQAModel.multi_eval_epoch_end_helper.<locals>.<listcomp>c                 S   r   )r[   rQ   r   r   r   r   r   r     r   c                 S   r   )r[   r@   r   r   r   r   r   r     r   c                 S   r   )r[   ro   r   r   r   r   r   r     r   c                 S   r   )r[   rp   r   r   r   r   r   r     r   c                 S   r   )r[   rq   r   r   r   r   r   r     r   c                 S   r   )r[   rr   r   r   r   r   r   r     r   c                 S   r   )r[   rs   r   r   r   r   r   r     r   c                 S   r   )r[   rt   r   r   r   r   r   r     r   c                 S   r   )r[   ru   r   r   r   r   r   r     r   c                 S   r   )r[   rv   r   r   r   r   r   r     r   c                 S   r   )r[   rw   r   r   r   r   r   r     r   c                 S   r   )r[   rx   r   r   r   r   r   r     r   c                 S   r   )r[   ry   r   r   r   r   r   r     r   N r   zpred_res_{}_{}T)exist_okr   rQ   r@   ro   rp   rq   rr   rs   rt   ru   rv   rw   rx   ry   zschema.jsontrainF)
output_dirschemasstate_tracker
eval_debugin_domain_services)joint_acc_across_turnuse_fuzzy_match)strdictr}   Tensorintr   dialogues_processorr   _services_id_to_vocabdatasetr   log_dirglobal_rankospathjoinr   r)   	task_namemakedirsr   get_dialogue_filesdata_dirr
   get_seen_servicesshaper   r   r	   r   r   )r   r   r   r   r   r   rP   rQ   r@   ro   rp   rq   rr   rs   rt   ru   rv   rw   rx   ry   r   r   r   prediction_dirinput_json_filesr   r   r   r   r   r     s   

		z&SGDQAModel.multi_eval_epoch_end_helperc                 C   s8   | j rdS t| jjj| jjj| j| jjd| _d| _ dS )zC
        Preprocessed schema and dialogues and caches this
        N)r   dialogues_example_dir	tokenizerr   T)r$   r   r)   r   r   r   r   r   r   r   r   r   prepare_data  s   
zSGDQAModel.prepare_datar   r   c                 C   sV   t j|st| d|| jj_|| jj_t	d| d t	d| d dS )z
        Update data directories

        Args:
            data_dir: path to data directory
            dialogues_example_dir: path to preprocessed dialogues example directory, if not exists will be created.
        z is not foundz"Setting model.dataset.data_dir to .z/Setting model.dataset.dialogues_example_dir to N)
r   r   exists
ValueErrorr)   r   r   r   r   info)r   r   r   r   r   r   update_data_dirs  s   

zSGDQAModel.update_data_dirstrain_data_configc                 C      |    | j||jd| _d S N)r   r   )r   _setup_dataloader_from_configds_item	_train_dl)r   r  r   r   r   setup_training_data%     zSGDQAModel.setup_training_dataval_data_configc                 C   r  r  )r   r	  r
  r   )r   r  r   r   r   setup_validation_data)  r  z SGDQAModel.setup_validation_datatest_data_configc                 C   r  r  )r   r	  r
  r   )r   r  r   r   r   setup_test_data-  r  zSGDQAModel.setup_test_datac              	   C   sx   | j j}|j}tj|std| dt|| j| jj	| jj
| jj|d}tjjj||j|j|j|j|j|jd}|S )Nz Data directory is not found at: r  )dataset_splitr   r   r   schema_configr   )r   r   
collate_fn	drop_lastshufflenum_workers
pin_memory)r)   r   r   r   r   r  FileNotFoundErrorr   r   
_tokenizerr   r  r}   utilsdatar   r   r  r  r  r  r  )r   r   r   dataset_cfgr   r   dlr   r   r   r	  1  s,   		z(SGDQAModel._setup_dataloader_from_configc                 C   s   g }| tdddd |S )z
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        sgdqa_bertbasecasedzrhttps://api.ngc.nvidia.com/v2/models/nvidia/nemo/sgdqa_bertbasecased/versions/1.0.0/files/sgdqa_bertbasecased.nemozDialogue State Tracking model finetuned from NeMo BERT Base Cased on Google SGD dataset which has a joint goal accuracy of 59.72% on dev set and 45.85% on test set.)pretrained_model_namelocationdescription)rb   r   )clsresultr   r   r   list_available_modelsN  s   z SGDQAModel.list_available_modelsr   )r   )&__name__
__module____qualname____doc__propertyr   r   r   r&   r   r>   rU   r   r}   r   r   r   rd   rh   r\   r   r   r   r  r  r   r   r   r  r   r  r  r  r	  classmethodr   r%  __classcell__r   r   r,   r   r   +   s@    



"4" )
 )&r)  r   typingr   r   r}   lightning.pytorchr   	omegaconfr   torch.utils.datar   "nemo.collections.nlp.data.dialoguer   r   /nemo.collections.nlp.data.dialogue.sgd.evaluater	   r
   7nemo.collections.nlp.data.dialogue.sgd.prediction_utilsr   nemo.collections.nlp.lossesr   %nemo.collections.nlp.models.nlp_modelr   nemo.collections.nlp.modulesr   r   &nemo.collections.nlp.parts.utils_funcsr   nemo.core.classes.commonr   r   
nemo.utilsr   nemo.utils.decoratorsr   __all__r   r   r   r   r   <module>   s&   