o
    Ni$                     @  s  U d Z ddlmZ ddlZddlZddlZddlZddlZddlm	Z	m
Z
mZmZ ddlmZmZ ddlmZ ddlmZmZmZmZmZmZmZmZ ddlmZ ddlmZmZm Z m!Z!m"Z"m#Z# dd	l$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+ dd
l,m-Z-m.Z.m/Z/ erddl0Z0ddl1m2Z2m3Z3m4Z4 ddl$m5Z5m6Z6m7Z7m8Z8m9Z9 ede:e7 dZ;g dZ<ej=dkrddini Z>ej?dddddde>G dd dee) Z@[>G dd dZAeA ZBdeCd< e ZDdeCd< [AerddlEmFZF eFdZGedZHeFdZIed ZJG d!d" d"eeGeHeIeJf ZKdd'd(ZLed)d*dd/d0ZMe	dd)d*dd3d0ZM	dd)d*dd6d0ZMeLeMed7dd?d@ZN[M[Le	ddddAddEdFZOeddGdFZO	ddddAddJdFZOddKdLZPejQddPdQZRddUdVZSddYdZZTdd]d^ZUddbdcZVddedfZWddjdkZXddldmZYddqdrZZddwdxZ[ddydzZ\dd{d|Z]dd~dZ^dddZ_dddZ`dddZadddZbdddZcdddZddddZedddZfdddZgdddZhdddZie:de@e:deTeUe!e'jjdeke@ekeVeWe"e'jldeme@emeXeYe"e'jndeoe@eoeZe[ee'jpdee@eefege e'jqde	e@e	e^e_ee'jrde
e@e
e`eaee'jsdee@eedeee"e'jtde(e@e(eheie#e'judi	ZvdeCd< e@eoe\e]ee'jpdZwe@e
ebecee'jsdZxdS )z&Registry for custom pytree node types.    )annotationsN)OrderedDictdefaultdictdeque
namedtuple)
itemgettermethodcaller)Lock)TYPE_CHECKINGAnyCallableClassVarGeneric
NamedTupleTypeVaroverload)	AutoEntryMappingEntryNamedTupleEntryPyTreeEntrySequenceEntryStructSequenceEntry)ChildrenMetaData
PyTreeKindStructSequenceTis_namedtuple_classis_structseq_class)safe_ziptotal_order_sortedunzip2)
Collection	GeneratorIterable)KTVTCustomTreeNodeFlattenFuncUnflattenFuncCustomTreeNodeType)bound)register_pytree_noderegister_pytree_node_classunregister_pytree_nodedict_insertion_ordered   
   slotsT)initrepreqfrozenc                   @  sb   e Zd ZU dZded< ded< ded< ejdkrd	ed
< eZded< e	j
Zded< dZded< dS )PyTreeNodeRegistryEntryz>A dataclass that stores the information of a pytree node type.zbuiltins.type[Collection[T]]typeFlattenFunc[T]flatten_funcUnflattenFunc[T]unflatten_funcr0   zdataclasses.KW_ONLY_zbuiltins.type[PyTreeEntry]path_entry_typer   kind str	namespaceN)__name__
__module____qualname____doc____annotations__sysversion_infor   r?   r   CUSTOMr@   rC    rL   rL   C/home/ubuntu/.local/lib/python3.10/site-packages/optree/registry.pyr8   I   s   
 
r8   c                   @  s$   e Zd ZU dZded< d	ddZdS )
GlobalNamespacerL   zClassVar[tuple[()]]	__slots__returnrB   c                C     dS )Nz<GLOBAL NAMESPACE>rL   selfrL   rL   rM   __repr__`      zGlobalNamespace.__repr__N)rP   rB   )rD   rE   rF   rO   rH   rT   rL   rL   rL   rM   rN   ]   s   
 rN   rB   __GLOBAL_NAMESPACEr	   __REGISTRY_LOCK)	ParamSpec_P_T_GetP_GetTc                   @  s    e Zd ZdddZdddZdS )_CallableWithGetargs_P.argskwargs	_P.kwargsrP   rZ   c                O     t NNotImplementedErrorrS   r^   r`   rL   rL   rM   __call__r   rU   z_CallableWithGet.__call__
_GetP.args_GetP.kwargsr\   c                O  rb   rc   rd   rf   rL   rL   rM   getv   rU   z_CallableWithGet.getN)r^   r_   r`   ra   rP   rZ   )r^   rh   r`   ri   rP   r\   )rD   rE   rF   rg   rj   rL   rL   rL   rM   r]   q   s    
r]   rj   Callable[_GetP, _GetT]rP   DCallable[[Callable[_P, _T]], _CallableWithGet[_P, _T, _GetP, _GetT]]c                  s   d fdd}|S )NfuncCallable[_P, _T]rP   &_CallableWithGet[_P, _T, _GetP, _GetT]c                  s
    | _ | S rc   rj   )rm   rp   rL   rM   	decorator   s   z_add_get.<locals>.decorator)rm   rn   rP   ro   rL   )rj   rq   rL   rp   rM   _add_getz   s   rr   rA   )rC   clsr9   rC   PyTreeNodeRegistryEntry | Nonec               C     d S rc   rL   rs   rC   rL   rL   rM   pytree_node_registry_get      rw   None#dict[type, PyTreeNodeRegistryEntry]c               C  ru   rc   rL   rv   rL   rL   rM   rw      rx   type | NoneDdict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | Nonec                 s>  |t u rd}| dur| turt| std| dt|ts(td|d| du r^t|dh t  fddt	
 D }W d   n1 sJw   Y  t|r\t|t< t|t< |S |dkrot	|| f}|duro|S t|r| tu rztS | tu rtS t	| }|dur|S t| rt	tS t| rt	tS dS )a  Look up the pytree node registry.

    >>> register_pytree_node.get()  # doctest: +IGNORE_WHITESPACE,ELLIPSIS
    {
        <class 'NoneType'>: PyTreeNodeRegistryEntry(
            type=<class 'NoneType'>,
            flatten_func=<function ...>,
            unflatten_func=<function ...>,
            path_entry_type=<class 'optree.PyTreeEntry'>,
            kind=<PyTreeKind.NONE: 2>,
            namespace=''
        ),
        <class 'tuple'>: PyTreeNodeRegistryEntry(
            type=<class 'tuple'>,
            flatten_func=<function ...>,
            unflatten_func=<function ...>,
            path_entry_type=<class 'optree.SequenceEntry'>,
            kind=<PyTreeKind.TUPLE: 3>,
            namespace=''
        ),
        <class 'list'>: PyTreeNodeRegistryEntry(
            type=<class 'list'>,
            flatten_func=<function ...>,
            unflatten_func=<function ...>,
            path_entry_type=<class 'optree.SequenceEntry'>,
            kind=<PyTreeKind.LIST: 4>,
            namespace=''
        ),
        ...
    }
    >>> register_pytree_node.get(defaultdict)  # doctest: +IGNORE_WHITESPACE,ELLIPSIS
    PyTreeNodeRegistryEntry(
        type=<class 'collections.defaultdict'>,
        flatten_func=<function ...>,
        unflatten_func=<function ...>,
        path_entry_type=<class 'optree.MappingEntry'>,
        kind=<PyTreeKind.DEFAULTDICT: 8>,
        namespace=''
    )
    >>> register_pytree_node.get(frozenset)  # frozenset is considered as a leaf node
    None

    Args:
        cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all
            the registered pytree nodes in the namespace are returned.
        namespace (str, optional): The namespace of the registry to retrieve. If not provided, the
            global namespace is used.

    Returns:
        If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the
        namespace is returned. If the ``cls`` is provided, the corresponding registry entry is
        returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned,
        i.e., the ``cls`` is represented as a leaf node.
    rA   NzExpected a class or None, got .$The namespace must be a string, got c                   s   i | ]}|j  v r|j|qS rL   )rC   r9   ).0handler
namespacesrL   rM   
<dictcomp>   s
    
z,pytree_node_registry_get.<locals>.<dictcomp>)rV   r   inspectisclass	TypeError
isinstancerB   	frozensetrW   _NODETYPE_REGISTRYvalues_Cis_dict_insertion_ordered&_DICT_INSERTION_ORDERED_REGISTRY_ENTRYdict-_DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRYr   rj   r   r   r   )rs   rC   registryr   rL   r   rM   rw      sN   <







)r?   r;   r:   r=   r<   type[Collection[T]]r?   type[PyTreeEntry]c               C  s   t | std| dt |rt|tstd|d|tur0t|ts0td|d|dkr8td|tu rA| }d}n|| f}t	 t
| |||| t| ||||dt|< W d   | S 1 sgw   Y  | S )	a  Extend the set of types that are considered internal nodes in pytrees.

    See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`.

    The ``namespace`` argument is used to avoid collisions that occur when different libraries
    register the same Python type with different behaviors. It is recommended to add a unique prefix
    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
    the same class in different namespaces for different use cases.

    .. warning::
        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
        prevent accidental collisions between different libraries that may register the same type.

    Args:
        cls (type): A Python type to treat as an internal pytree node.
        flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
            and returning a triple or optionally a pair, with (1) an iterable for the children to be
            flattened recursively, and (2) some hashable metadata to be stored in the treespec and
            to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path
            entries to the corresponding children. If the entries are not provided or given by
            :data:`None`, then `range(len(children))` will be used.
        unflatten_func (callable): A function taking two arguments: the metadata that was returned
            by ``flatten_func`` and stored in the treespec, and the unflattened children. The
            function should return an instance of ``cls``.
        path_entry_type (type, optional): The type of the path entry to be used in the treespec.
            (default: :class:`AutoEntry`)
        namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
            This is used to isolate the registry from other modules that might register a different
            custom behavior for the same type.

    Returns:
        The same type as the input ``cls``.

    Raises:
        TypeError: If the input type is not a class.
        TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
        TypeError: If the namespace is not a string.
        ValueError: If the namespace is an empty string.
        ValueError: If the type is already registered in the registry.

    .. versionadded:: 0.12.0
        The ``path_entry_type`` argument to specify the path entry type used in
        :meth:`PyTreeSpec.accessors` and :func:`tree_flatten_with_accessor`.
        If not provided, :class:`AutoEntry` will be used.

    Examples:
        >>> # Register a Python type with lambda functions
        >>> register_pytree_node(
        ...     set,
        ...     lambda s: (sorted(s), None, None),
        ...     lambda _, children: set(children),
        ...     namespace='set',
        ... )
        <class 'set'>

        >>> # Register a Python type into a namespace
        >>> import torch
        >>> register_pytree_node(
        ...     torch.Tensor,
        ...     flatten_func=lambda tensor: (
        ...         (tensor.cpu().detach().numpy(),),
        ...         {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
        ...     ),
        ...     unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
        ...     namespace='torch2numpy',
        ... )
        <class 'torch.Tensor'>

        >>> # doctest: +SKIP
        >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
        >>> tree
        {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}

        >>> # Flatten without specifying the namespace
        >>> tree_flatten(tree)  # `torch.Tensor`s are leaf nodes
        ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))

        >>> # Flatten with the namespace
        >>> tree_flatten(tree, namespace='torch2numpy')
        (
            [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
            PyTreeSpec(
                {
                    'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
                    'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
                },
                namespace='torch2numpy'
            )
        )

        >>> # Register the same type with a different namespace for different behaviors
        >>> def tensor2flatparam(tensor):
        ...     return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
        ...
        ... def flatparam2tensor(metadata, children):
        ...     return children[0].reshape(metadata)
        ...
        ... register_pytree_node(
        ...     torch.Tensor,
        ...     flatten_func=tensor2flatparam,
        ...     unflatten_func=flatparam2tensor,
        ...     namespace='tensor2flatparam',
        ... )
        <class 'torch.Tensor'>

        >>> # Flatten with the new namespace
        >>> tree_flatten(tree, namespace='tensor2flatparam')
        (
            [
                Parameter containing: tensor([0., 0.], requires_grad=True),
                Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
            ],
            PyTreeSpec(
                {
                    'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
                    'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
                },
                namespace='tensor2flatparam'
            )
        )
    Expected a class, got r}   (Expected a subclass of PyTreeEntry, got r~   rA   (The namespace cannot be an empty string.r?   rC   N)r   r   r   
issubclassr   rV   r   rB   
ValueErrorrW   r   register_noder8   r   )rs   r;   r=   r?   rC   registration_keyrL   rL   rM   r,     s@    

r,   r   
str | Nonetype[PyTreeEntry] | None2Callable[[CustomTreeNodeType], CustomTreeNodeType]c               C  ru   rc   rL   rs   r?   rC   rL   rL   rM   r-        r-   c               C  ru   rc   rL   r   rL   rL   rM   r-     r   CustomTreeNodeType | str | NoneGCustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]c                 s   t u s	t trdurtd dkrtdd  du r&tdt ur6tts6tddkr>td du rLd%fd
d}|S t sYtd ddu rct dttrmt	t
sutddt fdddD st fdddD st dt jd&dd}ttt jd jd'd!d"}| _| _t td# jd$  S )(a  Extend the set of types that are considered internal nodes in pytrees.

    See also :func:`register_pytree_node` and :func:`unregister_pytree_node`.

    The ``namespace`` argument is used to avoid collisions that occur when different libraries
    register the same Python type with different behaviors. It is recommended to add a unique prefix
    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
    the same class in different namespaces for different use cases.

    .. warning::
        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
        prevent accidental collisions between different libraries that may register the same type.

    Args:
        cls (type, optional): A Python type to treat as an internal pytree node.
        path_entry_type (type, optional): The type of the path entry to be used in the treespec.
            (default: :class:`AutoEntry`)
        namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
            type registry. This is used to isolate the registry from other modules that might
            register a different custom behavior for the same type.

    Returns:
        The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator
        function that registers the class as a pytree node.

    Raises:
        TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
        TypeError: If the namespace is not a string.
        TypeError: If the class does not define the required method pairs.
        ValueError: If the namespace is an empty string.
        ValueError: If the type is already registered in the registry.

    .. versionadded:: 0.12.0
        The ``TREE_PATH_ENTRY_TYPE`` class variable to specify the path entry type used in
        :meth:`PyTreeSpec.accessors` and :func:`tree_flatten_with_accessor`.
        If not provided, :class:`AutoEntry` will be used.

    .. versionadded:: 0.18.0
        Previously, this function looked for methods named ``tree_flatten`` and ``tree_unflatten``
        for the given class. Since version 0.18.0, it prefers methods named ``__tree_flatten__``
        and ``__tree_unflatten__`` instead. The old method names are still supported for
        backward compatibility, but it is recommended to use the new method names.
        The method resolution follows this priority:
        1. If both ``__tree_flatten__`` and ``__tree_unflatten__`` are defined, use them directly.
        2. If both ``tree_flatten`` and ``tree_unflatten`` are defined, wrap them as dunder methods.
        3. If neither complete pair is available, raise a :exc:`TypeError` suggesting the new method names.

    This function is a thin wrapper around :func:`register_pytree_node`, and provides a
    class-oriented interface:

    .. code-block:: python

        @register_pytree_node_class(namespace='foo')
        class Special:
            TREE_PATH_ENTRY_TYPE = GetAttrEntry

            def __init__(self, x, y):
                self.x = x
                self.y = y

            def __tree_flatten__(self):
                return ((self.x, self.y), None, ('x', 'y'))

            @classmethod
            def __tree_unflatten__(cls, metadata, children):
                return cls(*children)

        @register_pytree_node_class('mylist')
        class MyList(UserList):
            TREE_PATH_ENTRY_TYPE = SequenceEntry

            def __tree_flatten__(self):
                return self.data, None, None

            @classmethod
            def __tree_unflatten__(cls, metadata, children):
                return cls(*children)

        # Legacy style (still supported but not recommended)
        @register_pytree_node_class(namespace='legacy')
        class LegacyStyleMyList(UserList):
            def tree_flatten(self):
                # Implementation automatically wrapped as __tree_flatten__
                return self.data, None, None

            @classmethod
            def tree_unflatten(cls, metadata, children):
                # Implementation automatically wrapped as __tree_unflatten__
                return cls(*children)
    Nz?Cannot specify `namespace` when the first argument is a string.rA   r   z<Must specify `namespace` when the first argument is a class.r~   rs   r*   rP   c                  s   t |  dS )Nr   )r-   rs   )rC   r?   rL   rM   rq   6  s
   z-register_pytree_node_class.<locals>.decoratorr   r}   TREE_PATH_ENTRY_TYPEr   c                 3       | ]}t t |d V  qd S rc   callablegetattrr   methodr   rL   rM   	<genexpr>G  s
    
z-register_pytree_node_class.<locals>.<genexpr>)__tree_flatten____tree_unflatten__c                 3  r   rc   r   r   r   rL   rM   r   L  s    
)tree_flattentree_unflattenzh must define both `__tree_flatten__` and `__tree_unflatten__` methods for registration as a pytree node.rS   CustomTreeNode[T]Qtuple[Children[T], MetaData] | tuple[Children[T], MetaData, Iterable[Any] | None]c                S  s   |   S rc   )r   rR   rL   rL   rM   r   W     z4register_pytree_node_class.<locals>.__tree_flatten____func__type[CustomTreeNode[T]]metadatar   childrenChildren[T]c                S  s   |  ||S rc   )r   )rs   r   r   rL   rL   rM   r   ^  s   z6register_pytree_node_class.<locals>.__tree_unflatten__r   r   )rs   r*   rP   r*   )rS   r   rP   r   )rs   r   r   r   r   r   rP   r   )rV   r   rB   r   r   r   r   r   r   r   r   all	functoolswrapsr   classmethodr   r   r   r,   r   )rs   r?   rC   rq   r   r   rL   )rs   rC   r?   rM   r-     sZ   b



c               C  s   t | std| d|turt|tstd|d|dkr&td|tu r/| }d}n|| f}t t	| | t
|W  d   S 1 sKw   Y  dS )a  Remove a type from the pytree node registry.

    See also :func:`register_pytree_node` and :func:`register_pytree_node_class`.

    This function is the inverse operation of function :func:`register_pytree_node`.

    Args:
        cls (type): A Python type to remove from the pytree node registry.
        namespace (str): The namespace of the pytree node registry to remove the type from.

    Returns:
        The removed registry entry.

    Raises:
        TypeError: If the input type is not a class.
        TypeError: If the namespace is not a string.
        ValueError: If the namespace is an empty string.
        ValueError: If the type is a built-in type that cannot be unregistered.
        ValueError: If the type is not found in the registry.

    Examples:
        >>> # Register a Python type with lambda functions
        >>> register_pytree_node(
        ...     set,
        ...     lambda s: (sorted(s), None, None),
        ...     lambda _, children: set(children),
        ...     namespace='temp',
        ... )
        <class 'set'>

        >>> # Unregister the Python type
        >>> unregister_pytree_node(set, namespace='temp')
    r   r}   r~   rA   r   N)r   r   r   rV   r   rB   r   rW   r   unregister_noder   pop)rs   rC   r   rL   rL   rM   r.   w  s   
"$r.   modeboolGenerator[None]c               c  s    |t urt|tstd|d|dkrtd|t u r d}t tj|dd}tt	| | W d   n1 s<w   Y  zdV  W t t|| W d   dS 1 sZw   Y  dS t t|| W d   w 1 stw   Y  w )a  Context manager to temporarily set the dictionary sorting mode.

    This context manager is used to temporarily set the dictionary sorting mode for a specific
    namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
    should be sorted or keep the insertion order when flattening a pytree.

    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
    >>> tree_flatten(tree)  # doctest: +IGNORE_WHITESPACE
    (
        [1, 2, 3, 4, 5],
        PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
    )
    >>> with dict_insertion_ordered(True, namespace='some-namespace'):  # doctest: +IGNORE_WHITESPACE
    ...     tree_flatten(tree, namespace='some-namespace')
    (
        [2, 3, 4, 1, 5],
        PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
    )

    .. warning::
        The dictionary sorting mode is a global setting and is **not thread-safe**. It is
        recommended to use this context manager in a single-threaded environment.

    Args:
        mode (bool): The dictionary sorting mode to set.
        namespace (str): The namespace to set the dictionary sorting mode for.
    r~   r}   rA   r   F)inherit_global_namespaceN)
rV   r   rB   r   r   rW   r   r   set_dict_insertion_orderedr   )r   rC   prevrL   rL   rM   r/     s$   (r/   itemsIterable[tuple[KT, VT]]list[tuple[KT, VT]]c                C  s   t | tddS )Nr   )key)r    r   )r   rL   rL   rM   _sorted_items  s   r   r>   tuple[tuple[()], None]c                C  rQ   )N)rL   NrL   )r>   rL   rL   rM   _none_flatten  rU   r   r   Iterable[Any]c                C  s$   t  }tt|||urtdd S )NzExpected no children.)objectnextiterr   )r>   r   sentinelrL   rL   rM   _none_unflatten  s   r   tuptuple[T, ...]tuple[tuple[T, ...], None]c                C     | d fS rc   rL   r   rL   rL   rM   _tuple_flatten     r   Iterable[T]c                C     t |S rc   )tupler>   r   rL   rL   rM   _tuple_unflatten  r   r   lstlist[T]tuple[list[T], None]c                C  r   rc   rL   )r   rL   rL   rM   _list_flatten  r   r   c                C  r   rc   )listr   rL   rL   rM   _list_unflatten  r   r   dctdict[KT, VT]/tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]c                C  s"   t t|  \}}|t||fS rc   )r!   r   r   r   r   keysr   rL   rL   rM   _dict_flatten  s   r   r   list[KT]r   Iterable[VT]c                C     t t| |S rc   r   r   r   r   rL   rL   rM   _dict_unflatten     r   c                C     t |  \}}|t||fS rc   r!   r   r   r   rL   rL   rM   _dict_insertion_ordered_flatten     r   c                C  r   rc   r   r   rL   rL   rM   !_dict_insertion_ordered_unflatten  r   r   OrderedDict[KT, VT]c                C  r   rc   r   r   rL   rL   rM   _ordereddict_flatten  r   r   c                C  r   rc   )r   r   r   rL   rL   rM   _ordereddict_unflatten  r   r   defaultdict[KT, VT]Otuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]c                C     t | \}}}|| j|f|fS rc   )r   default_factoryr   r   dict_metadataentriesrL   rL   rM   _defaultdict_flatten"     r   r   !tuple[Callable[[], VT], list[KT]]c                C     | \}}t |t||S rc   )r   r   r   r   r   r   rL   rL   rM   _defaultdict_unflatten.     r  c                C  r   rc   )r   r   r   rL   rL   rM   &_defaultdict_insertion_ordered_flatten7  r   r  c                C  r  rc   )r   r   r  rL   rL   rM   (_defaultdict_insertion_ordered_unflattenC  r  r  deqdeque[T]tuple[deque[T], int | None]c                C  s
   | | j fS rc   maxlen)r  rL   rL   rM   _deque_flattenL  s   
r  r  
int | Nonec                C  s   t || dS )Nr
  )r   )r  r   rL   rL   rM   _deque_unflattenP     r  NamedTuple[T])tuple[tuple[T, ...], type[NamedTuple[T]]]c                C     | t | fS rc   r9   r   rL   rL   rM   _namedtuple_flattenT  r  r  type[NamedTuple[T]]c                C  s   | | S rc   rL   rs   r   rL   rL   rM   _namedtuple_unflattenY  r   r  seqStructSequence[T]-tuple[tuple[T, ...], type[StructSequence[T]]]c                C  r  rc   r  )r  rL   rL   rM   _structseq_flatten]  r  r  type[StructSequence[T]]c                C  s   | |S rc   rL   r  rL   rL   rM   _structseq_unflattena  r   r  )r?   r@   z6dict[type | tuple[str, type], PyTreeNodeRegistryEntry]r   rL   )rj   rk   rP   rl   )rs   r9   rC   rB   rP   rt   rc   )rs   ry   rC   rB   rP   rz   )rs   r{   rC   rB   rP   r|   )r;   r:   r=   r<   rs   r   r?   r   rC   rB   rP   r   )rs   r   r?   r   rC   r   rP   r   )rs   r*   r?   r   rC   rB   rP   r*   )rs   r   r?   r   rC   r   rP   r   )rs   r9   rC   rB   rP   r8   )r   r   rC   rB   rP   r   )r   r   rP   r   )r>   ry   rP   r   )r>   ry   r   r   rP   ry   )r   r   rP   r   )r>   ry   r   r   rP   r   )r   r   rP   r   )r>   ry   r   r   rP   r   )r   r   rP   r   )r   r   r   r   rP   r   )r   r   rP   r   )r   r   r   r   rP   r   )r   r   rP   r   )r   r   r   r   rP   r   )r  r  rP   r	  )r  r  r   r   rP   r  )r   r  rP   r  )rs   r  r   r   rP   r  )r  r  rP   r  )rs   r  r   r   rP   r  )yrG   
__future__r   
contextlibdataclassesr   r   rI   collectionsr   r   r   r   operatorr   r   	threadingr	   typingr
   r   r   r   r   r   r   r   	optree._Cr   optree.accessorsr   r   r   r   r   r   optree.typingr   r   r   r   r   r   r   optree.utilsr   r    r!   builtinscollections.abcr"   r#   r$   r%   r&   r'   r(   r)   r9   r*   __all__rJ   SLOTS	dataclassr8   rN   rV   rH   rW   typing_extensionsrX   rY   rZ   r[   r\   r]   rr   rw   r,   r-   r.   contextmanagerr/   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r  r  r  r  NONEr   TUPLEr   LISTr   DICT
NAMEDTUPLEORDEREDDICTDEFAULTDICTDEQUESTRUCTSEQUENCEr   r   r   rL   rL   rL   rM   <module>   sX  ( $	
	
k &	 
45
.














	

	




	C
