o
    Ti*                     @   s`  d Z ddlZddlmZ ddlZddlmZ dZdZdZ	dZ
d	Zd
Zee	e
eegZdZdZdZdZdZdZdZdZdZdZdZdZdZdZdZdZdZdZdZ dZ!dZ"d Z#d!Z$d"Z%d#Z&d$Z'd%d& Z(d'd( Z)d)d* Z*d+d, Z+d-d. Z,d/d0 Z-d1d2 Z.d3d4 Z/d5d6 Z0	 d7d8 Z1G d9d de2Z3G d:d de2Z4G d;d de2Z5G d<d	 d	e5Z6G d=d
 d
e2Z7dS )>z
Implementation of learning rate schedules.

Taken and modified from PyTorch v1.0.1 source
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
    N)	Optimizer)loggerlr_scheduleLRRangeTestOneCycleWarmupLRWarmupDecayLRWarmupCosineLRlr_range_test_min_lrlr_range_test_step_ratelr_range_test_step_sizelr_range_test_staircase
edge_value	mid_valuecycle_first_step_sizecycle_first_stair_countcycle_second_step_sizecycle_second_stair_countdecay_step_sizecycle_min_lrcycle_max_lrdecay_lr_ratecycle_min_momcycle_max_momdecay_mom_ratewarmup_min_lrwarmup_max_lrwarmup_num_stepswarmup_typeloglinearwarmup_min_ratiocos_min_ratiototal_num_stepsc                 C   s  |  dd}|jdtd dd |jdtddd |jd	td
dd |jdtddd |jdtddd |jdtddd |jdtddd |jdtddd |jdtddd |jdtddd |jdtddd |jd td!d"d |jd#td$d%d |jd&dd'd(d) |jd*td+d,d |jd-td.d/d |jd0td$d1d |jd2td3d4d |jd5tdd6d |jd7tdd8d |jd9ttd:d |jd;tdd<d |jd=tdd<d | S )>NzConvergence Tuningz!Convergence tuning configurationsz--lr_schedulezLR schedule for training.)typedefaulthelpz--lr_range_test_min_lrMbP?zStarting lr value.z--lr_range_test_step_rate      ?zscaling rate for LR range test.z--lr_range_test_step_size  ztraining steps per LR change.z--lr_range_test_staircaseFz(use staircase scaling for LR range test.z--cycle_first_step_sizez7size of first step of 1Cycle schedule (training steps).z--cycle_first_stair_countz&first stair count for 1Cycle schedule.z--cycle_second_step_sizezAsize of second step of 1Cycle schedule (default first_step_size).z--cycle_second_stair_countz'second stair count for 1Cycle schedule.z--decay_step_sizezAsize of intervals for applying post cycle decay (training steps).z--cycle_min_lrg{Gz?z1Cycle LR lower bound.z--cycle_max_lrg?z1Cycle LR upper bound.z--decay_lr_rate        zpost cycle LR decay rate.z--cycle_momentum
store_truez Enable 1Cycle momentum schedule.)r%   actionr&   z--cycle_min_mom皙?z1Cycle momentum lower bound.z--cycle_max_mom?z1Cycle momentum upper bound.z--decay_mom_ratezpost cycle momentum decay rate.z--warmup_min_lrr   z!WarmupLR minimum/initial LR valuez--warmup_max_lrzWarmupLR maximum LR value.z--warmup_num_stepsz"WarmupLR step count for LR warmup.z--warmup_typez*WarmupLR increasing function during warmupz--warmup_min_ratiozCosine LR lower bound.z--cos_min_ratio)add_argument_groupadd_argumentstrfloatintboolWARMUP_LOG_RATE)parsergroup r9   R/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/lr_schedules.pyadd_tuning_arguments<   sj   r;   c                  C   s$   t  } t| } |  \}}||fS N)argparseArgumentParserr;   parse_known_args)r7   lr_sched_argsunknown_argsr9   r9   r:   parse_arguments|   s   rB   c                 C      t | tr| jd ur| j|t< t | tr| jd ur| j|t< t | tr-| jd ur-| j|t< t | tr>| jd ur@| j|t< d S d S d S r<   )	hasattrLR_RANGE_TEST_MIN_LRr
   LR_RANGE_TEST_STEP_RATEr   LR_RANGE_TEST_STEP_SIZEr   LR_RANGE_TEST_STAIRCASEr   argsparamsr9   r9   r:   override_lr_range_test_params      


rL   c                 C   sV  t | tr| jd ur| j|t< t | tr| jd ur| j|t< t | tr-| jd ur-| j|t< t | tr<| jd ur<| j|t< t | t	rK| j
d urK| j
|t	< t | trZ| jd urZ| j|t< t | tri| jd uri| j|t< t | trx| jd urx| j|t< t | tr| jd ur| j|t< t | tr| jd ur| j|t< t | tr| jd ur| j|t< d S d S d S r<   )rD   CYCLE_FIRST_STEP_SIZEr   CYCLE_FIRST_STAIR_COUNTr   CYCLE_SECOND_STEP_SIZEr   CYCLE_SECOND_STAIR_COUNTr   DECAY_STEP_SIZEr   CYCLE_MIN_LRr   CYCLE_MAX_LRr   DECAY_LR_RATEr   CYCLE_MIN_MOMr   CYCLE_MAX_MOMr   DECAY_MOM_RATEr   rI   r9   r9   r:   override_1cycle_params   s.   









rY   c                 C   rC   r<   )	rD   WARMUP_MIN_LRr   WARMUP_MAX_LRr   WARMUP_NUM_STEPSr   WARMUP_TYPEr   rI   r9   r9   r:   override_warmupLR_params   rM   r^   c                 C   s"   t | | t| | t| | d S r<   )rL   rY   r^   rI   r9   r9   r:   override_params   s   

r_   c                 C   s   t | tr
| jd u rd dtfS | jtvrd d| jfS i }| j|d< i |d< | jtkr9t| |d  |d fS | jtkrIt| |d  |d fS t	| |d  |d fS )Nz"--{} not specified on command linez{} is not supported LR scheduler$   rK   )
rD   LR_SCHEDULEr   formatVALID_LR_SCHEDULESLR_RANGE_TESTrL   	ONE_CYCLErY   r^   )rJ   configr9   r9   r:   get_config_from_args   s   



rf   c                 C   sr   d| vrdS d| vrdS | d }| d }|t vrd d|fS |tkr)|t dfS |tkr3|t dfS |t dfS )Nr$   )Nz&LR schedule type not defined in configrK   )Nz(LR schedule params not defined in configz{} is not a valid LR schedule )rb   ra   rc   rE   rd   rT   r[   )re   r   	lr_paramsr9   r9   r:   get_lr_from_config   s   ri   c                 C   s*   t | |D ]\}}||d< qdd | D S )Nlrc                 S      g | ]}|d  qS rj   r9   .0r8   r9   r9   r:   
<listcomp>       zupdate_lr.<locals>.<listcomp>)zip)param_groupslrsparam_grouprj   r9   r9   r:   	update_lr   s   
ru   c                 C   s>   t | tr| S t| drt | jtr| jS tdt| j)N	optimizerz-{} is not a subclass of torch.optim.Optimizer)
isinstancer   rD   rv   	TypeErrorra   r$   __name__)rv   r9   r9   r:   get_torch_optimizer  s
   
rz   c                   @   s   e Zd ZdZ					d deded	ed
ededefddZdd Z	dd Z
dd Zdd Zdd Zd!ddZdd Zdd ZdS )"r   a:  Sets the learning rate of each parameter group according to
    learning rate range test (LRRT) policy. The policy increases learning
    rate starting from a base value with a constant frequency, as detailed in
    the paper `A disciplined approach to neural network hyper-parameters: Part 1 <https://arxiv.org/abs/1803.09820>`_

    LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
    configure the LR boundaries for Cyclic LR schedules.

    LRRT changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        lr_range_test_min_lr (float or list): Initial learning rate which is the
            lower boundary in the range test for each parameter group.
        lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
        lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
        lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
        last_batch_iteration (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_batch_iteration=-1, the schedule is started from the beginning.
            Default: -1

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = LRRangeTest(optimizer)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()

        _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
        https://arxiv.org/abs/1803.09820
r'     r(   Fr*   rv   r
   r   r   r   last_batch_iterationc                 C   s   t || _t|tst|tr-t|t| jjkr'tdt| jjt|t|| _	n
|gt| jj | _	|| _
|| _|| _|| _|rH| jn| j| _|dkr[t| jj| j	| _d S d S )Nz(expected {} lr_range_test_min_lr, got {}r*   )rz   rv   rw   listtuplelenrr   
ValueErrorra   min_lr	step_size	step_rater|   	staircase_staircase_interval_continuous_intervalinterval_fnru   _last_lr)selfrv   r
   r   r   r   r|   r9   r9   r:   __init__8  s    
zLRRangeTest.__init__c                 C   s   t t| jd | j S N   )mathfloorr3   r|   r   r   r9   r9   r:   r   S  s   zLRRangeTest._staircase_intervalc                 C   s   t | jd | j S r   )r3   r|   r   r   r9   r9   r:   r   V  s   z LRRangeTest._continuous_intervalc                 C   s   d| j |    S r   )r   r   r   r9   r9   r:   _get_increaseY  s   zLRRangeTest._get_increasec                    s   |     fdd| jD S )Nc                       g | ]}|  qS r9   r9   )rn   r
   lr_increaser9   r:   ro   ^  rp   z&LRRangeTest.get_lr.<locals>.<listcomp>)r   r   r   r9   r   r:   get_lr\  s   zLRRangeTest.get_lrc                 C      t | dddusJ d| jS zB Return last computed learning rate by current scheduler.
        r   Nzneed to call step() firstgetattrr   r   r9   r9   r:   get_last_lr`     zLRRangeTest.get_last_lrNc                 C   0   |d u r	| j d }|| _ t| jj|  | _d S r   r|   ru   rv   rr   r   r   )r   batch_iterationr9   r9   r:   stepf     
zLRRangeTest.stepc                 C   
   d| j iS Nr|   r|   r   r9   r9   r:   
state_dictl     
zLRRangeTest.state_dictc                 C      |d | _ d S r   r   r   sdr9   r9   r:   load_state_dicto     zLRRangeTest.load_state_dict)r'   r{   r(   Fr*   r<   )ry   
__module____qualname____doc__r   r3   r4   r5   r   r   r   r   r   r   r   r   r   r9   r9   r9   r:   r     s8    (

c                   @   s   e Zd ZdZ												d(d
dZdd Zdd Zdd Zdd Zdd Z	dd Z
dd Zdd Zdd Zdd Zd d! Zd)d"d#Zd$d% Zd&d' ZdS )*r   a  Sets the learning rate of each parameter group according to
    1Cycle learning rate policy (1CLR). 1CLR is a variation of the
    Cyclical Learning Rate (CLR) policy that involves one cycle followed by
    decay. The policy simultaneously cycles the learning rate (and momentum)
    between two boundaries with a constant frequency, as detailed in
    the paper `A disciplined approach to neural network hyper-parameters`_.

    1CLR policy changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    This implementation was adapted from the github repo: `PyTorch <https://github.com/pytorch/pytorch>`_.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        cycle_min_lr (float or list): Initial learning rate which is the
            lower boundary in the cycle for each parameter group.
        cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
            The lr at any cycle is the sum of cycle_min_lr
            and some scaling of the amplitude; therefore
            cycle_max_lr may not actually be reached depending on
            scaling function.
        decay_lr_rate(float): Decay rate for learning rate. Default: 0.
        cycle_first_step_size (int): Number of training iterations in the
            increasing half of a cycle. Default: 2000
        cycle_second_step_size (int): Number of training iterations in the
            decreasing half of a cycle. If cycle_second_step_size is None,
            it is set to cycle_first_step_size. Default: None
        cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
        lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
        cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
        lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
        decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
            to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
            Default: True
        cycle_min_mom (float or list): Initial momentum which is the
            lower boundary in the cycle for each parameter group.
            Default: 0.8
        cycle_max_mom (float or list): Upper momentum boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
            The momentum at any cycle is the difference of cycle_max_mom
            and some scaling of the amplitude; therefore
            cycle_min_mom may not actually be reached depending on
            scaling function. Default: 0.9
        decay_mom_rate (float): Decay rate for momentum. Default: 0.
        last_batch_iteration (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_batch_iteration=-1, the schedule is started from the beginning.
            Default: -1

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()


    .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
    r+   r{   Nr   Tr.   r/   r*   c                 C   sX   t || _| |||||	 | | j|||| |
| _|
r'| | j|||| || _d S r<   )rz   rv   _initialize_cycle_initialize_lrcycle_momentum_initialize_momentumr|   )r   rv   r   r   r   r   r   r   r   r   r   r   r   r   r|   r9   r9   r:   r     s   


zOneCycle.__init__c                 C   s~   t |}|d urt |n|}|| | _|| j | _|| _|d u r"|n|| _|| _t| jdr7d| _d| _	d S d| _d| _	d S )Nr   TF)
r3   
total_size
step_ratiofirst_stair_countsecond_stair_countr   r   iscloseskip_lr_decayskip_mom_decay)r   r   r   r   r   r   r9   r9   r:   r     s"   


zOneCycle._initialize_cyclec                 C   sn   |gt |j | _|dkrt| j|jD ]\}}||d< q|gt |j | _|| _t| jdr5d| _d S d S )Nr*   rj   r   T)	r   rr   min_lrsrq   max_lrsr   r   r   r   )r   rv   r   r   r   r|   rj   r8   r9   r9   r:   r     s   

zOneCycle._initialize_lrc           	      C   s   d|j vrt|j}td| d d| _d S || _|dfgt|j | _	|dfgt|j | _
|dkrEt| j	|jD ]\}}||d< q<t| jdrQd| _d S d S )	Nbetasz-cycle_momentum is disabled because optimizer z: does not support momentum, no betas attribute in defaultsFgGz?r*   r   T)defaultsr$   ry   r   warningr   r   r   rr   min_momsmax_momsrq   r   r   r   )	r   rv   r   r   r   r|   optimizer_namemomentumr8   r9   r9   r:   r     s    




zOneCycle._initialize_momentumc                 C   s^   | j d }td|| j  }d|| j  | }|| jkr$|| j }|S |d | jd  }|S Nr   r(   )r|   r   r   r   r   )r   r   cyclexscale_factorr9   r9   r:   _get_scale_factor  s   


zOneCycle._get_scale_factorc           	      C   s^   |   }g }t| j| jD ]\}}|d }|d }|| | }|| }|||d f q|S )Nr   r   )r   rq   r   r   append)	r   r   	momentums
base_betas	max_betasr   r   base_heightr   r9   r9   r:   _get_cycle_mom  s   zOneCycle._get_cycle_momc                 C   sF   |   }g }t| j| jD ]\}}|| | }|| }|| q|S r<   )r   rq   r   r   r   )r   r   rs   r   r   r   rj   r9   r9   r:   _get_cycle_lr&  s   zOneCycle._get_cycle_lrc                    <   | j r| jS || j }d| j|    fdd| jD }|S )Nr   c                    s   g | ]
\}}|  |fqS r9   r9   )rn   beta0beta1mom_decay_factorr9   r:   ro   6      z+OneCycle._get_decay_mom.<locals>.<listcomp>)r   r   r   r   )r   decay_batch_iterationdecay_intervalr   r9   r   r:   _get_decay_mom0  s   
zOneCycle._get_decay_momc                    r   )zCalculates the learning rate at batch index. This function is used
        after the cycle completes and post cycle decaying of lr/mom is enabled.
        This function treats `self.last_batch_iteration` as the last batch index.
        r   c                    s   g | ]}|  qS r9   r9   )rn   r   lr_decay_factorr9   r:   ro   D  rp   z*OneCycle._get_decay_lr.<locals>.<listcomp>)r   r   r   r   )r   r   r   rs   r9   r   r:   _get_decay_lr:  s   
zOneCycle._get_decay_lrc                 C   s*   | j | jk r
|  S | | j | j d S )zCalculates the learning rate at batch index. This function treats
        `self.last_batch_iteration` as the last batch index.
        r   )r|   r   r   r   r   r9   r9   r:   r   H  s   zOneCycle.get_lrc                 C   s4   | j sdS | j| jk r|  S | | j| j d S )zCalculates the momentum at batch index. This function treats
        `self.last_batch_iteration` as the last batch index.
        Nr   )r   r|   r   r   r   r   r9   r9   r:   get_momP  s
   zOneCycle.get_momc                 C   r   r   r   r   r9   r9   r:   r   [  r   zOneCycle.get_last_lrc                 C   sb   |du r	| j d }|| _ t| jj|  | _| jr-|  }t| jj|D ]
\}}||d< q$dS dS )z Updates the optimizer with the learning rate for the last batch index.
        `self.last_batch_iteration` is treated as the last batch index.

        If self.cycle_momentum is true, also updates optimizer momentum.
        Nr   r   )	r|   ru   rv   rr   r   r   r   r   rq   )r   r   r   rt   r   r9   r9   r:   r   a  s   

zOneCycle.stepc                 C   r   r   r   r   r9   r9   r:   r   r  r   zOneCycle.state_dictc                 C   r   r   r   r   r9   r9   r:   r   u  r   zOneCycle.load_state_dict)r+   r{   Nr   Nr   Tr.   r/   r+   r*   r<   )ry   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r9   r9   r9   r:   r   s  s8    H
#


c                   @   sx   e Zd ZdZdddedfdededed	ed
edefddZ	dd Z
dd ZdddZdd Zdd Zdd Zdd ZdS )r   u  Increase the learning rate of each parameter group from min lr to max lr
        over warmup_num_steps steps, and then fix at max lr.

        Args:
            optimizer (Optimizer): Wrapped optimizer.
            warmup_min_lr (float or list): minimum learning rate. Default: 0
            warmup_max_lr (float or list): maximum learning rate. Default: 0.001
            warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
            warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
            last_batch_iteration (int): The index of the last batch. Default: -1.
        Example:
            >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
            >>> scheduler = WarmupLR(optimizer)
            >>> data_loader = torch.utils.data.DataLoader(...)
            >>> for epoch in range(10):
            >>>     for batch in data_loader:
            >>>         train_batch(...)
            >>>         scheduler.step()

    r+   r'   r)   r*   rv   r   r   r   r   r|   c                 C   s   t || _| | j|d| _| | j|d| _dd t| j| jD | _td|| _|t	t
hvr;td| d t	}|| _dt| j | _|| _|d	krZt| jj|  | _d S d S )
Nr   max_lrc                 S   s   g | ]\}}|| qS r9   r9   )rn   bigsmallr9   r9   r:   ro     s    z%WarmupLR.__init__.<locals>.<listcomp>   zUsing unknown warmup_type: z1. The increasing function is set to default (log)r(   r*   )rz   rv   _format_paramr   r   rq   	delta_lrsmaxr   r6   WARMUP_LINEAR_RATEr   r   r   r   r   inverse_log_warm_upr|   ru   rr   r   r   )r   rv   r   r   r   r   r|   r9   r9   r:   r     s   
zWarmupLR.__init__c                    s>   | j dk rtd | jS |    fddt| j| jD S )Nr   DAttempting to get learning rate from scheduler before it has startedc                    s   g | ]
\}}||   qS r9   r9   )rn   r   delta_lrgammar9   r:   ro     r   z#WarmupLR.get_lr.<locals>.<listcomp>)r|   r   r   r   
_get_gammarq   r   r   r9   r   r:   r     s
   

zWarmupLR.get_lrc                 C   r   r   r   r   r9   r9   r:   r     r   zWarmupLR.get_last_lrNc                 C   r   r   r   r   r|   r9   r9   r:   r     r   zWarmupLR.stepc                 C   r   r   r   r   r9   r9   r:   r     r   zWarmupLR.state_dictc                 C   r   r   r   r   r9   r9   r:   r     r   zWarmupLR.load_state_dictc                 C   sF   | j | jk r!| jtkr| jt| j d  S | jtkr!| j | j S dS r   )r|   r   r   r6   r   r   r   r   r   r9   r9   r:   r     s   

zWarmupLR._get_gammac                 C   Z   t |ts
t |tr%t|t|jkr!tdt|j|t|t|S |gt|j S Nz expected {} value for {}, got {}rw   r}   r~   r   rr   r   ra   FileNotFoundErrorr   rv   param_value
param_namer9   r9   r:   r        zWarmupLR._format_paramr<   )ry   r   r   r   r6   r   r3   r4   r2   r   r   r   r   r   r   r   r   r9   r9   r9   r:   r   y  s6    

c                       sR   e Zd ZdZdddedfdededed	ed
ededef fddZ	dd Z
  ZS )r   u  Increase the learning rate of each parameter group from min lr to max lr
        over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.

        Args:
            optimizer (Optimizer): Wrapped optimizer.
            total_num_steps (int): total number of training steps
            warmup_min_lr (float or list): minimum learning rate. Default: 0
            warmup_max_lr (float or list): maximum learning rate. Default: 0.001
            warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
            warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
            last_batch_iteration (int): The index of the last batch. Default: -1.
        Example:
            >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
            >>> scheduler = WarmupDecayLR(optimizer, 1000000)
            >>> data_loader = torch.utils.data.DataLoader(...)
            >>> for epoch in range(10):
            >>>     for batch in data_loader:
            >>>         train_batch(...)
            >>>         scheduler.step()

    r+   r'   r)   r*   rv   r#   r   r   r   r   r|   c                    sF   || _ tt| |||||| | j | jk r!td|| d S d S )N3total_num_steps {} is less than warmup_num_steps {})r#   superr   r   r   r   r   ra   )r   rv   r#   r   r   r   r   r|   	__class__r9   r:   r     s   	
zWarmupDecayLR.__init__c                 C   sn   | j | jk r!| jtkr| jt| j d  S | jtkr!| j | j S tdt	| j
| j  t	td| j
| j  S )Nr   r+   r(   )r|   r   r   r6   r   r   r   r   r   r3   r#   r   r9   r9   r:   r     s   

zWarmupDecayLR._get_gamma)ry   r   r   r   r6   r   r4   r3   r2   r   r   __classcell__r9   r9   r   r:   r     s.    c                   @   s|   e Zd ZdZdddedfdededed	ed
ededefddZ	dd Z
dddZdd Zdd Zdd Zdd Zdd ZdS )r	   u  Increase the learning rate of each parameter group from min lr ratio to max lr ratio
        over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio.

        Args:
            optimizer (Optimizer): Wrapped optimizer.
            total_num_steps (int): total number of training steps
            warmup_min_ratio (float or list): warmup start learning rate ratio. Default: 0
            warmup_num_steps (int): number of steps to warm up from warmup_min_ratio to 1.0. Default: 1000
            warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
            cos_min_ratio (float): cosine end learning rate ratio. Default: 0.0001
            last_batch_iteration (int): The index of the last batch. Default: -1.
        Example:
            >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
            >>> scheduler = WarmupCosineLR(optimizer, 1000000)
            >>> data_loader = torch.utils.data.DataLoader(...)
            >>> for epoch in range(10):
            >>>     for batch in data_loader:
            >>>         train_batch(...)
            >>>         scheduler.step()

    r+   r)   g-C6?r*   rv   r#   r!   r   r"   r   r|   c                 C   s   t || _|| _|| _|| _|| _|| _td|| _dt	
| j | _| j| jk r2td|| dd | jjD | _|dkrLt| jj|  | _d S d S )Nr   r(   r   c                 S   rk   rl   r9   rm   r9   r9   r:   ro   4  rp   z+WarmupCosineLR.__init__.<locals>.<listcomp>r*   )rz   rv   r#   r|   r"   r   r!   r   r   r   r   r   r   r   ra   rr   org_lrsru   r   r   )r   rv   r#   r!   r   r"   r   r|   r9   r9   r:   r     s    
	zWarmupCosineLR.__init__c                 C   s   | j dk rtd dgS | j | jk r=| jtkr$| jt| j d  }n| jt	kr/| j | j }d| j
 }| j
||  }|S | j | j d }| j| j }d| j }dttj| |  d }td| j||  }|S )Nr   r   r+   r   r(   r   )r|   r   r   r   r   r6   r   r   r   r   r!   r#   r"   cospir   )r   ratioratio_deltareal_last_stepreal_total_stepsr9   r9   r:   get_lr_ratio:  s"   





zWarmupCosineLR.get_lr_ratioNc                 C   r   r   r   r   r9   r9   r:   r   O  r   zWarmupCosineLR.stepc                    s6   | j dk rtd dgS |    fdd| jD S )Nr   r   r+   c                    r   r9   r9   )rn   org_lrlr_ratior9   r:   ro   Z  rp   z)WarmupCosineLR.get_lr.<locals>.<listcomp>)r|   r   r   r   r   r   r9   r   r:   r   U  s
   

zWarmupCosineLR.get_lrc                 C   r   r   r   r   r9   r9   r:   r   \  r   zWarmupCosineLR.get_last_lrc                 C   r   r   r   r   r9   r9   r:   r   b  r   zWarmupCosineLR.state_dictc                 C   r   r   r   r   r9   r9   r:   r   e  r   zWarmupCosineLR.load_state_dictc                 C   r   r   r   r   r9   r9   r:   r   h  r   zWarmupCosineLR._format_paramr<   )ry   r   r   r   r6   r   r4   r3   r2   r   r   r   r   r   r   r   r   r9   r9   r9   r:   r	     s:    

)8r   r=   torch.optimr   r   deepspeed.utilsr   r`   rc   rd   	WARMUP_LRWARMUP_DECAY_LRWARMUP_COSINE_LRrb   rE   rF   rG   rH   
EDGE_VALUE	MID_VALUErN   rO   rP   rQ   rR   rS   rT   rU   rV   rW   rX   rZ   r[   r\   r]   r6   r   WARMUP_MIN_RATIOCOS_MIN_RATIOTOTAL_NUM_STEPSr;   rB   rL   rY   r^   r_   rf   ri   ru   rz   objectr   r   r   r   r	   r9   r9   r9   r:   <module>   sp   @%
b  Z3