o
    oiJ                     @   s  d dl Z d dlmZmZmZmZmZmZ d dlZd dl	Z	d dl	m
Z
mZmZmZ d dlm  mZ d dlZd dlmZmZ d dlmZ d dlmZ d dlmZmZ edd	d
ZG dd de	jjZG dd de	jjZG dd de	jj Z!ded fddZ"G dd dej#Z$G dd de$Z%G dd de$Z&G dd de	jj Z'dd Z(G dd  d ejZ)G d!d" d"ejZ*G d#d$ d$e*Z+G d%d& d&e*Z,G d'd( d(ej#Z-G d)d* d*ej#Z.G d+d, d,ej#Z/dS )-    N)AnyDictOptionalTypeVarUnionoverload)Tensordevicedtypenn)get_tile_indsundo_layout)
QuantState)GlobalOptimManager)*INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPINGOutlierTracerTztorch.nn.Module)boundc                       s   e Zd ZdZ								ddededee dee d	ed
ededee ddf fddZ	dddZ
	 dddZdedefddZ  ZS )StableEmbeddinga  
    Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization.

    Example:

    ```
    # Initialize StableEmbedding layer with vocabulary size 1000, embedding dimension 300
    embedding_layer = StableEmbedding(num_embeddings=1000, embedding_dim=300)

    # Reset embedding parameters
    embedding_layer.reset_parameters()

    # Perform a forward pass with input tensor
    input_tensor = torch.tensor([1, 2, 3])
    output_embedding = embedding_layer(input_tensor)
    ```

    Attributes:
        norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding.

    Methods:
        reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
        forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
    N       @Fnum_embeddingsembedding_dimpadding_idxmax_norm	norm_typescale_grad_by_freqsparse_weightreturnc                    sJ   t  |||||||||	|

 tjj||	d| _t | dddi dS a  
        Args:
            num_embeddings (`int`):
                The number of unique embeddings (vocabulary size).
            embedding_dim (`int`):
                The dimensionality of the embedding.
            padding_idx (`Optional[int]`):
                Pads the output with zeros at the given index.
            max_norm (`Optional[float]`):
                Renormalizes embeddings to have a maximum L2 norm.
            norm_type (`float`, defaults to `2.0`):
                The p-norm to compute for the `max_norm` option.
            scale_grad_by_freq (`bool`, defaults to `False`):
                Scale gradient by frequency during backpropagation.
            sparse (`bool`, defaults to `False`):
                Computes dense gradients. Set to `True` to compute sparse gradients instead.
            _weight (`Optional[Tensor]`):
                Pretrained embeddings.
        )r	   weight
optim_bits    N)	super__init__torchr   	LayerNormnormr   get_instanceregister_module_override)selfr   r   r   r   r   r   r   r   r	   r
   	__class__ K/home/ubuntu/.local/lib/python3.10/site-packages/bitsandbytes/nn/modules.pyr$   3   s    zStableEmbedding.__init__c                 C      t jj| j |   d S Nr%   r   initxavier_uniform_r    _fill_padding_idx_with_zeror*   r-   r-   r.   reset_parametersb      z StableEmbedding.reset_parametersc                 C   N   | j d ur%t  | j| j  d W d    d S 1 sw   Y  d S d S Nr   r   r%   no_gradr    fill_r5   r-   r-   r.   r4   m   
   

"z+StableEmbedding._fill_padding_idx_with_zeroinputc              	   C   sD   t || j| j| j| j| j| j}|t	
 }| || jjS r0   )F	embeddingr    r   r   r   r   r   tor%   get_default_dtyper'   r
   r*   r>   embr-   r-   r.   forwardr   s   zStableEmbedding.forward)NNr   FFNNNr   N)__name__
__module____qualname____doc__intr   floatboolr   r$   r6   r4   rE   __classcell__r-   r-   r+   r.   r      sB    	
/
r   c                       s   e Zd ZdZ							ddededee dee d	ed
ededee dee	 ddf fddZ
dddZ	 dddZdedefddZ  ZS )	EmbeddingzS
    Embedding class to store and retrieve word embeddings from their indices.
    Nr   Fr   r   r   r   r   r   r   r   r	   r   c
           
         s8   t  j|||||||||	d	 t | dddi dS r   )r#   r$   r   r(   r)   )
r*   r   r   r   r   r   r   r   r   r	   r+   r-   r.   r$      s   zEmbedding.__init__c                 C   r/   r0   r1   r5   r-   r-   r.   r6      r7   zEmbedding.reset_parametersc                 C   r8   r9   r:   r5   r-   r-   r.   r4      r=   z%Embedding._fill_padding_idx_with_zeror>   c              	   C   s&   t || j| j| j| j| j| j}|S r0   )r?   r@   r    r   r   r   r   r   rC   r-   r-   r.   rE      s   
zEmbedding.forward)NNr   FFNNrF   )rG   rH   rI   rJ   rK   r   rL   rM   r   r	   r$   r6   r4   rE   rN   r-   r-   r+   r.   rO      sD    	

,
rO   c                       s|  e Zd Zddddddejddf	deej dee ded	e	d
e
dejded de	dd fddZdd Zdd Zdd Zdd Ze			d.dejdee
ef de	ded dd f
ddZdd  Zd/d!eeeee
f  d"e	fd#d$Ze	%	%	%d0d&ed!eeeef  d'eeee
f  d"e	def
d(d)Zed1d&ed'eee
f d"e	defd*d)Zed1d&ed+ed"e	defd,d)Z fd-d)Z  ZS )2
Params4bitNF@   Tfp4dataquant_state	blocksizecompress_statistics
quant_typequant_storagemodule
Linear4bitbnb_quantizedr   c
                 C   sV   |d u r	t d}t j| ||}
||
_||
_||
_||
_||
_|	|
_	||
_
||
_|
S r9   )r%   emptyr   _make_subclassrU   rV   rW   rT   rX   r[   rS   rY   )clsrS   requires_gradrT   rU   rV   rW   rX   rY   r[   r*   r-   r-   r.   __new__   s   
zParams4bit.__new__c                 C   s"   | j  }| j|d< | j|d< |S )NrS   r_   )__dict__copyrS   r_   r*   stater-   r-   r.   __getstate__   s   


zParams4bit.__getstate__c                 C   s^   |d | _ |d | _|d | _|d | _|d | _|d | _|d | _|d | _|d	 | _d S )
Nr_   rU   rV   rW   rT   rS   rX   r[   rY   )	r_   rU   rV   rW   rT   rS   rX   r[   rY   rc   r-   r-   r.   __setstate__   s   







zParams4bit.__setstate__c                 C   sH   t | t | }|  }|| t|d |_t|d |_|S )NrT   rS   )typer`   re   rf   rb   deepcopyrT   rS   )r*   memonew_instancerd   r-   r-   r.   __deepcopy__   s   
zParams4bit.__deepcopy__c                 C   s(   t | t | }|  }|| |S r0   )rg   r`   re   rf   )r*   rj   rd   r-   r-   r.   __copy__  s   
zParams4bit.__copy__cudaquantized_statsr_   c                 K   st   t j| ||}||_tj||d|_|jj|_|jj	|_
|jj|_d|_|j|_||_|jd ur8|j|j_|S )N)qs_dictr	   T)r%   r   r]   rA   r_   r   	from_dictrT   rU   nestedrV   rW   r[   r
   rX   rY   )r^   rS   rn   r_   r	   rY   kwargsr*   r-   r-   r.   from_prequantized  s   





zParams4bit.from_prequantizedc                 C   sZ   | j  |}tjj|| j| j| j| j	d\}}|| _ || _
| jd ur(|| j_
d| _| S )N)rU   rV   rW   rX   T)rS   
contiguousrA   bnb
functionalquantize_4bitrU   rV   rW   rX   rT   rY   r[   )r*   r	   ww_4bitrT   r-   r-   r.   	_quantize&  s   

zParams4bit._quantizer	   non_blockingc                 C   s    | j |d u rd|dS ||dS )Nrm   )r	   r{   )rA   )r*   r	   r{   r-   r-   r.   rm   6  s    zParams4bit.cuda.r*   r
   c                 C      d S r0   r-   r*   r	   r
   r{   r-   r-   r.   rA   9     zParams4bit.toc                 C   r|   r0   r-   r*   r
   r{   r-   r-   r.   rA   A     tensorc                 C   r|   r0   r-   r*   r   r{   r-   r-   r.   rA   D  r   c              	      s   t jjj|i |\}}}}|d ur|jdkr| js| |S | jd ur*| j| t	t
 j|||d| j| j| j| j| j| jd}|S )Nrm   r	   r
   r{   )r_   rT   rU   rV   rW   rX   )r%   _C_nn	_parse_torg   r[   rz   rT   rA   rP   r#   r_   rU   rV   rW   rX   r*   argsrr   r	   r
   r{   convert_to_format	new_paramr+   r-   r.   rA   G  s   


)Frm   NNF....)rG   rH   rI   r%   uint8r   r   r   rK   rM   strr
   r`   re   rf   rk   rl   classmethodr   r   rs   rz   r   r	   rm   r   r   rA   rN   r-   r-   r+   r.   rP      s    	


"&rP   rY   )Embedding4bitrZ   c                 C   sr   t | jdd d urd S t | dd d u rtd | jjd dks"J t| jts2t| j| jdd| _| j| j_d S )NrT   zhFP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.   T)rX   r[   )	getattrr    warningswarnshape
isinstancerP   rX   rT   )rY   r-   r-   r.   'fix_4bit_weight_quant_state_from_module]  s   r   c                       sT   e Zd ZdZddddejdf fdd	Zdd Z fd	d
Zdej	fddZ
  ZS )rZ   a  
    This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
    QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various
    compute datatypes such as FP4 and NF4.

    In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
    the Linear4bit module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights.

    Example:

    ```python
    import torch
    import torch.nn as nn

    import bitsandbytes as bnb
    from bnb.nn import Linear4bit

    fp16_model = nn.Sequential(
        nn.Linear(64, 64),
        nn.Linear(64, 64)
    )

    quantized_model = nn.Sequential(
        Linear4bit(64, 64),
        Linear4bit(64, 64)
    )

    quantized_model.load_state_dict(fp16_model.state_dict())
    quantized_model = quantized_model.to(0) # Quantization happens here
    ```
    TNrR   c	           	         sH   t  |||| t| jjd|||| d| _|| _d| _d| _|| _dS )aw  
        Initialize Linear4bit class.

        Args:
            input_features (`str`):
                Number of input features of the linear layer.
            output_features (`str`):
                Number of output features of the linear layer.
            bias (`bool`, defaults to `True`):
                Whether the linear class uses the bias term as well.
        Fr_   rV   rW   rX   rY   N)	r#   r$   rP   r    rS   compute_dtypecompute_type_is_setrT   rX   )	r*   input_featuresoutput_featuresbiasr   rV   rW   rX   r	   r+   r-   r.   r$     s   	
zLinear4bit.__init__c                 C   s   |j tjtjfv r|j | _d S |j tjkrM| jtjkr0| |jd kr0t	d tj
ddd | jtjkrO| |jd krQt	d tj
ddd d S d S d S d S )NzInput type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.ignorez.*inference.)messagezInput type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.z.*inference or training)r
   r%   float32bfloat16r   float16numelr   r   r   filterwarnings)r*   xr-   r-   r.   set_compute_type  s   	zLinear4bit.set_compute_typec                    sd   t  ||| t| jdddur.| jjjdd D ]\}}|r#|n| ||d | < qdS dS )zc
        save weight and bias,
        then fill state_dict with components of quant_state
        rT   NT)packedzweight.)r#   _save_to_state_dictr   r    rT   as_dictitemsdetach)r*   destinationprefix	keep_varskvr+   r-   r.   r     s   zLinear4bit._save_to_state_dictr   c                 C   s   t |  | jd ur| jj|jkr| jj|j| j_| js%| | d| _|j}| jd ur3|| j}| jd u r:d n| j| j}tj	|| j
 || j
jd|S )NT)r   rT   )r   r   r
   rS   rA   r   r   r   ru   matmul_4bitr    trT   )r*   r   	inp_dtyper   r-   r-   r.   rE     s   

"zLinear4bit.forward)rG   rH   rI   rJ   r%   r   r$   r   r   r   rE   rN   r-   r-   r+   r.   rZ   n  s    $%rZ   c                       .   e Zd ZdZdddejdf fdd	Z  ZS )	LinearFP4z'
    Implements the FP4 data type.
    TNc              
         t  |||||d|| dS )Q  
        Args:
            input_features (`str`):
                Number of input features of the linear layer.
            output_features (`str`):
                Number of output features of the linear layer.
            bias (`bool`, defaults to `True`):
                Whether the linear class uses the bias term as well.
        rR   Nr#   r$   r*   r   r   r   r   rV   rX   r	   r+   r-   r.   r$        zLinearFP4.__init__rG   rH   rI   rJ   r%   r   r$   rN   r-   r-   r+   r.   r     s    r   c                       r   )	LinearNF4a"  Implements the NF4 data type.

    Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
    is normalized into the range [-1, 1].

    For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)

    Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
    the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
    TNc              
      r   )r   nf4Nr   r   r+   r-   r.   r$     r   zLinearNF4.__init__r   r-   r-   r+   r.   r     s    r   c                       s   e Zd Z					ddeej deej deej fddZ fd	d
Zdd Ze				dde
deeeef  deeeef  dede
f
ddZe	dde
deeef dede
fddZe	dde
dedede
fddZ fddZ  ZS )
Int8ParamsNTFrS   CBSCBc                 C   s8   |d u r	t d}t j| ||}||_||_||_|S r9   )r%   r\   r   r]   r   r   has_fp16_weights)r^   rS   r_   r   r   r   objr-   r-   r.   r`   7  s   
zInt8Params.__new__c                    sN   | j r	t |S | j  |}tj|\}}}|| _|| _	|| _
| S r0   )r   r#   rm   rS   rt   halfru   rv   int8_vectorwise_quantr   r   )r*   r	   Br   r   _r+   r-   r.   rm   G  s   zInt8Params.cudac              
   C   sD   t | jt | t| j|| j| jt| j|t| j|d}|S )N)rS   r_   r   r   r   )	rg   r`   rb   rh   rS   r_   r   r   r   )r*   ri   rj   r-   r-   r.   rk   T  s   zInt8Params.__deepcopy__.r*   r	   r
   r{   r   c                 C   r|   r0   r-   r}   r-   r-   r.   rA   `  r~   zInt8Params.toc                 C   r|   r0   r-   r   r-   r-   r.   rA   h  r   r   c                 C   r|   r0   r-   r   r-   r-   r.   rA   k  r   c                    sz   t jjj|i |\}}}}|d ur#|jdkr#| jjjdkr#| |S tt	 j
|||d| j| jd}| j|_| j|_|S )Nrm   cpur   )r_   r   )r%   r   r   r   rg   rS   r	   rm   r   r#   rA   r_   r   r   r   r   r+   r-   r.   rA   n  s    
)NTFNNr   r   )rG   rH   rI   r   r%   r   r`   rm   rk   r   r   r   rK   r	   r
   r   rM   rA   rN   r-   r-   r+   r.   r   6  sF    
&r   c           
      C   s   |  | d}|d u rd S | | dd}t|tjr!| }t|tr1|tvr1td| t|tr>|tv r>t| }|dkrTt	||j
}	t||	| | d< d S d S )Nr    weight_formatrowz'Expected supported weight format - got )getpopr   r%   r   itemrK   r   
ValueErrorr   r	   r   )

state_dictr   local_metadatastrictmissing_keysunexpected_keys
error_msgsr    r   tile_indicesr-   r-   r.   maybe_rearrange_weight  s   r   c                       s<   e Zd ZdZd fdd	Zdd Zdedefd	d
Z  ZS )Embedding8bita  
    This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer

    Quantization API is similar to Linear8bitLt:
    ```python
    import torch
    import torch.nn as nn

    from bitsandbytes.nn import Embedding8bit

    fp16_module = nn.Embedding(128, 64)
    int8_module = Embedding8bit(128, 64)

    int8_module.load_state_dict(fp16_module.state_dict())

    int8_module = int8_module.to(0) # Quantization happens here
    ```
    Nc                    s8   t  j||||d | jjj| _t| jjddd| _d S )Nr	   r
   Fr   r_   )r#   r$   r    rS   r
   r   )r*   r   r   r	   r
   r+   r-   r.   r$     s   zEmbedding8bit.__init__c                 C      t d)Nz.Saving Embedding8bit module is not implementedNotImplementedErrorr*   r   r   r   r-   r-   r.   r        z!Embedding8bit._save_to_state_dictr>   r   c                 C   s   t | jds
td| jj}| jj}|j| j| jfksJ |j| jfks&J t	||}t	||
| jd}||d  }|| jS )Nr   zKEmbedding layer is not quantized. Please call .cuda() or .to(device) first.r   g     _@)hasattrr    RuntimeErrorrS   r   r   r   r   r?   r@   viewrA   r
   )r*   r>   rows	row_statscompressed_outputcompressed_output_statsoutputr-   r-   r.   rE     s   zEmbedding8bit.forward)NN)	rG   rH   rI   rJ   r$   r   r   rE   rN   r-   r-   r+   r.   r     s
    r   c                       sT   e Zd ZdZddejdf fdd	ZdefddZd	d
 Z	dedefddZ
  ZS )r   a3  
    This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in
    [QLoRA](https://arxiv.org/abs/2305.14314) for embeddings.

    Quantization API is similar to Linear4bit:
    ```python
    import torch
    import torch.nn as nn

    from bitsandbytes.nn import Embedding4bit

    fp16_module = nn.Embedding(128, 64)
    quantized_module = Embedding4bit(128, 64)

    quantized_module.load_state_dict(fp16_module.state_dict())

    quantized_module = quantized_module.to(0) # Quantization happens here
    ```
    NrR   c                    sn   t  j||||d | jjj| _t| jjdd ||| d| _| jj}|| dkr5td| d| d d S d S )Nr   Fr   r   zEmbedding size z  is not divisible by block size z#. This will lead to slow inference.)	r#   r$   r    rS   r
   rP   rU   r   r   )r*   r   r   r
   rW   rX   r	   rU   r+   r-   r.   r$     s    		zEmbedding4bit.__init__r>   c           	      C   sR  | j | jjj dksJ | jjtj| j| j  d d}tj	j
j|| j| j d |ddd}|j| | j  d dfksCJ | j | jj }| jjj}|j| j| fksZJ tj	j
j|| j||dd}|j| | fkswJ t| jj}||_tg |j| j R |_tj
||}|jg |j| j R ksJ || jS )Nr      r   r    r>   r   )r   r    rT   rU   rS   r   r%   r   r   r   rv   r@   r   r   absmaxrb   rh   Sizeru   dequantize_4bitrA   r
   )	r*   r>   w_4bit_uint8output_4bitblocks_per_embr   output_absmaxoutput_quant_stater   r-   r-   r.    _forward_with_partial_dequantize  s6   $ 
z.Embedding4bit._forward_with_partial_dequantizec                 C   r   )Nz.Saving Embedding4bit module is not implementedr   r   r-   r-   r.   r     r   z!Embedding4bit._save_to_state_dictr   c                 C   sV   t |  | j| jjj dkr| |S tj| jj	| jj}t
jjj||d| jS )Nr   r   )r   r   r    rT   rU   r   ru   rv   r   rS   r%   r   r@   rA   r
   )r*   r>   dequantized_weightr-   r-   r.   rE     s   
zEmbedding4bit.forward)rG   rH   rI   rJ   r%   r   r$   r   r   r   rE   rN   r-   r-   r+   r.   r     s    !r   c                       &   e Zd Zdejdf fdd	Z  ZS )EmbeddingFP4Nc                       t  j|||d||d d S )NrR   r
   rW   rX   r	   r   r*   r   r   r
   rX   r	   r+   r-   r.   r$   )     
zEmbeddingFP4.__init__rG   rH   rI   r%   r   r$   rN   r-   r-   r+   r.   r   (  
    r   c                       r   )EmbeddingNF4Nc                    r   )Nr   r   r   r   r+   r-   r.   r$   <  r   zEmbeddingNF4.__init__r   r-   r-   r+   r.   r   ;  r   r   c                       sf   e Zd ZdZ					ddedef fddZ fd	d
Z fddZdd Zde	j
fddZ  ZS )Linear8bitLtaZ  
    This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
    To read more about it, have a look at the paper.

    In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
    the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights.

    Example:

    ```python
    import torch
    import torch.nn as nn

    import bitsandbytes as bnb
    from bnb.nn import Linear8bitLt

    fp16_model = nn.Sequential(
        nn.Linear(64, 64),
        nn.Linear(64, 64)
    )

    int8_model = nn.Sequential(
        Linear8bitLt(64, 64, has_fp16_weights=False),
        Linear8bitLt(64, 64, has_fp16_weights=False)
    )

    int8_model.load_state_dict(fp16_model.state_dict())
    int8_model = int8_model.to(0) # Quantization happens here
    ```
    T        Nr   r   c                    sh   t  |||| t | _|| _|| j_|| j_|dkr#|s#d| j_t	| j
j||d| _
| t dS )ay  
        Initialize Linear8bitLt class.

        Args:
            input_features (`int`):
                Number of input features of the linear layer.
            output_features (`int`):
                Number of output features of the linear layer.
            bias (`bool`, defaults to `True`):
                Whether the linear class uses the bias term as well.
        r   Tr   N)r#   r$   ru   MatmulLtStaterd   index	thresholdr   use_poolr   r    rS   "_register_load_state_dict_pre_hookr   )r*   r   r   r   r   r  r   r	   r+   r-   r.   r$   n  s   
zLinear8bitLt.__init__c           	         s   t  ||| d}t| j|}t| j|}||  }|d }| jjsW|d ur=|r+|n| ||< tjdtj	d||< d S |d urY|rE|n| ||< tjdtj	d||< d S d S d S )Nr   r   r   )r
   )
r#   r   r   r    rd   r   r   r%   r   r   )	r*   r   r   r   scb_nameparam_from_weightparam_from_statekey_nameformat_namer+   r-   r.   r     s   
z Linear8bitLt._save_to_state_dictc              	      s   t  ||||||| t|}|D ]4}	|	t|d  }
|
dkrF| jjd u r*td||	 }| jj| | jjd urA| jj| j_|	|	 qd S )Nr   zLoading a quantized checkpoint into non-quantized Linear8bitLt is not supported. Please call module.cuda() before module.load_state_dict())
r#   _load_from_state_dictlistlenr    r   r   copy_rd   remove)r*   r   r   r   r   r   r   r   unexpected_copykey
input_nameinput_paramr+   r-   r.   r	    s0   
	
z"Linear8bitLt._load_from_state_dictc                 C   ,   | j j| j_| j j| j_d | j _d | j _d S r0   r    r   rd   r   r5   r-   r-   r.   init_8bit_state     zLinear8bitLt.init_8bit_stater   c                 C   s   | j | j_| jjd ur|   | jd ur%| jj|jkr%| jj	|j| j_t
j|| j| j| jd}| jjsA| jjd urA| jj| j_|S N)r   rd   )trainingrd   is_trainingr    r   r  r   r
   rS   rA   ru   matmulr   r*   r   outr-   r-   r.   rE     s   
zLinear8bitLt.forward)TTr   NN)rG   rH   rI   rJ   rK   r$   r   r	  r  r%   r   rE   rN   r-   r-   r+   r.   r   N  s     #"'r   c                       s6   e Zd Zd fdd	Zdd Zdd Zd	d
 Z  ZS )OutlierAwareLinearTNc                    s"   t  |||| d | _d| _d S r   )r#   r$   outlier_dimis_quantized)r*   r   r   r   r	   r+   r-   r.   r$     s   
zOutlierAwareLinear.__init__c                 C   r   )NzJPlease override the `forward_with_outliers(self, x, outlier_idx)` functionr   )r*   r   outlier_idxr-   r-   r.   forward_with_outliers  r   z(OutlierAwareLinear.forward_with_outliersc                 C   r   )NzEPlease override the `quantize_weights(self, w, outlier_idx)` functionr   )r*   rx   r  r-   r-   r.   quantize_weight  r   z"OutlierAwareLinear.quantize_weightc                 C   sf   | j d u rt }| std || j}|| _ | js1| | j| j }| jj	
| d| _d S d S )NzTPlease use OutlierTracer.initialize(model) before using the OutlierAwareLinear layerT)r  r   r(   is_initializedprintget_outliersr    r  r!  rS   r  )r*   r   tracerr  rx   r-   r-   r.   rE     s   

zOutlierAwareLinear.forward)TN)rG   rH   rI   r$   r   r!  rE   rN   r-   r-   r+   r.   r    s
    r  c                       s:   e Zd Z						d fdd	Zdd Zd	d
 Z  ZS )SwitchBackLinearBnbTFr   Nc	           	         sf   t  |||| t | _|| _|| j_|| j_|| j_|dkr'|s'd| j_	t
| jj||d| _d S )Nr   Tr   )r#   r$   ru   r   rd   r   r  r   memory_efficient_backwardr  r   r    rS   )	r*   r   r   r   r   r'  r  r   r	   r+   r-   r.   r$     s   
zSwitchBackLinearBnb.__init__c                 C   r  r0   r  r5   r-   r-   r.   r    r  z#SwitchBackLinearBnb.init_8bit_statec                 C   sF   | j | j_| jjd ur|   tj| | j d | jd| j	 }d S r  )
r  rd   r  r    r   r  ru   matmul_mixedr   r   r  r-   r-   r.   rE     s   
(zSwitchBackLinearBnb.forward)TTFr   NN)rG   rH   rI   r$   r  rE   rN   r-   r-   r+   r.   r&    s    r&  )0rb   typingr   r   r   r   r   r   r   r%   r   r	   r
   r   torch.nn.functionalrv   r?   bitsandbytesru    bitsandbytes.autograd._functionsr   r   bitsandbytes.functionalr   bitsandbytes.optimr   bitsandbytes.utilsr   r   r   rO   r   	ParameterrP   r   LinearrZ   r   r   r   r   r   r   r   r   r   r  r&  r-   r-   r-   r.   <module>   s<    jO y$+I/d 