o
    Gi                  
   @   sL  d dl Z d dlZd dlZd dlmZ d dlmZ d dlZd dl	m
Z
mZ ddlmZ ddlmZmZ ddlmZ ddlmZmZmZmZmZmZmZ dd	lmZ d
dlmZmZ e rgd dl m!  m"Z# dZ$ndZ$e%e&Z'e rwd dl(m)Z) e r~d dl*Z*dZ+i dddgdddgdddgdddgdddgdddgddd gd!d"d#gd$d"d%gd&d"d'gd(d)d'gd*d)d+gd,d-d.gd/d-d0gd1d2d0gd3d2d4gd5d6d4gi d7d6d8gd9d8d8gd:d8d6gd;d4d6gd<d4d2gd=d0d2gd>d0d-gd?d.d-gd@d+d)gdAd'd)gdBd%d"gdCd#d"gdDd dgdEddgdFddgdGddgZ,i ddHd8gddHdIgddJd2gddJdKgddJd-gddLdMgddLd)gd!dNdOgd$dNd"gd&dNdPgd(dQdPgd*dQdgd,dRdSgd/dRdgd1dTdgd3dTdUgd5dVdUgi d7dVdgd9ddgd:ddVgd;dUdVgd<dUdTgd=ddTgd>ddRgd?dSdRgd@ddQgdAdPdQgdBd"dNgdCdOdNgdDd)dLgdEdMdLgdFdKdJgdGd8dHgZ-i ddWdgddWdXgddYdTgddYdZgddYdRgdd[d\gdd[dQgd!d]d^gd$d]dNgd&d]d_gd(d`d_gd*d`dLgd,dadbgd/dadJgd1dcdJgd3dcddgd5deddgi d7dedHgd9dHdHgd:dHdegd;dddegd<dddcgd=dJdcgd>dJdagd?dbdagd@dLd`gdAd_d`gdBdNd]gdCd^d]gdDdQd[gdEd\d[gdFdZdYgdGddWgZ.				dndfe/dB dge0ej1B dB dhe2e/ dB die2e3 dB fdjdkZ4G dldm dmeZ5dS )o    N)Callable)T5EncoderModelT5Tokenizer   )PixArtImageProcessor)AutoencoderKLPixArtTransformer2DModel)DPMSolverMultistepScheduler)BACKENDS_MAPPING	deprecateis_bs4_availableis_ftfy_availableis_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipelineImagePipelineOutputTF)BeautifulSoupa  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import PixArtAlphaPipeline

        >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
        >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
        >>> # Enable memory optimizations.
        >>> pipe.enable_model_cpu_offload()

        >>> prompt = "A small cactus with a happy face in the Sahara desert."
        >>> image = pipe(prompt).images[0]
        ```
z0.25g      @g      @z0.28g      @z0.32g      @g      @z0.33g      @z0.35g      @z0.4g      @g      @z0.42g      @z0.48g      @g      @z0.5g      @z0.52g      @z0.57g      @z0.6g      @z0.68g      @g      @z0.72g      @z0.78g      @z0.82g      @z0.88g      @z0.94g      @z1.0z1.07z1.13z1.21z1.29z1.38z1.46z1.67z1.75z2.0z2.09z2.4z2.5z3.0z4.0g      p@g      @g      r@g      @g      t@g      @g      v@g      @g      @g      x@g      z@g      @g      |@g      @g      ~@g      `@g      }@g      b@g      {@g      d@g      y@g      f@g      w@g      u@g      h@g      j@g      s@g      l@g      q@g      n@num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r   r   r   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r   r   r    )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r   len)	schedulerr   r   r   r   kwargsaccepts_timestepsaccept_sigmasr   r   j/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.pyretrieve_timesteps   s2   r)   c                4       s  e Zd ZdZedZddgZdZde	de
dededef
 fd	d
Z										d=deee B dedededejdB dejdB dejdB dejdB dejdB dedefddZdd Z				d>d d!Zd?d"d#Zd$d% Zd@d&d'Ze ee			(			)				*							+						dAdeee B ded,ed-ee d.ee d/ededB d0edB d1edB d2ed3ej eej  B dB d4ejdB dejdB dejdB dejdB dejdB d5edB d6ed7e!eeejgdf dB d8eded9eded:e"e#B f0d;d<Z$  Z%S )BPixArtAlphaPipelinea  
    Pipeline for text-to-image generation using PixArt-Alpha.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`T5EncoderModel`]):
            Frozen text-encoder. PixArt-Alpha uses
            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
        tokenizer (`T5Tokenizer`):
            Tokenizer of class
            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
        transformer ([`PixArtTransformer2DModel`]):
            A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as
            [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS/blob/main/transformer/config.json#L2)
            in the config, but the mismatch can be ignored.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
    u5   [#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\*]{1,}	tokenizertext_encoderztext_encoder->transformer->vaevaetransformerr$   c                    sX   t    | j|||||d t| dd r dt| jjjd  nd| _t	| jd| _
d S )N)r+   r,   r-   r.   r$   r-   r         )vae_scale_factor)super__init__register_modulesgetattrr#   r-   configblock_out_channelsr1   r   image_processor)selfr+   r,   r-   r.   r$   r"   r   r(   r3     s   

(zPixArtAlphaPipeline.__init__T r/   NFx   promptdo_classifier_free_guidancenegative_promptnum_images_per_promptr   prompt_embedsnegative_prompt_embedsprompt_attention_masknegative_prompt_attention_maskclean_captionmax_sequence_lengthc              	   K   sh  d|v rd}t dd|dd |du r| j}|}|du rz| j||
d}| j|d|d	d	d
d}|j}| j|dd
dj}|jd |jd krdt||sd| j|dd|d df }t	
d| d|  |j}||}| j|||d}|d }| jdur| jj}n| jdur| jj}nd}|j||d}|j\}}}|d|d}||| |d}|d|}||| d}|r|du rt|tr|g| n|}| j||
d}|jd }| j|d|d	d	d	d
d}|j}	|	|}	| j|j||	d}|d }|r*|jd }|j||d}|d|d}||| |d}|	d|}	|	|| d}	nd}d}	||||	fS )az  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
                instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
                PixArt-Alpha, this should be "".
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                whether to use classifier free guidance or not
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                number of images that should be generated per prompt
            device: (`torch.device`, *optional*):
                torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
                string.
            clean_caption (`bool`, defaults to `False`):
                If `True`, the function will preprocess and clean the provided caption before encoding.
            max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
        mask_featureThe use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version.1.0.0Fstandard_warnN)rE   
max_lengthTpt)paddingrL   
truncationadd_special_tokensreturn_tensorslongest)rN   rQ   r/   zZThe following part of your input was truncated because T5 can only handle sequences up to z	 tokens: )attention_maskr   dtyper   )rN   rL   rO   return_attention_maskrP   rQ   )r   _execution_device_text_preprocessingr+   	input_idsshapetorchequalbatch_decodeloggerwarningrT   tor,   rV   r.   repeatview
isinstancestr)r9   r=   r>   r?   r@   r   rA   rB   rC   rD   rE   rF   r%   deprecation_messagerL   text_inputstext_input_idsuntruncated_idsremoved_textrV   bs_embedseq_len_uncond_tokensuncond_inputr   r   r(   encode_prompt/  s   * 





	

z!PixArtAlphaPipeline.encode_promptc                 C   sX   dt t| jjj v }i }|r||d< dt t| jjj v }|r*||d< |S )Neta	generator)r   r   r   r$   stepr    r!   )r9   rr   rq   accepts_etaextra_step_kwargsaccepts_generatorr   r   r(   prepare_extra_step_kwargs  s   z-PixArtAlphaPipeline.prepare_extra_step_kwargsc
           
      C   s  |d dks|d dkrt d| d| d|d u s(|d ur5t|tr(|dkr5t d| dt| d|d urH|d urHt d| d	| d
|d u rT|d u rTt d|d urkt|tskt|tskt dt| |d ur~|d ur~t d| d| d
|d ur|d urt d| d| d
|d ur|d u rt d|d ur|	d u rt d|d ur|d ur|j|jkrt d|j d|j d|j|	jkrt d|j d|	j dd S d S d S )Nr0   r   z7`height` and `width` have to be divisible by 8 but are z and .z5`callback_steps` has to be a positive integer but is z	 of type zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z and `negative_prompt_embeds`: z'Cannot forward both `negative_prompt`: zEMust provide `prompt_attention_mask` when specifying `prompt_embeds`.zWMust provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` z`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but got: `prompt_attention_mask` z% != `negative_prompt_attention_mask` )r   rd   inttypere   listr[   )
r9   r=   heightwidthr?   callback_stepsrA   rB   rC   rD   r   r   r(   check_inputs  sl   z PixArtAlphaPipeline.check_inputsc                    s    rt  sttd d d td d  r0t s0ttd d d td d t|ttfs:|g}dt	f fdd	fd
d|D S )Nbs4rS   zSetting `clean_caption=True`z#Setting `clean_caption` to False...Fftfytextc                    s,    r | }  | } | S |   } | S N)_clean_captionlowerstrip)r   )rE   r9   r   r(   process  s   

z8PixArtAlphaPipeline._text_preprocessing.<locals>.processc                    s   g | ]} |qS r   r   ).0t)r   r   r(   
<listcomp>"  s    z;PixArtAlphaPipeline._text_preprocessing.<locals>.<listcomp>)
r   r_   r`   r
   formatr   rd   tupler{   re   )r9   r   rE   r   )rE   r   r9   r(   rY     s   



z'PixArtAlphaPipeline._text_preprocessingc                 C   s  t |}t|}|  }tdd|}tdd|}tdd|}t|ddj}tdd|}td	d|}td
d|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}tdd|}td d|}td!d|}t| j	d|}td"d|}t
d#}tt||d$krt|d|}t|}tt|}td%d|}td&d|}td'd|}td(d|}td)d|}td*d|}td+d|}td,d|}td-d|}td.d|}td/d0|}td1d2|}td3d|}|  td4d5|}td6d|}td7d|}td8d|}| S )9Nz<person>personzk\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))r;   zh\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))zhtml.parser)featuresz
@[\w\d]+\bz[\u31c0-\u31ef]+z[\u31f0-\u31ff]+z[\u3200-\u32ff]+z[\u3300-\u33ff]+z[\u3400-\u4dbf]+z[\u4dc0-\u4dff]+z[\u4e00-\u9fff]+z|[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+-u   [`´«»“”¨]"u   [‘’]'z&quot;?z&ampz"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3} z\d:\d\d\s+$z\\nz
#\d{1,3}\bz	#\d{5,}\bz
\b\d{6,}\bz0[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)z
[\"\']{2,}z[\.]{2,}z\s+\.\s+z	(?:\-|\_)r   z\b[a-zA-Z]{1,3}\d{3,15}\bz\b[a-zA-Z]+\d+[a-zA-Z]+\bz\b\d+[a-zA-Z]+\d+\bz!(worldwide\s+)?(free\s+)?shippingz(free\s)?download(\sfree)?z\bclick\b\s(?:for|on)\s\w+z9\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?z\bpage\s+\d+\bz*\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\bu   \b\d+\.?\d*[xх×]\d+\.?\d*\bz
\b\s+\:\s+z: z(\D[,\./])\bz\1 z\s+z^[\"\']([\w\W]+)[\"\']$z\1z^[\'\_,\-\:;]z[\'\_,\-\:\-\+]$z^\.\S+$)re   ulunquote_plusr   r   resubr   r   bad_punct_regexcompiler#   findallr   fix_texthtmlunescape)r9   captionregex2r   r   r(   r   %  s   
	

z"PixArtAlphaPipeline._clean_captionc	           
      C   s   ||t || j t || j f}	t|tr(t||kr(tdt| d| d|d u r5t|	|||d}n||}|| jj	 }|S )Nz/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)rr   r   rV   )
ry   r1   rd   r{   r#   r   r   ra   r$   init_noise_sigma)
r9   
batch_sizenum_channels_latentsr|   r}   rV   r   rr   latentsr[   r   r   r(   prepare_latents  s    
z#PixArtAlphaPipeline.prepare_latents         @        pilr   r   r   guidance_scaler|   r}   rq   rr   r   output_typereturn_dictcallbackr~   use_resolution_binningreturnc           4      K   sR  d|v rd}t dd|dd |p| jjj| j }|	p!| jjj| j }	|rV| jjjdkr.t}n| jjjdkr8t}n| jjjdkrBt}ntd	||	}}| j	j
||	|d
\}}	| |||	||||||	 |durot|trod}n|dur}t|tr}t|}n|jd }| j}|dk}| j|||||||||||d\}}}}|rtj||gdd}tj||gdd}trd} n|} t| j|| ||\}}| jjj}!| || |!||	|j|||}| ||
}"ddd}#| jjjdkr1t||	g|| d}$tt||	 g|| d}%|$j|j|d}$|%j|j|d}%|r,tj|$|$gdd}$tj|%|%gdd}%|$|%d}#t t||| jj!  d}&| j"|d}'t#|D ]\}(})|rZt|gd n|}*| j$|*|)}*|)}+t%|+s|*j&j'dk},|*j&j'dk}-t|+tr|,s|-rtj(ntj)}.n|,s|-rtj*ntj+}.tj|+g|.|*j&d}+nt|+jdkr|+d |*j&}+|+,|*jd }+| j|*|||+|#ddd }/|r|/-d\}0}1|0||1|0   }/| jjj.d |!kr|/j-dddd }/n|/}/|dkr| jj/|/|)|fi |"ddid }n| jj/|/|)|fi |"ddid }|(t|d ks3|(d |&krR|(d | jj! dkrR|'0  |durR|(| dkrR|(t1| jdd }2||2|)| trYt23  qJW d   n	1 sfw   Y  |dks| j4j5|| j4jj6 ddd }3|r| j	7|3||}3n|}3|dks| j	j8|3|d}3| 9  |s|3fS t:|3dS )uA  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_inference_steps (`int`, *optional*, defaults to 100):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            sigmas (`list[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            guidance_scale (`float`, *optional*, defaults to 4.5):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            height (`int`, *optional*, defaults to self.unet.config.sample_size):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to self.unet.config.sample_size):
                The width in pixels of the generated image.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
                applies to [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
            clean_caption (`bool`, *optional*, defaults to `True`):
                Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
                be installed. If the dependencies are not installed, the embeddings will be created from the raw
                prompt.
            use_resolution_binning (`bool` defaults to `True`):
                If set to `True`, the requested height and width are first mapped to the closest resolutions using
                `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
                the requested resolution. Useful for generating non-square images.
            max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        rG   rH   rI   FrJ      @       zInvalid sample size)ratiosNr/   r   g      ?)	r?   r@   r   rA   rB   rC   rD   rE   rF   )dimcpu)
resolutionaspect_ratiorU   )totalr   mpsnpu)encoder_hidden_statesencoder_attention_masktimestepadded_cond_kwargsr   r   orderlatent)r   )r   )images);r   r.   r6   sample_sizer1   ASPECT_RATIO_1024_BINASPECT_RATIO_512_BINASPECT_RATIO_256_BINr   r8   classify_height_width_binr   rd   re   r{   r#   r[   rX   rp   r\   catXLA_AVAILABLEr)   r$   in_channelsr   rV   rw   tensorrb   floatra   maxr   progress_bar	enumeratescale_model_input	is_tensorr   rz   float32float64int32int64expandchunkout_channelsrs   updater5   xm	mark_stepr-   decodescaling_factorresize_and_crop_tensorpostprocessmaybe_free_model_hooksr   )4r9   r=   r?   r   r   r   r   r@   r|   r}   rq   rr   r   rA   rC   rB   rD   r   r   r   r~   rE   r   rF   r%   rf   aspect_ratio_binorig_height
orig_widthr   r   r>   timestep_devicelatent_channelsru   r   r   r   num_warmup_stepsr   ir   latent_model_inputcurrent_timestepis_mpsis_npurV   
noise_prednoise_pred_uncondnoise_pred_textstep_idximager   r   r(   __call__  s  n




 


&$6
:

zPixArtAlphaPipeline.__call__)
Tr;   r/   NNNNNFr<   NNNN)Fr   )Nr;   r   NNr   r/   NNr   NNNNNNr   TNr/   TTr<   )&__name__
__module____qualname____doc__r   r   r   _optional_componentsmodel_cpu_offload_seqr   r   r   r   r	   r3   re   r{   boolry   r\   r   Tensorrp   rw   r   rY   r   r   no_gradr   EXAMPLE_DOC_STRINGr   	Generatorr   r   r   r   __classcell__r   r   r:   r(   r*      s   
	

 	

D
s
	
r*   r   )6r   r   r   urllib.parseparser   typingr   r\   transformersr   r   r8   r   modelsr   r   
schedulersr	   utilsr
   r   r   r   r   r   r   utils.torch_utilsr   pipeline_utilsr   r   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   r_   r   r   r   r   r   r   r   ry   re   r   r{   r   r)   r*   r   r   r   r(   <module>   s  $	
	
 !$	
 !$	
 !(


;