o
    Gi                     @   s  d dl mZmZ d dlZd dlZd dlmZ d dlm	Z	 ddl
mZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZmZmZm Z m!Z!m"Z" ddl#m$Z$ e rud dl%m&  m'Z( dZ)ndZ)e*e+Z,dZ-G dd deeZ.dS )    )AnyCallableN)AutoTokenizer)SmolLM3ForCausalLM   )VaeImageProcessor)FluxLoraLoaderMixin)AutoencoderKLWan)BriaFiboTransformer2DModel)BriaFiboPipelineOutput)calculate_shiftretrieve_timesteps)DiffusionPipeline)FlowMatchEulerDiscreteSchedulerKarrasDiffusionSchedulers)USE_PEFT_BACKENDis_torch_xla_availableloggingreplace_example_docstringscale_lora_layersunscale_lora_layers)randn_tensorTFa  
    Example:
    ```python
    import torch
    from diffusers import BriaFiboPipeline
    from diffusers.modular_pipelines import ModularPipeline

    torch.set_grad_enabled(False)
    vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)

    pipe = BriaFiboPipeline.from_pretrained(
        "briaai/FIBO",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    pipe.enable_model_cpu_offload()

    with torch.inference_mode():
        # 1. Create a prompt to generate an initial image
        output = vlm_pipe(prompt="a beautiful dog")
        json_prompt_generate = output.values["json_prompt"]

        # Generate the image from the structured json prompt
        results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
        results_generate.images[0].save("image_generate.png")
    ```
c                '   @   s  e Zd ZdZdZddgZdedeeB de	de
d	ef
d
dZ				dHdeee B dededejdB dejdB f
ddZedIddZ								dJdeee B dejdB dededeee B dB dejdB dejdB dededB fddZed d! Zed"d# Zed$d% Zed&d' Zed(d) Zed*d+ Zed,d- Z ed.d/ Z!ed0d1 Z"		2dKd3d4Z#ed5d6 Z$e% e&e'dddd7ddddddddd8d9dddgdd2fdeee B d:edB d;edB d<ed=ee dedeee B dB dedB d>ej(eej( B dB dejdB dejdB dejdB d?edB d@e)dAe*ee+f dB dBe,eegdf dB dCee def$dDdEZ-					dLdFdGZ.dS )MBriaFiboPipelinea  
    Args:
        transformer (`BriaFiboTransformer2DModel`):
            The transformer model for 2D diffusion modeling.
        scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
            Scheduler to be used with `transformer` to denoise the encoded latents.
        vae (`AutoencoderKLWan`):
            Variational Auto-Encoder for encoding and decoding images to and from latent representations.
        text_encoder (`SmolLM3ForCausalLM`):
            Text encoder for processing input prompts.
        tokenizer (`AutoTokenizer`):
            Tokenizer used for processing the input text prompts for the text_encoder.
    z=text_encoder->text_encoder_2->image_encoder->transformer->vaelatentsprompt_embedstransformer	schedulervaetext_encoder	tokenizerc                 C   s6   | j |||||d d| _t| jd d| _d| _d S )N)r   r   r   r   r         )vae_scale_factor@   )register_modulesr"   r   image_processordefault_sample_size)selfr   r   r   r   r    r(   d/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py__init__^   s   
zBriaFiboPipeline.__init__      Npromptnum_images_per_promptmax_sequence_lengthdevicedtypec                    s   p| j  |p
| jj}t|tr|gn|}|stdt|}d} d ur' ntd}t|tjs7t|}t	dd |D rRtj
|df|tj|d}	t|	}
n6| j|d|d	d	d
d}|j|}	|j|}
tdd |D rtjdd |D tj|d}||	|< d|
|< | j|	|
d	d}|j}tj|d |d gdd}|j |d}|jdd}t fdd|D }|
jddj d}
|||
fS )Nz7`prompt` must be a non-empty string or list of strings.i  cpuc                 s       | ]}|d kV  qdS  Nr(   .0pr(   r(   r)   	<genexpr>       z5BriaFiboPipeline.get_prompt_embeds.<locals>.<genexpr>r+   r1   r0   longestTpt)padding
max_length
truncationadd_special_tokensreturn_tensorsc                 s   r3   r4   r(   r6   r(   r(   r)   r9      r:   c                 S   s   g | ]}|d kqS )r5   r(   r6   r(   r(   r)   
<listcomp>   s    z6BriaFiboPipeline.get_prompt_embeds.<locals>.<listcomp>)attention_maskoutput_hidden_statesdimr0   r1   r   c                 3   s&    | ]}|j d dj dV  qdS )r   rH   r0   N)repeat_interleavetor7   layerr0   r.   r(   r)   r9      s    
rK   )_execution_devicer   r1   
isinstancestr
ValueErrorlentorchr0   allfulllong	ones_liker   	input_idsrM   rD   anytensorboolhidden_statescatrL   tuple)r'   r-   r.   r/   r0   r1   
batch_sizebot_token_idtext_encoder_devicer[   rD   	tokenized
empty_rowsencoder_outputsr_   r   r(   rP   r)   get_prompt_embedsr   sR   


z"BriaFiboPipeline.get_prompt_embedsc           	      C   s   | j \}}}|d u rtj||f| j| jd}n	|j| j| jd}||k r(td||kr[|| }tj|||f| j| jd}tj| |gdd} tj||f| j| jd}tj||gdd}| |fS )Nr;   rJ   zE`max_tokens` must be greater or equal to the current sequence length.r+   rH   )	shaperV   onesr1   r0   rM   rT   zerosr`   )	r   
max_tokensrD   rb   seq_lenrI   
pad_lengthr>   mask_paddingr(   r(   r)   pad_embedding   s"   zBriaFiboPipeline.pad_embedding     guidance_scalenegative_promptnegative_prompt_embeds
lora_scalec
              
      s  |pj }|	durttr|	_jdurtrtj|	 t|tr&|gn|}|dur1t|}
n|j	d }
d}d}|du r[j
||||d\}}}|jjjd}fdd|D }|dkrt|trl|d du rld}|pod}t|trz|
|g n|}|durt|t|urtd	t| d
t| d|
t|krtd| dt| d| d|
 d	j
||||d\}}}|jjjd}fdd|D }jdurttrtrtj|	 |dur|j|j|jd}|dur2|dur|j|j|jd}t|j	d |j	d  j| |d\}} fdd|D }j| |d\}} fdd|D }n|j	d  j| |d\}}d}jj}t|j	d  dj||d}|||||||fS )a  
        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            guidance_scale (`float`):
                Guidance scale for classifier free guidance.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
        Nr   )r-   r.   r/   r0   r1   c                       g | ]
}|j  jjd qS rw   rM   r   r1   r7   r]   r'   r(   r)   rC         z2BriaFiboPipeline.encode_prompt.<locals>.<listcomp>r+   r5   z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.c                    rx   ry   rz   r{   r|   r(   r)   rC      r}   rJ   )rD   c                       g | ]
} | d  qS r   rp   rN   rl   r'   r(   r)   rC   5  r}   c                    r   r   r   rN   r   r(   r)   rC   :  r}   r   )rQ   rR   r   _lora_scaler   r   r   rS   rU   ri   rh   rM   r   r1   listtype	TypeErrorrT   r   r0   maxrp   rV   rk   )r'   r-   r0   r.   rs   rt   r   ru   r/   rv   rb   prompt_attention_masknegative_prompt_attention_maskprompt_layersnegative_prompt_layersr1   text_idsr(   r   r)   encode_prompt   s   
"







zBriaFiboPipeline.encode_promptc                 C      | j S N)_guidance_scaler|   r(   r(   r)   rs   O     zBriaFiboPipeline.guidance_scalec                 C   r   r   )_joint_attention_kwargsr|   r(   r(   r)   joint_attention_kwargsW  r   z'BriaFiboPipeline.joint_attention_kwargsc                 C   r   r   )_num_timestepsr|   r(   r(   r)   num_timesteps[  r   zBriaFiboPipeline.num_timestepsc                 C   r   r   )
_interruptr|   r(   r(   r)   	interrupt_  r   zBriaFiboPipeline.interruptc                 C   sh   | j \}}}|| }|| }| ||d |d |d dd} | dddddd} | ||d ||} | S )Nr!      r   r   r+   rq   )ri   viewpermutereshaper   heightwidthr"   rb   num_patcheschannelsr(   r(   r)   _unpack_latentsc  s    z BriaFiboPipeline._unpack_latentsc           	      C   s|   t ||d}|d t |d d d f  |d< |d t |d d d f  |d< |j\}}}||| |}|j||dS )Nr   ).r+   ).r!   rJ   )rV   rk   arangeri   r   rM   )	rb   r   r   r0   r1   latent_image_idslatent_image_id_heightlatent_image_id_widthlatent_image_id_channelsr(   r(   r)   _prepare_latent_image_idsq  s   ""z*BriaFiboPipeline._prepare_latent_image_idsc                 C   s@   | j \}}}|| }|| }| ||||} | dddd} | S )Nr   r   r+   r!   )ri   r   r   r   r(   r(   r)   _unpack_latents_no_patch  s   z)BriaFiboPipeline._unpack_latents_no_patchc                 C   s&   |  dddd} | ||| |} | S )Nr   r!   r   r+   )r   r   r   rb   num_channels_latentsr   r   r(   r(   r)   _pack_latents_no_patch  s   z'BriaFiboPipeline._pack_latents_no_patchc                 C   sR   |  |||d d|d d} | dddddd} | ||d |d  |d } | S )Nr!   r   r   r+   r   rq   )r   r   r   r   r(   r(   r)   _pack_latents  s   zBriaFiboPipeline._pack_latentsFc
                 C   s   t || j }t || j }||||f}
|d ur*| |||||}|j||d|fS t|trBt||krBtdt| d| dt|
|||d}|	rf| 	|||||}| ||d |d ||}||fS | 
|||||}| |||||}||fS )NrJ   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)	generatorr0   r1   r!   )intr"   r   rM   rR   r   rU   rT   r   r   r   )r'   rb   r   r   r   r1   r0   r   r   do_patchingri   r   r(   r(   r)   prepare_latents  s(   z BriaFiboPipeline.prepare_latentsc                 C   s(   t d| | }t |dkdt j }|S )Nz
bi,bj->bijr+   g        )rV   einsumwhereinf)rD   attention_matrixr(   r(   r)   _prepare_attention_mask  s
   z(BriaFiboPipeline._prepare_attention_mask   pilTr   r   num_inference_steps	timestepsr   output_typereturn_dictr   callback_on_step_end"callback_on_step_end_tensor_inputsc           5         s@  |p| j | j }|p| j | j }| j||||||d || _|| _d| _|dur0t|tr0d}n|dur>t|tr>t	|}n|j
d }| j}| jdurR| jddnd}| j|||||||||d	\}}}}}|j
d }|dkrtj||gdd}fd	d
tt	D tj||gdd}t	| jjt	| jj }t	|krt	| d nd g|t	   | jjj}|rt|d }| |||||j||	|
|	\}
}tj|
j
d |
j
d g|
j|
jd}|dkr|dd}tj||gdd}| |}|jddj| jjd}| jdu ri | _|| jd< |r/|| jd  || jd   } n
|| j || j  } t !dd| |}!t"| | j#jj$| j#jj%| j#jj&| j#jj'}"t(| j#||d|!|"d\}}t)t	||| j#j*  d}#t	|| _+t	|j
dkr|d }t	|j
dkr|d }| j,|d}$t-|D ]\}%}&| j.rq|dkrt|
gd n|
}'|&/|'j
d j|'j|'jd}(| j|'|(|| jd||dd })|dkr|)0d\}*}+|*| j1|+|*   })|
j},| j#j2|)|&|
ddd }
|
j|,krtj3j45 r|
|,}
|dur3i }-|D ]
}.t6 |. |-|.< q|| |%|&|-}/|/7d|
}
|/7d|}|/7d|}|%t	|d ksN|%d |#krR|%d | j#j* dkrR|$8  t9rYt:;  qW d   n	1 sfw   Y  |dkrs|
}0n|r| <|
||| j}
n	| =|
||| j}
|
jdd}
|
d j}1|
d j},t>| j?jj@Ad| j?jjBddd|1|, dt>| j?jjCAd| j?jjBddd|1|,  fdd
|
D }2tj|2dd}2g }0|2D ]!}3| j?jD|3dddd }4| jEjF|4jGdd|d}4|0H|4 qt	|0dkr
|0d }0nt jI|0dd}0| J  |s|0fS tK|0dS ) aH  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image. This is set to 1024 by default for the best results.
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image. This is set to 1024 by default for the best results.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            guidance_scale (`float`, *optional*, defaults to 5.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
                of a plain tuple.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
            do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
        Examples:
          Returns:
            [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
            generated images.
        )r-   r   r   r   r   r/   FNr+   r   scale)	r-   rt   rs   r   ru   r0   r/   r.   rv   rH   c                    s&   g | ]}t j | | gd dqS )r   rH   )rV   r`   )r7   i)r   r   r(   r)   rC   `  s    z-BriaFiboPipeline.__call__.<locals>.<listcomp>rF   r   r;   r!   rw   rD   g      ?)r   r0   r   sigmasmur   )totalrJ   )r_   timestepencoder_hidden_statestext_encoder_layersr   r   txt_idsimg_ids)r   r   r   ru   latentc                    s   g | ]}|   qS r(   r(   )r7   r   )latents_meanlatents_stdr(   r)   rC     s    )r   )axis)images)Lr&   r"   check_inputsr   r   r   rR   rS   r   rU   ri   rQ   r   getr   rV   r`   ranger   transformer_blockssingle_transformer_blocksconfigin_channelsr   r   r1   rj   r0   repeatr   	unsqueezerM   nplinspacer   r   base_image_seq_lenmax_image_seq_len
base_shift	max_shiftr   r   orderr   progress_bar	enumerater   expandchunkrs   stepbackendsmpsis_availablelocalspopupdateXLA_AVAILABLExm	mark_stepr   r   r]   r   r   r   z_dimr   decoder%   postprocesssqueezeappendstackmaybe_free_model_hooksr   )5r'   r-   r   r   r   r   rs   rt   r.   r   r   r   ru   r   r   r   r   r   r/   r   rb   r0   rv   r   r   r   prompt_batch_sizetotal_num_layers_transformerr   r   latent_attention_maskrD   rm   r   r   num_warmup_stepsr   r   tlatent_model_inputr   
noise_prednoise_pred_uncondnoise_pred_textlatents_dtypecallback_kwargskcallback_outputsimagelatents_devicelatents_scaledscaled_latent
curr_imager(   )r   r   r   r   r)   __call__  s^  `	








	
	


6
8

&

zBriaFiboPipeline.__call__c	           	         s^  |d dks|d dkrt d| d| d|d ur8t fdd|D s8t d j d	 fd
d|D  |d urK|d urKt d| d| d|d u rW|d u rWt d|d urnt|tsnt|tsnt dt| |d ur|d urt d| d| d|d ur|d ur|j|jkrt d|j d|j d|d ur|dkrt d| d S d S )Nr    r   z8`height` and `width` have to be divisible by 16 but are z and r~   c                 3   s    | ]}| j v V  qd S r   _callback_tensor_inputsr7   r   r|   r(   r)   r9   $  s    

z0BriaFiboPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r(   r  r  r|   r(   r)   rC   (  s    z1BriaFiboPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z'Cannot forward both `negative_prompt`: z and `negative_prompt_embeds`: zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` rr   z9`max_sequence_length` cannot be greater than 3000 but is )rT   rW   r  rR   rS   r   r   ri   )	r'   r-   r   r   rt   r   ru   r   r/   r(   r|   r)   r     sH   zBriaFiboPipeline.check_inputs)r+   r,   NNr   )Nr+   rq   NNNrr   N)NF)NNNNN)/__name__
__module____qualname____doc__model_cpu_offload_seqr  r
   r   r   r	   r   r   r*   rS   r   r   rV   r0   r1   rh   staticmethodrp   floatFloatTensorr   propertyrs   r   r   r   r   r   r   r   r   r   r   no_gradr   EXAMPLE_DOC_STRING	Generatorr^   dictr   r   r  r   r(   r(   r(   r)   r   L   s*   


<
	

 









%
	
	
  Qr   )/typingr   r   numpyr   rV   transformersr   ,transformers.models.smollm3.modeling_smollm3r   r%   r   loadersr   &models.autoencoders.autoencoder_kl_wanr	   )models.transformers.transformer_bria_fibor
   #pipelines.bria_fibo.pipeline_outputr   pipelines.flux.pipeline_fluxr   r   pipelines.pipeline_utilsr   
schedulersr   r   utilsr   r   r   r   r   r   utils.torch_utilsr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr  loggerr  r   r(   r(   r(   r)   <module>   s,   
 
