o
    Gi_                     @   s   d dl Z d dlmZ d dlZd dlmZ ddlmZmZ ddl	m
Z
 ddlmZ e
eZG dd	 d	eZG d
d deeZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdS )    N)Callable   )ConfigMixinregister_to_config)logging   )
ModelMixinc                       s   e Zd ZdZded f fddZddejdee dB d	eej fd
dZ					dde
ejB dededede
dB f
ddZede
ejB dB fddZ  ZS )MultiAdaptera  
    MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
    user-assigned weighting.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
    or saving.

    Args:
        adapters (`list[T2IAdapter]`, *optional*, defaults to None):
            A list of `T2IAdapter` model instances.
    adapters
T2IAdapterc                    s   t t|   t|| _t|| _t|dkrtdt|dkr&td|d j	}|d j
}tdt|D ],}|| j	|ksG|| j
|krctd| d| d| d|| j	 d| d	|| j
 q7|| _	|| _
d S )
Nr   zExpecting at least one adapterr   zQFor a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`zjExpecting all adapters to have the same downscaling behavior, but got:
adapters[0].total_downscale_factor=z
adapters[0].downscale_factor=z

adapter[`z`].total_downscale_factor=z`].downscale_factor=)superr	   __init__lennum_adapternn
ModuleListr
   
ValueErrortotal_downscale_factordownscale_factorrange)selfr
   $first_adapter_total_downscale_factorfirst_adapter_downscale_factoridx	__class__ L/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/adapter.pyr   )   s<   




zMultiAdapter.__init__Nxsadapter_weightsreturnc           	      C   s   |du rt d| j g| j }nt |}d}t||| jD ]6\}}}||}|du r@|}tt|D ]
}|||  ||< q4qtt|D ]}||  |||  7  < qFq|S )az  
        Args:
            xs (`torch.Tensor`):
                A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
                models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
                `num_adapter` * number of channel per image.

            adapter_weights (`list[float]`, *optional*, defaults to None):
                A list of floats representing the weights which will be multiplied by each adapter's output before
                summing them together. If `None`, equal weights will be used for all adapters.
        Nr   )torchtensorr   zipr
   r   r   )	r   r   r   accume_statexwadapterfeaturesir   r   r   forwardL   s   
zMultiAdapter.forwardTsave_directoryis_main_processsave_functionsafe_serializationvariantc           	      C   sB   d}|}| j D ]}|j|||||d |d7 }|d|  }qdS )a6  
        Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
        `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.

        Args:
            save_directory (`str` or `os.PathLike`):
                The directory where the model will be saved. If the directory does not exist, it will be created.
            is_main_process (`bool`, optional, defaults=True):
                Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
                TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
                for the main process to avoid race conditions.
            save_function (`Callable`):
                Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
                `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
                variable.
            safe_serialization (`bool`, optional, defaults=True):
                If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
            variant (`str`, *optional*):
                If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
        r   )r,   r-   r.   r/   r   _N)r
   save_pretrained)	r   r+   r,   r-   r.   r/   r   model_path_to_saver'   r   r   r   r1   i   s   
zMultiAdapter.save_pretrainedpretrained_model_pathc                 K   s   d}g }|}t j|r+tj|fi |}|| |d7 }|d|  }t j|stt| d| d t|dkrOt	dt j
| d|d  d| |S )	a  
        Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
        the model, set it back to training mode using `model.train()`.

        Warnings:
            *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
            with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
            from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.

        Args:
            pretrained_model_path (`os.PathLike`):
                A path to a *directory* containing model weights saved using
                [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
            torch_dtype (`torch.dtype`, *optional*):
                Override the default `torch.dtype` and load the model under this dtype.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
            device_map (`str` or `dict[str, int | str | torch.device]`, *optional*):
                A map that specifies where each submodule should go. It doesn't need to be refined to each
                parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
                same device.

                To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
                more information about each option see [designing a device
                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
            max_memory (`Dict`, *optional*):
                A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
                available for each GPU and the available CPU RAM if unset.
            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
                Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
                also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
                model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
                setting this argument to `True` will raise an error.
            variant (`str`, *optional*):
                If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
                be ignored when using `from_flax`.
            use_safetensors (`bool`, *optional*, defaults to `None`):
                If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
                installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
                `safetensors` is not used.
        r   r   r0   z adapters loaded from .zNo T2IAdapters found under z. Expected at least _0)ospathisdirr   from_pretrainedappendloggerinfor   r   dirname)clsr3   kwargsr   r
   model_path_to_loadr'   r   r   r   r9      s   -
zMultiAdapter.from_pretrainedN)TNTN)__name__
__module____qualname____doc__listr   r!   Tensorfloatr*   strr6   PathLikeboolr   r1   classmethodr9   __classcell__r   r   r   r   r	      s*    (# 
*"r	   c                       s   e Zd ZdZedg ddddfdedee d	ed
edef
 fddZde	j
dee	j
 fddZedd Zedd Z  ZS )r   a\  
    A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
    generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
    architecture follows the original implementation of
    [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
     and
     [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
    downloading or saving.

    Args:
        in_channels (`int`, *optional*, defaults to `3`):
            The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
            image.
        channels (`list[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
            determines the number of downsample blocks in the adapter.
        num_res_blocks (`int`, *optional*, defaults to `2`):
            Number of ResNet blocks in each downsample block.
        downscale_factor (`int`, *optional*, defaults to `8`):
            A factor that determines the total downscale factor of the Adapter.
        adapter_type (`str`, *optional*, defaults to `full_adapter`):
            Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
       @       rR   r      full_adapterin_channelschannelsnum_res_blocksr   adapter_typec                    sn   t    |dkrt||||| _d S |dkr!t||||| _d S |dkr/t||||| _d S td| d)NrT   full_adapter_xllight_adapterzUnsupported adapter_type: 'zH'. Choose either 'full_adapter' or 'full_adapter_xl' or 'light_adapter'.)r   r   FullAdapterr'   FullAdapterXLLightAdapterr   )r   rU   rV   rW   r   rX   r   r   r   r      s   
	
zT2IAdapter.__init__r%   r    c                 C   s
   |  |S )a  
        This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
        each representing information extracted at a different scale from the input. The length of the list is
        determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
        `num_res_blocks` parameters during initialization.
        )r'   r   r%   r   r   r   r*   
  s   
zT2IAdapter.forwardc                 C   s   | j jS rA   )r'   r   r   r   r   r   r     s   z!T2IAdapter.total_downscale_factorc                 C   s
   | j jjS )zThe downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
        not evenly divisible by the downscale_factor then an exception will be raised.
        )r'   	unshuffler   r_   r   r   r   r     s   
zT2IAdapter.downscale_factor)rB   rC   rD   rE   r   intrF   rI   r   r!   rG   r*   propertyr   r   rM   r   r   r   r   r      s0    	
r   c                	       ^   e Zd ZdZdg dddfdedee ded	ef fd
dZdejdeej fddZ	  Z
S )r[   2
    See [`T2IAdapter`] for more information.
    rN   rO   r   rS   rU   rV   rW   r   c                    s   t    ||d  }t|| _tj| d ddd| _tt d  d g fddt	dt
 D | _|dt
 d   | _d S )Nr   r   rN   r   kernel_sizepaddingc                    s(   g | ]}t  |d    | ddqS r   Tdown)AdapterBlock.0r)   rV   rW   r   r   
<listcomp>8      z(FullAdapter.__init__.<locals>.<listcomp>)r   r   r   PixelUnshuffler`   Conv2dconv_inr   rk   r   r   bodyr   r   rU   rV   rW   r   r   rn   r   r   '  s   

zFullAdapter.__init__r%   r    c                 C   :   |  |}| |}g }| jD ]}||}|| q|S )a  
        This method processes the input tensor `x` through the FullAdapter model and performs operations including
        pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
        capturing information at a different stage of processing within the FullAdapter model. The number of feature
        tensors in the list is determined by the number of downsample blocks specified during initialization.
        r`   rs   rt   r:   r   r%   r(   blockr   r   r   r*   A  s   


zFullAdapter.forwardrB   rC   rD   rE   ra   rF   r   r!   rG   r*   rM   r   r   r   r   r[   "       "r[   c                	       rc   )r\   rd   rN   rO   r      rU   rV   rW   r   c              	      s   t    ||d  }t|| _tj||d ddd| _g | _tt	|D ]<}|dkr=| j
t||d  || | q&|dkrT| j
t||d  || |dd q&| j
t|| || | q&t| j| _|d | _d S )Nr   r   rN   r   re   Tri   )r   r   r   rq   r`   rr   rs   rt   r   r   r:   rk   r   r   )r   rU   rV   rW   r   r)   r   r   r   r   Y  s   
"&zFullAdapterXL.__init__r%   r    c                 C   rv   )z
        This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
        including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
        rw   rx   r   r   r   r*   u  s   


zFullAdapterXL.forwardrz   r   r   r   r   r\   T  s     "r\   c                	       J   e Zd ZdZddedededef fddZd	ejd
ejfddZ	  Z
S )rk   a#  
    An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
    `FullAdapterXL` models.

    Args:
        in_channels (`int`):
            Number of channels of AdapterBlock's input.
        out_channels (`int`):
            Number of channels of AdapterBlock's output.
        num_res_blocks (`int`):
            Number of ResNet blocks in the AdapterBlock.
        down (`bool`, *optional*, defaults to `False`):
            If `True`, perform downsampling on AdapterBlock's input.
    FrU   out_channelsrW   rj   c                    sh   t    d | _|rtjdddd| _d | _| kr#tj| dd| _tj fddt|D  | _	d S )Nr   Trf   stride	ceil_moder   rf   c                       g | ]}t  qS r   )AdapterResnetBlockrm   r0   r~   r   r   ro         z)AdapterBlock.__init__.<locals>.<listcomp>)
r   r   
downsampler   	AvgPool2din_convrr   
Sequentialr   resnetsr   rU   r~   rW   rj   r   r   r   r     s   

zAdapterBlock.__init__r%   r    c                 C   s6   | j dur
|  |}| jdur| |}| |}|S )a  
        This method takes tensor x as input and performs operations downsampling and convolutional layers if the
        self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
        residual blocks to the input tensor.
        N)r   r   r   r^   r   r   r   r*     s   




zAdapterBlock.forwardFrB   rC   rD   rE   ra   rK   r   r!   rG   r*   rM   r   r   r   r   rk     s     rk   c                       <   e Zd ZdZdef fddZdejdejfddZ  Z	S )	r   z
    An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.

    Args:
        channels (`int`):
            Number of channels of AdapterResnetBlock's input and output.
    rV   c                    s>   t    tj||ddd| _t | _tj||dd| _d S )NrN   r   re   r   r   r   r   rr   block1ReLUactblock2r   rV   r   r   r   r     s   

zAdapterResnetBlock.__init__r%   r    c                 C   "   |  | |}| |}|| S )z
        This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
        layer on the input tensor. It returns addition with the input tensor.
        r   r   r   r   r%   hr   r   r   r*        
zAdapterResnetBlock.forward
rB   rC   rD   rE   ra   r   r!   rG   r*   rM   r   r   r   r   r     s    r   c                	       rc   )r]   rd   rN   )rP   rQ   rR      rS   rU   rV   rW   r   c              	      s   t    ||d  }t|| _tt| d g fddtt d D t d  d dd| _	|dt   | _
d S )	Nr   r   c                    s(   g | ]}t  |  |d   ddqS rh   )LightAdapterBlockrl   rn   r   r   ro     rp   z)LightAdapter.__init__.<locals>.<listcomp>r   Tri   )r   r   r   rq   r`   r   r   r   r   rt   r   ru   r   rn   r   r     s   
zLightAdapter.__init__r%   r    c                 C   s0   |  |}g }| jD ]}||}|| q
|S )z
        This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
        feature tensor corresponds to a different level of processing within the LightAdapter.
        )r`   rt   r:   rx   r   r   r   r*     s   

zLightAdapter.forwardrz   r   r   r   r   r]     r{   r]   c                	       r}   )r   a<  
    A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
    `LightAdapter` model.

    Args:
        in_channels (`int`):
            Number of channels of LightAdapterBlock's input.
        out_channels (`int`):
            Number of channels of LightAdapterBlock's output.
        num_res_blocks (`int`):
            Number of LightAdapterResnetBlocks in the LightAdapterBlock.
        down (`bool`, *optional*, defaults to `False`):
            If `True`, perform downsampling on LightAdapterBlock's input.
    FrU   r~   rW   rj   c                    st   t    |d  d | _|rtjdddd| _tj| dd| _tj fddt|D  | _	tj |dd| _
d S )	Nr   r   Tr   r   r   c                    r   r   )LightAdapterResnetBlockr   mid_channelsr   r   ro     r   z.LightAdapterBlock.__init__.<locals>.<listcomp>)r   r   r   r   r   rr   r   r   r   r   out_convr   r   r   r   r     s   
zLightAdapterBlock.__init__r%   r    c                 C   s6   | j dur
|  |}| |}| |}| |}|S )z
        This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
        layer, a sequence of residual blocks, and out convolutional layer.
        N)r   r   r   r   r^   r   r   r   r*     s   




zLightAdapterBlock.forwardr   r   r   r   r   r   r     s     r   c                       r   )	r   a  
    A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
    architecture than `AdapterResnetBlock`.

    Args:
        channels (`int`):
            Number of channels of LightAdapterResnetBlock's input and output.
    rV   c                    s@   t    tj||ddd| _t | _tj||ddd| _d S )NrN   r   re   r   r   r   r   r   r   8  s   

z LightAdapterResnetBlock.__init__r%   r    c                 C   r   )z
        This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
        another convolutional layer and adds it to input tensor.
        r   r   r   r   r   r*   >  r   zLightAdapterResnetBlock.forwardr   r   r   r   r   r   .  s    	r   )r6   typingr   r!   torch.nnr   configuration_utilsr   r   utilsr   modeling_utilsr   
get_loggerrB   r;   r	   r   Moduler[   r\   rk   r   r]   r   r   r   r   r   r   <module>   s$   
 =J220/+