o
    Gi%                     @   s  d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m
Z
 d dlmZmZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZmZmZmZ ddlmZ ddlmZ ddlm Z  e rkd dl!m"  m#Z$ dZ%ndZ%e&e'Z(dZ)eG dd deZ*G dd deZ+dS )    )	dataclass)AnyN)Image)tqdm)CLIPTextModelCLIPTokenizer   )PipelineImageInput)AutoencoderKLUNet2DConditionModel)DDIMSchedulerLCMScheduler)
BaseOutputis_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipeline   )MarigoldImageProcessorTFa  
Examples:
```py
>>> import diffusers
>>> import torch

>>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
...     "prs-eth/marigold-normals-v1-1", variant="fp16", torch_dtype=torch.float16
... ).to("cuda")

>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
>>> normals = pipe(image)

>>> vis = pipe.image_processor.visualize_normals(normals.prediction)
>>> vis[0].save("einstein_normals.png")
```
c                   @   sD   e Zd ZU dZejejB ed< dejB ejB ed< dejB ed< dS )MarigoldNormalsOutputug  
    Output class for Marigold monocular normals prediction pipeline.

    Args:
        prediction (`np.ndarray`, `torch.Tensor`):
            Predicted normals with values in the range [-1, 1]. The shape is `numimages × 3 × height × width` for
            `torch.Tensor` or `numimages × height × width × 3` for `np.ndarray`.
        uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
            Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
            height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
        latent (`None`, `torch.Tensor`):
            Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
            The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
    
predictionNuncertaintylatent)	__name__
__module____qualname____doc__npndarraytorchTensor__annotations__ r$   r$   j/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/marigold/pipeline_marigold_normals.pyr   M   s
   
 r   c                !       s,  e Zd ZdZdZdZ				d9dededee	B d	e
d
ededB dedB dedB dedB f fddZdedededededededeeef dB dejdB dejeej B dB dededefddZejjd:d d!Ze ee		"			#	#	"				$	%	%	d;dededB dededB d&ededededeeef dB dejeej B dB dejeej B dB deded'ed(efd)d*ZdejdejdB dejdB dededeejejf fd+d,Z d-ejdejfd.d/Z!e"d<d1ejd2e#dejfd3d4Z$e"	5d=d1ejded6edeejejdB f fd7d8Z%  Z&S )>MarigoldNormalsPipelinea0	  
    Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        unet (`UNet2DConditionModel`):
            Conditional U-Net to denoise the normals latent, conditioned on image latent.
        vae (`AutoencoderKL`):
            Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
            representations.
        scheduler (`DDIMScheduler` or `LCMScheduler`):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents.
        text_encoder (`CLIPTextModel`):
            Text-encoder, for empty text embedding.
        tokenizer (`CLIPTokenizer`):
            CLIP tokenizer.
        prediction_type (`str`, *optional*):
            Type of predictions made by the model.
        use_full_z_range (`bool`, *optional*):
            Whether the normals predicted by this model utilize the full range of the Z dimension, or only its positive
            half.
        default_denoising_steps (`int`, *optional*):
            The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
            quality with the given model. This value must be set in the model config. When the pipeline is called
            without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
            reasonable results with various model flavors compatible with the pipeline, such as those relying on very
            short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
        default_processing_resolution (`int`, *optional*):
            The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
            the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
            default value is used. This is required to ensure reasonable results with various model flavors trained
            with varying optimal processing resolution values.
    ztext_encoder->unet->vae)normalsNTunetvae	schedulertext_encoder	tokenizerprediction_typeuse_full_z_rangedefault_denoising_stepsdefault_processing_resolutionc
           
         s   t    || jvrtd| d| j d | j|||||d | j||||	d t| dd r;dt| j	j
jd  nd	| _|| _|| _|	| _d | _t| jd
| _d S )Nz*Potentially unsupported `prediction_type='z&'`; values supported by the pipeline: .)r(   r)   r*   r+   r,   )r-   r.   r/   r0   r)   r   r      )vae_scale_factor)super__init__supported_prediction_typesloggerwarningregister_modulesregister_to_configgetattrlenr)   configblock_out_channelsr3   r.   r/   r0   empty_text_embeddingr   image_processor)
selfr(   r)   r*   r+   r,   r-   r.   r/   r0   	__class__r$   r%   r5      s4   

(z MarigoldNormalsPipeline.__init__imagenum_inference_stepsensemble_sizeprocessing_resolutionresample_method_inputresample_method_output
batch_sizeensembling_kwargslatents	generatoroutput_typeoutput_uncertaintyreturnc              	      s  dt | jjjd  }|| jkrtd| j d| d|d u r$td|dk r,td|dk r4td|dkr=td	 |dkrG|rGtd
|d u rOtd|dk rWtd|| j dkrgtd| j d|dvrotd|dvrwtd|dk rtd|dvrtd|	d ur d urtd|d urt|t	stdd|v r|d dvrtdd}d\}}t|t
s|g}t|D ]{\}}t|tjst|r|jdvrtd| d|j d|jdd  \}}d}|jd kr|jd }nt|tjr|j\}}d}ntd!| d"t| d|d u r||}}n||f||fkr7td#| d$||f d%||f ||7 }q|	d urt|	sKtd&|	 d kr[td'|	j d|dkrt||}|| | }|| | }|dks{|dkrtd(| d)| d*||}}|| j d | j }|| j d | j }|| | jjj||f}|	j|krtd+|	j d,| d d urt t
rt  || krtd-t fd.d/ D std0|S t tjstd1t  d|S )2Nr   r   z/`vae_scale_factor` computed at initialization (z) differs from the actual one (z).zW`num_inference_steps` is not specified and could not be resolved from the model config.z'`num_inference_steps` must be positive.z!`ensemble_size` must be positive.zk`ensemble_size` == 2 results are similar to no ensembling (1); consider increasing the value to at least 3.zpComputing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` greater than 1.zY`processing_resolution` is not specified and could not be resolved from the model config.r   zx`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for downsampled processing.z.`processing_resolution` must be a multiple of r1   )nearestznearest-exactbilinearbicubicareazy`resample_method_input` takes string values compatible with PIL library: nearest, nearest-exact, bilinear, bicubic, area.zz`resample_method_output` takes string values compatible with PIL library: nearest, nearest-exact, bilinear, bicubic, area.z`batch_size` must be positive.)ptr   z*`output_type` must be one of `pt` or `np`.z2`latents` and `generator` cannot be used together.z)`ensembling_kwargs` must be a dictionary.	reductionclosestmeanzG`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.)NN)r   r      z`image[z(]` has unsupported dimensions or shape: rZ   zUnsupported `image[z	]` type: zInput `image[z]` has incompatible dimensions z with the previous images z!`latents` must be a torch.Tensor.z/`latents` has unsupported dimensions or shape: z*Extreme aspect ratio of the input image: [z x ]z`latents` has unexpected shape=z
 expected=z^The number of generators must match the total number of ensemble members for all input images.c                 3   s$    | ]}|j j d  j jkV  qdS )r   N)devicetype).0grM   r$   r%   	<genexpr>2  s   " z7MarigoldNormalsPipeline.check_inputs.<locals>.<genexpr>z;`generator` device placement is not consistent in the list.zUnsupported generator type: )r<   r)   r=   r>   r3   
ValueErrorr7   r8   
isinstancedictlist	enumerater   r    r!   	is_tensorndimshaper   sizer^   dimmaxlatent_channelsall	Generator)rA   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   actual_vae_scale_factor
num_imagesWHiimgH_iW_iN_imax_orignew_Hnew_Wwhshape_expectedr$   ra   r%   check_inputs   s   













z$MarigoldNormalsPipeline.check_inputsc                 C   s   t | ds	i | _nt| jtstdt| j dtdi | j}|d||d< |d||d< |d ur>t|fi |S |d urKtdd|i|S td)	N_progress_bar_configz=`self._progress_bar_config` should be of type `dict`, but is r1   descleavetotalz/Either `total` or `iterable` has to be defined.r$   )hasattrr   rd   re   rc   r^   getr   )rA   iterabler   r   r   progress_bar_configr$   r$   r%   progress_bar9  s   
z$MarigoldNormalsPipeline.progress_barr   rR   r   Fmatch_input_resolutionoutput_latentreturn_dictc           $         s0  j }j}|du rj}|du rj}|||||| |
||}jdu rEd}j|djjddd}|j	|}
|d _j|||||\}}}||
|| \}~jj	||d d	d	}g }jtd||  dd
dD ]\}|||   }||   }|jd }|d| }jj||d jjjdddD ]'} tj||gd	d}!j|!| |ddd }"jj|"| ||dj}trt  q|| q{tj|dd~~~~~~~!~"tj fddtdjd  D dd|sdj|d}#|d	krHj||gjd	d R  fddt|D t  \}#tjddrFtj|#dd}#nd}#|rljj!||dd"|#durlrljj!|#||dd}#|dkrj#|#durrj#|#}#$  |s|#fS t%|#dS )a^  
        Function invoked when calling the pipeline.

        Args:
            image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`),
                `list[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
                arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
                by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
                three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
                same width and height.
            num_inference_steps (`int`, *optional*, defaults to `None`):
                Number of denoising diffusion steps during inference. The default value `None` results in automatic
                selection.
            ensemble_size (`int`, defaults to `1`):
                Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
            processing_resolution (`int`, *optional*, defaults to `None`):
                Effective processing resolution. When set to `0`, matches the larger input image dimension. This
                produces crisper predictions, but may also lead to the overall loss of global context. The default
                value `None` resolves to the optimal value from the model config.
            match_input_resolution (`bool`, *optional*, defaults to `True`):
                When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
                side of the output will equal to `processing_resolution`.
            resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
                Resampling method used to resize input images to `processing_resolution`. The accepted values are:
                `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
            resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
                Resampling method used to resize output predictions to match the input resolution. The accepted values
                are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
            batch_size (`int`, *optional*, defaults to `1`):
                Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
            ensembling_kwargs (`dict`, *optional*, defaults to `None`)
                Extra dictionary with arguments for precise ensembling control. The following options are available:
                - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
                  every pixel location, can be either `"closest"` or `"mean"`.
            latents (`torch.Tensor`, *optional*, defaults to `None`):
                Latent noise tensors to replace the random initialization. These can be taken from the previous
                function call's output.
            generator (`torch.Generator`, or `list[torch.Generator]`, *optional*, defaults to `None`):
                Random number generator object to ensure reproducibility.
            output_type (`str`, *optional*, defaults to `"np"`):
                Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
                values are: `"np"` (numpy array) or `"pt"` (torch tensor).
            output_uncertainty (`bool`, *optional*, defaults to `False`):
                When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
                the `ensemble_size` argument is set to a value above 2.
            output_latent (`bool`, *optional*, defaults to `False`):
                When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
                within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
                `latents` argument.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.marigold.MarigoldNormalsOutput`] instead of a plain tuple.

        Examples:

        Returns:
            [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
                `tuple` is returned where the first element is the prediction, the second element is the uncertainty
                (or `None`), and the third is the latent (or `None`).
        N 
do_not_padTrU   )padding
max_length
truncationreturn_tensorsr   )r]   dtyper   zMarigold predictions...)r   r   )r]   FzDiffusion steps...rl   )encoder_hidden_statesr   ra   c                    s"   g | ]} ||   qS r$   )decode_predictionr_   ru   )rJ   pred_latentrA   r$   r%   
<listcomp>  s    z4MarigoldNormalsPipeline.__call__.<locals>.<listcomp>c                    s(   g | ]}j | fi  pi qS r$   )ensemble_normalsr   )rK   rO   r   rA   r$   r%   r   &      )is_aar   )r   r   r   )&_execution_devicer   r/   r0   r   r?   r,   model_max_length	input_idstor+   r@   
preprocessprepare_latentsrepeatr   rangerj   r*   set_timesteps	timestepsr!   catr(   stepprev_sampleXLA_AVAILABLExm	mark_stepappendunpad_imagereshapezipresize_antialiasnormalize_normalspt_to_numpymaybe_free_model_hooksr   )$rA   rD   rE   rF   rG   r   rH   rI   rJ   rK   rL   rM   rN   rO   r   r   r]   r   rr   prompttext_inputstext_input_idsr   original_resolutionimage_latentbatch_empty_text_embeddingpred_latentsru   batch_image_latentbatch_pred_latenteffective_batch_sizetexttbatch_latentnoiser   r$   )rJ   rK   rO   r   r   rA   r%   __call__L  s   R
	









z MarigoldNormalsPipeline.__call__c                    s~   dd t j fddtdjd  D dd}|jjj }|j|dd}|}|d u r;t|j||j	|j
d}||fS )Nc                 S   s,   t | dr
| j S t | dr| jS td)Nlatent_distrL   z3Could not access latents of provided encoder_output)r   r   moderL   AttributeError)encoder_outputr$   r$   r%   retrieve_latents[  s
   


zAMarigoldNormalsPipeline.prepare_latents.<locals>.retrieve_latentsc              	      s(   g | ]}j ||   qS r$   )r)   encoder   rJ   rD   r   rA   r$   r%   r   d  r   z;MarigoldNormalsPipeline.prepare_latents.<locals>.<listcomp>r   r   )rM   r]   r   )r!   r   r   rj   r)   r=   scaling_factorrepeat_interleaver   r]   r   )rA   rD   rL   rM   rF   rJ   r   r   r$   r   r%   r   S  s$   z'MarigoldNormalsPipeline.prepare_latentsr   c                 C   s   |  dks|jd | jjjkrtd| jjj d|j d| jj|| jjj ddd }t	|d	d
}| j
s\|d d dd d d d f  d9  < |d d dd d d d f  d7  < | |}|S )NrZ   r   z Expecting 4D tensor of shape [B,z,H,W]; got r1   F)r   r   g      g      ?r   g      ?)rl   rj   r)   r=   rn   rc   decoder   r!   clipr.   r   )rA   r   r   r$   r$   r%   r   x  s    $$
z)MarigoldNormalsPipeline.decode_predictionư>r'   epsc                 C   sP   |   dks| jd dkrtd| j dtj| ddd}| |j|d } | S )	NrZ   r   r   ,Expecting 4D tensor of shape [B,3,H,W]; got r1   Trl   keepdim)min)rl   rj   rc   r!   normclamp)r'   r   r   r$   r$   r%   r     s
   z)MarigoldNormalsPipeline.normalize_normalsrX   rV   c                 C   s   |   dks| jd dkrtd| j d|dvr"td| d| jdd	d
}t|}||  jdd	d
}|dd}d}|rO| }|jdd	d
t	j
 }|dkrW||fS |jdd	d
}|dddd}t| d|}||fS )a2  
        Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
        the number of ensemble members for a given prediction of size `(H x W)`.

        Args:
            normals (`torch.Tensor`):
                Input ensemble normals maps.
            output_uncertainty (`bool`, *optional*, defaults to `False`):
                Whether to output uncertainty map.
            reduction (`str`, *optional*, defaults to `"closest"`):
                Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
                `"mean"`.

        Returns:
            A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
            uncertainties of shape `(1, 1, H, W)`.
        rZ   r   r   r   r1   rW   zUnrecognized reduction method: r   Tr   NrY   )rl   rj   rc   rY   r&   r   sumr   arccosr   piargmaxr   r!   gather)r'   rO   rV   mean_normalssim_cosr   closest_indicesclosest_normalsr$   r$   r%   r     s$   
z(MarigoldNormalsPipeline.ensemble_normals)NTNN)NNNT)Nr   NTrR   rR   r   NNNr   FFT)r   )rX   )'r   r   r   r   model_cpu_offload_seqr6   r   r
   r   r   r   r   strboolintr5   r	   re   r   r!   r"   rp   rf   r   compilerdisabler   no_gradr   EXAMPLE_DOC_STRINGr   tupler   r   staticmethodfloatr   r   __classcell__r$   r$   rB   r%   r&   c   s
   $		
,	

 	
  
%	r&   ),dataclassesr   typingr   numpyr   r!   PILr   	tqdm.autor   transformersr   r   r@   r	   modelsr
   r   
schedulersr   r   utilsr   r   r   r   utils.torch_utilsr   pipeline_utilsr   marigold_image_processingr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   r7   r   r   r&   r$   r$   r$   r%   <module>   s.   
