o
    ig                     @   s  U d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZmZmZmZmZmZmZmZmZmZ d dlZddlmZmZmZmZmZ ddlmZ ddl m!Z! dd	l"m#Z# dd
l$m%Z%m&Z&m'Z'm(Z(m)Z) edZ*edZ+edddZ,edi dZ-ee. e/d< d*ddZ0G dd dee*e+f Z1ej2de	de	de3de1de4f
ddZ5ej2de	de	de3de1de	f
ddZ6ede1dZ7de7dee3ee3e	f f de7fd d!Z8d"gfde7d#e9de7fd$d%Z:de1d&e
e1ge7f de7fd'd(Z;g d)Z<dS )+    N)
ContextVar)Path)AnyCallableDictGenericIterableIteratorListOptionalSequenceSetTupleTypeVarUnioncast   )CupyOpsNumpyOpsOpsParamServerget_current_ops)	Optimizer)Shim)FloatsXd)DATA_VALIDATIONconvert_recursiveis_xp_arraypartialvalidate_fwd_input_outputInTOutTSelfTModel)boundcontext_operators)defaultmodelreturnc                 O   s   | S N )r'   argskwargsr*   r*   ?/home/ubuntu/.local/lib/python3.10/site-packages/thinc/model.py
empty_init-   s   r.   c                   @   s  e Zd ZU dZdZeed< e Z	ejed< e
Zeed< eed< eed< eed< eed	< eed
< eeee f ed< ed  ed< ee ed< eeef ed< eeee f ed< g dZdi i g g i i dddeded	ee deeee f deeee f ded  dee deeef deeed  f deeeef  fddZeded  fddZedee fdd Zedeeef fd!d"Z ede!ed#f fd$d%Z"ede!ed#f fd&d'Z#ede!ed#f fd(d)Z$ede!ed#f fd*d+Z%e&e'j(d,eeef fd-d.Z)dedee fd/d0Z*dedefd1d2Z+d3d4ded5ed6eddfd7d8Z,dedee fd9d:Z-dedee fd;d<Z.dedefd=d>Z/dedee fd?d@Z0ded5ee ddfdAdBZ1dedefdCdDZ2dedefdEdFZ3ded5eddfdGdHZ4dedee fdIdJZ5ded5eddfdKdLZ6dedee fdMdNZ7dedd fdOdPZ8deded  fdQdRZ9ded5ed  ddfdSdTZ:dUe;dVede!e<ef fdWdXZ=ddUee; dYee< dd fdZd[Z>dUe;de!e<ee<ge;f f fd\d]Z?dUe;de<fd^d_Z@d`eAddfdadbZBe'j(dee!eef ef fdcddZCdedfdgedeDd  fdhdiZEdeDd  fdjdkZFddledeDd  fdmdnZGddpdqZHddrded	ee ddfdsdtZIdud dvd defdwdxZJdee!eef e!eef f fdydzZKd{eLdeLfd|d}ZM	dd{eLd~eeeed ef f  deLfddZNdeddfddZOdddZPdeddfddZQdeRfddZSdeeTef ddfddZUdefddZVdeRdd fddZWdeeTef dd fddZXdedd fddZYdddeeTef dedefddZZdddeRdedefddZ[dddededefddZ\dedd fddZ]dedd fddZ^dedd fddZ_dedd fddZ`dedd fddZadedd fddZbdedd fddZcdedd fddZddedd fddZededd fddZfdedd fddZgdedd fddZhdedd fddZidedd fddZjdS )r#   z/Class for implementing Thinc models and layers.r   	global_idglobal_id_locknameopsid_funcinit_params_dims_layers_shims_attrs_has_params)r1   r3   r2   r4   r5   r6   r7   r:   _refsr8   r9   r;   N)r5   dimsparamslayersshimsattrsrefsr2   forwardr=   r>   r?   r@   rA   rB   c                C   s   || _ |du rtt| }t| d| t| d| |
dur|
nt | _t | _t|| _	t|| _
t|	| _t|| _t|| _tj t jd7  _tj| _W d   n1 sXw   Y  i | _| D ]\}}d| j|< |durw| || qddS )zInitialize a new model.Nr4   r5   r   )r1   r   r.   setattrr   r2   r   r6   dictr7   r:   r<   listr8   r9   r#   r0   r/   r3   r;   items	set_param)selfr1   rC   r5   r=   r>   r?   r@   rA   rB   r2   valuer*   r*   r-   __init__U   s.   







zModel.__init__r(   c                 C      | j S )zmA list of child layers of the model. You can append to it to add
        layers but not reassign it.
        )r8   rI   r*   r*   r-   r?   |      zModel.layersc                 C   rL   r)   )r9   rM   r*   r*   r-   r@      s   zModel.shimsc                 C   rL   )zfA dict of the model's attrs. You can write to it to update attrs but
        not reassign it.
        )r:   rM   r*   r*   r-   rA      rN   zModel.attrs.c                 C      t | j S )z8Get the names of registered parameter (including unset).)tupler;   keysrM   r*   r*   r-   param_names      zModel.param_namesc                    s   t  fdd jD S )zHGet the names of parameters with registered gradients (including unset).c                    s   g | ]	}  |r|qS r*   )has_grad).0r1   rM   r*   r-   
<listcomp>   s    z$Model.grad_names.<locals>.<listcomp>)rP   rR   rM   r*   rM   r-   
grad_names   s   zModel.grad_namesc                 C   rO   )z9Get the names of registered dimensions (including unset).)rP   r7   rQ   rM   r*   r*   r-   	dim_names   rS   zModel.dim_namesc                 C   rO   )z>Get the names of registered node references (including unset).)rP   r<   rQ   rM   r*   r*   r-   	ref_names   rS   zModel.ref_names	operatorsc                 c   s(    | j t|}dV  | j | dS )a  Bind arbitrary binary functions to Python operators, for use in any
        `Model` instance. Can (and should) be used as a contextmanager.

        EXAMPLE:
            with Model.define_operators({">>": chain}):
                model = Relu(512) >> Relu(512) >> Softmax()
        N)_context_operatorssetrE   reset)clsrZ   tokenr*   r*   r-   define_operators   s   
zModel.define_operatorsc                 C   $   || j vrdS | j | durdS dS )zCheck whether the model has a dimension of a given name. If the
        dimension is registered but the value is unset, returns None.
        FNT)r7   rI   r1   r*   r*   r-   has_dim   
   
zModel.has_dimc                 C   T   || j vrtd| d| j d| j | }|du r(d| d| j d}t||S )z4Retrieve the value of a dimension of the given name.zCannot get dimension '' for model ''Nz': value unset)r7   KeyErrorr1   
ValueErrorrI   r1   rJ   errr*   r*   r-   get_dim      

zModel.get_dimF)forcerJ   rn   c                C   s   || j vrtd| d| j d| j | }tdd | j D }|duo0||ko0| p0|o0|}|rFd| d| j d| d	| }t||| j |< dS )
zSet a value for a dimension.zCannot set unknown dimension 'rf   '.c                 s   s    | ]	\}}t |V  qd S r)   )bool)rU   xyr*   r*   r-   	<genexpr>   s    z Model.set_dim.<locals>.<genexpr>NzAttempt to change dimension 'z' from z to )r7   rh   r1   anyr;   rG   ri   )rI   r1   rJ   rn   	old_value
has_paramsinvalid_changerk   r*   r*   r-   set_dim   s   

zModel.set_dimc                 C      |  |r
| |S dS )z=Retrieve the value of a dimension of the given name, or None.N)rc   rl   rb   r*   r*   r-   maybe_get_dim      zModel.maybe_get_dimc                 C   ra   )zCheck whether the model has a weights parameter of the given name.

        Returns None if the parameter is registered but currently unset.
        FNT)r;   rb   r*   r*   r-   	has_param   s
   
zModel.has_paramc                 C   sZ   || j vrtd| d| j d| j| j|s%td| d| j d| j| j|S )z%Retrieve a weights parameter by name.zUnknown param: 'rf   ro   zParameter 'z' has not been allocated yet.)r;   rh   r1   r6   r|   r3   	get_paramrb   r*   r*   r-   r}      s   
zModel.get_paramc                 C   ry   )z.Retrieve a weights parameter by name, or None.N)r|   r}   rb   r*   r*   r-   maybe_get_param   r{   zModel.maybe_get_paramc                 C   s6   |du rd| j |< dS | j| j|| d| j |< dS )z Set a weights parameter's value.NT)r;   r6   rH   r3   rI   r1   rJ   r*   r*   r-   rH      s   zModel.set_paramc                 C      | j | j|S )z@Check whether the model has a non-zero gradient for a parameter.)r6   rT   r3   rb   r*   r*   r-   rT         zModel.has_gradc                 C   r   )zGet a gradient from the model.)r6   get_gradr3   rb   r*   r*   r-   r      r   zModel.get_gradc                 C      | j | j|| dS )z#Set a gradient value for the model.N)r6   set_gradr3   r   r*   r*   r-   r        zModel.set_gradc                 C   ry   )z%Retrieve a gradient by name, or None.N)rT   r   rb   r*   r*   r-   maybe_get_grad  r{   zModel.maybe_get_gradc                 C   r   )z1Increment the gradient of a parameter by a value.N)r6   inc_gradr3   r   r*   r*   r-   r     r   zModel.inc_gradc                 C   ra   )zCheck whether the model has a reference of a given name. If the
        reference is registered but the value is unset, returns None.
        FNT)r<   rb   r*   r*   r-   has_ref  rd   zModel.has_refc                 C   re   )z4Retrieve the value of a reference of the given name.zCannot get reference 'rf   ro   Nz': value unset.)r<   rh   r1   ri   rj   r*   r*   r-   get_ref  rm   zModel.get_refc                 C   ry   )z8Retrieve the value of a reference if it exists, or None.N)r   r   rb   r*   r*   r-   maybe_get_ref&  r{   zModel.maybe_get_refc                 C   s8   |du r|| j |< dS ||  v r|| j |< dS td)zSet a value for a reference.Nz)Cannot add reference to node not in tree.)r<   walkri   r   r*   r*   r-   set_ref*  s
   zModel.set_refXis_trainc                 C   s   | j | ||dS )z~Call the model's `forward` function, returning the output and a
        callback to compute the gradients via backpropagation.r   r4   )rI   r   r   r*   r*   r-   __call__3  s   zModel.__call__Yc                 C   s8   t  rt| j| j|| | jdur| j| ||d | S )zFinish initialization of the model, optionally providing a batch of
        example input and output data to perform shape inference.N)r   r   )r   getr   r1   r4   r5   )rI   r   r   r*   r*   r-   
initialize8  s
   
zModel.initializec                 C   s   | j | |ddS )ap  Run the model over a batch of data, returning the output and a
        callback to complete the backward pass. A tuple (Y, finish_update),
        where Y is a batch of output data, and finish_update is a callback that
        takes the gradient with respect to the output and an optimizer function,
        and returns the gradient with respect to the input.
        Tr   r   rI   r   r*   r*   r-   begin_updateA  s   zModel.begin_updatec                 C   s   | j | |ddd S )zCall the model's `forward` function with `is_train=False`, and return
        only the output, instead of the `(output, callback)` tuple.
        Fr   r   r   r   r*   r*   r-   predictJ  s   zModel.predict	optimizerc                 C   sz   |   D ]}|jD ]}|| q	q|   D ]$}|jD ]}||r9||j|f||||\}}||| qqdS )zUpdate parameters with current gradients. The optimizer is called
        with each parameter and gradient of the model.
        N)	r   r@   finish_updaterR   rT   r3   r}   r   rH   )rI   r   nodeshimr1   paramgradr*   r*   r-   r   P  s   


zModel.finish_updatec           	      c   s    i }| j D ]}| j|f}||v r | |||< | |||  qt '}| jD ]
}||| q)| j	D ]
}||| q7dV  W d   n1 sOw   Y  |re|
 D ]\}}| || qZdS dS )zContext manager to temporarily set the model's parameters to
        specified values. The params are a dictionary keyed by model IDs, whose
        values are arrays of weight values.
        N)rR   r3   r}   rH   
contextlib	ExitStackr?   enter_context
use_paramsr@   rG   )	rI   r>   backupr1   keystacklayerr   r   r*   r*   r-   r   _  s(   




zModel.use_paramsbfsorderr   c                C   s@   |dkr|   S |dkr| jddS |dkr| jddS td)zIterate out layers of the model.

        Nodes are returned in breadth-first order by default. Other possible
        orders are "dfs_pre" (depth-first search in preorder) and "dfs_post"
        (depth-first search in postorder).r   dfs_preF)
post_orderdfs_postTz5Invalid order, must be one of: bfs, dfs_pre, dfs_post)	_walk_bfs	_walk_dfsri   )rI   r   r*   r*   r-   r   v  s   z
Model.walkc                 c   sJ    | g}t  }|D ]}t||v rq	|t| |V  ||j q	dS )z/Iterate out layers of the model, breadth-first.N)r\   r3   addextendr?   )rI   queueseenr   r*   r*   r-   r     s   zModel._walk_bfsr   c                 c   s    t  }| g}t| j|t| < |s| V  |rUz%t|t|d  }t||vr;|s-|V  || t|j|t|< W n tyP   |rJ|d V  |  Y nw |sdS dS )z-Iterate out layers of the model, depth-first.N)rE   iterr?   r3   nextappendStopIterationpop)rI   r   r   r   
next_childr*   r*   r-   r     s*   

zModel._walk_dfsr   c                 C   s~   t |  D ]}||jv r|j| ||jv sqt|  }|D ]}|jD ]}||}|dur;||vr;||d q&q!dS )a]  Remove a node from all layers lists, and then update references.
        References that no longer point to a node within the tree will be set
        to `None`. For instance, let's say a node has its grandchild as a reference.
        If the child is removed, the grandchild reference will be left dangling,
        so will be set to None.
        N)rF   r   r?   remover\   rY   r   r   )rI   r   childtreer1   refr*   r*   r-   remove_node  s   



zModel.remove_node)r5   c                C   s   t | d| t | d| d S )Nr4   r5   )rD   )rI   rC   r5   r*   r*   r-   replace_callbacks  s   zModel.replace_callbacksoldnewc                    sj   d}t | jddD ](}|u rd}q
 fdd|jD |_|jD ]}||u r1||  q"q
|S )zzReplace a node anywhere it occurs within the model. Returns a boolean
        indicating whether the replacement was made.Fr   r   Tc                    s   g | ]
}|u r
 n|qS r*   r*   )rU   r   r   r   r*   r-   rV     s    z&Model.replace_node.<locals>.<listcomp>)rF   r   r8   rY   r   r   )rI   r   r   r   r   r1   r*   r   r-   replace_node  s   
zModel.replace_nodec                 C   sH   i }|   D ]}|jD ]}||}||}||f||j|f< qq|S )zGet non-zero gradients of the model's parameters, as a dictionary
        keyed by the parameter ID. The values are (weights, gradients) tuples.
        )r   rW   r}   r   r3   )rI   	gradientsr   r1   r   r   r*   r*   r-   get_gradients  s   


zModel.get_gradientsrI   c                 C   s   |   S )z
        Create a copy of the model, its attributes, and its parameters. Any child
        layers will also be deep-copied. The copy will receive a distinct `model.id`
        value.
        )_copyrM   r*   r*   r-   copy  s   z
Model.copyr   c              
   C   sB  |d u ri }i }| j D ]}| |r| |nd ||< qg }| jD ]%}t||v r6|tt|t|  q!||}||t|< || q!g }| j	D ]$}t||v ra|tt
|t|  qL| }	|	|t|< ||	 qLt| j| j| jt|t| jt| j||d}
| jD ]}|
|| |  qtt|
S )N)r5   r>   r=   rA   r?   r@   )rR   r|   r}   r?   r3   r   r   r#   r   r@   r   r   r1   r4   r5   deepcopyr7   r:   rW   r   r   r"   )rI   r   r>   r1   copied_layersr   copied_layercopied_shimsr   copied_shimcopiedr*   r*   r-   r     s@   








zModel._copygpu_idc                 C   sH   ddl }|jj| | t  W d   dS 1 sw   Y  dS )z)Transfer the model to a given GPU device.r   N)cupy.cuda.devicecudadeviceDevice_to_opsr   )rI   r   cupyr*   r*   r-   to_gpu  s   "zModel.to_gpuc                 C   s   |  t  dS )zTransfer the model to CPU.N)r   r   rM   r*   r*   r-   to_cpu  r   zModel.to_cpuc              
   C   s   |   D ];}||_|jD ]$}||r||||| ||r0||||	| q|j
D ]
}||j|j q4qdS )z Common method for to_cpu/to_gpu.N)r   r2   rR   r|   rH   	asarray_fr}   rT   r   r   r@   	to_devicedevice_type	device_id)rI   r2   r   r1   r   r*   r*   r-   r     s   



zModel._to_opsc                 C   s.   |   }t| jjdd}tt||}t|S )aa  Serialize the model to a bytes representation. Models are usually
        serialized using msgpack, so you should be able to call msgpack.loads()
        on the data and get back a dictionary with the contents.

        Serialization should round-trip identically, i.e. the same bytes should
        result from loading and serializing a model.
        <)
byte_order)to_dictr   r2   to_numpyr   r   srslymsgpack_dumps)rI   msgto_numpy_ler*   r*   r-   to_bytes&  s   
zModel.to_bytespathc                 C   sT   t |tr	t|n|}|d}||   W d   dS 1 s#w   Y  dS )zSerialize the model to disk. Most models will serialize to a single
        file, which should just be the bytes contents of model.to_bytes().
        wbN)
isinstancestrr   openwriter   )rI   r   file_r*   r*   r-   to_disk3  s   "zModel.to_diskc              
   C   s  g g g g d}t |  }dd t|D }t|D ][\}}i }g }|jD ]#}||s1d||< q%||}	|	j|v rC||	j ||< q%|| q%|rRtd| i }
|j	D ]}|
|rc||nd|
|< qW|d ||j|
|d q|D ])}i }|j D ]\}}zt||||||< W q ty   Y qw |d | qx|D ]}|d	 d
d |jD  q|D ]'}i }|jD ]}||rttt ||||< qd||< q|d | q|S )zSerialize the model to a dict representation.

        Serialization should round-trip identically, i.e. the same dict should
        result from loading and serializing a model.
        )nodesrA   r>   r@   c                 S   s   i | ]\}}|j |qS r*   )r3   )rU   ir   r*   r*   r-   
<dictcomp>K  s    z!Model.to_dict.<locals>.<dictcomp>NzCannot get references: r   )indexr1   r=   rB   rA   r@   c                 S   s   g | ]}|  qS r*   )r   )rU   r   r*   r*   r-   rV   i  s    z!Model.to_dict.<locals>.<listcomp>r>   )rF   r   	enumeraterY   r   r   r3   r   ri   rX   rc   rl   r1   rA   rG   serialize_attr	TypeErrorr@   rR   r|   r   r   r   r}   )rI   r   r   	node_to_ir   r   rB   invalid_refsr1   r   r=   dimrA   rJ   r>   r*   r*   r-   r   ;  sP   









zModel.to_dict
bytes_datac                 C   s$   t |}tt| jj|}| |S )ae  Deserialize the model from a bytes representation. Models are usually
        serialized using msgpack, so you should be able to call msgpack.loads()
        on the data and get back a dictionary with the contents.

        Serialization should round-trip identically, i.e. the same bytes should
        result from loading and serializing a model.
        )r   msgpack_loadsr   r   r2   asarray	from_dict)rI   r   r   r*   r*   r-   
from_bytest  s   

zModel.from_bytesc                 C   sR   t |tr	t|n|}|d}| }W d   n1 sw   Y  | |S )zDeserialize the model from disk. Most models will serialize to a single
        file, which should just be the bytes contents of model.to_bytes().
        rbN)r   r   r   r   readr   )rI   r   r   r   r*   r*   r-   	from_disk  s
   

zModel.from_diskr   c                 C   sx  d|  vrd}t|t|  }t|d t|kr tdt|D ]\}}|d | }|d |_|d  D ]\}}|d urG||| q9|d  D ]\}	}
|
d u r]|	|	d  qN|	|	||
  qN|d |  D ]\}}|j
|}t||||}||j
|< qn|d |  D ]\}}|d ur|j| }||| qt|d	 | D ]\}}|j| | qq$| S )
Nr   zMTrying to read a Model that was created with an incompatible version of Thincz.Cannot deserialize model: mismatched structurer1   r=   rB   rA   r>   r@   )rQ   ri   rF   r   lenr   r1   rG   rx   r   rA   r   deserialize_attrr2   r   r   rH   r@   r   )rI   r   rk   r   r   r   infor   rJ   r   	ref_indexattrdefault_valueloaded_value
param_name
shim_bytesr*   r*   r-   r     s:   
zModel.from_dictTstrictr  c                C   sj   t |tr	t|n|}| s| sdS |d}| }W d   n1 s)w   Y  | j||dS )zCheck whether serialized data on disk is compatible with the model.
        If 'strict', the function returns False if the model has an attribute
        already loaded that would be changed.
        Fr   Nr  )r   r   r   is_direxistsr   r   can_from_bytes)rI   r   r  r   r   r*   r*   r-   can_from_disk  s   
zModel.can_from_diskc                C   s2   zt |}W n
 ty   Y dS w | j||dS )zCheck whether the bytes data is compatible with the model. If 'strict',
        the function returns False if the model has an attribute already loaded
        that would be changed.
        Fr  )r   r   ri   can_from_dict)rI   r   r  r   r*   r*   r-   r
    s   zModel.can_from_bytesc             
   C   s  d|  vrdS t|  }t|d t|krdS t|D ]\}}|d | }|r4|d |jkr4 dS t|d | t|jkrD dS |d  D ]\}}||}	|	du r[  dS |	rh|	||krh  dS qJ|d |  D ]&\}
}|
|
}|du r  dS |r|dur||
}|j|jkr  dS qq|r|d |  D ]*\}}||jv rzt|j| |j| ||}W n	 ty   Y qw ||kr  dS qqd	S )
zCheck whether a dictionary is compatible with the model.
        If 'strict', the function returns False if the model has an attribute
        already loaded that would be changed.
        r   Fr1   r@   r=   r>   NrA   T)rQ   rF   r   r   r   r1   r@   rG   rc   rl   r|   r}   shaperA   r   r   )rI   r   r  r   r   r   r   r   rJ   rc   r  r|   r   r  
serializedr*   r*   r-   r    sR   



zModel.can_from_dictotherc                 C   *   d| j  vrtd| j  d | |S )z-Apply the function bound to the '+' operator.+zUndefined operator: +r[   r   r   rI   r  r*   r*   r-   __add__     zModel.__add__c                 C   r  )z-Apply the function bound to the '-' operator.-zUndefined operator: -r  r  r*   r*   r-   __sub__  r  zModel.__sub__c                 C   r  )z-Apply the function bound to the '*' operator.*zUndefined operator: *r  r  r*   r*   r-   __mul__  r  zModel.__mul__c                 C   r  )z-Apply the function bound to the '@' operator.@zUndefined operator: @r  r  r*   r*   r-   
__matmul__  r  zModel.__matmul__c                 C   r  z-Apply the function bound to the '/' operator./zUndefined operator: /r  r  r*   r*   r-   __div__  r  zModel.__div__c                 C   r  r  r  r  r*   r*   r-   __truediv__  r  zModel.__truediv__c                 C   r  )z.Apply the function bound to the '//' operator.z//zUndefined operator: //r  r  r*   r*   r-   __floordiv__  r  zModel.__floordiv__c                 C   r  )z-Apply the function bound to the '%' operator.%zUndefined operator: %r  r  r*   r*   r-   __mod__  r  zModel.__mod__c                 K   r  )z.Apply the function bound to the '**' operator.z**zUndefined operator: **r  )rI   r  r,   r*   r*   r-   __pow__  r  zModel.__pow__c                 C   r  )z.Apply the function bound to the '<<' operator.z<<zUndefined operator: <<r  r  r*   r*   r-   
__lshift__#  r  zModel.__lshift__c                 C   r  )z.Apply the function bound to the '>>' operator.z>>zUndefined operator: >>r  r  r*   r*   r-   
__rshift__)  r  zModel.__rshift__c                 C   r  )z-Apply the function bound to the '&' operator.&zUndefined operator: &r  r  r*   r*   r-   __and__/  r  zModel.__and__c                 C   r  )z-Apply the function bound to the '^' operator.^zUndefined operator: ^r  r  r*   r*   r-   __xor__5  r  zModel.__xor__c                 C   r  )z-Apply the function bound to the '|' operator.|zUndefined operator: |r  r  r*   r*   r-   __or__;  r  zModel.__or__)NN)F)r   r#   r(   Nr)   )r(   N)k__name__
__module____qualname____doc__r/   int__annotations__	threadingLockr0   r%   r[   r   r   r   r   r   r   r
   r   r   rp   	__slots__r   r   r   r   r   rK   propertyr?   r@   rA   r   rR   rW   rX   rY   classmethodr   contextmanagerr`   rc   rl   rx   rz   r|   r}   r~   rH   rT   r   r   r   r   r   r   r   r   r    r!   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r"   r   r   r   r   r   bytesr   r   r   r   r   r   r   r  r
  r  r  r  r  r  r  r  r   r"  r#  r$  r%  r'  r)  r+  r*   r*   r*   r-   r#   1   s  
 	


' 
	 $		 

&	
)
9	$/_rJ   r1   c                 C   
   t |S )zSerialize an attribute value (defaults to msgpack). You can register
    custom serializers using the @serialize_attr.register decorator with the
    type to serialize, e.g.: @serialize_attr.register(MyCustomObject).
    )r   r   r9  rJ   r1   r'   r*   r*   r-   r   B     
r   c                 C   r:  )zDeserialize an attribute value (defaults to msgpack). You can register
    custom deserializers using the @deserialize_attr.register decorator with the
    type to deserialize, e.g.: @deserialize_attr.register(MyCustomObject).
    )r   r   r;  r*   r*   r-   r   K  r<  r   _ModelTmappingc                 C   sL   |   D ]}|j|v r#||j }| D ]\}}||jv r"||j|< qq| S )zWalk over the model's nodes, changing the value of attributes using the
    provided mapping, which maps node names to attr names to attr values.
    )r   r1   rG   rA   )r'   r>  r   rA   r  rJ   r*   r*   r-   change_attr_valuesW  s   



r?  dropout_ratedropc                 C   s0   |   D ]}|D ]}||jv r||j|< qq| S )zWalk over the model's nodes, setting the dropout rate. You can specify
    one or more attribute names, by default it looks for ["dropout_rate"].
    )r   rA   )r'   rA  rA   r   r  r*   r*   r-   set_dropout_rated  s   

rB  wrapperc                 C   s*   t |  D ]
}| ||| q|| S )zORecursively wrap a model and its submodules. The model is updated
    in-place.)rF   r   r   )r'   rC  r   r*   r*   r-   wrap_model_recursiveo  s   rD  )r#   r   r   r?  rB  rD  )r'   r#   r(   r#   )=r   r   	functoolsr2  contextvarsr   pathlibr   typingr   r   r   r   r   r	   r
   r   r   r   r   r   r   r   r   backendsr   r   r   r   r   
optimizersr   r@   r   typesr   utilr   r   r   r   r   r    r!   r"   r%   rE   r1  r.   r#   singledispatchr   r8  r   r   r=  r?  floatrB  rD  __all__r*   r*   r*   r-   <module>   sD   
 @
        & 	