o
    cij                     @   s  d dl Z d dlZd dlmZ d dlZd dlZd dlm	Z	 d dl
mZ d dlmZ d dlmZmZmZ d dlmZmZ d dlmZmZ d d	lmZ d d
lmZmZmZmZmZ e \Z Z!Z"e Z#eG dd deZ$eG dd de$Z%ede&fddZ'eG dd de$Z(eG dd de%Z)eG dd de$Z*eG dd de$Z+eG dd de$Z,eG dd de$Z-eG dd de$Z.eG d d! d!e$Z/eG d"d# d#e$Z0dS )$    N)log)Optional)ActionDistribution)ModelV2)MIN_LOG_NN_OUTPUTMAX_LOG_NN_OUTPUTSMALL_NUMBER)OldAPIStackoverride)try_import_tftry_import_tfp)get_base_struct_from_space)
TensorTypeListUnionTupleModelConfigDictc                       sp   e Zd ZdZeedee def fddZ	defddZ
eedefd	d
ZeedefddZ  ZS )TFActionDistributionz9TF-specific extensions for building action distributions.inputsmodelc                    s*   t  || |  | _| | j| _d S N)super__init___build_sample_op	sample_oplogpsampled_action_logp_op)selfr   r   	__class__ V/home/ubuntu/.local/lib/python3.10/site-packages/ray/rllib/models/tf/tf_action_dist.pyr      s   
zTFActionDistribution.__init__returnc                 C   s   t )zImplement this instead of sample(), to enable op reuse.

        This is needed since the sample op is non-deterministic and is shared
        between sample() and sampled_action_logp().
        )NotImplementedErrorr   r    r    r!   r      s   z%TFActionDistribution._build_sample_opc                 C      | j S )z+Draw a sample from the action distribution.)r   r$   r    r    r!   sample&      zTFActionDistribution.samplec                 C   r%   )z2Returns the log probability of the sampled action.)r   r$   r    r    r!   sampled_action_logp+   r'   z(TFActionDistribution.sampled_action_logp)__name__
__module____qualname____doc__r
   r   r   r   r   r   r   r&   r(   __classcell__r    r    r   r!   r      s    r   c                       s   e Zd ZdZ	ddee dedef fddZe	e
d	efd
dZe	e
ded	efddZe	e
d	efddZe	e
de
d	efddZe	ed	efddZee	e
dd Z  ZS )Categoricalz4Categorical distribution for discrete action spaces.N      ?r   r   temperaturec                    s&   |dksJ dt  || | d S )N        (Categorical `temperature` must be > 0.0!r   r   r   r   r   r0   r   r    r!   r   5   s   zCategorical.__init__r"   c                 C   s   t jj| jddS N   axis)tfmathargmaxr   r$   r    r    r!   deterministic_sample=      z Categorical.deterministic_samplexc                 C   s   t jj| jt |t jd S )N)logitslabels)r9   nn(sparse_softmax_cross_entropy_with_logitsr   castint32r   r>   r    r    r!   r   A   s   zCategorical.logpc                 C   sX   | j tj| j ddd }t|}tj|ddd}|| }tj|tj||  ddS Nr6   Tr8   keepdimsr7   r   r9   
reduce_maxexp
reduce_sumr:   r   )r   a0ea0z0p0r    r    r!   entropyG   s
   
zCategorical.entropyotherc           	      C   s   | j tj| j ddd }|j tj|j ddd }t|}t|}tj|ddd}tj|ddd}|| }tj||tj| | tj|  ddS rF   rI   )	r   rR   rM   a1rN   ea1rO   z1rP   r    r    r!   klO   s   

.zCategorical.klc                 C   s   t jt j| jdddS r5   )r9   squeezerandomcategoricalr   r$   r    r    r!   r   Z      zCategorical._build_sample_opc                 C   r%   r   naction_spacemodel_configr    r    r!   required_model_output_shape^   r'   z'Categorical.required_model_output_shapeNr/   )r)   r*   r+   r,   r   r   r   floatr   r
   r   r<   r   rQ   rV   r   r   staticmethodr`   r-   r    r    r   r!   r.   1   s.    
r.   tc                    s   G  fdddt }|S )zGCategorical distribution class that has customized default temperature.c                       s"   e Zd Zdf fdd	Z  ZS )zJget_categorical_class_with_temperature.<locals>.CategoricalWithTemperatureNc                    s   t  ||| d S r   r3   r4   r   r    r!   r   i   s   zSget_categorical_class_with_temperature.<locals>.CategoricalWithTemperature.__init__)r)   r*   r+   r   r-   r    rd   r   r!   CategoricalWithTemperatureh   s    rf   r.   )rd   rf   r    re   r!   &get_categorical_class_with_temperatured   s   rh   c                   @   s  e Zd ZdZ	ddee dedeee e	j
eedf f fddZeed	efd
dZeeded	efddZeed	efddZeed	efddZeeded	efddZeeded	efddZeed	efddZeeedejded	eee	j
f fddZdS )MultiCategoricalz>MultiCategorical distribution for MultiDiscrete action spaces.Nr   r   
input_lens.c                    st   t | |   fddtj||ddD | _|| _| jd u r,tjdd | jD | _| 	 | _
| | j
| _d S )Nc                    s   g | ]}t | qS r    rg   ).0input_r   r    r!   
<listcomp>|   s    z-MultiCategorical.__init__.<locals>.<listcomp>r6   r7   c                 S   s   g | ]}|j jd  qS )r6   )r   shape)rk   cr    r    r!   rn      s    )r   r   r9   splitcatsr^   gymspacesMultiDiscreter   r   r   r   )r   r   r   rj   r^   r    rm   r!   r   s   s   


zMultiCategorical.__init__r"   c                 C   sT   t jdd | jD dd}t| jtjjr(t t 	|dgt
| jj | jjS |S )Nc                 S      g | ]}|  qS r    r<   rk   catr    r    r!   rn          z9MultiCategorical.deterministic_sample.<locals>.<listcomp>r6   r7   r9   stackrr   
isinstancer^   rs   rt   BoxrC   reshapelistro   dtype)r   sample_r    r    r!   r<      s   z%MultiCategorical.deterministic_sampleactionsc                 C   s   t |tjr<t | jtjjrt|dtt	
| jjg}nt | jtjjr0|d t| jf tjt|tjdd}tdd t| j|D }tj|ddS )Nr{   r6   r7   c                 S      g | ]	\}}| |qS r    )r   )rk   ry   actr    r    r!   rn          z)MultiCategorical.logp.<locals>.<listcomp>r   )r~   r9   Tensorr^   rs   rt   r   r   intnpprodro   ru   	set_shapelenrr   unstackrC   rD   r}   ziprL   )r   r   logpsr    r    r!   r      s   zMultiCategorical.logpc                 C   s   t jdd | jD ddS )Nc                 S   rv   r    rQ   rx   r    r    r!   rn      rz   z2MultiCategorical.multi_entropy.<locals>.<listcomp>r6   r7   )r9   r}   rr   r$   r    r    r!   multi_entropy   rZ   zMultiCategorical.multi_entropyc                 C   s   t j|  ddS r5   )r9   rL   r   r$   r    r    r!   rQ      r=   zMultiCategorical.entropyrR   c                 C   s"   t jdd t| j|jD ddS )Nc                 S   r   r    rV   )rk   ry   oth_catr    r    r!   rn      r   z-MultiCategorical.multi_kl.<locals>.<listcomp>r6   r7   )r9   r}   r   rr   r   rR   r    r    r!   multi_kl   s   zMultiCategorical.multi_klc                 C   s   t j| |ddS r5   )r9   rL   r   r   r    r    r!   rV      s   zMultiCategorical.klc                 C   sV   t jdd | jD dd}t| jtjjr)t jt 	|dgt
| jj | jjdS |S )Nc                 S   rv   r    r&   rx   r    r    r!   rn      rz   z5MultiCategorical._build_sample_op.<locals>.<listcomp>r6   r7   r{   r   r|   )r   r   r    r    r!   r      s   z!MultiCategorical._build_sample_opr^   r_   c                 C   s   t | tjjr?| jjdsJ t| j	}t
| j}t| j	|ks&J t| j|ks0J tj| jtjd|| d  S t| jS )Nr   r   r6   )r~   rs   rt   r   r   name
startswithr   minlowmaxhighallr   ro   rD   sumnvec)r^   r_   low_high_r    r    r!   r`      s   z,MultiCategorical.required_model_output_shaper   )r)   r*   r+   r,   r   r   r   r   r   r   ndarrayr   r   r
   r   r<   r   r   rQ   r   rV   r   r   rc   rs   Spacer   r`   r    r    r    r!   ri   o   sB    
		ri   c                
       s|   e Zd ZdZ				ddee dededee	j
j f fdd	Zeed
ef fddZeeded
efddZ  ZS )SlateMultiCategoricalaf  MultiCategorical distribution for MultiDiscrete action spaces.

    The action space must be uniform, meaning all nvec items have the same size, e.g.
    MultiDiscrete([10, 10, 10]), where 10 is the number of candidates to pick from
    and 3 is the slate size (pick 3 out of 10). When picking candidates, no candidate
    must be picked more than once.
    Nr/   r   r   r0   r^   c                    s`   |dksJ dt  || | | _t jtjjr)t fdd jjD s+J | _	d S )Nr1   r2   c                 3   s     | ]}| j jd  kV  qdS )r   N)r^   r   )rk   r\   r$   r    r!   	<genexpr>   s    
z1SlateMultiCategorical.__init__.<locals>.<genexpr>)
r   r   r^   r~   rs   rt   ru   r   r   
all_slates)r   r   r   r0   r^   r   r   r$   r!   r      s   
zSlateMultiCategorical.__init__r"   c                    s   t   }t| j|S r   )r   r<   r9   gatherr   r   r&   r   r    r!   r<      s   
z*SlateMultiCategorical.deterministic_sampler>   c                 C   s   t | jd d df S )Nr   )r9   	ones_liker   rE   r    r    r!   r      s   zSlateMultiCategorical.logp)Nr/   NN)r)   r*   r+   r,   r   r   r   rb   r   rs   rt   ru   r   r
   r   r<   r   r-   r    r    r   r!   r      s&    
r   c                
       s   e Zd ZdZ	ddee dedef fddZe	e
d	efd
dZe	e
ded	efddZe	ed	efddZee	e
dejded	eeejf fddZ  ZS )GumbelSoftmaxa  GumbelSoftmax distr. (for differentiable sampling in discr. actions

    The Gumbel Softmax distribution [1] (also known as the Concrete [2]
    distribution) is a close cousin of the relaxed one-hot categorical
    distribution, whose tfp implementation we will use here plus
    adjusted `sample_...` and `log_prob` methods. See discussion at [0].

    [0] https://stackoverflow.com/questions/56226133/
    soft-actor-critic-with-discrete-action-space

    [1] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017):
    https://arxiv.org/abs/1611.01144
    [2] The Concrete Distribution: A Continuous Relaxation of Discrete Random
    Variables (Maddison et al, 2017) https://arxiv.org/abs/1611.00712
    Nr/   r   r   r0   c                    sD   |dksJ t jj||d| _tj| jjj| _	t
 || dS )aA  Initializes a GumbelSoftmax distribution.

        Args:
            temperature: Temperature parameter. For low temperatures,
                the expected value approaches a categorical random variable.
                For high temperatures, the expected value approaches a uniform
                distribution.
        r1   )r0   r?   N)tfpdistributionsRelaxedOneHotCategoricaldistr9   rA   softmax_distributionr?   probsr   r   r4   r   r    r!   r     s   zGumbelSoftmax.__init__r"   c                 C   r%   r   )r   r$   r    r    r!   r<     r'   z"GumbelSoftmax.deterministic_sampler>   c                 C   sz   |j | jjj kr*tj|| jjj  d tjd}|j | jjj ks*J |j | jjj ftj| tjj	| jjdd dd S )Nr{   r   r7   )
ro   r   r?   r9   one_hotas_listfloat32rL   rA   log_softmax)r   r>   valuesr    r    r!   r   #  s   zGumbelSoftmax.logpc                 C   
   | j  S r   r   r&   r$   r    r    r!   r   6     
zGumbelSoftmax._build_sample_opr^   r_   c                 C   r%   r   r[   r]   r    r    r!   r`   :  s   z)GumbelSoftmax.required_model_output_shapera   )r)   r*   r+   r,   r   r   r   rb   r   r
   r   r<   r   r   r   rc   rs   r   r   r   r   r   r   r`   r-   r    r    r   r!   r      s2    r   c                
       s   e Zd ZdZdddee dedeej	j
 f fddZeed	efd
dZeeded	efddZeeded	efddZeed	efddZeed	efddZeeedej
ded	eeejf fddZ  ZS )DiagGaussianzAction distribution where each vector element is a gaussian.

    The first half of the input vector defines the gaussian means, and the
    second half the gaussian standard deviations.
    N)r^   r   r   r^   c                   sN   t j|ddd\}}|| _|| _t || _|o|jdk| _t 	|| d S )N   r6   r7   r    )
r9   rq   meanlog_stdrK   stdro   zero_action_dimr   r   )r   r   r   r^   r   r   r   r    r!   r   J  s   zDiagGaussian.__init__r"   c                 C   r%   r   )r   r$   r    r    r!   r<   Y     z!DiagGaussian.deterministic_sampler>   c                 C   s   t t|jd dkrtj|dd}dtjtjt|tj| j	 | j
 dd dtdtj  tt|d tj  tj| jdd S )Nr   r6   r7   g            ?       @)r   r9   ro   expand_dimsrL   r:   squarerC   r   r   r   r   r   pir   rE   r    r    r!   r   ]  s   "*zDiagGaussian.logprR   c                 C   s\   t |tsJ tj|j| j tj| jtj| j|j  dtj|j   d ddS )Nr   r   r6   r7   )	r~   r   r9   rL   r   r:   r   r   r   r   r    r    r!   rV   k  s    zDiagGaussian.klc                 C   s*   t j| jdtdtj tj   ddS )Nr   r   r6   r7   )r9   rL   r   r   r   r   er$   r    r    r!   rQ   w  s   *zDiagGaussian.entropyc                 C   s8   | j | jtjt| j   }| jrtj|ddS |S Nr{   r7   )r   r   r9   rX   normalro   r   rW   r   r    r    r!   r   {  s    zDiagGaussian._build_sample_opr_   c                 C      t j| jt jdd S Nr   r   r   r   ro   rD   r]   r    r    r!   r`        z(DiagGaussian.required_model_output_shape)r)   r*   r+   r,   r   r   r   r   rs   rt   r   r   r
   r   r<   r   rV   rQ   r   r   rc   r   r   r   r   r   r`   r-   r    r    r   r!   r   B  s:    
r   c                
       s  e Zd ZdZ		d#dee dededef fdd	Ze	e
d
efddZe	ed
efddZe	e
ded
efddZdd Ze	e
d
efddZe	e
de
d
efddZded
efddZded
efddZee	e
dejd ed
eeejf fd!d"Z  ZS )$SquashedGaussianzA tanh-squashed Gaussian distribution defined by: mean, std, low, high.

    The distribution will never return low or high exactly, but
    `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
          r/   r   r   r   r   c                    s~   t dusJ tj|ddd\}}t|tt}t|}t jj||d| _	t
t
||s0J || _|| _t || dS )zParameterizes the distribution via `inputs`.

        Args:
            low: The lowest possible sampling value
                (excluding this value).
            high: The highest possible sampling value
                (excluding this value).
        Nr   r{   r7   )locscale)r   r9   rq   clip_by_valuer   r   rK   r   Normaldistrr   r   lessr   r   r   r   )r   r   r   r   r   r   r   r   r   r    r!   r     s   
zSquashedGaussian.__init__r"   c                 C      | j  }| |S r   )r   r   _squashr   r   r    r    r!   r<        

z%SquashedGaussian.deterministic_samplec                 C      |  | j S r   )r   r   r&   r$   r    r    r!   r        z!SquashedGaussian._build_sample_opr>   c                 C   st   t | || jj}| j|}t |dd}t j|dd}t j	
|}|t jt j	d|d  t dd }|S )Nid   r{   r7   r6   r   )r9   rC   	_unsquashr   r   r   log_probr   rL   r:   tanhr   r   )r   r>   unsquashed_valueslog_prob_gaussianunsquashed_values_tanhdr   r    r    r!   r     s   zSquashedGaussian.logpc                 C   sF   | j  }| |}|tj| j |tjd||  t  ddfS Nr6   r{   r7   )	r   r&   r   r9   rL   r   r:   r   r   )r   zr   r    r    r!   sample_logp  s   

"zSquashedGaussian.sample_logpc                 C      t d)Nz)Entropy not defined for SquashedGaussian!
ValueErrorr$   r    r    r!   rQ        zSquashedGaussian.entropyrR   c                 C   r   )Nz$KL not defined for SquashedGaussian!r   r   r    r    r!   rV     r   zSquashedGaussian.kl
raw_valuesc                 C   s8   t j|d d | j| j  | j }t || j| jS )Nr/   r   )r9   r:   r   r   r   r   )r   r   squashedr    r    r!   r     s   
zSquashedGaussian._squashr   c                 C   sD   || j  | j| j   d d }t|dt dt }tj|}|S )Nr   r/   r   )r   r   r9   r   r   r:   atanh)r   r   normed_valuessave_normed_values
unsquashedr    r    r!   r     s   zSquashedGaussian._unsquashr^   r_   c                 C   r   r   r   r]   r    r    r!   r`     r   z,SquashedGaussian.required_model_output_shape)r   r/   )r)   r*   r+   r,   r   r   r   rb   r   r
   r   r<   r   r   r   r   rQ   rV   r   r   rc   rs   r   r   r   r   r   r   r`   r-   r    r    r   r!   r     sF    
	r   c                
       s   e Zd ZdZ		ddee dededef fdd	Ze	e
d
efddZe	ed
efddZe	e
ded
efddZded
efddZded
efddZee	e
dejded
eeejf fddZ  ZS )BetaaB  
    A Beta distribution is defined on the interval [0, 1] and parameterized by
    shape parameters alpha and beta (also called concentration parameters).

    PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
        with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
        and Gamma(n) = (n - 1)!
    r1   r/   r   r   r   r   c                    sx   t |tttt }t jt j|d d }|| _|| _t j|ddd\}}t	j
j||d| _t || d S )Nr/   r   r{   r7   )concentration1concentration0)r9   r   r   r   r:   rK   r   r   rq   r   r   r   r   r   r   )r   r   r   r   r   alphabetar   r    r!   r     s   zBeta.__init__r"   c                 C   r   r   )r   r   r   r   r    r    r!   r<   	  r   zBeta.deterministic_samplec                 C   r   r   )r   r   r&   r$   r    r    r!   r     r   zBeta._build_sample_opr>   c                 C   s"   |  |}tjj| j|ddS r   )r   r9   r:   rL   r   r   )r   r>   r   r    r    r!   r     s   
z	Beta.logpr   c                 C   s   || j | j  | j S r   )r   r   )r   r   r    r    r!   r        zBeta._squashr   c                 C   s   || j  | j| j   S r   )r   r   )r   r   r    r    r!   r     r   zBeta._unsquashr^   r_   c                 C   r   r   r   r]   r    r    r!   r`     r   z Beta.required_model_output_shape)r1   r/   )r)   r*   r+   r,   r   r   r   rb   r   r
   r   r<   r   r   r   r   r   rc   rs   r   r   r   r   r   r   r`   r-   r    r    r   r!   r     s<    r   c                
   @   s   e Zd ZdZeedefddZeededefddZ	eedefdd	Z
eeed
ejdedeeejf fddZdS )DeterministiczAction distribution that returns the input values directly.

    This is similar to DiagGaussian with standard deviation zero (thus only
    requiring the "mean" values as NN output).
    r"   c                 C   r%   r   r   r$   r    r    r!   r<   -  r   z"Deterministic.deterministic_sampler>   c                 C   s   t | jS r   )r9   
zeros_liker   rE   r    r    r!   r   1  s   zDeterministic.logpc                 C   r%   r   r   r$   r    r    r!   r   5  r   zDeterministic._build_sample_opr^   r_   c                 C      t j| jt jdS Nr   r   r]   r    r    r!   r`   9     z)Deterministic.required_model_output_shapeN)r)   r*   r+   r,   r
   r   r   r<   r   r   r   rc   rs   r   r   r   r   r   r   r`   r    r    r    r!   r   %  s"    r   c                   @   s   e Zd ZdZdd Zeedd Zeedd Zeedd	 Z	eed
d Z
eedd Zeedd Zeedd ZdS )MultiActionDistributionzAction distribution that operates on a set of actions.

    Args:
        inputs (Tensor list): A list of tensors from which to compute samples.
    c                   sZ   t | | t|| _tj|tjd| _tj	|| jdd}t
 fdd||| _d S )Nr   r6   r7   c                    s   | |fi  S r   r    )r   rl   kwargsr   r    r!   <lambda>S  s    z2MultiActionDistribution.__init__.<locals>.<lambda>)r   r   r   action_space_structr   arrayrD   rj   r9   rq   treemap_structureflat_child_distributions)r   r   r   child_distributionsrj   r^   r   split_inputsr    r   r!   r   I  s   

z MultiActionDistribution.__init__c                 C   s   t |tjtjfrUg }| jD ]=}t |tr|d qt |tr0|j	d ur0|t
|j	j q| }t|jdkrA|d q|t|d  qtj||dd}nt|}dd }t||| j}tdd |S )Nr6   r7   c                 S   s>   t |trtt| jdkrtj| ddn| tj} || S r   )	r~   r.   r9   rC   r   ro   rW   rD   r   )valr   r    r    r!   map_p  s
   
"
z*MultiActionDistribution.logp.<locals>.map_c                 S      | | S r   r    abr    r    r!   r   |      z.MultiActionDistribution.logp.<locals>.<lambda>)r~   r9   r   r   r   r  r.   appendri   r^   r   ro   r&   r   rq   r  flattenr  	functoolsreduce)r   r>   split_indicesr   r&   split_xr	  
flat_logpsr    r    r!   r   X  s$   




zMultiActionDistribution.logpc                 C   s(   dd t | j|jD }tdd |S )Nc                 S   r   r    r   )rk   dor    r    r!   rn     s    z.MultiActionDistribution.kl.<locals>.<listcomp>c                 S   r
  r   r    r  r    r    r!   r     r  z,MultiActionDistribution.kl.<locals>.<lambda>)r   r  r  r  )r   rR   kl_listr    r    r!   rV   ~  s   zMultiActionDistribution.klc                 C   s    dd | j D }tdd |S )Nc                 S   rv   r    r   )rk   r  r    r    r!   rn     rz   z3MultiActionDistribution.entropy.<locals>.<listcomp>c                 S   r
  r   r    r  r    r    r!   r     r  z1MultiActionDistribution.entropy.<locals>.<lambda>)r  r  r  )r   entropy_listr    r    r!   rQ     s   zMultiActionDistribution.entropyc                 C       t | j| j}t dd |S )Nc                 S      |   S r   r   sr    r    r!   r     r  z0MultiActionDistribution.sample.<locals>.<lambda>r  unflatten_asr  r  r  r   r  r    r    r!   r&     s   zMultiActionDistribution.samplec                 C   r  )Nc                 S   r  r   rw   r  r    r    r!   r     r  z>MultiActionDistribution.deterministic_sample.<locals>.<lambda>r  r   r    r    r!   r<     s   z,MultiActionDistribution.deterministic_samplec                 C   s2   | j d  }| j dd  D ]}|| 7 }q|S )Nr   r6   )r  r(   )r   prp   r    r    r!   r(     s   z+MultiActionDistribution.sampled_action_logpc                 C   r   r   )r   r   rj   rD   )r   r^   r_   r    r    r!   r`     r=   z3MultiActionDistribution.required_model_output_shapeN)r)   r*   r+   r,   r   r
   r   r   rV   rQ   r&   r<   r   r(   r`   r    r    r    r!   r   A  s"    
%
	



r   c                
       s   e Zd ZdZdee def fddZee	defddZ
ee	d	edefd
dZee	defddZee	de	defddZeedefddZeee	dejdedeeejf fddZ  ZS )	DirichletzDirichlet distribution for continuous actions that are between
    [0,1] and sum to 1.

    e.g. actions that represent resource allocation.r   r   c                    s<   d| _ t|| j  }tjj|ddd| _t || dS )aB  Input is a tensor of logits. The exponential of logits is used to
        parametrize the Dirichlet distribution as all parameters need to be
        positive. An arbitrary small epsilon is added to the concentration
        parameters to be zero due to numerical error.

        See issue #4440 for more details.
        gHz>TF)concentrationvalidate_argsallow_nan_statsN)	epsilonr9   rK   tf1r   r"  r   r   r   )r   r   r   r#  r   r    r!   r     s   zDirichlet.__init__r"   c                 C   s   t j| jjS r   )r9   rA   r   r   r#  r$   r    r    r!   r<     r   zDirichlet.deterministic_sampler>   c                 C   s.   t || j}|t j|ddd }| j|S )Nr{   TrG   )r9   maximumr&  rL   r   r   rE   r    r    r!   r     s   zDirichlet.logpc                 C   r   r   )r   rQ   r$   r    r    r!   rQ     r   zDirichlet.entropyrR   c                 C   s   | j |j S r   )r   kl_divergencer   r    r    r!   rV     s   zDirichlet.klc                 C   r   r   r   r$   r    r    r!   r     r   zDirichlet._build_sample_opr^   r_   c                 C   r   r   r   r]   r    r    r!   r`     r   z%Dirichlet.required_model_output_shape)r)   r*   r+   r,   r   r   r   r   r
   r   r<   r   rQ   rV   r   r   rc   rs   r   r   r   r   r   r   r`   r-   r    r    r   r!   r"    s,    r"  )1r  	gymnasiumrs   r:   r   numpyr   r  typingr   ray.rllib.models.action_distr   ray.rllib.models.modelv2r   ray.rllib.utilsr   r   r   ray.rllib.utils.annotationsr	   r
   ray.rllib.utils.frameworkr   r   "ray.rllib.utils.spaces.space_utilsr   ray.rllib.utils.typingr   r   r   r   r   r'  r9   tfvr   r   r.   rb   rh   ri   r   r   r   r   r   r   r   r"  r    r    r    r!   <module>   sN    2
_*GGb7g