o
    ߥiVz                     @   s  d dl Z d dlmZmZmZmZmZmZ d dlZd dl	Z
d dlZd dlmZ d dlZd dlmZ d dlmZmZmZmZmZmZmZmZmZmZ d dlmZ d dlmZ d dl m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' d dl(m)Z)m*Z*m+Z+ d dl,m-Z- d d	l.m/Z/m0Z0 d d
l1m2Z2 d dl3T d dl4T d dl5T d dl6m7Z7m8Z8 d dl9m:Z: e: Z;dZ<e2j=e8j>e-j>dG dd de0Z?	dddZ@G dd deZAdS )    N)AnyCallableDictListOptionalUnion)
AutoencoderKLControlNetModelDiffusionPipelineEulerAncestralDiscreteSchedulerEulerDiscreteScheduler(StableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetInpaintPipelineStableDiffusionInpaintPipelineStableDiffusionPipelineUNet2DConditionModel)MultiControlNetModel)StableDiffusionPipelineOutput)	deprecateis_accelerate_availableis_accelerate_versionis_compiled_moduleloggingrandn_tensorreplace_example_docstring)load_objload_objs_as_meshessave_obj)Models)Tensor
TorchModel)MODELS)*)	ModelFileTasks)
get_loggera  
    Examples:
        ```py
        >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
        >>> from diffusers.utils import load_image
        >>> import numpy as np
        >>> import torch

        >>> init_image = load_image(image_path)
        >>> init_image = init_image.resize((512, 512))
        >>> generator = torch.Generator(device="cpu").manual_seed(1)
        >>> mask_image = load_image(mask_path)
        >>> mask_image = mask_image.resize((512, 512))
        >>> def make_inpaint_condition(image, image_mask):
        ...     image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
        ...     image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
        ...     assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
        ...     image[image_mask > 0.5] = -1.0  # set as masked pixel
        ...     image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
        ...     image = torch.from_numpy(image)
        ...     return image
        >>> control_image = make_inpaint_condition(init_image, mask_image)
        >>> controlnet = ControlNetModel.from_pretrained(
        ...     "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
        ... )
        >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
        ...     "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
        ... )
        >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        >>> pipe.enable_model_cpu_offload()
        >>> image = pipe(
        ...     "a handsome man with ray-ban sunglasses",
        ...     num_inference_steps=20,
        ...     generator=generator,
        ...     eta=1.0,
        ...     image=init_image,
        ...     mask_image=mask_image,
        ...     control_image=control_image,
        ... ).images[0]
        ```
)module_namec                       s@   e Zd Z fddZdd Zdd Zddd	Zdd
dZ  ZS )Tex2Texturec                    s   t  j|d|i| tj r td| _td| j nt	d t
  |d }|d }|d }|dtj}tj||d	| j| _tj||d	| j| _tj|| j|d
| j| _td dS )a  The Tex2Texture is modified based on TEXTure and Text2Tex, publicly available at
                https://github.com/TEXTurePaper/TEXTurePaper &
                https://github.com/daveredrum/Text2Tex
        Args:
            model_dir: the root directory of the model files
        	model_dircudazUse GPU: {}zno gpu avaiablez/base_model/z/control_model/z/inpaint_model/torch_dtype)r*   )
controlnetr*   zmodel load overN)super__init__torchr)   is_availabledeviceloggerinfoformatprintexitgetfloat16r	   from_pretrainedtor+   r   inpaintmodel%StableDiffusionControlinpaintPipelinepipe)selfr(   argskwargs
model_pathcontrolmodel_pathinpaintmodel_pathr*   	__class__ l/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/models/cv/text_texture_generation/Tex2Texture.pyr-   V   s8   
zTex2Texture.__init__c                 C   s0   t || jd\}}}t|g| jd}||||fS )Nr0   )r   r0   r   )r=   	mesh_pathvertsfacesauxmeshrE   rE   rF   	init_mesht   s   zTex2Texture.init_meshc           	      C   s   |  }| jd }|jdd|d}|| }|dd d df |dd d df  }| }d| }|d|}||}| |||fS )Nr      dim   g?)	get_bounding_boxesverts_packedshapemeanrepeatoffset_vertsmax	unsqueezescale_verts)	r=   rL   bbox	num_vertsmesh_centerlensmax_lenscalenew_meshrE   rE   rF   normalize_meshy   s   $
zTex2Texture.normalize_meshnormalized.objc                 C   s@   t d |}t|||jd|j|j|jt|j d  d d S )Nz!=> saving normalized mesh file...   r   )rI   rJ   decimal_places	verts_uvs	faces_uvstexture_map)r4   r   	verts_idxrf   textures_idxtexture_imageslistkeys)r=   rI   rJ   rK   pathobj_pathrE   rE   rF   save_normalized_obj   s   
zTex2Texture.save_normalized_objc           	      C   sD   |  |\}}}}| |\}}}}| |||| ||||||fS N)rM   rb   rp   )	r=   rH   	save_pathrL   rI   rJ   rK   r]   r`   rE   rE   rF   mesh_normalized   s   zTex2Texture.mesh_normalized)rc   )	__name__
__module____qualname__r-   rM   rb   rp   rs   __classcell__rE   rE   rC   rF   r'   R   s    
r'   Fc                    sf  | d u rt d|d u rt dt| tjrt|tjs&tdt| d| jdkr;| jd dks6J d| d} |jdkrH|dd}|jdkr_|jd d	krZ|d}n|d	}| jd
kri|jd
ksmJ d| jdd  |jdd  ksJ d| jd |jd ksJ d| 	 dk s| 
 d	krt d|	 dk s|
 d	krt dd||dk < d	||dk< | jtjd} nt|tjrtdt|  dt| tjjtjfr| g} t| trt| d tjjr fdd| D } dd | D } tj| dd} nt| tr"t| d tjr"tjdd | D dd} | ddd	d} t| jtjdd d } t|tjjtjfrF|g}t|trut|d tjjru fdd|D }tjdd |D dd}|tjd }nt|trt|d tjrtjdd |D dd}d||dk < d	||dk< t|}| |dk  }|r||| fS ||fS ) Nz"`image` input cannot be undefined.z'`mask_image` input cannot be undefined.z,`image` is a torch.Tensor but `mask` (type: z is not   r   z2Image outside a batch should be of shape (3, H, W)rN   rQ      z%Image and Mask must have 4 dimensionsz4Image and Mask must have the same spatial dimensionsz,Image and Mask must have the same batch sizez Image should be in [-1, 1] rangezMask should be in [0, 1] range      ?)dtypez,`mask` is a torch.Tensor but `image` (type: c                    "   g | ]}|j  ftjjd qS )resampleresizePILImageLANCZOS.0iheightwidthrE   rF   
<listcomp>       z1prepare_mask_and_masked_image.<locals>.<listcomp>c                 S   s(   g | ]}t |d dddf qS )RGBNnparrayconvertr   rE   rE   rF   r      s   ( )axisc                 S   s   g | ]
}|d d d f qS rq   rE   r   rE   rE   rF   r      s    g     _@      ?c                    r~   r   r   r   r   rE   rF   r      r   c                 S   s*   g | ]}t |d ddddf qS )LNr   r   mrE   rE   rF   r      s   * g     o@c                 S   s   g | ]}|d d d d f qS rq   rE   r   rE   rE   rF   r      s    )
ValueError
isinstancer.   r   	TypeErrortypendimrT   rY   minrX   r9   float32r   r   r   ndarrayrl   concatenate	transpose
from_numpyastype)imagemaskr   r   return_imagemasked_imagerE   r   rF   prepare_mask_and_masked_image   s   




  

r   c                2   @   sx  e Zd Ze ee																								
d$deee	e f deej
ejjf deej
ejjf deejejjeje	ej e	ejj e	ej f dee dee dedededeeee	e f  dee dedeeeje	ej f  deej deej deej dee dedeeeeejgdf  dedeeeef  d eee	e f d!ef.d"d#ZdS )%r;   Nr   2         @rQ           pilTr|   Fpromptr   
mask_imagecontrol_imager   r   strengthnum_inference_stepsguidance_scalenegative_promptnum_images_per_prompteta	generatorlatentsprompt_embedsnegative_prompt_embedsoutput_typereturn_dictcallbackcallback_stepscross_attention_kwargscontrolnet_conditioning_scale
guess_modec           =      C   s  |  |||\}}| ||||||
|||	 |dur"t|tr"d}n|dur0t|tr0t|}n|jd }| j}|	dk}t| j	rE| j	j
n| j	}t|trZt|trZ|gt|j }t|trc|jjn|jd jj}|pm|}|durx|ddnd}| j|||||
|||d}t|tr| j||||| |||j||d	}n't|trg }|D ]}| j||||| |||j||d	}|| q|}nJ t||||d	d
\} }!}"| jj||d | j|||d\}#}|#dd || }$|dk}%| jjj}&| jjj}'|'dk}(| j|| |&|||j||||"|$|%d	|(d})|(r|)\}}*}+n|)\}}*|  | |!|| |||j|||	\} },| !||}-| j"|d}.t#|#D ]\}/}0|rUt$%|gd n|}1| j&|1|0}1|ru|ru|}2| j&|2|0}2|'dd }3n|1}2|}3| j	|2|0|3|||dd\}4}5|r|rdd |4D }4t$%t$(|5|5g}5|'dkrt$j%|1| |,gdd}1| j|1|0|||4|5ddd }6|r|6'd\}7}8|7|	|8|7   }6| jj)|6|0|fi |-ddid }|'dkr|+dd }9| dd }:|/t|#d k r| j*|9|*t$+|0g}9d|: |9 |:|  }|/t|#d ks#|/d | jj, dkr9|.-  |dur9|/| dkr9||/|0| qFW d   n	1 sFw   Y  t.| drh| j/durh| j0d | j	0d t$j12  |dks| jj3|| jjj4 ddd }| 5|||j\}};n|}d};|;du rd	g|jd  }<ndd |;D }<| j6j7|||<d}t.| dr| j/dur| j/8  |s||;fS t9||;dS )uV  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
                    `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
                the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
                specified in init, images must be passed as a list such that each element of the list can be correctly
                batched for input to a single controlnet.
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image.
            strength (`float`, *optional*, defaults to 1.):
                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
                `strength`. The number of denoising steps depends on the amount of noise initially added. When
                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
                portion of the reference `image`.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
                corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
                than for [`~StableDiffusionControlNetPipeline.__call__`].
            guess_mode (`bool`, *optional*, defaults to `False`):
                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
            When returning a tuple, the first element is a list with the generated images, and the second element is a
            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content, according to the `safety_checker`.
        NrQ   r   r   r`   )r   r   
lora_scale)	r   r   r   
batch_sizer   r0   r}   do_classifier_free_guidancer   FT)r   rG   )r   r   r0   ry   )r   timestepis_strength_maxreturn_noisereturn_image_latents)totalrN   )encoder_hidden_statescontrolnet_condconditioning_scaler   r   c                 S   s    g | ]}t t ||gqS rE   )r.   cat
zeros_like)r   drE   rE   rF   r   ;  s    zBStableDiffusionControlinpaintPipeline.__call__.<locals>.<listcomp>	   rO   )r   r   down_block_additional_residualsmid_block_additional_residualr   r   final_offload_hookcpulatent)r   c                 S   s   g | ]}| qS rE   rE   )r   has_nsfwrE   rE   rF   r     s    )r   do_denormalize)imagesnsfw_content_detected):_default_height_widthcheck_inputsr   strrl   lenrT   _execution_devicer   r+   	_orig_modr   floatnetsr	   configglobal_pool_conditionsr6   _encode_promptprepare_control_imager}   appendr   	schedulerset_timestepsget_timestepsrV   vaelatent_channelsunetin_channelsprepare_latentsprepare_mask_latentsprepare_extra_step_kwargsprogress_bar	enumerater.   r   scale_model_inputchunkr   step	add_noisetensororderupdatehasattrr   r9   r)   empty_cachedecodescaling_factorrun_safety_checkerimage_processorpostprocessoffloadr   )=r=   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r0   r   r+   r   text_encoder_lora_scalecontrol_imagescontrol_image_r   r   
init_image	timestepslatent_timestepr   num_channels_latentsnum_channels_unetr   latents_outputsnoiseimage_latentsmasked_image_latentsextra_step_kwargsr   r   tlatent_model_inputcontrol_model_inputcontrolnet_prompt_embedsdown_block_res_samplesmid_block_res_sample
noise_prednoise_pred_uncondnoise_pred_textinit_latents_proper	init_maskhas_nsfw_conceptr   rE   rE   rF   __call__  s  }











 [




z.StableDiffusionControlinpaintPipeline.__call__)NNNNNNr   r   r   NrQ   r   NNNNr   TNrQ   Nr|   F)rt   ru   rv   r.   no_gradr   EXAMPLE_DOC_STRINGr   r   r   r   r   r   FloatTensorr   r   r   intr   	Generatorboolr   r   r   r  rE   rE   rE   rF   r;     s    	
r;   )F)Bostypingr   r   r   r   r   r   cv2numpyr   r   	PIL.Imager   r.   torchvision.transforms
transforms	diffusersr   r	   r
   r   r   r   r   r   r   r   .diffusers.pipelines.controlnet.multicontrolnetr   $diffusers.pipelines.stable_diffusionr   diffusers.utilsr   r   r   r   r   r   r   pytorch3d.ior   r   r   modelscope.metainfor   modelscope.models.baser   r    modelscope.models.builderr!   8modelscope.models.cv.text_texture_generation.lib2.camera;modelscope.models.cv.text_texture_generation.lib2.init_view2modelscope.models.cv.text_texture_generation.utilsmodelscope.utils.constantr#   r$   modelscope.utils.loggerr%   r1   r  register_moduletext_texture_generationr'   r   r;   rE   rE   rE   rF   <module>   s@    0$+I

l