o
    `۷iJ                     @   s   d Z ddlZddlmZmZ ddlmZ ddlmZm	Z	 ddl
mZmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZ eeZdZdZ dZ!G dd deZ"G dd deZ#dS )aO  Asynchronous Proximal Policy Optimization (APPO)

The algorithm is described in [1] (under the name of "IMPACT"):

Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#appo

[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Luo et al. 2020
https://arxiv.org/pdf/1912.00167
    N)OptionalType)Self)DEPRECATED_VALUEdeprecation_warning)AlgorithmConfigNotProvided)IMPALAIMPALAConfig)RLModuleSpec)Policy)override)LAST_TARGET_UPDATE_TSLEARNER_STATS_KEYNUM_AGENT_STEPS_SAMPLEDNUM_ENV_STEPS_SAMPLEDNUM_TARGET_UPDATESmean_kl_losscurr_kl_coeffold_action_distc                        s$  e Zd ZdZd fdd	Zeeeeeeeeeeeeeeeeeedde	e
 de	e
 de	e d	e	e d
e	e
 de	e de	e de	e de	e de	e de	e
 de	e de	e de	e def fddZeed  fddZeedd ZeedefddZeee fddZ  ZS )!
APPOConfiga
  Defines a configuration class from which an APPO Algorithm can be built.

    .. testcode::

        from ray.rllib.algorithms.appo import APPOConfig
        config = (
            APPOConfig()
            .training(lr=0.01, grad_clip=30.0, train_batch_size_per_learner=50)
        )
        config = config.learners(num_learners=1)
        config = config.env_runners(num_env_runners=1)
        config = config.environment("CartPole-v1")

        # Build an Algorithm object from the config and run 1 training iteration.
        algo = config.build()
        algo.train()
        del algo

    .. testcode::

        from ray.rllib.algorithms.appo import APPOConfig
        from ray import tune

        config = APPOConfig()
        # Update the config object.
        config = config.training(lr=tune.grid_search([0.001,]))
        # Set the config object's env.
        config = config.environment(env="CartPole-v1")
        # Use to_dict() to get the old-style python config dict when running with tune.
        tune.Tuner(
            "APPO",
            run_config=tune.RunConfig(
                stop={"training_iteration": 1},
                verbose=0,
            ),
            param_space=config.to_dict(),

        ).fit()

    .. testoutput::
        :hide:

        ...
    Nc                    s   ddi| _ t j|ptd d| _d| _d| _d| _d| _d| _	d| _
d	| _d| _d
| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _d| _ d| _!d| _"d| _#d| _$d| _%t&| _'t&| _(dS )z"Initializes a APPOConfig instance.typeStochasticSampling)
algo_classTg      ?g?Fg{Gz?g       @             g      D@global_normadamgMb@?gGz?        g?g      ?Nr   d      i,  ))exploration_configsuper__init__APPOvtraceuse_gaelambda_
clip_paramuse_kl_losskl_coeff	kl_targettarget_worker_clippinguse_circular_buffercircular_buffer_num_batches$circular_buffer_iterations_per_batchsimple_queue_sizenum_env_runnerstarget_network_update_freqbroadcast_interval	grad_clipgrad_clip_byopt_typelrdecaymomentumepsilonvf_loss_coeffentropy_coefftaulr_scheduleentropy_coeff_schedulenum_gpusnum_multi_gpu_tower_stacksminibatch_buffer_sizereplay_proportionreplay_buffer_num_slotslearner_queue_sizelearner_queue_timeoutr   target_update_frequency
use_critic)selfr   	__class__ T/home/ubuntu/vllm_env/lib/python3.10/site-packages/ray/rllib/algorithms/appo/appo.pyr%   V   sN   

zAPPOConfig.__init__)r'   r(   r)   r*   r+   r,   r-   r4   r?   r.   r/   r0   r1   r2   rI   rJ   r'   r(   r)   r*   r+   r,   r-   r4   r?   r.   r/   r0   r1   r2   returnc                   s  |t krtdddd |t krtdddd t jd	i | |tur&|| _|tur-|| _|tur4|| _|tur;|| _|turB|| _	|turI|| _
|turP|| _|turW|| _|	tur^|	| _|
ture|
| _|turl|| _|turs|| _|turz|| _|tur|| _| S )
ud  Sets the training related configuration.

        Args:
            vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE
                advantages will be used instead.
            use_gae: If true, use the Generalized Advantage Estimator (GAE)
                with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
                Only applies if vtrace=False.
            lambda_: GAE (lambda) parameter.
            clip_param: PPO surrogate slipping parameter.
            use_kl_loss: Whether to use the KL-term in the loss function.
            kl_coeff: Coefficient for weighting the KL-loss term.
            kl_target: Target term for the KL-term to reach (via adjusting the
                `kl_coeff` automatically).
            target_network_update_freq: NOTE: This parameter is only applicable on
                the new API stack. The frequency with which to update the target
                policy network from the main trained policy network. The metric
                used is `NUM_ENV_STEPS_TRAINED_LIFETIME` and the unit is `n` (see [1]
                4.1.1), where: `n = [circular_buffer_num_batches (N)] *
                [circular_buffer_iterations_per_batch (K)] * [train batch size]`
                For example, if you set `target_network_update_freq=2`, and N=4, K=2,
                and `train_batch_size_per_learner=500`, then the target net is updated
                every 2*4*2*500=8000 trained env steps (every 16 batch updates on each
                learner).
                The authors in [1] suggests that this setting is robust to a range of
                choices (try values between 0.125 and 4).
            target_network_update_freq: The frequency to update the target policy and
                tune the kl loss coefficients that are used during training. After
                setting this parameter, the algorithm waits for at least
                `target_network_update_freq` number of environment samples to be trained
                on before updating the target networks and tune the kl loss
                coefficients. NOTE: This parameter is only applicable when using the
                Learner API (enable_rl_module_and_learner=True).
            tau: The factor by which to update the target policy network towards
                the current policy network. Can range between 0 and 1.
                e.g. updated_param = tau * current_param + (1 - tau) * target_param
            target_worker_clipping: The maximum value for the target-worker-clipping
                used for computing the IS ratio, described in [1]
                IS = min(π(i) / π(target), ρ) * (π / π(i))
            use_circular_buffer: Whether to use a circular buffer for storing
                training batches. If false, a simple Queue will be used. Defaults to
                True.
            circular_buffer_num_batches: The number of train batches that fit
                into the circular buffer. Each such train batch can be sampled for
                training max. `circular_buffer_iterations_per_batch` times.
            circular_buffer_iterations_per_batch: The number of times any train
                batch in the circular buffer can be sampled for training. A batch gets
                evicted from the buffer either if it's the oldest batch in the buffer
                and a new batch is added OR if the batch reaches this max. number of
                being sampled.
            simple_queue_size: The size of the simple queue (if `use_circular_buffer`
                is False) for storing training batches.

        Returns:
            This updated AlgorithmConfig object.
        rI   r4   T)oldnewerrorrJ   zM`use_critic` no longer supported! APPO always uses a value function (critic).)rQ   helprS   NrN   )r   r   r$   trainingr   r'   r(   r)   r*   r+   r,   r-   r4   r?   r.   r/   r0   r1   r2   )rK   r'   r(   r)   r*   r+   r,   r-   r4   r?   r.   r/   r0   r1   r2   rI   rJ   kwargsrL   rN   rO   rU      sV   OzAPPOConfig.trainingc                    sb   t    | jr-| jdks| jdkr| d | jdkr!| d | jdkr/| d d S d S d S )Nr   r    a  `minibatch_buffer_size/replay_proportion` not valid on new API stack with APPO! Use `circular_buffer_num_batches` for the number of train batches in the circular buffer. To change the maximum number of times any batch may be sampled, set `circular_buffer_iterations_per_batch`.aO  `num_multi_gpu_tower_stacks` not supported on new API stack with APPO! In order to train on multi-GPU, use `config.learners(num_learners=[number of GPUs], num_gpus_per_learner=1)`. To scale the throughput of batch-to-GPU-pre-loading on each of your `Learners`, set `num_gpu_loader_threads` to a higher number (recommended values: 1-8).r"   aE  `learner_queue_size` not supported on new API stack with APPO! In order set the size of the circular buffer (which acts as a 'learner queue'), use `config.training(circular_buffer_num_batches=..)`. To change the maximum number of times any batch may be sampled, set `config.training(circular_buffer_iterations_per_batch=..)`.)r$   validateenable_rl_module_and_learnerrD   rE   _value_errorrC   rG   rK   rL   rN   rO   rW   "  s    


	zAPPOConfig.validatec                 C   s>   | j dkrddlm} |S | j dv rtdtd| j  d)Ntorchr   )APPOTorchLearner)tf2tfzPTensorFlow is no longer supported on the new API stack! Use `framework='torch'`.The framework z+ is not supported. Use `framework='torch'`.)framework_str2ray.rllib.algorithms.appo.torch.appo_torch_learnerr\   
ValueError)rK   r\   rN   rN   rO   get_default_learner_classE  s   

z$APPOConfig.get_default_learner_classc                 C   s4   | j dkrddlm} n	td| j  dt|dS )Nr[   r   )APPOTorchRLModuler_   z/ is not supported. Use either 'torch' or 'tf2'.)module_class)r`   4ray.rllib.algorithms.appo.torch.appo_torch_rl_modulerd   rb   r   )rK   RLModulerN   rN   rO   get_default_rl_module_specX  s   

z%APPOConfig.get_default_rl_module_specc                    s   t  jddiB S )Nvf_share_layersF)r$   _model_config_auto_includesrZ   rL   rN   rO   rj   f  s   z&APPOConfig._model_config_auto_includesNrP   N)__name__
__module____qualname____doc__r%   r   r
   r   r   r   boolfloatintr   rU   rW   rc   r   rh   propertyr   rj   __classcell__rN   rN   rL   rO   r   (   sz    -K	
 "
r   c                       sr   e Zd Z fddZeed fddZeeedefddZ	eeed	e
deee  fd
dZ  ZS )r&   c                    s:   t  j|g|R i | | jjs| jdd  dS dS )zInitializes an APPO instance.c                 S      |   S rk   update_targetp_rN   rN   rO   <lambda>v      zAPPO.__init__.<locals>.<lambda>N)r$   r%   configrX   
env_runnerforeach_policy_to_train)rK   r~   argsrV   rL   rN   rO   r%   m  s   zAPPO.__init__rP   Nc                    s   | j jr	t  S t   | jt }| j| j jdkrtnt }| j j	| j j
 }|| |krT| jt  d7  < || jt< | jdd  | j jrT fdd}| j|  S )Nagent_stepsr   c                 S   rv   rk   rw   ry   rN   rN   rO   r|     r}   z$APPO.training_step.<locals>.<lambda>c                    sj   t  vsJ dt  f| v r+ | t  d}|d us$J  |f| | d S td| d S )Nz'{} should be nested under policy id keyklzNo data for {}, not updating kl)r   formatget	update_klloggerwarning)pipi_idr   train_resultsrN   rO   update  s   
z"APPO.training_step.<locals>.update)r~   rX   r$   training_step	_countersr   count_steps_byr   r   
num_epochsrD   r   r   r   r+   )rK   last_updatecur_tstarget_update_freqr   rL   r   rO   r   x  s$   



zAPPO.training_stepc                 C   s   t  S rk   )r   )clsrN   rN   rO   get_default_config  s   zAPPO.get_default_configr~   c                 C   sV   |d dkrddl m} |S |d dkr#|jrtdddlm} |S ddlm} |S )	N	frameworkr[   r   )APPOTorchPolicyr^   zWRLlib's RLModule and Learner API is not supported for tf1. Use framework='tf2' instead.)APPOTF1Policy)APPOTF2Policy)+ray.rllib.algorithms.appo.appo_torch_policyr   rX   rb   (ray.rllib.algorithms.appo.appo_tf_policyr   r   )r   r~   r   r   r   rN   rN   rO   get_default_policy_class  s   zAPPO.get_default_policy_classrl   )rm   rn   ro   r%   r   r	   r   classmethodr   r   r   r   r   r   r   ru   rN   rN   rL   rO   r&   l  s    0
r&   )$rp   loggingtypingr   r   typing_extensionsr   ray._common.deprecationr   r   %ray.rllib.algorithms.algorithm_configr   r   "ray.rllib.algorithms.impala.impalar	   r
   "ray.rllib.core.rl_module.rl_moduler   ray.rllib.policy.policyr   ray.rllib.utils.annotationsr   ray.rllib.utils.metricsr   r   r   r   r   	getLoggerrm   r   LEARNER_RESULTS_KL_KEY!LEARNER_RESULTS_CURR_KL_COEFF_KEYOLD_ACTION_DIST_KEYr   r&   rN   rN   rN   rO   <module>   s&    
  F