o
    Gi                  
   @   s  d dl Z d dlZd dlmZmZ d dlZd dlZd dlZd dl	m
Z
mZmZmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZm Z  ddl!m"Z" ddl#m$Z$ eZ%eZ&eddryd dl	m&Z&m%Z% e rd dl'm(  m)Z* dZ+ndZ+e,e-Z.dZ/			d,de0de1de1de1fddZ2				d-de0dB de3ej4B dB d e5e0 dB d!e5e1 dB fd"d#Z6	$d.d%ej7d&ej8dB d'e3fd(d)Z9G d*d+ d+eZ:dS )/    N)AnyCallable)ByT5TokenizerPreTrainedModelProcessorMixinT5EncoderModel   )MultiPipelineCallbacksPipelineCallback)VaeImageProcessor)AutoencoderKLGlmImageTransformer2DModel)GlmImageKVCache)DiffusionPipeline)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableis_transformers_versionloggingreplace_example_docstring)randn_tensor   )GlmImagePipelineOutputz>=z
5.0.0.dev0) GlmImageForConditionalGenerationGlmImageProcessorTFa  
    Examples:
        ```python
        >>> import torch
        >>> from diffusers import GlmImagePipeline

        >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")

        >>> prompt = "A photo of an astronaut riding a horse on mars"
        >>> image = pipe(prompt).images[0]
        >>> image.save("output.png")
        ```
         ?      ?base_seq_len
base_shift	max_shiftreturnc                 C   s   | | d }|| | }|S )Ng      ? )image_seq_lenr   r   r   mmur!   r!   d/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/glm_image/pipeline_glm_image.pycalculate_shiftD   s   r&   num_inference_stepsdevice	timestepssigmasc                 K   sT  dt t| jj v }dt t| jj v }|durF|durF|s/|s/td| j d| jd|||d| | j}t	|}||fS |duro|du ro|sYtd| j d| jd||d| | j}t	|}||fS |du r|dur|std| j d	| jd||d
| | j}t	|}||fS | j|fd|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    r)   r*   NzThe current scheduler class z's `set_timesteps` does not support custom timestep or sigma schedules. Please check whether you are using the correct scheduler.)r)   r*   r(   zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r)   r(   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r*   r(   r(   r!   )
setinspect	signatureset_timesteps
parameterskeys
ValueError	__class__r)   len)	schedulerr'   r(   r)   r*   kwargsaccepts_timestepsaccepts_sigmasr!   r!   r%   retrieve_timestepsP   s@   r8   sampleencoder_output	generatorsample_modec                 C   sR   t | dr|dkr| j|S t | dr|dkr| j S t | dr%| jS td)Nlatent_distr9   argmaxlatentsz3Could not access latents of provided encoder_output)hasattrr=   r9   moder?   AttributeError)r:   r;   r<   r!   r!   r%   retrieve_latents   s   

rC   c                4       s  e Zd ZdZg ZdZddgZdedede	de
d	ed
edef fddZedefddZedejdedededejf
ddZedejdededejfddZedeejj eeejj  B dedeeejj  fdd Z	!	!	!dZd"eee B d#ed$edeeejj  d!B d%ejd!B d&ejd!B fd'd(Zd)d* Z	!	+	!	!d[d"eee B d,ed%ejd!B d-ej d!B fd.d/Z!	0	1	!	!	!	!	+d\d"eee B d2ed3edejd!B d4ejd!B d%ejd!B d-ej d!B d,efd5d6Z"d]d7d8Z#	!	!	!	!	!	!d^d9d:Z$e%d;d< Z&e%d=d> Z'e%d?d@ Z(e%dAdB Z)e%dCdD Z*e%dEdF Z+e, e-e.d!d!d!d!dGd!d!dHd1d!d!d!d!d!d!d!dIdJd0d!d!dgd+fd"eee B d!B dejejjB e/j0B eej B eejj B ee/j0 B d!B d#ed!B d$ed!B dKedLee d!B dMee1 d!B dNe1d3ed&ejeej B d!B dejd!B dejd!B d4ejd!B dOejd!B dPeej d!B dQeej d!B dRe2eef dSedTedUe3ee4f d!B dVe5eee3gd!f e6B e7B d!B dWee d,ede8e2B f0dXdYZ9  Z:S )_GlmImagePipelineaT  
    Pipeline for text-to-image generation using GLM-Image.

    This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
    transformer) model for image decoding.

    Args:
        tokenizer (`PreTrainedTokenizer`):
            Tokenizer for the text encoder.
        processor (`AutoProcessor`):
            Processor for the AR model to handle chat templates and tokenization.
        text_encoder ([`T5EncoderModel`]):
            Frozen text-encoder for glyph embeddings.
        vision_language_encoder ([`GlmImageForConditionalGeneration`]):
            The AR model that generates image tokens from text prompts.
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        transformer ([`GlmImageTransformer2DModel`]):
            A text conditioned transformer to denoise the encoded image latents (DiT).
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
    z7vision_language_encoder->text_encoder->transformer->vaer?   prompt_embeds	tokenizer	processortext_encodervision_language_encodervaetransformerr4   c              	      s   t    | j|||||||d t| dd r"dt| jjjd  nd| _t	| jd| _
t| drE| jd urEt| jjdrE| jjj| _d S d	| _d S )
N)rF   rG   rH   rI   rJ   rK   r4   rJ      r      )vae_scale_factorrK   sample_size   )super__init__register_modulesgetattrr3   rJ   configblock_out_channelsrN   r   image_processorr@   rK   rO   default_sample_size)selfrF   rG   rH   rI   rJ   rK   r4   r2   r!   r%   rR      s*   

(	
zGlmImagePipeline.__init__is_text_to_imagec                 C   s   g }g }t | jd D ]}| |  \}}}|t||  |t|t|f q|s<|d d }d}	|d \}
}nt|}|d }t|dd  }	|d \}
}||	|
|fS )Nr   r   )rangeshapetolistappendintsum)image_grid_thwr[   
grid_sizesgrid_hwithwmax_new_tokenslarge_image_start_offsettarget_grid_htarget_grid_wtotal_tokensr!   r!   r%   _compute_generation_params   s   z+GlmImagePipeline._compute_generation_paramsoutputsinput_lengthrk   large_image_tokensr    c                 C   s(   | d |d  }|}|| }||| S )Nr   r!   )rp   rq   rk   rr   generated_tokenslarge_image_startlarge_image_endr!   r!   r%   _extract_large_image_tokens   s   z,GlmImagePipeline._extract_large_image_tokens	token_idstoken_htoken_wc                 C   sB   |  dd||} tjjj|  dddjtjd} |  dd} | S )Nr   rL   nearest)scale_factorrA   dtyper\   )viewtorchnn
functionalinterpolatefloattolong)rw   rx   ry   r!   r!   r%   _upsample_token_ids   s   z$GlmImagePipeline._upsample_token_idsimage
batch_sizec              	   C   s  | du s
t | dkrdS | d }|dkr7t|ttfs t| gS t | dkr0tdt |  dt| d gS t|ttfsItdt|j dt | |kr\tdt |  d	| d
t | d }t| D ]\}}t ||krtd| d| dt | dqfdd | D S )a:  
        Validate and normalize image inputs to List[List[PIL.Image]].

        Rules:
        - batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length
        - batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]])
        - Other formats raise ValueError

        Args:
            image: Input images in various formats
            batch_size: Number of prompts in the batch

        Returns:
            Normalized images as List[List[PIL.Image]], or None if no images provided
        Nr   r   zOFor batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got .zJFor batch_size > 1, images must be List[List[PIL.Image]] format. Got List[zA] instead. Each prompt requires its own list of condition images.zNumber of image lists (z) must match batch size (z).zHAll prompts must have the same number of condition images. Prompt 0 has z images, but prompt z has z images.c                 S   s   g | ]}t |qS r!   )list).0imgsr!   r!   r%   
<listcomp>?  s    zCGlmImagePipeline._validate_and_normalize_images.<locals>.<listcomp>)r3   
isinstancer   tupler1   type__name__	enumerate)r   r   first_elementnum_input_images_per_promptidxr   r!   r!   r%   _validate_and_normalize_images  sB   
z/GlmImagePipeline._validate_and_normalize_imagesNpromptheightwidthr(   r;   c           7   	   C   s  |p| j }t|tr|gn|}t|}|du }	g }
t|D ](\}}g }|	s4|| D ]
}|d|d q)|d|d |
d|dg q| jj|
d|d	krQdnd
||ddd|}|	d}|	d}|	rjdnt|d }|dur{|d 
 }n|jd }|d| }| j||	d\}}}}d}d}|	s4g }t|D ]}|| }|t|||  q|| }t|dkr4| j|d |j}tj|dd}| j||} |jdd }!t| |!}"g }#t|"D ] \}$}%||$  \}&}'}(| |%t|'t|(})|#|)d qtj|#dd}| }*|*ddd	f d |*ddd	f< |*dddf d |*dddf< |*}|durS| }+t|+ |durS|jdkrStj|+ | jj di ||dd},g }-|d jd }.t|D ]}| !|,||d	  |.||| }/| |/||}0|-|0 qntj|-dd}0d}1d}2|dur|durt"t||}2|jdd }3g }4t|D ]}$|$| }5|5| }6|4t#|3|5|6  qt"t||4}1|0|1|2fS )a  
        Generate prior tokens for the DiT model using the AR model.

        Args:
            prompt: Single prompt or list of prompts
            height: Target image height
            width: Target image width
            image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated
                   using _validate_and_normalize_images() before calling this method.
            device: Target device
            generator: Random generator for reproducibility

        Returns:
            Tuple of:
                - prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens
                - prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains
                    the upsampled prior token ids for all condition images in that sample. None for t2i.
                - source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape
                    (num_condition_images, 3) with upsampled grid info. None for t2i.
        Nr   )r   r   text)r   r   user)rolecontentTr   Fpt)tokenizepaddingtarget_htarget_wreturn_dictreturn_tensorsrc   images_per_sampler   )rc   r[   pixel_valuesdimr\   rL   cuda)rj   	do_sample	input_idsr!   )$_execution_devicer   strr3   r   r`   rG   apply_chat_templater   getitemr^   ro   r]   extendrI   get_image_featurespooler_outputr   catget_image_tokensprodr_   splitr   ra   squeezecloneinitial_seedmanual_seedr   r   generaterv   r   rb   )7rY   r   r   r   r   r(   r;   prompt_listr   r[   all_messagesr   pr   imginputsrc   r   num_condition_imagesnum_grids_per_samplefirst_sample_gridsrj   large_image_offsetrx   ry   prior_token_image_idssource_image_grid_thwsource_indices
sample_idxbasesource_gridsprior_token_image_embedprior_token_image_ids_d32split_sizesprior_ids_per_sourceupsampled_prior_idsrf   	prior_idsrg   rh   ri   	upsampledupsampled_gridsseedrp   all_prior_token_idsmax_input_lengthprior_token_ids_d32prior_token_ids prior_token_image_ids_per_sample source_image_grid_thw_per_sampletokens_per_imagetokens_per_sample	start_idxend_idxr!   r!   r%   generate_prior_tokensA  s   




  




z&GlmImagePipeline.generate_prior_tokensc                 C   s\   t |tr|g}g }|D ]}td|td| td| td| }|| q|S )zQExtract glyph texts from prompt(s). Returns a list of lists for batch processing.z	'([^']*)'z\u201c([^\u201c\u201d]*)\u201dz	"([^"]*)"u   「([^「」]*)」)r   r   refindallr`   )rY   r   all_ocr_textsr   	ocr_textsr!   r!   r%   get_glyph_texts  s   




z GlmImagePipeline.get_glyph_texts   max_sequence_lengthr}   c                    sp  |pj }|p
jj}|}g }|D ]Z}t|dkrdg}j||ddj  fdd D  tdd  D tj	fd	d D |d
}tj	fdd D |d
 j |d}	|	j
|  d}
||
 qtdd |D }g }|D ]-}|d|k rtj|d||d |d||jd}tj||gdd}|| q|tj|dd}
|
j||dS )z2Get glyph embeddings for each prompt in the batch.r    T)
max_length
truncationc                    s*   g | ]}j jgt d  d  | qS r   rL   rF   pad_token_idr3   r   
input_ids_)r   rY   r!   r%   r     s    z6GlmImagePipeline._get_glyph_embeds.<locals>.<listcomp>c                 s   s    | ]}t |V  qd S Nr3   r   r!   r!   r%   	<genexpr>      z5GlmImagePipeline._get_glyph_embeds.<locals>.<genexpr>c                    s,   g | ]}d gt | dg t |   qS )r   r   r   r   )r   r!   r%   r     s   , r(   c                    s&   g | ]}|j jg t|   qS r!   r   r   )r   rY   r!   r%   r     s    )attention_maskc                 s   s    | ]}| d V  qdS )r   N)size)r   embr!   r!   r%   r     s    r   rL   r(   r}   r   )r   rH   r}   r   r3   rF   r   maxr   tensorlast_hidden_statebool	unsqueezer`   r   zerosr   r   )rY   r   r   r(   r}   all_glyph_textsall_glyph_embedsglyph_textsr   rp   glyph_embedsmax_seq_lenpadded_embedsr   padr!   )r   r   rY   r%   _get_glyph_embeds  sP   

,z"GlmImagePipeline._get_glyph_embedsTr   do_classifier_free_guidancenum_images_per_promptnegative_prompt_embedsc	                 C   s   |p| j }t|tr|gn|}|durt|}	n|jd }	|du r)| ||||}|dkr4|j|dd}|r[|du r[d}
t|
trF|	|
g n|
}
| |
|||}|dkr[|j|dd}||fS )aX  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                Whether to use classifier free guidance or not.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                Number of images that should be generated per prompt. torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            device: (`torch.device`, *optional*):
                torch device
            dtype: (`torch.dtype`, *optional*):
                torch dtype
            max_sequence_length (`int`, defaults to `2048`):
                Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
        Nr   r   r   r   )r   r   r   r3   r^   r  repeat_interleave)rY   r   r  r  rE   r  r(   r}   r   r   negative_promptr!   r!   r%   encode_prompt!  s    


zGlmImagePipeline.encode_promptc	           
      C   sv   |d ur	| |S ||t|| j t|| j f}	t|tr1t||kr1tdt| d| dt|	|||d}|S )Nz/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)r;   r(   r}   )r   ra   rN   r   r   r3   r1   r   )
rY   r   num_channels_latentsr   r   r}   r(   r;   r?   r^   r!   r!   r%   prepare_latentsZ  s   
z GlmImagePipeline.prepare_latentsc                    s  |d ur| j  jjj d  dks!|d ur2| jjjd  dkr2td j d  d| d| d|d urSt fdd	|D sStd
 j d fdd|D  |d urf|d urftd| d| d|d u rr|d u rrtd|d urt|tst|t	stdt
| |d u r|d u rtd|d ur|d ur|j|jkrtd|j d|j d||	g}tdd	 |D }|dkr|t|k rtd|d u d|	d u d|dkr|d u rtd|dkr|
d u rtd|d ur|d u r|d u rtdd S d S d S )NrL   r   z-`height` and `width` have to be divisible by    z	 but are z and r   c                 3   s    | ]}| j v V  qd S r   _callback_tensor_inputsr   krY   r!   r%   r     s    

z0GlmImagePipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r!   r  r  r  r!   r%   r     s    z1GlmImagePipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is ziProvide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` c                 s   s    | ]}|d uV  qd S r   r!   )r   xr!   r!   r%   r     r   zv`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. Got prior_token_image_ids=z, source_image_grid_thw=zi`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided.z`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided for i2i mode, as the images are needed for VAE encoding to build the KV cache.zI`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.)rN   rK   rU   
patch_sizer1   allr  r   r   r   r   r^   rb   r3   )rY   r   r   r   "callback_on_step_end_tensor_inputsrE   r  r   r   r   r   prior_image_inputsnum_prior_image_inputsr!   r  r%   check_inputsl  sr   zGlmImagePipeline.check_inputsc                 C      | j S r   _guidance_scaler  r!   r!   r%   guidance_scale     zGlmImagePipeline.guidance_scalec                 C   s
   | j dkS )Nr   r  r  r!   r!   r%   r    s   
z,GlmImagePipeline.do_classifier_free_guidancec                 C   r  r   )_num_timestepsr  r!   r!   r%   num_timesteps  r  zGlmImagePipeline.num_timestepsc                 C   r  r   )_attention_kwargsr  r!   r!   r%   attention_kwargs  r  z!GlmImagePipeline.attention_kwargsc                 C   r  r   )_current_timestepr  r!   r!   r%   current_timestep  r  z!GlmImagePipeline.current_timestepc                 C   r  r   )
_interruptr  r!   r!   r%   	interrupt  r  zGlmImagePipeline.interrupt2   g      ?)r   r   pilr'   r)   r*   r  r   r   r   crops_coords_top_leftoutput_typer   r#  callback_on_step_endr  c           D      C   s  t |ttfr
|j}| ||||||||||
 || _|| _d| _d| _|dur0t |t	r0d}n|dur>t |t
r>t|}n|jd }| j}| ||}t |
t
rU|
d n|
}|du rj| j||||||d\}}}n|}|}d}|durg }|D ]P}g } |D ]D}!t |!tjjr|!jddd n|!jdd \}"}#| j| jjj }$|"|$ |$ }"|#|$ |$ }#| jj|!|"|#d}!| |! |p|"}|p|#}q~||  qx| j|| j|	||||| jd	\}}| jjj}%| j||	 |%|||j||
|d
}t| jjj d}&|dur|&!d t"#| j$jj%&d| j$jj'dd}'t"#| j$jj(&d| j$jj'dd}(|'j)||jd}'|(j)||jd}(t*|D ]{})||) }||) }*||) }+|+j+dd, },t"-|*|,}-t.||-D ]T\}.}/|.j)||jd}.t/| j$0|.|
dd}0|0|' |( }0| j|0t"1|dddddf |/t"j2|/dt"j3dt"j4d|dt"j#|.jdd g|dt"j4d|d||&d	}1qX|&5  q7||f}2t"j#|2g|j|d}2t"j#|g|j|d}|26||	 d}2|6||	 d}|| j || j  | jjjd  }3|du rt78| j9jj:d|d dd nt7;|}|<t7j=<t7j>}|du r|| j9jj: n|}t?|3| j9j@dd| j9j@dd| j9j@dd}4tA| j9|||||4d \}}t|| _B| jj}5tCt||| j9jD  d}6|	dkrf|jE|	dd}t"j2|dt"j3d}7t"j2|d!t"j3d}8| jF|d"}9tG|D ]\}:};| jHrq|;| _|)|5}<|;I|jd d }=|dur|&!d# | j|<|||7|=|2||d|&d$
d J }>| jr|dur|&!d% | j|<|||8|=|2||d|&d$
d J }?|?| jK|>|?   }@n|>}@| j9jL|@|;|dd&d }|dur i }A|D ]
}BtM |B |A|B< q|| |:| j9jN|: |A}C|COd'|}|COd(|}|:t|d ks;|:d |6kr?|:d | j9jD dkr?|9P  tQrFtRS  qW d   n	1 sSw   Y  d| _|&T  |d)ks|)| j$j}t"#| j$jj%&d| j$jj'dd)|jU|j}'t"#| j$jj(&d| j$jj'dd)|jU|j}(||( |' }| j$jV|d|
d*d }| jjW||d+}n|}| X  |s|fS tY|d,S )-a  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
                W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
                generates a 1152x768 image.
            image: Optional condition images for image-to-image generation.
            height (`int`, *optional*):
                The height in pixels. If not provided, derived from prompt shape info.
            width (`int`, *optional*):
                The width in pixels. If not provided, derived from prompt shape info.
            num_inference_steps (`int`, *optional*, defaults to `50`):
                The number of denoising steps for DiT.
            guidance_scale (`float`, *optional*, defaults to `1.5`):
                Guidance scale for classifier-free guidance.
            num_images_per_prompt (`int`, *optional*, defaults to `1`):
                The number of images to generate per prompt.
            generator (`torch.Generator`, *optional*):
                Random generator for reproducibility.
            output_type (`str`, *optional*, defaults to `"pil"`):
                Output format: "pil", "np", or "latent".

        Examples:

        Returns:
            [`GlmImagePipelineOutput`] or `tuple`: Generated images.
        NFr   r   )r   r   r   r   r(   r;   r\   rL   )r   r   )r  rE   r  r   r(   r}   )r   r  r   r   r}   r(   r;   r?   )
num_layerswriter   r   r>   )r;   r<   .r|   )r   r   r   )	hidden_statesencoder_hidden_statesprior_token_idprior_token_droptimesteptarget_sizecrop_coordsr#  	kv_caches)r}   r(   g      ?base_image_seq_lenr   r   r   r   r   )r$   T)totalread)
r0  r1  r2  r3  r4  r5  r6  r#  r   r7  skip)r   r?   rE   latent)r   r;   )r+  )images)Zr   r
   r	   tensor_inputsr  r  r"  r$  r&  r   r   r3   r^   r   r   r   PILImager   rN   rK   rU   r  rW   
preprocessr`   r  r  r}   in_channelsr  r   r-  set_moder   r   rJ   latents_meanr~   latent_channelslatents_stdr   r]   r   r_   r   ziprC   encode
zeros_like	full_liker   r   next_samplerepeatnplinspacer4   num_train_timestepsarrayastypeint64float32r&   r   r8   r   r   orderr	  progress_barr   r'  expandr   r  steplocalsr*   popupdateXLA_AVAILABLExm	mark_stepclearr(   decodepostprocessmaybe_free_model_hooksr   )DrY   r   r   r   r   r'   r)   r*   r  r  r;   r?   rE   r  r   r   r   r*  r+  r   r#  r,  r  r   r   r(   normalized_imagear_generatorr   r   preprocessed_imagesprompt_imagesprompt_preprocessedr   image_heightimage_widthmultiple_ofrE  r7  rD  rF  
prompt_idxprompt_prior_idsprompt_grid_thwr   prior_ids_per_imagecondition_imagecondition_image_prior_token_idcondition_latent_r5  r"   r$   transformer_dtypenum_warmup_stepsprior_token_drop_condprior_token_drop_uncondrU  rf   rg   latent_model_inputr4  noise_pred_condnoise_pred_uncond
noise_predcallback_kwargsr  callback_outputsr!   r!   r%   __call__  s  C


0






""

 








6?

zGlmImagePipeline.__call__)NNN)Nr   NN)Tr   NNNNr   r   )NNNNNN);r   
__module____qualname____doc___optional_componentsmodel_cpu_offload_seqr  r   r   r   r   r   r   r   rR   staticmethodr   ro   r   Tensorra   rv   r   r   r?  r@  r   r   r(   	Generatorr   r   r}   r  r  r  r  propertyr  r  r!  r#  r%  r'  no_gradr   EXAMPLE_DOC_STRINGrM  ndarrayr   r   dictr   r   r
   r	   r   r|  __classcell__r!   r!   rZ   r%   rD      s     >

 

9
	

9
K










 !"rD   )r   r   r   )NNNN)Nr9   );r,   r   typingr   r   numpyrM  r?  r   transformersr   r   r   r   	callbacksr	   r
   rW   r   modelsr   r   )models.transformers.transformer_glm_imager   pipelines.pipeline_utilsr   
schedulersr   utilsr   r   r   r   utils.torch_utilsr   pipeline_outputr   r   r   torch_xla.core.xla_modelcore	xla_modelr\  r[  
get_loggerr   loggerr  ra   r   r&   r   r(   r   r8   r  r  rC   rD   r!   r!   r!   r%   <module>   sx   





E
