o
    5tiT                     @   s  d Z ddlZddlZddlZddlZddlZddlmZmZm	Z	m
Z
mZmZ ddlZddlZddlmZmZmZ ddlmZ ddlmZ ddlmZ ddlmZmZmZ dd	lm Z  dd
l!m"Z"m#Z#m$Z$ ddl%m&Z&m'Z'm(Z( ddl)m*Z* ddl+m,Z, ddl-m.Z. ddl/m0Z0 e0e1Z2G dd de"Z3dd Z4G dd dZ5G dd de5Z6G dd de6Z7G dd de6Z8G dd de6Z9G dd  d Z:d$d"d#Z;dS )%z EvaluationModule base class.    N)AnyDictListOptionalTupleUnion)DatasetInfoDownloadConfigDownloadManager)Dataset)ArrowReader)ArrowWriter)FeaturesSequenceValue)#_check_non_null_non_empty_recursive)BaseFileLockFileLockTimeout)copyfunc	temp_seedzip_dict   )config)EvaluationModuleInfo)camelcase_to_snakecase)
get_loggerc                       sB   e Zd ZdZ fddZdd Zdd Zedefd	d
Z	  Z
S )FileFreeLockz-Thread lock until a file **cannot** be lockedc                    s.   t || _t j|g|R i | d | _d S N)r   filelocksuper__init___lock_file_fd)self	lock_fileargskwargs	__class__ C/home/ubuntu/.local/lib/python3.10/site-packages/evaluate/module.pyr!   /   s   

zFileFreeLock.__init__c                 C   sH   z
| j jddd W n ty   | j j| _Y d S w | j   d | _d S )Ng{Gz?g{Gz?)timeoutpoll_interval)r   acquirer   r$   r"   releaser#   r)   r)   r*   _acquire4   s   

zFileFreeLock._acquirec                 C   s
   d | _ d S r   r"   r/   r)   r)   r*   _release?      
zFileFreeLock._releasereturnc                 C   s
   | j d uS r   r1   r/   r)   r)   r*   	is_lockedB   s   
zFileFreeLock.is_locked)__name__
__module____qualname____doc__r!   r0   r2   propertyboolr5   __classcell__r)   r)   r'   r*   r   ,   s    r   c                 C   sP   t | tust| dkr|  S dd }d|| d d  d|| dd   dS )	N   c                 S   s   d dd | D S )Nz, c                 s   s    | ]}t |V  qd S r   )repr).0xr)   r)   r*   	<genexpr>N       z?summarize_if_long_list.<locals>.format_chunk.<locals>.<genexpr>)join)chunkr)   r)   r*   format_chunkM   s   z,summarize_if_long_list.<locals>.format_chunk[   z, ..., ])typelistlen)objrE   r)   r)   r*   summarize_if_long_listI   s   *rN   c                   @   s0  e Zd ZdZdefddZedd Zedefdd	Z	ede
e fd
dZedefddZedefddZedefddZedefddZede
e fddZedefddZede
ee  fddZede
ee  fddZedefddZede
e fddZedefd d!Zd"S )#EvaluationModuleInfoMixinzThis base class exposes some attributes of EvaluationModuleInfo
    at the base level of the EvaluationModule for easy access.
    infoc                 C   s
   || _ d S r   _module_info)r#   rP   r)   r)   r*   r!   X   r3   z"EvaluationModuleInfoMixin.__init__c                 C   s   | j S )zc:class:`evaluate.EvaluationModuleInfo` object containing all the metadata in the evaluation module.rQ   r/   r)   r)   r*   rP   [   s   zEvaluationModuleInfoMixin.infor4   c                 C      | j jS r   )rR   module_namer/   r)   r)   r*   name`      zEvaluationModuleInfoMixin.namec                 C   rS   r   )rR   experiment_idr/   r)   r)   r*   rW   d   rV   z'EvaluationModuleInfoMixin.experiment_idc                 C   rS   r   )rR   descriptionr/   r)   r)   r*   rX   h   rV   z%EvaluationModuleInfoMixin.descriptionc                 C   rS   r   )rR   citationr/   r)   r)   r*   rY   l   rV   z"EvaluationModuleInfoMixin.citationc                 C   rS   r   )rR   featuresr/   r)   r)   r*   rZ   p   rV   z"EvaluationModuleInfoMixin.featuresc                 C   rS   r   )rR   inputs_descriptionr/   r)   r)   r*   r[   t   rV   z,EvaluationModuleInfoMixin.inputs_descriptionc                 C   rS   r   )rR   homepager/   r)   r)   r*   r\   x   rV   z"EvaluationModuleInfoMixin.homepagec                 C   rS   r   )rR   licenser/   r)   r)   r*   r]   |   rV   z!EvaluationModuleInfoMixin.licensec                 C   rS   r   )rR   codebase_urlsr/   r)   r)   r*   r^      rV   z'EvaluationModuleInfoMixin.codebase_urlsc                 C   rS   r   )rR   reference_urlsr/   r)   r)   r*   r_      rV   z(EvaluationModuleInfoMixin.reference_urlsc                 C   rS   r   )rR   
streamabler/   r)   r)   r*   r`      rV   z$EvaluationModuleInfoMixin.streamablec                 C   rS   r   )rR   formatr/   r)   r)   r*   ra      rV   z EvaluationModuleInfoMixin.formatc                 C   rS   r   )rR   module_typer/   r)   r)   r*   rb      rV   z%EvaluationModuleInfoMixin.module_typeN)r6   r7   r8   r9   r   r!   r:   rP   strrU   r   rW   rX   rY   r   rZ   r[   r\   r]   r   r^   r_   r;   r`   ra   rb   r)   r)   r)   r*   rO   S   s>    
rO   c                   @   s  e Zd ZdZ										dCdee d	ed
ee dededee dee dededeee	f fddZ
dd Zdd Zdd ZdDdeeef fddZdeee ee f fddZdd  Zd!d" Zd#d$ Zddd%dee fd&d'Zddd%d(d)Zddd*d+d,Zd-d. Zd/d0 Zd1d2 ZdDd3d4Zdefd5d6Z		dEd7ee  d8ee! fd9d:Z"d;d< Z#ddd%de$ee%f fd=d>Z&d?d@ Z'dAdB Z(dS )FEvaluationModuleaa  A `EvaluationModule` is the base class and common API for metrics, comparisons, and measurements.

    Args:
        config_name (`str`):
            This is used to define a hash specific to a module computation script and prevents the module's data
            to be overridden when the module loading script is modified.
        keep_in_memory (`bool`):
            Keep all predictions and references in memory. Not possible in distributed settings.
        cache_dir (`str`):
            Path to a directory in which temporary prediction/references data will be stored.
            The data directory should be located on a shared file-system in distributed setups.
        num_process (`int`):
            Specify the total number of nodes in a distributed settings.
            This is useful to compute module in distributed setups (in particular non-additive modules like F1).
        process_id (`int`):
            Specify the id of the current process in a distributed setup (between 0 and num_process-1)
            This is useful to compute module in distributed setups (in particular non-additive metrics like F1).
        seed (`int`, optional):
            If specified, this will temporarily set numpy's random seed when [`~evaluate.EvaluationModule.compute`] is run.
        experiment_id (`str`):
            A specific experiment id. This is used if several distributed evaluations share the same file system.
            This is useful to compute module in distributed setups (in particular non-additive metrics like F1).
        hash (`str`):
            Used to identify the evaluation module according to the hashed file contents.
        max_concurrent_cache_files (`int`):
            Max number of concurrent module cache files (default `10000`).
        timeout (`Union[int, float]`):
            Timeout in second for distributed setting synchronization.
    NFr   r   '  d   config_namekeep_in_memory	cache_dirnum_process
process_idseedrW   hashmax_concurrent_cache_filesr+   c                 K   s  |pd| _ |  }t| jj|_| j |_ |pd|_t| | t	|t
r(|dk r,tdt	|t
r5||kr9td|rC|dkrCtd|| _|| _|	| _|| _tj|pVtj| _|  | _|d u rytj ^}}}}|dk rs|| n|d | _n|| _|
| _tt| j | | _ tt| j!| | _!tt| j"| | _"| j j# j$| j%j&7  _$| j!j# j$| j%j&7  _$| j"j# j$| j%j&7  _$d | _'d | _(d | _)d | _*d | _+d | _,d | _-d | _.d | _/d | _0|| _1d S )	Ndefaultdefault_experimentr   z.'process_id' should be a number greater than 0z8'num_process' should be a number greater than process_idr   zPUsing 'keep_in_memory' is not possible in distributed setting (num_process > 1).ip  )2rg   _infor   r(   r6   rT   rW   rO   r!   
isinstanceint
ValueErrorrj   rk   rn   rh   ospath
expanduserr   HF_METRICS_CACHE_data_dir_root_build_data_dirdata_dirnprandom	get_staterl   r+   types
MethodTyper   compute	add_batchadd__func__r9   rP   r[   selected_feature_format
buf_writerwriterwriter_batch_sizedatacache_file_namer   rendez_vous_lock
file_paths	filelocks_hash)r#   rg   rh   ri   rj   rk   rl   rW   rm   rn   r+   r&   rP   _posr)   r)   r*   r!      sP   



zEvaluationModule.__init__c                 C   s   | j du rdS t| j S )zReturn the number of examples (predictions or predictions/references pair)
        currently stored in the evaluation module's cache.
        Nr   )r   rL   r/   r)   r)   r*   __len__   s   zEvaluationModule.__len__c                 C   s0   d| j  d| j d| j d| j dt|  dS )NzEvaluationModule(name: "z", module_type: "z", features: z, usage: """z""", stored examples: ))rU   rb   rZ   r[   rL   r/   r)   r)   r*   __repr__  s   zEvaluationModule.__repr__c                 C   s,   | j }tj|| j| j}tj|dd |S )a  Path of this evaluation module in cache_dir:
        Will be:
            self._data_dir_root/self.name/self.config_name/self.hash (if not none)/
        If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
        T)exist_ok)ry   ru   rv   rC   rU   rg   makedirs)r#   builder_data_dirr)   r)   r*   rz   
  s   z EvaluationModule._build_data_dirr4   c                 C   s   t j| j| j d| j d| j d}d}t| jD ]^}t	|d }z|j
|d W nI tyt   | jdkrDtd| d| j d	d|| jd krUtd
| j ddtt }t j| j| j d| d| j d| j d}Y qw  ||fS ||fS )zTCreate a new cache file. If the default cache file is used, we generated a new hash.-.arrowN.lockr+   r   ziError in _create_cache_file: another evaluation module instance is already using the local cache file at .. Please specify an experiment_id (currently: E) to avoid collision between distributed evaluation module instances.zCannot acquire lock, too many evaluation module instance are operating concurrently on this file system.You should set a larger value of max_concurrent_cache_files when creating the evaluation module (current value is z).)ru   rv   rC   r{   rW   rj   rk   rangern   r   r-   r   rt   rc   uuiduuid4)r#   r+   	file_pathr   i	file_uuidr)   r)   r*   _create_cache_file  s<   (
$z#EvaluationModule._create_cache_filec              
      s    j dkr jdu rtd jg}n fddt j D }g }t|D ]7\}}|dkr4| j q%t|d }z	|j j	d W n t
yV   td	| d
| ddw || q%||fS )zGet a lock on all the cache files in a distributed setup.
        We wait for timeout second to let all the distributed node finish their tasks (default is 100 seconds).
        r   NzEvaluation module cache file doesn't exist. Please make sure that you call `add` or `add_batch` at least once before calling `compute`.c                    2   g | ]}t j j j d  j d | dqS )r   r   ru   rv   rC   r{   rW   rj   r?   rk   r/   r)   r*   
<listcomp>B      $z9EvaluationModule._get_all_cache_files.<locals>.<listcomp>r   r   r   z#Cannot acquire lock on cached file z for process .)rj   r   rt   r   	enumerateappendr   r   r-   r+   r   )r#   r   r   rk   r   r   r)   r/   r*   _get_all_cache_files6  s2   



z%EvaluationModule._get_all_cache_filesc              
      sr    fddt  jD }|D ](}t|}z	|j jd W n ty1   td| d j dd w |  qd S )Nc                    r   )r   z.arrow.lockr   r   r/   r)   r*   r   Z  r   z?EvaluationModule._check_all_processes_locks.<locals>.<listcomp>r   Expected to find locked file  from process  but it doesn't exist.)	r   rj   r   r-   r+   r   rt   rk   r.   )r#   expected_lock_file_namesexpected_lock_file_name
nofilelockr)   r/   r*   _check_all_processes_locksY  s    

z+EvaluationModule._check_all_processes_locksc              	   C   s   t j| j| j d| j d}t|}z	|j| jd W n t	y1   t
d| d| j dd w |  t j| j| j d| j d}t|}z	|j| jd W n t	yg   t
d| d| j d	d w |  d S )
Nr   z-0.arrow.lockr   r   r   r   	-rdv.lockzCouldn't acquire lock on r   )ru   rv   rC   r{   rW   rj   r   r-   r+   r   rt   rk   r.   r   )r#   r   r   lock_file_namer   r)   r)   r*   _check_rendez_vousi  s(     z#EvaluationModule._check_rendez_vousc                 C   s   | j dur
| j   d| _ | jdur| jdkr| j  | jr5t| jt| j	dd}t
| j | _dS | jdkrq|  \}}ztdt| j	dd}t
d	i |dd |D | _W n tyh   tddw || _|| _dS dS )
zClose all the writing process and load/gather the data
        from all the nodes if main node or all_process is True.
        Nr   )rZ   )rv   rP    c                 S   s   g | ]}d |iqS )filenamer)   )r?   fr)   r)   r*   r         z.EvaluationModule._finalize.<locals>.<listcomp>zError in finalize: another evaluation module instance is already using the local cache file. Please specify an experiment_id to avoid collision between distributed evaluation module instances.r)   )r   finalizer   rk   r.   rh   r   r{   r   r   r   from_bufferr   getvaluer   r   
read_filesFileNotFoundErrorrt   r   r   )r#   readerr   r   r)   r)   r*   	_finalize}  s.   



$
zEvaluationModule._finalizepredictions
referencesc                   s  ||d |du r |du r  fdd  D } | n fdd  D }|r;td| dt    fd	d  D }fd
dD }tdd | D rcjdi |   d_d_	d_
jdkrjjjjd fdd  D }tj jdi ||}W d   n1 sw   Y  jdurd_`d_|S tttjjD ]\}	}
td|
  `d_`d_t|
 |	  q|S dS )a  Compute the evaluation module.

        Usage of positional arguments is not allowed to prevent mistakes.

        Args:
            predictions (`list/array/tensor`, *optional*):
                Predictions.
            references (`list/array/tensor`, *optional*):
                References.
            **kwargs (optional):
                Keyword arguments that will be forwarded to the evaluation module [`~evaluate.EvaluationModule.compute`]
                method (see details in the docstring).

        Return:
            `dict` or `None`

            - Dictionary with the results if this evaluation module is run on the main process (`process_id == 0`).
            - `None` if the evaluation module is not run on the main process (`process_id != 0`).

        ```py
        >>> import evaluate
        >>> accuracy =  evaluate.load("accuracy")
        >>> accuracy.compute(predictions=[0, 1, 1, 0], references=[0, 1, 0, 1])
        ```
        r   Nc                    s   i | ]	}| vr|d qS r   r)   r?   k
all_kwargsr)   r*   
<dictcomp>  s    z,EvaluationModule.compute.<locals>.<dictcomp>c                    s   g | ]}| vr|qS r)   r)   r   r   r)   r*   r         z,EvaluationModule.compute.<locals>.<listcomp>z&Evaluation module inputs are missing: . All required inputs are c                       i | ]}| | qS r)   r)   r?   
input_namer   r)   r*   r         c                    s"   i | ]}|  vr| | qS r)   _feature_namesr   )r&   r#   r)   r*   r     s   " c                 s   s    | ]}|d uV  qd S r   r)   )r?   vr)   r)   r*   rA     rB   z+EvaluationModule.compute.<locals>.<genexpr>r   )rJ   c                    s    i | ]}| j | d d  qS r   )r   r   r/   r)   r*   r          z	Removing r)   )r   updatert   rK   anyvaluesr   r   r   r   r   rk   r   
set_formatrP   ra   r   rl   _computer   reversedzipr   r   loggerr   ru   remover.   )r#   r   r   r&   missing_kwargsmissing_inputsinputscompute_kwargsoutputr   r   r)   )r   r&   r#   r*   r     sL   



zEvaluationModule.computec          
         s  fdd|D }|rt d| dt  ||d|  fdd D  jdu r: _  z)  D ]\}}t|d	krT	j| |d	  q?j
  j  W dS  tjtfy   t fd
d D rtt  fdd D d	 }d dt   d| dt |  d	}n2tjddhkrdj d}d fddjD }	||	7 }ndj dt| dt| }t |dw )a  Add a batch of predictions and references for the evaluation module's stack.

        Args:
            predictions (`list/array/tensor`, *optional*):
                Predictions.
            references (`list/array/tensor`, *optional*):
                References.

        Example:

        ```py
        >>> import evaluate
        >>> accuracy = evaluate.load("accuracy")
        >>> for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
        ...     accuracy.add_batch(references=refs, predictions=preds)
        ```
        c                       g | ]
}|   vr|qS r)   r   r   r/   r)   r*   r         z.EvaluationModule.add_batch.<locals>.<listcomp>"Bad inputs for evaluation module: r   r   c                    r   r)   r)   r   batchr)   r*   r      r   z.EvaluationModule.add_batch.<locals>.<dictcomp>Nr   c                 3   s0    | ]}t  | t tt  kV  qd S r   )rL   nextiterr   r?   cr   r)   r*   rA     s   . z-EvaluationModule.add_batch.<locals>.<genexpr>c                    s(   g | ]}t  | t   kr|qS r)   )rL   r   )r   col0r)   r*   r     s   ( zMismatch in the number of z (z) and r   r   r   z@Module inputs don't match the expected format.
Expected format: ,
c                 3   (    | ]}d | dt  |  V  qdS zInput : NrN   r   r   r)   r*   rA     
    
zPPredictions and/or references don't match the expected format.
Expected format: ,
Input predictions: ,
Input references: )rt   rK   r   r   _infer_feature_from_batchr   _init_writeritemsrL   _enforce_nested_string_typeencode_batchwrite_batchpaArrowInvalid	TypeErrorr   r   r   setrC   rN   )
r#   r   r   r&   
bad_inputskeycolumnbad_col	error_msgerror_msg_inputsr)   )r   r   r#   r*   r     sT   
,

zEvaluationModule.add_batch)
prediction	referencec             	      s   fdd|D }|rt d| dt  ||d|  fdd D  jdu r: _  zj  j  j	  W dS  t
jtfyw   d	j d
}d
 fddjD }||7 }t |dw )a  Add one prediction and reference for the evaluation module's stack.

        Args:
            prediction (`list/array/tensor`, *optional*):
                Predictions.
            reference (`list/array/tensor`, *optional*):
                References.

        Example:

        ```py
        >>> import evaluate
        >>> accuracy = evaluate.load("accuracy")
        >>> accuracy.add(references=[0,1], predictions=[1,0])
        ```
        c                    r   r)   r   r   r/   r)   r*   r   5  r   z(EvaluationModule.add.<locals>.<listcomp>r   r   r   c                    r   r)   r)   r   exampler)   r*   r   ;  r   z(EvaluationModule.add.<locals>.<dictcomp>NzKEvaluation module inputs don't match the expected format.
Expected format: r   c                 3   r   r   r   r   r  r)   r*   rA   H  r   z'EvaluationModule.add.<locals>.<genexpr>)rt   rK   r   r   _infer_feature_from_exampler   r   r   encode_examplewriter   r   r   rC   )r#   r  r  r&   r   r  r  r)   )r  r#   r*   r   $  s2   

zEvaluationModule.addc                 C   s2   t | jtr	| jS tdd | D }| |S )Nc                 S   s   g | ]
\}}||d  fqS r   r)   )r?   r   r   r)   r)   r*   r   S  r   z>EvaluationModule._infer_feature_from_batch.<locals>.<listcomp>)rr   rZ   r   dictr   r  )r#   r   r  r)   r)   r*   r   O  s   
z*EvaluationModule._infer_feature_from_batchc              
   C   s   t | jtr	| jS | jD ]}z| || || |W   S  ttfy)   Y qw ddd t| jD }d| dt	|d  dt	|d  }t|d )	N
c                 S   s    g | ]\}}d | d| qS )zFeature option r   r)   )r?   r   featurer)   r)   r*   r   a  r   z@EvaluationModule._infer_feature_from_example.<locals>.<listcomp>zPPredictions and/or references don't match the expected format.
Expected format:
r   r   r   r   )
rr   rZ   r   r   r	  rt   r   rC   r   rN   )r#   r  rZ   feature_stringsr  r)   r)   r*   r  V  s(   





z,EvaluationModule._infer_feature_from_examplec                 C   s4   t | jtrt| jd  }|S t| j }|S )Nr   )rr   rZ   rK   keys)r#   feature_namesr)   r)   r*   r   j  s
   zEvaluationModule._feature_namesc              	   C   s$  | j dkr=| jdkr=tj| j| j d| j  d}t|| _z	| jj	|d W n t
y<   td| d| j dd w | jrQt | _t| j| j| jd	| _n$d | _| jd u s^| jd u rj|  \}}|| _|| _t| j| j| jd
| _| j dkr| jdkr|   | j  d S |   d S d S )Nr   r   r   r   r   zbError in _init_writer: another evalution module instance is already using the local cache file at r   r   )rZ   streamr   )rZ   rv   r   )rj   rk   ru   rv   rC   r{   rW   r   r   r-   TimeoutErrorrt   rh   r   BufferOutputStreamr   r   r   r   r   r   r   r   r   r.   r   )r#   r+   r   r   r   r)   r)   r*   r   q  sH   

 




zEvaluationModule._init_writerc                 C      t )a-  Construct the EvaluationModuleInfo object. See `EvaluationModuleInfo` for details.

        Warning: This function is only called once and the result is cached for all
        following .info() calls.

        Returns:
            info: (EvaluationModuleInfo) The EvaluationModule information
        NotImplementedErrorr/   r)   r)   r*   rq        	zEvaluationModule._infodownload_config
dl_managerc                 C   sN   |du r |du rt  }tj| jd|_d|_t| j|| jd}| 	| dS )a|  Downloads and prepares evaluation module for reading.

        Args:
            download_config ([`DownloadConfig`], *optional*):
                Specific download configuration parameters.
            dl_manager ([`DownloadManager`], *optional*):
                Specific download manager to use.

        Example:

        ```py
        >>> import evaluate
        ```
        N	downloadsF)dataset_namer  r{   )
r	   ru   rv   rC   r{   ri   force_downloadr
   rU   _download_and_prepare)r#   r  r  r)   r)   r*   download_and_prepare  s   
z%EvaluationModule.download_and_preparec                 C   s   dS )ao  Downloads and prepares resources for the evaluation module.

        This is the internal implementation to overwrite called when user calls
        `download_and_prepare`. It should download all required resources for the evaluation module.

        Args:
            dl_manager (:class:`DownloadManager`): `DownloadManager` used to download and cache data.
        Nr)   )r#   r  r)   r)   r*   r    r  z&EvaluationModule._download_and_preparec                K   r  )zOThis method defines the common API for all the evaluation module in the libraryr  )r#   r   r   r&   r)   r)   r*   r     s   zEvaluationModule._computec                 C   s`   t | dr| jd ur| j  t | dr| jd ur| j  t | dr%| `t | dr.| `d S d S )Nr   r   r   r   )hasattrr   r.   r   r   r   r/   r)   r)   r*   __del__  s   



zEvaluationModule.__del__c                    s  t |tr fddt||D S t |ttfr&|d  fdd|D S t |trt |jtrt |ttfrbt|jg|R  D ]\}}|dd D ]}t||d r^ |d |  nqKqAdS t|j|D ]\}\}|D ]}t|r |  nqpqhdS t |t	rt
d| d|du rdS t|dkr|D ]
}t||jr nqt |ts |j|S dS dS t |trtj|jrt |t	std	t| d
dS dS dS )a  
        Recursively checks if there is any Value feature of type string and throws TypeError if corresponding object is not a string.
        Since any Python object can be cast to string this avoids implicitly casting wrong input types (e.g. lists) to string without error.
        c                    s    g | ]\}\}}  ||qS r)   r   )r?   r   
sub_schemaor/   r)   r*   r     r   z@EvaluationModule._enforce_nested_string_type.<locals>.<listcomp>r   c                    s   g | ]}  |qS r)   r"  )r?   r$  r#   r#  r)   r*   r     r   r   Nz+Got a string but expected a list instead: ''zExpected type str but got r   )rr   r  r   rK   tupler   r  r   r   rc   rt   rL   r   r   r   	is_stringpa_typer   rJ   )r#   schemarM   r   dict_tuplessub_objsub_objs
first_elmtr)   r%  r*   r     sT   





z,EvaluationModule._enforce_nested_string_type)
NFNr   r   NNNre   rf   r   NN))r6   r7   r8   r9   r   rc   r;   rs   r   floatr!   r   r   rz   r   r   r   r   r   r   r   r   r  r   r   r   r   r  r   r   r   rq   r	   r
   r  r  r   r   r   r!  r   r)   r)   r)   r*   rd      sx     	


I!#"I<+
)

rd   c                   @      e Zd ZdZdS )Metrica  A Metric is the base class and common API for all metrics.

    Args:
        config_name (`str`):
            This is used to define a hash specific to a metric computation script and prevents the metric's data
            to be overridden when the metric loading script is modified.
        keep_in_memory (`bool`):
            Keep all predictions and references in memory. Not possible in distributed settings.
        cache_dir (`str`):
            Path to a directory in which temporary prediction/references data will be stored.
            The data directory should be located on a shared file-system in distributed setups.
        num_process (`int`):
            Specify the total number of nodes in a distributed settings.
            This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
        process_id (`int`):
            Specify the id of the current process in a distributed setup (between 0 and num_process-1)
            This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
        seed (`int`, *optional*):
            If specified, this will temporarily set numpy's random seed when [`~evaluate.Metric.compute`] is run.
        experiment_id (`str`):
            A specific experiment id. This is used if several distributed evaluations share the same file system.
            This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
        max_concurrent_cache_files (`int`):
            Max number of concurrent metric cache files (default `10000`).
        timeout (`Union[int, float]`):
            Timeout in second for distributed setting synchronization.
    Nr6   r7   r8   r9   r)   r)   r)   r*   r3        r3  c                   @   r2  )
Comparisona  A Comparison is the base class and common API for all comparisons.

    Args:
        config_name (`str`):
            This is used to define a hash specific to a comparison computation script and prevents the comparison's data
            to be overridden when the comparison loading script is modified.
        keep_in_memory (`bool`):
            Keep all predictions and references in memory. Not possible in distributed settings.
        cache_dir (`str`):
            Path to a directory in which temporary prediction/references data will be stored.
            The data directory should be located on a shared file-system in distributed setups.
        num_process (`int`):
            Specify the total number of nodes in a distributed settings.
            This is useful to compute  comparisons in distributed setups (in particular non-additive comparisons).
        process_id (`int`):
            Specify the id of the current process in a distributed setup (between 0 and num_process-1)
            This is useful to compute  comparisons in distributed setups (in particular non-additive comparisons).
        seed (`int`, *optional*):
            If specified, this will temporarily set numpy's random seed when [`~evaluate.Comparison.compute`] is run.
        experiment_id (`str`):
            A specific experiment id. This is used if several distributed evaluations share the same file system.
            This is useful to compute  comparisons in distributed setups (in particular non-additive comparisons).
        max_concurrent_cache_files (`int`):
            Max number of concurrent comparison cache files (default `10000`).
        timeout (`Union[int, float]`):
            Timeout in second for distributed setting synchronization.
    Nr4  r)   r)   r)   r*   r6  ,  r5  r6  c                   @   r2  )Measurementa  A Measurement is the base class and common API for all measurements.

    Args:
        config_name (`str`):
            This is used to define a hash specific to a measurement computation script and prevents the measurement's data
            to be overridden when the measurement loading script is modified.
        keep_in_memory (`bool`):
            Keep all predictions and references in memory. Not possible in distributed settings.
        cache_dir (`str`):
            Path to a directory in which temporary prediction/references data will be stored.
            The data directory should be located on a shared file-system in distributed setups.
        num_process (`int`):
            Specify the total number of nodes in a distributed settings.
            This is useful to compute measurements in distributed setups (in particular non-additive measurements).
        process_id (`int`):
            Specify the id of the current process in a distributed setup (between 0 and num_process-1)
            This is useful to compute measurements in distributed setups (in particular non-additive measurements).
        seed (`int`, *optional*):
            If specified, this will temporarily set numpy's random seed when [`~evaluate.Measurement.compute`] is run.
        experiment_id (`str`):
            A specific experiment id. This is used if several distributed evaluations share the same file system.
            This is useful to compute measurements in distributed setups (in particular non-additive measurements).
        max_concurrent_cache_files (`int`):
            Max number of concurrent measurement cache files (default `10000`).
        timeout (`Union[int, float]`):
            Timeout in second for distributed setting synchronization.
    Nr4  r)   r)   r)   r*   r7  J  r5  r7  c                   @   s<   e Zd ZdddZdddZdddZdd	d
Zdd ZdS )CombinedEvaluationsFc                 C   s   ddl m} d | _t|tr|| _nt|tr%t| | _t| | _g }| jD ]}t|t	r5||}|
| q*|| _| jd u rLdd | jD | _|| _d S )Nr   )loadc                 S   s   g | ]}|j qS r)   )rU   )r?   moduler)   r)   r*   r   {      z0CombinedEvaluations.__init__.<locals>.<listcomp>)loadingr9  evaluation_module_namesrr   rK   evaluation_modulesr  r   r  rc   r   force_prefix)r#   r>  r?  r9  loaded_modulesr:  r)   r)   r*   r!   i  s    





zCombinedEvaluations.__init__Nc                    D   | j D ]}||d|  fdd| D  |jdi   qdS )aY  Add one prediction and reference for each evaluation module's stack.

        Args:
            predictions (`list/array/tensor`, *optional*):
                Predictions.
            references (`list/array/tensor`, *optional*):
                References.

        Example:

        ```py
        >>> import evaluate
        >>> accuracy = evaluate.load("accuracy")
        >>> f1 = evaluate.load("f1")
        >>> clf_metrics = combine(["accuracy", "f1"])
        >>> for ref, pred in zip([0,1,0,1], [1,0,0,1]):
        ...     clf_metrics.add(references=ref, predictions=pred)
        ```
        r   c                    r   r)   r)   r   r   r)   r*   r     r   z+CombinedEvaluations.add.<locals>.<dictcomp>Nr)   )r>  r   r   )r#   r  r  r&   evaluation_moduler)   r   r*   r     s
   
zCombinedEvaluations.addc                    rA  )am  Add a batch of predictions and references for each evaluation module's stack.

        Args:
            predictions (`list/array/tensor`, *optional*):
                Predictions.
            references (`list/array/tensor`, *optional*):
                References.

        Example:
        ```py
        >>> import evaluate
        >>> accuracy = evaluate.load("accuracy")
        >>> f1 = evaluate.load("f1")
        >>> clf_metrics = combine(["accuracy", "f1"])
        >>> for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
        ...     clf_metrics.add(references=refs, predictions=preds)
        ```
        r   c                    r   r)   r)   r   r   r)   r*   r     r   z1CombinedEvaluations.add_batch.<locals>.<dictcomp>Nr)   )r>  r   r   )r#   r   r   r&   rB  r)   r   r*   r     s
   
zCombinedEvaluations.add_batchc                 K   s>   g }| j D ]}||d|}||jdi | q| |S )aV  Compute each evaluation module.

        Usage of positional arguments is not allowed to prevent mistakes.

        Args:
            predictions (`list/array/tensor`, *optional*):
                Predictions.
            references (`list/array/tensor`, *optional*):
                References.
            **kwargs (*optional*):
                Keyword arguments that will be forwarded to the evaluation module [`~evaluate.EvaluationModule.compute`]
                method (see details in the docstring).

        Return:
            `dict` or `None`

            - Dictionary with the results if this evaluation module is run on the main process (`process_id == 0`).
            - `None` if the evaluation module is not run on the main process (`process_id != 0`).

        Example:

        ```py
        >>> import evaluate
        >>> accuracy = evaluate.load("accuracy")
        >>> f1 = evaluate.load("f1")
        >>> clf_metrics = combine(["accuracy", "f1"])
        >>> clf_metrics.compute(predictions=[0,1], references=[1,1])
        {'accuracy': 0.5, 'f1': 0.6666666666666666}
        ```
        r   Nr)   )r>  r   r   _merge_results)r#   r   r   r&   resultsrB  r   r)   r)   r*   r     s
   

zCombinedEvaluations.computec              	   C   s   i }t tjdd |D }dd t| D }dd t| j D }dd |D }t| j|D ]B\}}| D ]-\}	}
|	|vrN| j	sN|
||	 < q=||v ra|
|| d||  d|	 < q=|
|| d|	 < q=||v rw||  d	7  < q5|S )
Nc                 S   s   g | ]}|  qS r)   )r  )r?   rr)   r)   r*   r     r   z6CombinedEvaluations._merge_results.<locals>.<listcomp>c                 S   s   h | ]
\}}|d kr|qS r/  r)   r?   itemcountr)   r)   r*   	<setcomp>  r   z5CombinedEvaluations._merge_results.<locals>.<setcomp>c                 S   s   g | ]
\}}|d kr|qS r/  r)   rF  r)   r)   r*   r     s    c                 S   s   i | ]}|d qS r  r)   )r?   rU   r)   r)   r*   r     r;  z6CombinedEvaluations._merge_results.<locals>.<dictcomp>r   r   )
rK   	itertoolschainfrom_iterablecollectionsCounterr   r=  r   r?  )r#   rD  merged_resultsresults_keysduplicate_keysduplicate_namesduplicate_counterrT   resultr   r   r)   r)   r*   rC    s$   z"CombinedEvaluations._merge_resultsFr0  )r6   r7   r8   r!   r   r   r   rC  r)   r)   r)   r*   r8  h  s    



'r8  Fc                 C   s   t | |dS )a  Combines several metrics, comparisons, or measurements into a single `CombinedEvaluations` object that
    can be used like a single evaluation module.

    If two scores have the same name, then they are prefixed with their module names.
    And if two modules have the same name, please use a dictionary to give them different names, otherwise an integer id is appended to the prefix.

    Args:
        evaluations (`Union[list, dict]`):
            A list or dictionary of evaluation modules. The modules can either be passed
            as strings or loaded `EvaluationModule`s. If a dictionary is passed its keys are the names used and the values the modules.
            The names are used as prefix in case there are name overlaps in the returned results of each module or if `force_prefix=True`.
        force_prefix (`bool`, *optional*, defaults to `False`):
            If `True` all scores from the modules are prefixed with their name. If
            a dictionary is passed the keys are used as name otherwise the module's name.

    Examples:

    ```py
    >>> import evaluate
    >>> accuracy = evaluate.load("accuracy")
    >>> f1 = evaluate.load("f1")
    >>> clf_metrics = combine(["accuracy", "f1"])
    ```
    )r?  )r8  )evaluationsr?  r)   r)   r*   combine  s   rW  rU  )<r9   rM  rJ  ru   r   r   typingr   r   r   r   r   r   numpyr|   pyarrowr   datasetsr   r	   r
   datasets.arrow_datasetr   datasets.arrow_readerr   datasets.arrow_writerr   datasets.featuresr   r   r   datasets.features.featuresr   datasets.utils.filelockr   r   r   datasets.utils.py_utilsr   r   r   r   r   rP   r   namingr   utils.loggingr   r6   r   r   rN   rO   rd   r3  r6  r7  r8  rW  r)   r)   r)   r*   <module>   sH    
B    } 	