o
    %ݫi                     @   s  d Z ddlZddlZddlmZ ddlmZ ddlmZ eeZ	d(ddZ
ejG dd	 d	ZejG d
d dZG dd dZejG dd dZG dd dZejG dd dZejG dd dZejG dd dZejG dd dZejG dd dZejG dd dZejG dd dZejG d d! d!ZejG d"d# d#ZG d$d% d%ejZejG d&d' d'ZdS ))z
Schedulers for updating hyperparameters (such as learning rate).

Authors
 * Mirco Ravanelli 2020
 * Peter Plantinga 2020
 * Loren Lugosch 2020
 * Shucong Zhang 2023
    N)nn)checkpoints)
get_loggerc                 C   sn   |du rt t| j}n|}|D ]$}| j| d }||kr4|| j| d< || j| d< td||f  qdS )a  Change the learning rate value within an optimizer.

    Arguments
    ---------
    optimizer : torch.optim object
        Updates the learning rate for this optimizer.
    new_lr : float
        The new value to use for the learning rate.
    param_group : list of int
        The param group indices to update. If not provided, all groups updated.

    Example
    -------
    >>> from torch.optim import SGD
    >>> from speechbrain.nnet.linear import Linear
    >>> model = Linear(n_neurons=10, input_size=10)
    >>> optimizer = SGD(model.parameters(), lr=0.1)
    >>> update_learning_rate(optimizer, 0.2)
    >>> optimizer.param_groups[0]["lr"]
    0.2
    Nlrprev_lrzChanging lr from %.2g to %.2g)rangelenparam_groupsloggerinfo)	optimizernew_lrparam_groupgroupsiold_lr r   O/home/ubuntu/.local/lib/python3.10/site-packages/speechbrain/nnet/schedulers.pyupdate_learning_rate   s   r   c                       sH   e Zd ZdZd fdd	Zdd Zejdd Zej	dddZ
  ZS )WarmAndExpDecayLRScheduleaU  Warms up linearly, and then decay exponentially to ('lr' / 'decay_factor') in 'total_steps' steps.


    Arguments
    ---------
    lr : float
        The max learning rate to reach after warmup.
    n_warmup_steps : int
        Number of warmup steps (following a linear increase).
    total_steps : int
        Total number of steps (used to decay).
    decay_factor : float
        Decay factor applied every decay_every steps. (default: 0.01)

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler = WarmAndExpDecayLRSchedule(lr=1, n_warmup_steps=2, decay_factor=0.01, total_steps=6)
    >>> scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.0
    >>> scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.5
    >>> scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    1
    >>> scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.31622776601683794
    皙?c                    s<   t t|   || _d| _|| _|| _|| j | _d| _d S Nr   )	superr   __init__base_lr
current_lrn_warmup_stepsdecay_factordecay_stepscurrent_step)selfr   r   total_stepsr   	__class__r   r   r   b   s   
z"WarmAndExpDecayLRSchedule.__init__c                 C   sv   | j | jk r| j| j  | j }n| j| j| j | j | j   }t| j|}|jD ]}||d< q(|| _|  j d7  _ d S )Nr      )r   r   r   r   r   minr	   r   )r    optr   
decayed_lrr   r   r   r   __call__k   s   

z"WarmAndExpDecayLRSchedule.__call__c                 C   *   | j | j| j| j| jd}t|| dS )0Saves the current metrics on the specified path.)r   r   r   r   r   N)r   r   r   r   r   torchsaver    pathdatar   r   r   r,   {      zWarmAndExpDecayLRSchedule.saveFNc                 C   sD   ~~t |}|d | _|d | _|d | _|d | _|d | _dS )Loads the needed information.r   r   r   r   r   N)r+   loadr   r   r   r   r   r    r.   end_of_epochdevicer/   r   r   r   r2      s   




zWarmAndExpDecayLRSchedule.load)r   FN__name__
__module____qualname____doc__r   r(   r   mark_as_saverr,   mark_as_loaderr2   __classcell__r   r   r"   r   r   <   s    $	
r   c                   @   sF   e Zd ZdZ			dddZdd Zejd	d
 Zej	dddZ
dS )NewBobSchedulera  Scheduler with new-bob technique, used for LR annealing.

    The learning rate is annealed based on the validation performance.
    In particular: if (past_loss-current_loss)/past_loss< impr_threshold:
    lr=lr * annealing_factor.

    Arguments
    ---------
    initial_value : float
        The initial hyperparameter value.
    annealing_factor : float
        It is annealing factor used in new_bob strategy.
    improvement_threshold : float
        It is the improvement rate between losses used to perform learning
        annealing in new_bob strategy.
    patient : int
        When the annealing condition is violated patient times,
        the learning rate is finally reduced.

    Example
    -------
    >>> scheduler = NewBobScheduler(initial_value=1.0)
    >>> scheduler(metric_value=10.0)
    (1.0, 1.0)
    >>> scheduler(metric_value=2.0)
    (1.0, 1.0)
    >>> scheduler(metric_value=2.5)
    (1.0, 0.5)
          ?{Gzd?r   c                 C   s*   || _ || _|| _|| _g | _| j| _d S N)hyperparam_valueannealing_factorimprovement_thresholdpatientmetric_valuescurrent_patient)r    initial_valuerD   rE   rF   r   r   r   r      s   zNewBobScheduler.__init__c                 C   s   | j  }}t| jdkr9| jd }|dkrd}n|| | }|| jk r9| jdkr2|| j9 }| j| _n|  jd8  _| j| || _ ||fS )a  Returns the current and new value for the hyperparameter.

        Arguments
        ---------
        metric_value : int
            A number for determining whether to change the hyperparameter value.
        Returns
        -------
        Current and new hyperparam value.
        r   r$   )rC   r   rG   rE   rH   rD   rF   append)r    metric_value	old_value	new_valueprev_metricimprovementr   r   r   r(      s   





zNewBobScheduler.__call__c                 C   "   | j | j| jd}t|| dS )r*   )rC   rG   rH   N)rC   rG   rH   r+   r,   r-   r   r   r   r,      
   zNewBobScheduler.saveFc                 C   .   ~t |}|d | _|d | _|d | _dS )r1   rC   rG   rH   N)r+   r2   rC   rG   rH   r    r.   r4   r/   r   r   r   r2      
   


zNewBobScheduler.loadN)r@   rA   r   Fr8   r9   r:   r;   r   r(   r   r<   r,   r=   r2   r   r   r   r   r?      s    !
 
	r?   c                   @   s    e Zd ZdZdd Zdd ZdS )LinearSchedulera  Scheduler with linear annealing technique.

    The learning rate linearly decays over the specified number of epochs.

    Arguments
    ---------
    initial_value : float
        The value upon initialization.
    final_value : float
        The value used when the epoch count reaches ``epoch_count - 1``.
    epoch_count : int
        Number of epochs.

    Example
    -------
    >>> scheduler = LinearScheduler(1.0, 0.0, 4)
    >>> scheduler(current_epoch=1)
    (1.0, 0.666...)
    >>> scheduler(current_epoch=2)
    (0.666..., 0.333...)
    >>> scheduler(current_epoch=3)
    (0.333..., 0.0)
    >>> scheduler(current_epoch=4)
    (0.0, 0.0)
    c                 C   s   t j|||d | _d S )Nsteps)r+   linspacetolistvalue_at_epoch)r    rI   final_valueepoch_countr   r   r   r     s
   zLinearScheduler.__init__c                 C   s6   t d|d }t|t| jd }| j| | j| fS )a	  Returns the current and new value for the hyperparameter.

        Arguments
        ---------
        current_epoch : int
            Number of times the dataset has been iterated.

        Returns
        -------
        Current and new hyperparam value.
        r   r$   )maxr%   r   r]   )r    current_epoch	old_indexindexr   r   r   r(     s   zLinearScheduler.__call__N)r8   r9   r:   r;   r   r(   r   r   r   r   rX      s    rX   c                   @   sF   e Zd ZdZdd Zdd Zdd Zejdd	 Z	ej
dddZdS )LinearWarmupScheduleraL  Create a schedule with a learning rate that decreases linearly
    from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly
    from 0 to the initial lr set in the optimizer.
    * Ge Li 2022

    Arguments
    ---------
    initial_value : float
        The value upon initialization (lr0).
    num_warmup_steps : int
        Number of warmup steps. The learning rate reaches lr0 at
        ``num_warmup_steps + 1`` step.
    num_training_steps : int
        The total number of training steps.

    Example
    -------
    >>> scheduler = LinearWarmupScheduler(1.0, 2, 4)
    >>> scheduler.get_next_value()
    0.0
    >>> scheduler.get_next_value()
    0.5
    >>> scheduler.get_next_value()
    1.0
    >>> scheduler.get_next_value()
    0.5
    >>> scheduler.get_next_value()
    0.0
    c                 C   s   || _ || _|| _d| _d S r   )lr0num_warmup_stepsnum_training_stepsr   )r    rI   rf   rg   r   r   r   r   H  s   
zLinearWarmupScheduler.__init__c              	   C   sX   || j k rt|ttd| j  | j S | jtdt| j| ttd| j| j    S )a  Returns the current and new value for the hyperparameter.

        Arguments
        ---------
        current_step : int
            Number of steps the model has been updated.

        Returns
        -------
        Current and new hyperparam value.
        r$           )rf   floatr`   re   rg   )r    r   r   r   r   calculate_lrN  s   
z"LinearWarmupScheduler.calculate_lrc                 C   s   |  | j}|  jd7  _|S )z<Returns the next learning rate value for the hyperparameter.r$   )rj   r   )r    rN   r   r   r   get_next_valuef  s   z$LinearWarmupScheduler.get_next_valuec                 C   s&   | j | j| j| jd}t|| dS )r*   )rI   rf   rg   r   N)re   rf   rg   r   r+   r,   r-   r   r   r   r,   l  s   zLinearWarmupScheduler.saveFc                 C   s8   ~t |}|d | _|d | _|d | _|d | _dS )r1   rI   rf   rg   r   N)r+   r2   re   rf   rg   r   rT   r   r   r   r2   w  s   



zLinearWarmupScheduler.loadNrV   )r8   r9   r:   r;   r   rj   rk   r   r<   r,   r=   r2   r   r   r   r   rd   '  s    

rd   c                   @   s<   e Zd ZdZdZdZ	dddZdd Zd	d
 Zdd Z	dS )StepSchedulera  Learning rate scheduler with step annealing technique.

    The hyperparameter's value decays over the epochs with the
    selected ``epoch_decay`` factor.

    ``value = init_value * decay_factor ^ floor((1 + epoch) / decay_drop)``

    Arguments
    ---------
    initial_value : float
        Initial value for the hyperparameter being updated.
    decay_factor : float
        Factor multiplied with the initial_value
    decay_drop : float
        Annealing factor (the decay of the hyperparameter value is faster
        with higher ``decay_drop`` values).
    half_life : int
        A convenience parameter to set decay_factor such that the parameter
        will drop to half its value at the specified epoch. May not
        be used together with decay_factor or decay_drop

    Example
    -------
    >>> scheduler = StepScheduler(initial_value=1.0)
    >>> scheduler(current_epoch=1)
    (1.0, 0.5)
    >>> scheduler(current_epoch=2)
    (0.5, 0.5)
    >>> scheduler(current_epoch=3)
    (0.5, 0.25)
    r@      Nc                 C   sL   || _ |r|s	|rtd| || _d| _d S |p| j| _|p"| j| _d S )NzBhalf_life cannot be used together with decay_factor and decay_drop      ?)rI   
ValueError_compute_half_life_decay_factorr   
decay_dropDEFAULT_DECAY_FACTORDEFAULT_DECAY_DROP)r    rI   r   rq   	half_lifer   r   r   r     s   
zStepScheduler.__init__c                 C   s   t t d | S )Nrm   )mathexplog)r    rt   r   r   r   rp     s   z-StepScheduler._compute_half_life_decay_factorc                 C   s    |  |d }|  |}||fS )zReturns current and new hyperparameter value.

        Arguments
        ---------
        current_epoch : int
            Number of times the dataset has been iterated.

        Returns
        -------
        Current and new hyperparam value.
        r$   )_compute_value)r    ra   current_value
next_valuer   r   r   r(     s   
zStepScheduler.__call__c                 C   s$   | j t| jtd| | j  S Nr$   )rI   ru   powr   floorrq   )r    ra   r   r   r   rx     s   zStepScheduler._compute_valueNNN)
r8   r9   r:   r;   rr   rs   r   rp   r(   rx   r   r   r   r   rl     s     
rl   c                   @   sH   e Zd ZdZdddZdd Zdd Zejd	d
 Z	ej
dddZdS )NoamScheduleraC  The is an implementation of the transformer's learning rate scheduler with warmup.
    Reference: https://arxiv.org/abs/1706.03762

    Note: this scheduler anneals the lr at each update of the model's weight,
    and n_steps must be saved for restarting.

    Arguments
    ---------
    lr_initial : float
        Initial learning rate (i.e. the lr used at epoch 0).
    n_warmup_steps : int
        number of warm-up steps
    model_size : int
        size of transformer embed_dim. It is used to scale the maximum learning rate value reached
        by the scheduler. It is divided by model_size ** (0.5).
        If not specified the maximum learning rate value is instead multiplied by warmup_steps ** (0.5).

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler =NoamScheduler(optim.param_groups[0]["lr"], 3)
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.3333333333333333
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.6666666666666666
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.9999999999999999
    Nc                 C   sB   || _ || _|| _g | _d| _|d | _|d ur|d | _d S d S Nr   r@         )
lr_initialr   r   lossesn_steps	normalize)r    r   r   
model_sizer   r   r   r     s   
zNoamScheduler.__init__c                 C   L   |  j d7  _ |jd d }| j|   }|jD ]}||d< q|| _||fS a/  
        Arguments
        ---------
        opt : optimizer
            The optimizer to update using this scheduler.

        Returns
        -------
        current_lr : float
            The learning rate before the update.
        lr : float
            The learning rate after the update.
        r$   r   r   r   r	   r   _get_lr_scaler   r    r&   r   r   r   r   r   r   r(         

zNoamScheduler.__call__c                 C   s*   | j | j}}| jt|d ||d   S Nr         )r   r   r   r%   r    r   r   r   r   r   r     s   zNoamScheduler._get_lr_scalec                 C      | j | jd}t|| dS r*   )r   r   Nr   r   r+   r,   r-   r   r   r   r,   !     zNoamScheduler.saveFc                 C   $   ~t |}|d | _|d | _dS r1   r   r   Nr+   r2   r   r   rT   r   r   r   r2   '     

zNoamScheduler.loadrB   rV   r8   r9   r:   r;   r   r(   r   r   r<   r,   r=   r2   r   r   r   r   r     s    
$

r   c                   @   sJ   e Zd ZdZ	dddZdd Zdd Zejd	d
 Z	ej
dddZdS )NoamIntervalSchedulera  A combination of Noam Scheduler and Interval Scheduler.
    The scheduler behaves as a Noam Scheduler, and anneals the learning rate
    at designed steps with designed decays.

    Note: this scheduler anneals the lr at each update of the model's weight,
    and n_steps must be saved for restarting.

    Arguments
    ---------
    lr_initial : float
        Initial learning rate (i.e. the lr used at epoch 0).
    n_warmup_steps : int
        number of warm-up steps.
    anneal_steps: list
        Pre-designed steps where the learning rate is to be annealed.
    anneal_rates: list
        Pre-designed decay rate for each anneal step.
    model_size : int
        size of transformer embed_dim. It is used to scale the maximum learning rate value reached
        by the scheduler. It is divided by model_size ** (0.5).
        If not specified the maximum learning rate value is instead multiplied by warmup_steps ** (0.5).

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler = NoamIntervalScheduler(
    ...    lr_initial=optim.param_groups[0]["lr"],
    ...    n_warmup_steps=3,
    ...    anneal_steps=[6, 9],
    ...    anneal_rates=[0.5, 0.1],
    ... )
    >>> for _ in range(10):
    ...     curr_lr,next_lr=scheduler(optim)
    ...     print(optim.param_groups[0]["lr"])
    0.3333333333333333
    0.6666666666666666
    0.9999999999999999
    0.8660254037844386
    0.7745966692414833
    0.7071067811865475
    0.3273268353539886
    0.3061862178478973
    0.28867513459481287
    0.027386127875258306
    Nc                 C   sN   || _ || _|| _g | _d| _|d | _|| _|| _|d ur%|d | _d S d S r   )r   r   r   r   r   r   anneal_stepsanneal_rates)r    r   r   r   r   r   r   r   r   r   d  s   
zNoamIntervalScheduler.__init__c                 C   r   r   r   r   r   r   r   r(   w  r   zNoamIntervalScheduler.__call__c                 C   s`   | j | j}}| jt|d ||d   }tt| jD ]}| j | j| kr-|| j|  }q|S r   )r   r   r   r%   r   r   r   r   )r    r   r   lr_scaler   r   r   r   r     s   z#NoamIntervalScheduler._get_lr_scalec                 C   r   r   r   r-   r   r   r   r,     r   zNoamIntervalScheduler.saveFc                 C   &   ~~t |}|d | _|d | _dS r   r   r3   r   r   r   r2     
   

zNoamIntervalScheduler.loadrB   r6   r   r   r   r   r   r   0  s    8


r   c                   @   sF   e Zd ZdZdd Zdd Zdd Zejdd	 Z	ej
dddZdS )LinearNoamSchedulera!  The is an implementation of the extended Noam scheduler in the Squeezeformer paper.
    Reference: https://arxiv.org/pdf/2206.00888.pdf

    Note: this scheduler anneals the lr at each update of the model's weight,
    and n_steps must be saved for restarting.

    Arguments
    ---------
    lr_initial : float
        Initial learning rate (i.e. the lr used at epoch 0).
    n_warmup_steps : int
        number of warm-up steps.
    n_keep_steps : int
        after warmp-up steps, number of steps that the lr is kept unchanged.

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler =LinearNoamScheduler(optim.param_groups[0]["lr"], 2, 2)
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.5
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    1.0
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    1.0
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    1.0
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.6666666666666666
    c                 C   s(   || _ || _|| _|| _g | _d| _d S r   )r   r   n_keep_stepsr   r   r   )r    r   r   r   r   r   r   r     s   
zLinearNoamScheduler.__init__c                 C   r   r   r   r   r   r   r   r(     r   zLinearNoamScheduler.__call__c                 C   sB   | j | j}}||k r|d | S || j| k rdS ||| j  S )Nrh   rn   )r   r   r   r   r   r   r   r     s   z!LinearNoamScheduler._get_lr_scalec                 C   r   r   r   r-   r   r   r   r,     r   zLinearNoamScheduler.saveFNc                 C   r   r   r   r3   r   r   r   r2     r   zLinearNoamScheduler.loadr6   r   r   r   r   r   r     s    (	
r   c                   @   sH   e Zd ZdZdddZdd Zdd	 Zejd
d Z	ej
dddZdS )CyclicCosineSchedulerac  The is an implementation of the Cyclic-Cosine learning rate scheduler with warmup.

    Reference:  https://openreview.net/pdf?id=BJYwwY9ll

    Note: this scheduler anneals the lr at each update of the model's weight,
    and n_steps must be saved for restarting.

    Arguments
    ---------
    n_warmup_steps : int
        Number of warm up steps.
    lr_initial : float
        Initial learning rate (i.e. the lr used at epoch 0).
    total_steps : int
        Total number of updating steps.

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler =CyclicCosineScheduler(3, optim.param_groups[0]["lr"])
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.9999999990130395
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    0.9999999997532598
    >>> curr_lr,next_lr=scheduler(optim)
    >>> optim.param_groups[0]["lr"]
    1.0
    N順 c                 C   s:   || _ g | _|| _|| _|| _d| _d||d   | _d S )Nr   r$   r   )r   r   
initial_lrr   totalr   r   )r    r   r   r!   r   r   r   r   7  s   zCyclicCosineScheduler.__init__c                 C   s\   |  j d7  _ | jdu r|jd d }n| j}||   }|jD ]}||d< q || _||fS )a9  
        Arguments
        ---------
        opt : list of optimizers
            The optimizers to update using this scheduler.

        Returns
        -------
        current_lr : float
            The learning rate before the update.
        lr : float
            The learning rate after the update.
        r$   Nr   r   )r   r   r	   r   r   r   r   r   r   r(   A  s   


zCyclicCosineScheduler.__call__c                 C   s0   | j | j}}dttj||  | j d  S )Nr@   r$   )r   r   ru   cospir   r   r   r   r   r   _  s   z#CyclicCosineScheduler._get_lr_scalec                 C   r   r   r   r-   r   r   r   r,   e  r   zCyclicCosineScheduler.saveFc                 C   r   r   r   rT   r   r   r   r2   k  r   zCyclicCosineScheduler.load)Nr   rV   r   r   r   r   r   r     s    
#

r   c                   @   sB   e Zd ZdZ	dddZdd	 Zejd
d Zej	dddZ
dS )ReduceLROnPlateaual  Learning rate scheduler which decreases the learning rate if the loss
    function of interest gets stuck on a plateau, or starts to increase.
    The difference from NewBobLRScheduler is that, this one keeps a memory of
    the last step where do not observe improvement, and compares against that
    particular loss value as opposed to the most recent loss.

    Arguments
    ---------
    lr_min : float
        The minimum allowable learning rate.
    factor : float
        Factor with which to reduce the learning rate.
    patience : int
        How many epochs to wait before reducing the learning rate.
    dont_halve_until_epoch : int
        Number of epochs to wait until halving.

    Example
    -------
    >>> from torch.optim import Adam
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(n_neurons=10, input_size=3)
    >>> optim = Adam(lr=1.0, params=model.parameters())
    >>> output = model(inp_tensor)
    >>> scheduler = ReduceLROnPlateau(0.25, 0.5, 2, 1)
    >>> curr_lr,next_lr=scheduler([optim],current_epoch=1, current_loss=10.0)
    >>> curr_lr,next_lr=scheduler([optim],current_epoch=2, current_loss=11.0)
    >>> curr_lr,next_lr=scheduler([optim],current_epoch=3, current_loss=13.0)
    >>> curr_lr,next_lr=scheduler([optim],current_epoch=4, current_loss=14.0)
    >>> next_lr
    0.5
    :0yE>r@   rm   A   c                 C   s.   || _ || _|| _d| _g | _|| _d| _d S )Nr   i )lr_minfactorpatiencepatience_counterr   dont_halve_until_epochanchor)r    r   r   r   r   r   r   r   r     s   
zReduceLROnPlateau.__init__c                 C   s   |D ]D}|j d d }|| jkr|}|| _n*|| jkr$d| _|}|| _n|| jkr8| j| jk r8| jd | _|}n|| j }d| _t|| j}q| j	| ||fS )a  
        Arguments
        ---------
        optim_list : list of optimizers
            The optimizers to update using this scheduler.
        current_epoch : int
            Number of times the dataset has been iterated.
        current_loss : int
            A number for determining whether to change the learning rate.

        Returns
        -------
        current_lr : float
            The learning rate before the update.
        next_lr : float
            The learning rate after the update.
        r   r   r$   )
r	   r   r   r   r   r   r`   r   r   rK   )r    
optim_listra   current_lossr&   r   next_lrr   r   r   r(     s$   



zReduceLROnPlateau.__call__c                 C   rQ   )r*   )r   r   r   N)r   r   r   r+   r,   r-   r   r   r   r,     rR   zReduceLROnPlateau.saveFc                 C   rS   )r1   r   r   r   N)r+   r2   r   r   r   rT   r   r   r   r2     rU   zReduceLROnPlateau.loadN)r   r@   rm   r   rV   rW   r   r   r   r   r   t  s    #
/
	r   c                       sp   e Zd ZdZ							d fd	d
	ZdddZdd Zdd Zdd Ze	j
dd Ze	jdddZ  ZS )CyclicLRSchedulera`
  This implements a cyclical learning rate policy (CLR).
    The method cycles the learning rate between two boundaries with
    some constant frequency, as detailed in this paper (https://arxiv.org/abs/1506.01186).
    The amplitude of the cycle can be scaled on a per-iteration or
    per-cycle basis.

    This class has three built-in policies, as put forth in the paper.
    "triangular":
        A basic triangular cycle w/ no amplitude scaling.
    "triangular2":
        A basic triangular cycle that scales initial amplitude by half each cycle.
    "exp_range":
        A cycle that scales initial amplitude by gamma**(cycle iterations) at each
        cycle iteration.
    For more detail, please see the reference paper.

    Arguments
    ---------
    base_lr : float
        initial learning rate which is the
        lower boundary in the cycle.
    max_lr : float
        upper boundary in the cycle. Functionally,
        it defines the cycle amplitude (max_lr - base_lr).
        The lr at any cycle is the sum of base_lr
        and some scaling of the amplitude; therefore
        max_lr may not actually be reached depending on
        scaling function.
    step_size : int
        number of training iterations per
        half cycle. The authors suggest setting step_size
        2-8 x training iterations in epoch.
    mode : str
        one of {triangular, triangular2, exp_range}.
        Default 'triangular'.
        Values correspond to policies detailed above.
        If scale_fn is not None, this argument is ignored.
    gamma : float
        constant in 'exp_range' scaling function:
        gamma**(cycle iterations)
    scale_fn : lambda function
        Custom scaling policy defined by a single
        argument lambda function, where
        0 <= scale_fn(x) <= 1 for all x >= 0.
        mode parameter is ignored
    scale_mode : str
        {'cycle', 'iterations'}.
        Defines whether scale_fn is evaluated on
        cycle number or cycle iterations (training
        iterations since start of cycle). Default is 'cycle'.

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler = CyclicLRScheduler(base_lr=0.1, max_lr=0.3, step_size=2)
    >>> scheduler.on_batch_end(optim)
    >>> optim.param_groups[0]["lr"]
    0.2
    >>> scheduler.on_batch_end(optim)
    >>> optim.param_groups[0]["lr"]
    0.3
    >>> scheduler.on_batch_end(optim)
    >>> optim.param_groups[0]["lr"]
    0.2
    MbP?~jtx?     @@
triangularrn   Ncyclec                    s   t    g | _|| _|| _|| _|| _ | _|d u rG| jdkr)dd | _d| _	n$| jdkr7dd | _d| _	n| jdkrF fdd| _d	| _	n|| _|| _	d
| _
|   d S )Nr   c                 S   s   dS )Nrn   r   xr   r   r   <lambda>B  s    z,CyclicLRScheduler.__init__.<locals>.<lambda>r   triangular2c                 S   s   dd| d   S )Nr$   g       @r   r   r   r   r   r   E  s    	exp_rangec                    s    |  S rB   r   r   gammar   r   r   H  s    
iterationsrh   )r   r   r   r   max_lr	step_sizemoder   scale_fn
scale_modeclr_iterations_reset)r    r   r   r   r   r   r   r   r"   r   r   r   .  s,   






zCyclicLRScheduler.__init__c                 C   s4   |dur|| _ |dur|| _|dur|| _d| _dS )zQResets cycle iterations.
        Optional boundary/step size adjustment.
        Nrh   )r   r   r   r   )r    new_base_lr
new_max_lrnew_step_sizer   r   r   r   Q  s   
zCyclicLRScheduler._resetc                 C   s   | j }| | jd }||fS r{   )r   clrr   )r    epochr   r   r   r   r   r(   ]  s   zCyclicLRScheduler.__call__c                 C   s   t d|d| j   }t|| j d|  d }| jdkr3| j| j| j tdd|  | |  S | j| j| j tdd|  | |  S )zClears iterations.r$   rm   r   r   )	ru   r}   r   absr   r   r   r`   r   )r    r   r   r   r   r   r   r   c  s   
zCyclicLRScheduler.clrc                 C   sF   |  j d7  _ | | j }|jd d }|jD ]}||d< q|| _dS )z
        Arguments
        ---------
        opt : optimizers
            The optimizers to update using this scheduler.
        r$   r   r   N)r   r   r	   r   )r    r&   r   r   r   r   r   r   on_batch_endp  s   


zCyclicLRScheduler.on_batch_endc                 C   r   )r*   )r   r   N)r   r   r+   r,   r-   r   r   r   r,     r   zCyclicLRScheduler.saveFc                 C   r   )r1   r   r   N)r+   r2   r   r   rT   r   r   r   r2     r   zCyclicLRScheduler.load)r   r   r   r   rn   Nr   r~   rV   )r8   r9   r:   r;   r   r   r(   r   r   r   r<   r,   r=   r2   r>   r   r   r"   r   r     s$    H
#
r   c                   @   sN   e Zd ZdZdd Zdd Zdd Zdd	 Zej	d
d Z
ejdddZdS )IntervalSchedulera  A simple scheduler implementation that sets the learning rate to
    specific values after a specific number of steps has been reached.

    Arguments
    ---------
    intervals : list
        a list of dictionaries: {"steps": <number of steps>, "lr": the learning rate}
        'steps' indicates the global step count at which a given
        rate will apply

    Example
    -------
    >>> import torch
    >>> from speechbrain.nnet.schedulers import IntervalScheduler
    >>> from speechbrain.nnet.linear import Linear
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> scheduler = IntervalScheduler(
    ...    intervals=[
    ...        {"steps": 2, "lr": 0.01},
    ...        {"steps": 5, "lr": 0.005},
    ...        {"steps": 9, "lr": 0.001}
    ...    ]
    ... )
    >>> optim.param_groups[0]["lr"]
    1
    >>> for _ in range(10):
    ...     pre, post = scheduler(optim)
    ...     print(f"{pre} -> {post}")
    1 -> 1
    1 -> 0.01
    0.01 -> 0.01
    0.01 -> 0.01
    0.01 -> 0.005
    0.005 -> 0.005
    0.005 -> 0.005
    0.005 -> 0.005
    0.005 -> 0.001
    0.001 -> 0.001
    c                 C   s   || _ d| _g | _|   d S r   )	intervalsr   r   _compute_next)r    r   r   r   r   r     s   zIntervalScheduler.__init__c                 C   sH   |  j d7  _ |jd d }| |}|jD ]}||d< q|| _||fS r   )r   r	   _get_lrr   r   r   r   r   r(     s   


zIntervalScheduler.__call__c                    s    fdd j D  _d S )Nc                    s   g | ]}|d   j kr|qS rY   )r   ).0intervalr    r   r   
<listcomp>  s
    z3IntervalScheduler._compute_next.<locals>.<listcomp>)r   _next_intervalsr   r   r   r   r     s   
zIntervalScheduler._compute_nextc                 C   s6   |}| j r| j d }| j|d kr|d }| j d= |S )Nr   rZ   r   )r   r   )r    r   r   next_intervalr   r   r   r     s   
zIntervalScheduler._get_lrc                 C   r   r   r   r-   r   r   r   r,     r   zIntervalScheduler.saveFc                 C   s,   ~t |}|d | _|d | _|   dS r   )r+   r2   r   r   r   rT   r   r   r   r2     s
   


zIntervalScheduler.loadNrV   )r8   r9   r:   r;   r   r(   r   r   r   r<   r,   r=   r2   r   r   r   r   r     s    )	
r   c                   @   s6   e Zd ZdZdd Zdd Zdd Zejdd	 Z	d
S )InverseSquareRootSchedulerzThe Inverse Square Root Scheduler, as defined in the T5 paper
    https://arxiv.org/pdf/1910.10683.pdf

    Arguments
    ---------
    warmup_steps : int
        The number of steps over which the learning rate will be constant
    c                 C   s   || _ d| _d S r   )warmup_stepsr   )r    r   r   r   r   r     s   
z#InverseSquareRootScheduler.__init__c                 C   sF   |  j d7  _ |jd d }|  }|jD ]}||d< q|| _||fS )zReturns current and new hyperparameter value.

        Arguments
        ---------
        opt : optimizer
            The optimizer to update using this scheduler.

        Returns
        -------
        current and new hyperparam value
        r$   r   r   )r   r	   rx   r   r   r   r   r   r(     s   

z#InverseSquareRootScheduler.__call__c                 C   s   dt t| j| j S r{   )ru   sqrtr`   r   r   r   r   r   r   rx   %  s   z)InverseSquareRootScheduler._compute_valuec                 C      d| j i}t|| dS )r*   r   Nr   r+   r,   r-   r   r   r   r,   (     
zInverseSquareRootScheduler.saveN)
r8   r9   r:   r;   r   r(   rx   r   r<   r,   r   r   r   r   r     s    	r   c                       sL   e Zd ZdZ		d fdd	Zdd Zejdd	 Zej	dddZ
  ZS )WarmCoolDecayLRScheduleaC  Warms up linearly, very slowly decays and cools down linearly again
    at the end of training. This is a three steps scheduler.

    Reference
    ---------
    Scaling Vision Transformers
    arxiv.org/abs/2106.04560

    Arguments
    ---------
    lr : float
        The max learning rate to reach after warmup.
    warmup : int
        Number of warmup steps (following a linear increase).
    cooldown : int
        Number of cooldown steps (following a linear decrease).
    total_steps : int
        Total number of steps (used to decay).
    decay_factor : float
        Decay factor applied every decay_every steps.
    decay_every : int
        Apply the decay factor to the learning rate every decay_every steps.

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler = WarmCoolDecayLRSchedule(lr=1, warmup=2, total_steps=6, decay_factor=0.5, decay_every=1, cooldown=1)
    >>> optim.param_groups[0]["lr"]
    1
    >>> scheduler(optim, 1)
    >>> optim.param_groups[0]["lr"]
    0.5
    >>> scheduler(optim, 2)
    >>> optim.param_groups[0]["lr"]
    1.0
    >>> scheduler(optim, 3)
    >>> optim.param_groups[0]["lr"]
    0.5
    >>> scheduler(optim, 4)
    >>> optim.param_groups[0]["lr"]
    0.25
    >>> scheduler(optim, 5)
    >>> optim.param_groups[0]["lr"]
    0.12500000000000003
    >>> scheduler(optim, 6)
    >>> optim.param_groups[0]["lr"]
    0.0
          ?r   c                    s6   t    || _|| _|| _|| _t|| | _d S rB   )	r   r   r   warmupcooldownr!   ru   rw   power)r    r   r   r   r!   r   decay_everyr"   r   r   r   f  s   
	z WarmCoolDecayLRSchedule.__init__c                 C   s   || j k r| j| | j  }n9|| j| j kr9| jt| j| j| j   }|| j }|| j| j  }|||  }n| jt| j|| j    }|jD ]}||d< qJd S )Nr   )r   r   r!   r   ru   rv   r   r	   )r    r&   num_updatesr   r   decreasenr   r   r   r   r(   v  s   



z WarmCoolDecayLRSchedule.__call__c                 C   r)   )r*   )r   r   r   r   r!   N)r   r   r   r   r!   r+   r,   r-   r   r   r   r,     r0   zWarmCoolDecayLRSchedule.saveFc                 C   sB   ~t |}|d | _|d | _|d | _|d | _|d | _dS )r1   r   r   r   r   r!   N)r+   r2   r   r   r   r   r!   rT   r   r   r   r2     s   




zWarmCoolDecayLRSchedule.load)r   r   rV   r7   r   r   r"   r   r   /  s    ;
r   c                       sN   e Zd ZdZ fddZdd Zejdd Zej	dd
dZ
dd Z  ZS )ScheduledLossaT  A convenience class for switching to a different loss function on a
    schedule

    Arguments
    ---------
    schedule : list
        a list of dictionaries with the following keys
            loss_fn: the loss function to use
            steps: the number of steps to apply before switching
                to the next one

    Example
    -------
    >>> loss_fn = ScheduledLoss(
    ...     schedule=[
    ...         {"steps": 3, "loss_fn": nn.MSELoss()},
    ...         {"steps": 2, "loss_fn": nn.L1Loss()},
    ...         {"loss_fn": nn.SmoothL1Loss()}
    ...     ]
    ... )
    >>> x = torch.tensor([1., 2.])
    >>> y = torch.tensor([1.5, 2.5])
    >>> for idx in range(10):
    ...     loss = loss_fn(x, y)
    ...     print(loss.item())
    0.25
    0.25
    0.25
    0.5
    0.5
    0.125
    0.125
    0.125
    0.125
    0.125
    c                    sL   t    t|stdtdd |D rtd|| _d| _|   d S )Nz&At least one schedule item is requiredc                 s   s"    | ]}t |d s|V  qdS )loss_fnN)callableget)r   itemr   r   r   	<genexpr>  s     z)ScheduledLoss.__init__.<locals>.<genexpr>z*Each schedule item needs to have at least r   )r   r   anyro   scheduler   find_next_switch)r    r   r"   r   r   r     s   
zScheduledLoss.__init__c                 O   s2   | j | jkr
|   |  j d7  _ | j|i |S )aE  Computes the loss at the specified step number.

        Arguments
        ---------
        *args : tuple
        **kwargs : dict
            Any arguments passed to this will be passed on to the specified
            loss_fn

        Returns
        -------
        result : torch.Tensor
            the loss value
        r$   )r   next_switchr   current_loss_fn)r    argskwargsr   r   r   forward  s   zScheduledLoss.forwardc                 C   r   )z.Saves the current state on the specified path.r   Nr   r-   r   r   r   r,     r   zScheduledLoss.saveFNc                 C   s    t |}|d | _|   dS )r1   r   N)r+   r2   r   r   r3   r   r   r   r2     s   

zScheduledLoss.loadc                 C   sJ   d}| j D ]}|dtj}||7 }|| jkr"|d | _|| _ dS qdS )zUFinds the threshold at which the next switch will occur
        based on the scheduler   rZ   r   N)r   r   r+   infr   r   r   )r    cumulative_stepsr   
item_stepsr   r   r   r     s   


zScheduledLoss.find_next_switchr6   )r8   r9   r:   r;   r   r   r   r<   r,   r=   r2   r   r>   r   r   r"   r   r     s    %

r   c                       sL   e Zd ZdZ		d fdd	Zdd Zejdd	 Zej	dddZ
  ZS )TriStageLRScheduleaa  Warms up linearly, very slowly decays and cools down linearly again
    at the end of training. This is a three steps scheduler.
    Reference
    https://arxiv.org/pdf/1904.08779.pdf

    Arguments
    ---------
    lr : float
        The max learning rate to reach after warmup.
    warmup_steps : int
        Number of warmup steps (following a linear increase).
    hold_steps : int
        Number of holding steps (lr remains unchanged).
    decay_steps : int
        Number of decay steps.
    total_steps : int
        Total number of steps (used to decay).
    init_lr_scale : float
        The initial learning rate scale during warmup phase.
    final_lr_scale : float
        The final learning rate scale.

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> inp_tensor = torch.rand([1,660,3])
    >>> model = Linear(input_size=3, n_neurons=4)
    >>> optim = torch.optim.Adam(model.parameters(), lr=1)
    >>> output = model(inp_tensor)
    >>> scheduler = TriStageLRSchedule(lr=1, warmup_steps=2, hold_steps=2, decay_steps=2, total_steps=6, init_lr_scale=0.01, final_lr_scale=0.05)
    >>> optim.param_groups[0]["lr"]
    1
    >>> scheduler(optim, 1)
    >>> optim.param_groups[0]["lr"]
    0.505
    >>> scheduler(optim, 2)
    >>> optim.param_groups[0]["lr"]
    1
    >>> scheduler(optim, 3)
    >>> optim.param_groups[0]["lr"]
    1
    >>> scheduler(optim, 4)
    >>> optim.param_groups[0]["lr"]
    1.0
    >>> scheduler(optim, 5)
    >>> optim.param_groups[0]["lr"]
    0.223606797749979
    >>> scheduler(optim, 6)
    >>> optim.param_groups[0]["lr"]
    0.05000000000000001
    {Gz?皙?c                    st   t t|   || _|| _|| _|| _|| _|| _|| _	| j| j | _
| j| j
 | j | _t| j	 | j | _d S rB   )r   r   r   peak_lrr   
hold_stepsr   r!   init_lr_scalefinal_lr_scaleinit_lrwarmup_rateru   rw   r   )r    r   r   r   r   r!   r   r   r"   r   r   r   6  s   
zTriStageLRSchedule.__init__c                 C   sp   || j k r| j| j|  }n|| j | j k r| j}n| jt| j || j | j    }|jD ]}||d< q/dS )zLCalculate the learning rate corresponding to the current step (num_updates).r   N)	r   r   r  r   r   ru   rv   r   r	   )r    r&   r   r   r   r   r   r   r(   M  s   


zTriStageLRSchedule.__call__c                 C   s>   | j | j| j| j| j| j| j| j| j| j	d
}t
|| dS )r*   )
r   r   r   r   r!   r   r   r   r  r   N)r   r   r   r   r!   r   r   r   r  r   r+   r,   r-   r   r   r   r,   _  s   zTriStageLRSchedule.saveFNc                 C   sv   ~~t |}|d | _|d | _|d | _|d | _|d | _|d | _|d | _|d | _	|d	 | _
|d
 | _dS )r1   r   r   r   r   r!   r   r   r   r  r   N)r+   r2   r   r   r   r   r!   r   r   r   r  r   r3   r   r   r   r2   p  s   









zTriStageLRSchedule.load)r   r   r6   r7   r   r   r"   r   r      s    ;
r   rB   )r;   ru   r+   r   speechbrain.utilsr   speechbrain.utils.loggerr   r8   r
   r   register_checkpoint_hooksr   r?   rX   rd   rl   r   r   r   r   r   r   r   r   r   Moduler   r   r   r   r   r   <module>   sL    

&Wa1ZN_{eaq +k1r^