o
    ۷i)                     @   s(  d dl mZ d dlmZ d dlmZ d dlZd dlZd dl	m
Z
 d dlmZ d dlmZmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZmZmZmZ ddlmZ ddlm Z  ddl!m"Z" ddl#m$Z$ e rwd dl%m&  m'Z( dZ)ndZ)e*e+Z,dZ-eG dd deZ.G dd de"Z/dS )    )	dataclass)partial)AnyN)Image)tqdm)CLIPTextModelCLIPTokenizer   )PipelineImageInput)AutoencoderKLUNet2DConditionModel)DDIMSchedulerLCMScheduler)
BaseOutputis_torch_xla_availableloggingreplace_example_docstring)is_scipy_available)randn_tensor   )DiffusionPipeline   )MarigoldImageProcessorTFaE  
Examples:
```py
>>> import diffusers
>>> import torch

>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
...     "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
... ).to("cuda")

>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
>>> depth = pipe(image)

>>> vis = pipe.image_processor.visualize_depth(depth.prediction)
>>> vis[0].save("einstein_depth.png")

>>> depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
>>> depth_16bit[0].save("einstein_depth_16bit.png")
```
c                   @   sD   e Zd ZU dZejejB ed< dejB ejB ed< dejB ed< dS )MarigoldDepthOutputug  
    Output class for Marigold monocular depth prediction pipeline.

    Args:
        prediction (`np.ndarray`, `torch.Tensor`):
            Predicted depth maps with values in the range [0, 1]. The shape is `numimages × 1 × height × width` for
            `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
        uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
            Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
            height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
        latent (`None`, `torch.Tensor`):
            Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
            The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
    
predictionNuncertaintylatent)	__name__
__module____qualname____doc__npndarraytorchTensor__annotations__ r&   r&   j/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/marigold/pipeline_marigold_depth.pyr   R   s
   
 r   c                !       s<  e Zd ZdZdZdZ					d>dededee	B d	e
d
ededB dedB dedB dedB dedB f fddZdedededededededeeef dB dejdB dejeej B dB dededefdd Zejjd?d!d"Ze ee		#			$	$	#				%	&	&	d@dededB dededB d'ededededeeef dB dejeej B dB dejeej B dB deded(ed)efd*d+ZdejdejdB dejdB dededeejejf fd,d-Z d.ejdejfd/d0Z!e"			&	1	2	3	4	5dAd6ejdededed7ed8e#d9ed:e#d;edeejejdB f fd<d=Z$  Z%S )BMarigoldDepthPipelinea7  
    Pipeline for monocular depth estimation using the Marigold method: https://marigoldmonodepth.github.io.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        unet (`UNet2DConditionModel`):
            Conditional U-Net to denoise the depth latent, conditioned on image latent.
        vae (`AutoencoderKL`):
            Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
            representations.
        scheduler (`DDIMScheduler` or `LCMScheduler`):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents.
        text_encoder (`CLIPTextModel`):
            Text-encoder, for empty text embedding.
        tokenizer (`CLIPTokenizer`):
            CLIP tokenizer.
        prediction_type (`str`, *optional*):
            Type of predictions made by the model.
        scale_invariant (`bool`, *optional*):
            A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
            the model config. When used together with the `shift_invariant=True` flag, the model is also called
            "affine-invariant". NB: overriding this value is not supported.
        shift_invariant (`bool`, *optional*):
            A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
            the model config. When used together with the `scale_invariant=True` flag, the model is also called
            "affine-invariant". NB: overriding this value is not supported.
        default_denoising_steps (`int`, *optional*):
            The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
            quality with the given model. This value must be set in the model config. When the pipeline is called
            without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
            reasonable results with various model flavors compatible with the pipeline, such as those relying on very
            short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
        default_processing_resolution (`int`, *optional*):
            The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
            the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
            default value is used. This is required to ensure reasonable results with various model flavors trained
            with varying optimal processing resolution values.
    ztext_encoder->unet->vae)depth	disparityNTunetvae	schedulertext_encoder	tokenizerprediction_typescale_invariantshift_invariantdefault_denoising_stepsdefault_processing_resolutionc                    s   t    || jvrtd| d| j d | j|||||d | j||||	|
d t| dd r<dt| j	j
jd  nd	| _|| _|| _|	| _|
| _d | _t| jd
| _d S )Nz*Potentially unsupported `prediction_type='z&'`; values supported by the pipeline: .)r+   r,   r-   r.   r/   )r0   r1   r2   r3   r4   r,   r   r      )vae_scale_factor)super__init__supported_prediction_typesloggerwarningregister_modulesregister_to_configgetattrlenr,   configblock_out_channelsr7   r1   r2   r3   r4   empty_text_embeddingr   image_processor)selfr+   r,   r-   r.   r/   r0   r1   r2   r3   r4   	__class__r&   r'   r9      s8   

(zMarigoldDepthPipeline.__init__imagenum_inference_stepsensemble_sizeprocessing_resolutionresample_method_inputresample_method_output
batch_sizeensembling_kwargslatents	generatoroutput_typeoutput_uncertaintyreturnc              	      s  dt | jjjd  }|| jkrtd| j d| d|d u r$td|dk r,td|dk r4td|dkr=td	 |dkrN| jsG| j	rNt
 sNtd
|dkrX|rXtd|d u r`td|dk rhtd|| j dkrxtd| j d|dvrtd|dvrtd|dk rtd|dvrtd|	d ur d urtd|d urt|tstdd|v r|d dvrtdd}d\}}t|ts|g}t|D ]}\}}t|tjst|r	|jdvrtd| d|j d|jd d  \}}d}|jd!kr|jd }nt|tjr|j\}}d}ntd"| d#t| d|d u r0||}}n||f||fkrJtd$| d%||f d&||f ||7 }q|	d urt|	s^td'|	 d!krntd(|	j d|dkrt||}|| | }|| | }|dks|dkrtd)| d*| d+||}}|| j d | j }|| j d | j }|| | jjj||f}|	j|krtd,|	j d-| d d ur
t trt  || krtd.t fd/d0 D std1|S t tjs
td2t  d|S )3Nr   r   z/`vae_scale_factor` computed at initialization (z) differs from the actual one (z).zW`num_inference_steps` is not specified and could not be resolved from the model config.z'`num_inference_steps` must be positive.z!`ensemble_size` must be positive.zk`ensemble_size` == 2 results are similar to no ensembling (1); consider increasing the value to at least 3.z9Make sure to install scipy if you want to use ensembling.zpComputing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` greater than 1.zY`processing_resolution` is not specified and could not be resolved from the model config.r   zx`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for downsampled processing.z.`processing_resolution` must be a multiple of r5   )nearestnearest-exactbilinearbicubicareazy`resample_method_input` takes string values compatible with PIL library: nearest, nearest-exact, bilinear, bicubic, area.zz`resample_method_output` takes string values compatible with PIL library: nearest, nearest-exact, bilinear, bicubic, area.z`batch_size` must be positive.)ptr!   z*`output_type` must be one of `pt` or `np`.z2`latents` and `generator` cannot be used together.z)`ensembling_kwargs` must be a dictionary.	reductionmeanmedianzF`ensembling_kwargs['reduction']` can be either `'mean'` or `'median'`.)NN)r   r	      z`image[z(]` has unsupported dimensions or shape: r_   zUnsupported `image[z	]` type: zInput `image[z]` has incompatible dimensions z with the previous images z!`latents` must be a torch.Tensor.z/`latents` has unsupported dimensions or shape: z*Extreme aspect ratio of the input image: [z x ]z`latents` has unexpected shape=z
 expected=z^The number of generators must match the total number of ensemble members for all input images.c                 3   s$    | ]}|j j d  j jkV  qdS )r   N)devicetype).0grQ   r&   r'   	<genexpr>A  s   " z5MarigoldDepthPipeline.check_inputs.<locals>.<genexpr>z;`generator` device placement is not consistent in the list.zUnsupported generator type: )r@   r,   rA   rB   r7   
ValueErrorr;   r<   r1   r2   r   ImportError
isinstancedictlist	enumerater!   r"   r#   	is_tensorndimshaper   sizerc   dimmaxlatent_channelsall	Generator)rE   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   actual_vae_scale_factor
num_imagesWHiimgH_iW_iN_imax_orignew_Hnew_Wwhshape_expectedr&   rf   r'   check_inputs   s   












z"MarigoldDepthPipeline.check_inputsc                 C   s   t | ds	i | _nt| jtstdt| j dtdi | j}|d||d< |d||d< |d ur>t|fi |S |d urKtdd|i|S td)	N_progress_bar_configz=`self._progress_bar_config` should be of type `dict`, but is r5   descleavetotalz/Either `total` or `iterable` has to be defined.r&   )hasattrr   rj   rk   rh   rc   getr   )rE   iterabler   r   r   progress_bar_configr&   r&   r'   progress_barH  s   
z"MarigoldDepthPipeline.progress_barr   rW   r!   Fmatch_input_resolutionoutput_latentreturn_dictc           $         s&  j }j}|du rj}|du rj}|||||| |
||}jdu rEd}j|djjddd}|j	|}
|d _j|||||\}}}||
|| \}~jj	||d d	d	}g }jtd||  dd
dD ]\}|||   }||   }|jd }|d| }jj||d jjjdddD ]'} tj||gd	d}!j|!| |ddd }"jj|"| ||dj}trt  q|| q{tj|dd~~~~~~~!~"tj fddtdjd  D dd|sdj|d}#|d	krHj||gjd	d R  fddt|D t  \}#tjddrFtj|#dd}#nd}#|rgjj!||dd|#durgrgjj!|#||dd}#|dkrj"|#durrj"|#}##  |s|#fS t$|#dS )aA  
        Function invoked when calling the pipeline.

        Args:
            image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`),
                `list[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For
                arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
                by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
                three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
                same width and height.
            num_inference_steps (`int`, *optional*, defaults to `None`):
                Number of denoising diffusion steps during inference. The default value `None` results in automatic
                selection.
            ensemble_size (`int`, defaults to `1`):
                Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
            processing_resolution (`int`, *optional*, defaults to `None`):
                Effective processing resolution. When set to `0`, matches the larger input image dimension. This
                produces crisper predictions, but may also lead to the overall loss of global context. The default
                value `None` resolves to the optimal value from the model config.
            match_input_resolution (`bool`, *optional*, defaults to `True`):
                When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
                side of the output will equal to `processing_resolution`.
            resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
                Resampling method used to resize input images to `processing_resolution`. The accepted values are:
                `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
            resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
                Resampling method used to resize output predictions to match the input resolution. The accepted values
                are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
            batch_size (`int`, *optional*, defaults to `1`):
                Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
            ensembling_kwargs (`dict`, *optional*, defaults to `None`)
                Extra dictionary with arguments for precise ensembling control. The following options are available:
                - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in
                  every pixel location, can be either `"median"` or `"mean"`.
                - regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that
                  pulls the aligned predictions to the unit range from 0 to 1.
                - max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to
                  `scipy.optimize.minimize` function, `options` argument.
                - tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the
                  tolerance is reached.
                - max_res (`int`, *optional*, defaults to `None`): Resolution at which the alignment is performed;
                  `None` matches the `processing_resolution`.
            latents (`torch.Tensor`, or `list[torch.Tensor]`, *optional*, defaults to `None`):
                Latent noise tensors to replace the random initialization. These can be taken from the previous
                function call's output.
            generator (`torch.Generator`, or `list[torch.Generator]`, *optional*, defaults to `None`):
                Random number generator object to ensure reproducibility.
            output_type (`str`, *optional*, defaults to `"np"`):
                Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
                values are: `"np"` (numpy array) or `"pt"` (torch tensor).
            output_uncertainty (`bool`, *optional*, defaults to `False`):
                When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
                the `ensemble_size` argument is set to a value above 2.
            output_latent (`bool`, *optional*, defaults to `False`):
                When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
                within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
                `latents` argument.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.

        Examples:

        Returns:
            [`~pipelines.marigold.MarigoldDepthOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.marigold.MarigoldDepthOutput`] is returned, otherwise a
                `tuple` is returned where the first element is the prediction, the second element is the uncertainty
                (or `None`), and the third is the latent (or `None`).
        N 
do_not_padTrZ   )padding
max_length
truncationreturn_tensorsr   )rb   dtyper   zMarigold predictions...)r   r   )rb   FzDiffusion steps...rr   )encoder_hidden_statesr   rf   c                    s"   g | ]} ||   qS r&   )decode_predictionrd   r{   )rN   pred_latentrE   r&   r'   
<listcomp>(  s    z2MarigoldDepthPipeline.__call__.<locals>.<listcomp>c                    s0   g | ]}j | jjfi  pi qS r&   )ensemble_depthr1   r2   r   )rO   rS   r   rE   r&   r'   r   =  s    )is_aar!   )r   r   r   )%_execution_devicer   r3   r4   r   rC   r/   model_max_length	input_idstor.   rD   
preprocessprepare_latentsrepeatr   rangerp   r-   set_timesteps	timestepsr#   catr+   stepprev_sampleXLA_AVAILABLExm	mark_stepappendunpad_imagereshapezipresize_antialiaspt_to_numpymaybe_free_model_hooksr   )$rE   rH   rI   rJ   rK   r   rL   rM   rN   rO   rP   rQ   rR   rS   r   r   rb   r   rx   prompttext_inputstext_input_idsr   original_resolutionimage_latentbatch_empty_text_embeddingpred_latentsr{   batch_image_latentbatch_pred_latenteffective_batch_sizetexttbatch_latentnoiser   r&   )rN   rO   rS   r   r   rE   r'   __call__[  s   Z
	









zMarigoldDepthPipeline.__call__c                    s~   dd t j fddtdjd  D dd}|jjj }|j|dd}|}|d u r;t|j||j	|j
d}||fS )Nc                 S   s,   t | dr
| j S t | dr| jS td)Nlatent_distrP   z3Could not access latents of provided encoder_output)r   r   moderP   AttributeError)encoder_outputr&   r&   r'   retrieve_latentsu  s
   


z?MarigoldDepthPipeline.prepare_latents.<locals>.retrieve_latentsc              	      s(   g | ]}j ||   qS r&   )r,   encoder   rN   rH   r   rE   r&   r'   r   ~  s    z9MarigoldDepthPipeline.prepare_latents.<locals>.<listcomp>r   r   )rQ   rb   r   )r#   r   r   rp   r,   rA   scaling_factorrepeat_interleaver   rb   r   )rE   rH   rP   rQ   rJ   rN   r   r   r&   r   r'   r   m  s$   z%MarigoldDepthPipeline.prepare_latentsr   c                 C   s   |  dks|jd | jjjkrtd| jjj d|j d| jj|| jjj ddd }|jdd	d
}t	
|dd}|d d }|S )Nr_   r   z Expecting 4D tensor of shape [B,z,H,W]; got r5   F)r   r   Trr   keepdimg            ?g       @)rr   rp   r,   rA   rt   rh   decoder   r]   r#   clip)rE   r   r   r&   r&   r'   r     s    z'MarigoldDepthPipeline.decode_predictionr^   {Gz?r   MbP?   r)   r[   regularizer_strengthmax_itertolmax_resc	              
      s  |   dks| jd dkrtd| j ddvr"td d	s*
r*tddtjf	
fd	d
dtjdtjdtjf	
fdd 	ddtjdtdttjtjdB f ffdddtjdtjdt	f fdddtjffdd}		p
}
| jd |
r|	| } | |} | |d\} }| 
 }	r
r|  }n		rd}ntd|| jdd}| | | } |r|| }| |fS )a	  
        Ensembles the depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
        number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
        depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
        alignment happens when the predictions have one or more degrees of freedom, that is when they are either
        affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
        `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
        alignment is skipped and only ensembling is performed.

        Args:
            depth (`torch.Tensor`):
                Input ensemble depth maps.
            scale_invariant (`bool`, *optional*, defaults to `True`):
                Whether to treat predictions as scale-invariant.
            shift_invariant (`bool`, *optional*, defaults to `True`):
                Whether to treat predictions as shift-invariant.
            output_uncertainty (`bool`, *optional*, defaults to `False`):
                Whether to output uncertainty map.
            reduction (`str`, *optional*, defaults to `"median"`):
                Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
                `"median"`.
            regularizer_strength (`float`, *optional*, defaults to `0.02`):
                Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
            max_iter (`int`, *optional*, defaults to `2`):
                Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
                argument.
            tol (`float`, *optional*, defaults to `1e-3`):
                Alignment solver tolerance. The solver stops when the tolerance is reached.
            max_res (`int`, *optional*, defaults to `1024`):
                Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
        Returns:
            A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
            `(1, 1, H, W)`.
        r_   r   z,Expecting 4D tensor of shape [B,1,H,W]; got r5   r\   Unrecognized reduction method: z1Pure shift-invariant ensembling is not supported.r)   c                    s   |   djddj}|   djddj}r5r5d|| jdd }| | }t||f  }nrFd|jdd }|  }nt	d|
tj}|S )Nr   r   r   ư>minUnrecognized alignment.)r   r   valuesrs   clampr#   r   cpunumpyrh   astyper!   float64)r)   init_mininit_maxinit_sinit_tparamrJ   r1   r2   r&   r'   
init_param  s   
z8MarigoldDepthPipeline.ensemble_depth.<locals>.init_paramr   rT   c                    s   r0r0t |d\}}t||  ddd}t||  ddd}| | | }|S rFt||  ddd}| | }|S td)Nr   r   r   )r!   splitr#   
from_numpyr   viewrh   )r)   r   sr   outr   r&   r'   align  s   z3MarigoldDepthPipeline.ensemble_depth.<locals>.alignFdepth_alignedreturn_uncertaintyNc                    s   d } dkrt j| ddd}|rt j| ddd}||fS  dkr=t j| dddj}|r9t jt | | dddj}||fS td  d)Nr]   r   Tr   r^   r   r5   )r#   r]   stdr^   r   absrh   )r   r   r   r   )r[   r&   r'   ensemble  s   z6MarigoldDepthPipeline.ensemble_depth.<locals>.ensemblec                    s   d} || }t t D ]\}}|| ||  }||d    7 }qdkrN|dd\}}|   }	d|    }
||	|
  7 }|S )Ng        r   r   Fr   r   )	r#   combinationsaranger]   sqrtitemr   r   rs   )r   r)   costr   r{   jdiffr   _err_nearerr_far)r   r   rJ   r   r&   r'   cost_fn  s   
z5MarigoldDepthPipeline.ensemble_depth.<locals>.cost_fnc                    sr   dd l }| tj}d ur t|jdd  kr t|d}|}|jj	t
 |d|dddd}|jS )	Nr   r   rV   )r)   BFGSF)maxiterdisp)methodr   options)scipyr   r#   float32rs   rp   r   resize_to_max_edgeoptimizeminimizer   x)r)   r  depth_to_alignr   res)r  r   r   r   r   r&   r'   compute_param  s   
z;MarigoldDepthPipeline.ensemble_depth.<locals>.compute_paramr   r   r   r   r   )F)rr   rp   rh   r#   r$   r!   r"   booltuplefloatrs   r   r   )r)   r1   r2   rS   r[   r   r   r   r   r  requires_aligningr   r   	depth_max	depth_mindepth_ranger&   )r   r  r   rJ   r   r   r   r[   r   r1   r2   r   r'   r     sF   .$$


z$MarigoldDepthPipeline.ensemble_depth)NTTNN)NNNT)Nr   NTrW   rW   r   NNNr!   FFT)TTFr^   r   r   r   r   )&r   r   r   r    model_cpu_offload_seqr:   r   r   r   r   r   r   strr  intr9   r
   rk   r   r#   r$   rv   rl   r   compilerdisabler   no_gradr   EXAMPLE_DOC_STRINGr   r  r   r   staticmethodr  r   __classcell__r&   r&   rF   r'   r(   h   s2   )		
/	

 	
  
%	
r(   )0dataclassesr   	functoolsr   typingr   r   r!   r#   PILr   	tqdm.autor   transformersr   r   rD   r
   modelsr   r   
schedulersr   r   utilsr   r   r   r   utils.import_utilsr   utils.torch_utilsr   pipeline_utilsr   marigold_image_processingr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   r;   r"  r   r(   r&   r&   r&   r'   <module>   s2   
