o
    Gi                    @   s	  d Z ddlZddlZddlZddlmZ ddlmZ ddlm	Z	 ddl
Z
ddlZddlZddlmZ ddlmZmZmZmZmZmZmZmZ dd	lmZmZmZmZmZmZ dd
lm Z  ddl!m"Z" ddl#m$Z$ e rqddl%m&Z& e rddl'm(Z( ddl)m*Z* e+e,Z-i dddddddddddddgdddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2i d3d4d5d6d7d8d9d:d;gd<d=d>gd?d@dAdBdCdDdEddFdGdHg dIdJdKdLgdMg dNdOdPdQdRdSdTdUgdVdWd&dXdYgdZd[gd\d]g d^d_d`gdadbdcddg deg dfdgdhgg didjZ.i ddkdliddkdmidndkdoid#dkdpiddkdqid%dkdridsdkdtiddkduiddkdvid!dkdwidxdkdyiddkdziddkd{id5d|d}d~dd|dd~d7ddd~dddd~i d9dkdid<dkdiddkdiddkdidAdkdiddkdidCdkdidEdkdidFdkdidHdkdiddkdiddkdiddkdiddkdiddkdidMdkdiddkdii ddkdiddkdiddkdiddkdiddkdiddkdidSdkdidVdkdiddkdiddkdiddkdiddkdiddkdiddkdiddkdiddkdiddkdidkdidkdidkdidkdidkdidkdidkdidkdidkdidkdidkdidkdidkdidkdid͜Z/ddddddddddddќZ0dddddddddddܜ
dddddddddddddddddddddd
dddddddddddi dddddddddddddddddddddddddddddddddddd ddddddddd	ddd
	ddZ1g dZ2ddddddddddddZ3ddgZ4dZ5dZ6dZ7dZ8d d!gZ9dZ:d"d#gZ;g d$Z<G d%d& d&e=Z>d'd( Z?d)d* Z@d+d, ZAd-d. ZBd/d0 ZC								dd1d2ZDdd3d4ZEd5d6 ZFd7d8 ZGd9d: ZHd;d< ZId=d> ZJd?d@ ZKdAdB ZLdCdD ZMdEdF ZNdGdH ZOddIdJZPdKdL ZQ	ddMdNZRddOdPZSddQdRZTddSdTZUdUdV ZVdWdX ZWdYdZ ZXd[d\ ZYdd]d^ZZd_d` Z[dadb Z\ddcddZ]	eddfdgZ^	h				ddidjZ_	ddkdlZ`ddmdnZadodp Zbdqdr Zcdsdt Zddudv Zedwdx Zfdydz Zgd{d| Zhd}d~ Zi	h			dddZjdd Zkdd Zldd Zmdd Zndd Zodd Zpdd Zqdd Zrdd Zsdd Ztdd Zudd Zvdd Zwdd Zxdd Zydd Zzdd Z{dd Z|dd Z}dd Z~dd ZdS (  z7Conversion script for the Stable Diffusion checkpoints.    N)nullcontext)BytesIO)urlparse   )load_state_dict)DDIMSchedulerDPMSolverMultistepSchedulerEDMDPMSolverMultistepSchedulerEulerAncestralDiscreteSchedulerEulerDiscreteSchedulerHeunDiscreteSchedulerLMSDiscreteSchedulerPNDMScheduler)SAFETENSORS_WEIGHTS_NAMEWEIGHTS_NAME	deprecateis_accelerate_availableis_transformers_availablelogging)DIFFUSERS_REQUEST_TIMEOUT)_get_model_file)empty_device_cache)AutoImageProcessor)init_empty_weights)load_model_dict_into_metav1z?model.diffusion_model.output_blocks.11.0.skip_connection.weightv2zMmodel.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weightxl_basezEconditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias
xl_refinerzEconditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.biasupscalez<model.diffusion_model.input_blocks.10.0.skip_connection.bias
controlnetz!control_model.time_embed.0.weight(controlnet_cond_embedding.conv_in.weightcontrolnet_xladd_embedding.linear_1.weightcontrolnet_xl_largezAdown_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weightcontrolnet_xl_midz&down_blocks.1.attentions.0.norm.weightplayground-v2-5edm_mean
inpaintingz-model.diffusion_model.input_blocks.0.0.weightclipzLcond_stage_model.transformer.text_model.embeddings.position_embedding.weight	clip_sdxlzSconditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weightclip_sd3zPtext_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight	open_clipz-cond_stage_model.model.token_embedding.weightopen_clip_sdxlz2conditioner.embedders.1.model.positional_embeddingopen_clip_sdxl_refinerz-conditioner.embedders.0.model.text_projectionopen_clip_sd3zPtext_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weightstable_cascade_stage_bz$down_blocks.1.0.channelwise.0.weightstable_cascade_stage_cclip_txt_mapper.weightsd3z4joint_blocks.0.context_block.adaLN_modulation.1.biaszJmodel.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias
sd35_largez&joint_blocks.37.x_block.mlp.fc1.weightz<model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weightanimatediffzjdown_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.peanimatediff_v2z9mid_block.motion_modules.0.temporal_transformer.norm.biasanimatediff_sdxl_betaz=up_blocks.2.motion_modules.0.temporal_transformer.norm.weightanimatediff_scribbleanimatediff_rgbz controlnet_cond_embedding.weightauraflow)zdouble_layers.0.attn.w2q.weightzdouble_layers.0.attn.w1q.weightcond_seq_linear.weightt_embedder.mlp.0.weightfluxz,double_blocks.0.img_attn.norm.key_norm.scalezBmodel.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale	ltx-video)z*model.diffusion_model.patchify_proj.weightz=model.diffusion_model.transformer_blocks.27.scale_shift_tablezpatchify_proj.weightz'transformer_blocks.27.scale_shift_tablez"vae.decoder.last_scale_shift_tablez6vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weightautoencoder-dcz.decoder.stages.1.op_list.0.main.conv.conv.biasautoencoder-dc-sanaencoder.project_in.conv.biasmochi-1-previewz0model.diffusion_model.blocks.0.attn.qkv_x.weightzblocks.0.attn.qkv_x.weighthunyuan-videoz@txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.biasz+model.diffusion_model.cap_embedder.0.weightzcap_embedder.0.weightz8model.diffusion_model.layers.0.adaLN_modulation.0.weightz"layers.0.adaLN_modulation.0.weightz!control_all_x_embedder.2-1.weightz+control_layers.14.adaLN_modulation.0.weight)z#blocks.0.cross_attn.q_linear.weightz!blocks.0.cross_attn.q_linear.biasz$blocks.0.cross_attn.kv_linear.weightz"blocks.0.cross_attn.kv_linear.biasz%model.diffusion_model.head.modulationhead.modulation!decoder.middle.0.residual.0.gammazvace_blocks.0.after_proj.bias#motion_encoder.dec.direction.weightz4double_stream_blocks.0.block.adaLN_modulation.1.bias)net.x_embedder.proj.1.weight3net.blocks.block1.blocks.0.block.attn.to_q.0.weightz net.extra_pos_embedder.pos_emb_h)rG   z$net.blocks.0.self_attn.q_proj.weightz"net.pos_embedder.dim_spatial_rangez9model.diffusion_model.single_stream_modulation.lin.weightz#single_stream_modulation.lin.weight)zWmodel.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weightz(vae.per_channel_statistics.mean-of-meansz.audio_vae.per_channel_statistics.mean-of-means)instruct-pix2pixlumina2z-image-turboz-image-turbo-controlnetz-image-turbo-controlnet-2.xsanawanwan_vaewan_vacewan_animatehidream
cosmos-1.0
cosmos-2.0flux2ltx2pretrained_model_name_or_pathz(stabilityai/stable-diffusion-xl-base-1.0z+stabilityai/stable-diffusion-xl-refiner-1.0
xl_inpaintz0diffusers/stable-diffusion-xl-1.0-inpainting-0.1z-playgroundai/playground-v2.5-1024px-aestheticz(stabilityai/stable-diffusion-x4-upscalerz1stable-diffusion-v1-5/stable-diffusion-inpaintinginpainting_v2z)stabilityai/stable-diffusion-2-inpaintingz"lllyasviel/control_v11p_sd15_cannyz#diffusers/controlnet-canny-sdxl-1.0z'diffusers/controlnet-canny-sdxl-1.0-midcontrolnet_xl_smallz)diffusers/controlnet-canny-sdxl-1.0-smallz stabilityai/stable-diffusion-2-1z+stable-diffusion-v1-5/stable-diffusion-v1-5zstabilityai/stable-cascadedecoder)rX   	subfolderstable_cascade_stage_b_litedecoder_litez stabilityai/stable-cascade-priorpriorstable_cascade_stage_c_lite
prior_litez/stabilityai/stable-diffusion-3-medium-diffusersz&stabilityai/stable-diffusion-3.5-largesd35_mediumz'stabilityai/stable-diffusion-3.5-mediumanimatediff_v1z&guoyww/animatediff-motion-adapter-v1-5z(guoyww/animatediff-motion-adapter-v1-5-2animatediff_v3z(guoyww/animatediff-motion-adapter-v1-5-3z+guoyww/animatediff-motion-adapter-sdxl-betaz&guoyww/animatediff-sparsectrl-scribblez!guoyww/animatediff-sparsectrl-rgbzfal/AuraFlow-v0.3flux-devzblack-forest-labs/FLUX.1-dev	flux-fillz!black-forest-labs/FLUX.1-Fill-dev
flux-depthz"black-forest-labs/FLUX.1-Depth-devflux-schnellz black-forest-labs/FLUX.1-schnell
flux-2-devzblack-forest-labs/FLUX.2-devzdiffusers/LTX-Video-0.9.0ltx-video-0.9.1zdiffusers/LTX-Video-0.9.1ltx-video-0.9.5zLightricks/LTX-Video-0.9.5ltx-video-0.9.7zLightricks/LTX-Video-0.9.7-devautoencoder-dc-f128c512z,mit-han-lab/dc-ae-f128c512-mix-1.0-diffusersautoencoder-dc-f64c128z+mit-han-lab/dc-ae-f64c128-mix-1.0-diffusersautoencoder-dc-f32c32z*mit-han-lab/dc-ae-f32c32-mix-1.0-diffusersautoencoder-dc-f32c32-sanaz+mit-han-lab/dc-ae-f32c32-sana-1.0-diffuserszgenmo/mochi-1-previewz#hunyuanvideo-community/HunyuanVideorI   ztimbrooks/instruct-pix2pixrJ   zAlpha-VLLM/Lumina-Image-2.0rN   z1Efficient-Large-Model/Sana_1600M_1024px_diffuserswan-t2v-1.3Bz Wan-AI/Wan2.1-T2V-1.3B-Diffuserswan-t2v-14BzWan-AI/Wan2.1-T2V-14B-Diffuserswan-i2v-14Bz$Wan-AI/Wan2.1-I2V-14B-480P-Diffuserswan-animate-14Bz#Wan-AI/Wan2.2-Animate-14B-Diffuserswan-vace-1.3Bz!Wan-AI/Wan2.1-VACE-1.3B-diffuserswan-vace-14Bz Wan-AI/Wan2.1-VACE-14B-diffuserszHiDream-ai/HiDream-I1-Devz)nvidia/Cosmos-1.0-Diffusion-7B-Text2Worldz*nvidia/Cosmos-1.0-Diffusion-14B-Text2Worldz*nvidia/Cosmos-1.0-Diffusion-7B-Video2Worldz+nvidia/Cosmos-1.0-Diffusion-14B-Video2Worldz$nvidia/Cosmos-Predict2-2B-Text2Imagez%nvidia/Cosmos-Predict2-14B-Text2Imagez%nvidia/Cosmos-Predict2-2B-Video2Worldz&nvidia/Cosmos-Predict2-14B-Video2WorldzTongyi-MAI/Z-Image-Turboz'hlky/Z-Image-Turbo-Fun-Controlnet-Unionz+hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0z+hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1zLightricks/LTX-2)rS   cosmos-1.0-t2w-7Bcosmos-1.0-t2w-14Bcosmos-1.0-v2w-7Bcosmos-1.0-v2w-14Bcosmos-2.0-t2i-2Bcosmos-2.0-t2i-14Bcosmos-2.0-v2w-2Bcosmos-2.0-v2w-14BrK   rL   z-image-turbo-controlnet-2.0z-image-turbo-controlnet-2.1ltx2-dev   i      )r   r   rY   r&   r   r(   rZ   r    rI   r   r   time_embed.0.weightztime_embed.0.biasztime_embed.2.weightztime_embed.2.biaszinput_blocks.0.0.weightzinput_blocks.0.0.biaszout.0.weightz
out.0.biaszout.2.weightz
out.2.bias)
time_embedding.linear_1.weighttime_embedding.linear_1.biastime_embedding.linear_2.weighttime_embedding.linear_2.biasconv_in.weightconv_in.biaszconv_norm_out.weightzconv_norm_out.biaszconv_out.weightzconv_out.biaszlabel_emb.0.0.weightzlabel_emb.0.0.biaszlabel_emb.0.2.weightzlabel_emb.0.2.bias)zclass_embedding.linear_1.weightzclass_embedding.linear_1.biaszclass_embedding.linear_2.weightzclass_embedding.linear_2.bias)r#   zadd_embedding.linear_1.biaszadd_embedding.linear_2.weightzadd_embedding.linear_2.bias)layersclass_embed_typeaddition_embed_typezinput_hint_block.0.weightzinput_hint_block.0.biaszinput_hint_block.14.weightzinput_hint_block.14.bias)
r   r   r   r   r   r   r!   z&controlnet_cond_embedding.conv_in.biasz)controlnet_cond_embedding.conv_out.weightz'controlnet_cond_embedding.conv_out.biasencoder.conv_in.weightencoder.conv_in.biasencoder.conv_out.weightencoder.conv_out.biaszencoder.conv_norm_out.weightzencoder.norm_out.weightzencoder.conv_norm_out.biaszencoder.norm_out.biasdecoder.conv_in.weightdecoder.conv_in.biasdecoder.conv_out.weightdecoder.conv_out.biaszdecoder.conv_norm_out.weightzdecoder.norm_out.weightzdecoder.conv_norm_out.biaszdecoder.norm_out.biasquant_conv.weightquant_conv.biaspost_quant_conv.weightpost_quant_conv.biaspositional_embeddingztoken_embedding.weightzln_final.weightzln_final.biastext_projection)z/text_model.embeddings.position_embedding.weightz,text_model.embeddings.token_embedding.weightz"text_model.final_layer_norm.weightz text_model.final_layer_norm.biastext_projection.weightz
resblocks.ln_1ln_2z.c_fc.z.c_proj.z.attnz	ln_final.)	ztext_model.encoder.layers.layer_norm1layer_norm2z.fc1.z.fc2.z
.self_attnz(transformer.text_model.final_layer_norm.z8transformer.text_model.embeddings.token_embedding.weightz;transformer.text_model.embeddings.position_embedding.weight)r   transformer)unetr    vaeopenclip)zAcond_stage_model.model.transformer.resblocks.23.attn.in_proj_biaszCcond_stage_model.model.transformer.resblocks.23.attn.in_proj_weightzBcond_stage_model.model.transformer.resblocks.23.attn.out_proj.biaszDcond_stage_model.model.transformer.resblocks.23.attn.out_proj.weightz9cond_stage_model.model.transformer.resblocks.23.ln_1.biasz;cond_stage_model.model.transformer.resblocks.23.ln_1.weightz9cond_stage_model.model.transformer.resblocks.23.ln_2.biasz;cond_stage_model.model.transformer.resblocks.23.ln_2.weightz=cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.biasz?cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weightz?cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.biaszAcond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weightz&cond_stage_model.model.text_projectionscaled_linearg_QK?g~jt?linear  epsilon      ?FT   leading)beta_schedule
beta_startbeta_endinterpolation_typenum_train_timestepsprediction_typesample_max_valueset_alpha_to_oneskip_prk_stepssteps_offsettimestep_spacingzfirst_stage_model.vae.g{P?      ?model.diffusion_model.zcontrol_model.zcond_stage_model.transformer.z$conditioner.embedders.0.transformer.r   scheduler_type)zhttps://huggingface.co/zhuggingface.co/zhf.co/zhttps://hf.co/c                       s   e Zd Zd fdd	Z  ZS )SingleFileComponentErrorNc                    s   || _ t | j  d S N)messagesuper__init__)selfr   	__class__ W/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.pyr     s   z!SingleFileComponentError.__init__r   )__name__
__module____qualname__r   __classcell__r   r   r   r   r     s    r   c                 C   s   t | }|jr|jrdS dS )NTF)r   schemenetloc)urlresultr   r   r   is_valid_url  s   r   c                 C   s0   t j| r
t| sdS t| \}}t|o|S NF)ospathisfiler   !_extract_repo_id_and_weights_namebool)rX   repo_idweight_namer   r   r   _is_single_file_path_or_url  s   r   c                 C   sv   t | stdd}d }d}tD ]}| |d} qt|| }|s%||fS |d d|d }|d}||fS )	NzOInvalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.z#([^/]+)/([^/]+)/(?:blob/main/)?(.+)r    r   /r      )r   
ValueErrorVALID_URL_PREFIXESreplacerematchgroup)rX   patternweights_namer   prefixr   r   r   r   r     s   
r   c                 C   s>   t j| |}d}ttfD ]}t jt j||rd}q|S )NFT)r   r   joinr   r   r   )cached_foldernamerX   weights_existr   r   r   r   "_is_model_weights_in_cached_folder  s   r   c                 C   s   t dd |  D S )Nc                 s   s    | ]}|t v V  qd S r   )SCHEDULER_LEGACY_KWARGS.0kr   r   r   	<genexpr>      z._is_legacy_scheduler_kwargs.<locals>.<genexpr>)anykeys)kwargsr   r   r   _is_legacy_scheduler_kwargs  s   r   c	                 C   sr   |d u r	ddd}t j| r| } nt| \}	}
t|	|
|||||||d	} t| |d}d|v r7|d }d|v s/|S )Nsingle_filepytorch)	file_type	framework)r   force_download	cache_dirproxieslocal_files_onlytokenrevision
user_agent)disable_mmap
state_dict)r   r   r   r   r   r   )pretrained_model_link_or_pathr   r   r   r   r   r   r   r   r   r   
checkpointr   r   r   load_single_file_checkpoint  s*   
r  c                 C   s   t j| r t| d}| } W d    n1 sw   Y  nt| r5|r*tdttj	| t
dj} ntdt| }|S )Nrz|`local_files_only` is set to True, but a URL was provided as `original_config_file`. Please provide a valid local file path.)timeoutzSInvalid `original_config_file` provided. Please set it to a valid file path or URL.)r   r   r   openreadr   r   r   requestsgetr   contentyaml	safe_load)original_config_filer   fporiginal_configr   r   r   fetch_original_config  s   

r  c                 C      t d | v rdS dS )Nr)   TFCHECKPOINT_KEY_NAMESr  r   r   r   is_clip_model     r  c                 C   r  )Nr*   TFr  r  r   r   r   is_clip_sdxl_model
  r  r  c                 C   r  )Nr+   TFr  r  r   r   r   is_clip_sd3_model  r  r  c                 C   r  )Nr,   TFr  r  r   r   r   is_open_clip_model  r  r  c                 C   r  )Nr-   TFr  r  r   r   r   is_open_clip_sdxl_model  r  r  c                 C   r  )Nr/   TFr  r  r   r   r   is_open_clip_sd3_model&  r  r  c                 C   r  )Nr.   TFr  r  r   r   r   is_open_clip_sdxl_refiner_model-  r  r  c                 C   sL   t t|t|t|t|t|t|g}| jdks | jdkr$|r$dS dS )NCLIPTextModelCLIPTextModelWithProjectionTF)r   r  r  r  r  r  r  r   )	class_objr  is_clip_in_checkpointr   r   r   is_clip_model_in_single_file4  s   r   c           	         s  t d  v r4 t d  jd dkr4t d  v r& t d  jd dkr&d}|S t d  v r0d	}|S d}|S t d  v rI t d  jd dkrId}|S t d
  v rSd
}|S t d  v r]d}|S t d  v rgd}|S t d  v rqd}|S t fddt d D rt d  v rt d  v rd}|S t d  v rd}|S d}|S d}|S t d  v r t d  jd dkrd}|S t d  v rʈ t d  jd dkrd}|S t d  v r߈ t d  jd dkrd}|S t d  v r t d  jd dkrd}|S t fddt d D r6t fddt d D r6d  v rd }nd!} | jd d"kr(d}|S  | jd d#kr4d$}|S t fd%dt d& D rHd&}|S t d'  v rt d(  v rZd(}|S t d)  v red)}|S t d*  v rpd*}|S  t d+  jd d,krd+}|S  t d'  jd d-krd.}|S d/}|S t fd0dt d1 D rd2}|S t fd3dt d4 D rt fd5dd6D rd7 v rd7}nd8} | jd d9krd:}|S  | jd d;krd<}|S d=}|S d>}|S t fd?dt d@ D r,dA v }tdBd  D rdC}|S |r dD jd dkrdE}|S dF v r(dG}|S d@}|S t dH  v rvdI}dJ}t dK  v rBdL}|S  | jd dMkrZ | jd dNkrZdO}|S  | jd dMkrr | jd d;krrdP}|S dQ}|S t fdRdt dS D rdS}|S t dT  v rdT}|S t fdUdt dV D rdV}|S t dW  v r t dW  jd dXkrdW}|S t fdYdt dZ D rdZ}|S t fd[dt d\ D rd\}|S t fd]dt d^ D rd^}|S t fd_dt d` D r_da v rda}ndb}t dc  v r* | jd dkrdd}n | jd dekr*df}t dg  v r5dh}|S  | jd dkrCdi}|S  | jd dekr[ | jd djkr[dk}|S dl}|S t dm  v rjdk}|S t dn  v rudn}|S t fdodt dp D r t dp d  j}|d dqkr|d drkrdsndt}|S |d dukr|d drkrdvndw}|S tdx| dyt fdzdt d{ D r t d{ d  j}|d dqkr|d dkrd|nd}}|S |d dukr|d dkrd~nd}|S tdx| dt d  v r. dd }|d u rd}|S |d ur*t|dkr*d}|S d}|S t d  v r9d}|S t fddt d D rKd}|S d}|S )Nr(   r   	   r   r   rZ   r   rY   r&   r   r   c                 3       | ]}| v V  qd S r   r   r   keyr  r   r   r   b  r   z-infer_diffusers_model_type.<locals>.<genexpr>r    r"   r$   r%   r[   r1   r   i   ra      r0   i@  r^   i  c                 3   r#  r   r   r$  r  r   r   r     r   r3   c                 3   s.    | ]}| v r | j d  dkndV  qdS )r"  i $  FNshaper$  r  r   r   r     s    
zmodel.diffusion_model.pos_embed	pos_embedi   i @ rc   c                 3   r#  r   r   r$  r  r   r   r     r   r4   r5   r8   r9   r6   r7   i@     rd   re   c                 3   r#  r   r   r$  r  r   r   r     r   rV   rj   c                 3   r#  r   r   r$  r  r   r   r     r   r=   c                 3   r#  r   r   )r   gr  r   r   r     s    
)guidance_in.in_layer.biasz/model.diffusion_model.guidance_in.in_layer.biasz#model.diffusion_model.img_in.weightimg_in.weighti  rg      rh   rf   ri   c                 3   r#  r   r   r$  r  r   r   r     r   r>   zvae.encoder.conv_in.conv.biasc                 s       | ]}| d V  qdS )z'transformer_blocks.47.scale_shift_tableN)endswithr$  r   r   r   r         rm    vae.encoder.conv_out.conv.weightrl   @vae.decoder.last_time_embedder.timestep_embedder.linear_1.weightrk   r?   z!encoder.project_in.conv.conv.biasz#decoder.project_in.main.conv.weightr@   rq   @       rp   ro   rn   c                 3   r#  r   r   r$  r  r   r   r     r   rB   rC   c                 3   r#  r   r   r$  r  r   r   r     r   r:   rI      c                 3   r#  r   r   r$  r  r   r   r     r   rK   c                 3   r#  r   r   r$  r  r   r   r     r   rJ   c                 3   r#  r   r   r$  r  r   r   r     r   rN   c                 3   r#  r   r   r$  r  r   r   r     r   rO   z,model.diffusion_model.patch_embedding.weightzpatch_embedding.weightrQ   rv   i   rw   rR   ru   rr      rs   rt   rP   rS   c                 3   r#  r   r   r$  r  r   r   r     r   rT   D   i   rx   ry   H   rz   r{   zUnexpected x_embedder shape: z when loading Cosmos 1.0 model.c                 3   r#  r   r   r$  r  r   r   r     r   rU   r|   r}   r~   r   z when loading Cosmos 2.0 model.rM   z*control_noise_refiner.0.before_proj.weightr   g        r   rL   c                 3   r#  r   r   r$  r  r   r   r   +  r   rW   r   r   )r  r(  r   allr   r  torch)	r  
model_typer%  has_vaeencoder_keydecoder_key
target_keyx_embedder_shapebefore_proj_weightr   r  r   infer_diffusers_model_typeG  s  " e  c  a " ^  [  X  U  R  M  K  I  G  A  ;  5  / (

 $  "                	 

|zxvqo
mkd(a(^[XURLIFC
41(/-)&!
rC  c                 C   s   t | }t| }t|}|S r   )rC   DIFFUSERS_DEFAULT_PIPELINE_PATHScopydeepcopy)r  r<  
model_pathr   r   r   fetch_diffusers_config4  s   
rH  c                 C   s   |r|S t | }t| }|S r   )rC  'DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP)r  
image_sizer<  r   r   r   set_image_size<  s
   rK  c                 C   s   t |  }g d}|D ]B}d|ddd  |v r4| | jdkr3| | d d d d ddf | |< qd|v rN| | jdkrN| | d d d d df | |< qd S )N)zquery.weightz
key.weightzvalue.weight.r   r   zproj_attn.weight)listr   r   splitndim)r  r   	attn_keysr%  r   r   r   conv_attn_to_linearG  s    rR  c                    s"  |durd}t dd| t||d}d| d d v r/| d d d dur/| d d d d n
| d d d	 d |durHd
}t dd| |}nd }| d d d d d }fddd D }g }	d}
tt|D ]}|
d v rwdnd}|	| |t|d kr|
d9 }
qmg }tt|D ]}|
d v rdnd}|| |
d }
qd durtd trd ntd }nd}dt|d d  }dv rՈd nd}dv r߈d nd}|r|du rd d    fddtd D }d}d}d}d}d}d  durtd  trd  nd  d! }d"v rBd" d#krB|d$v r5d%}d&}nd'}d(v s>J d( }|| ||	|d) ||||||||d*}|durgd+}t dd| ||d,< d-v rrd- |d.< d"v rtd" trd" |d/< d0 |d0< ||d1< |S )2R
    Creates a config for the diffusers based on the config of the LDM model.
    NzConfiguring UNet2DConditionModel with the `image_size` argument to `from_single_file`is deprecated and will be ignored in future versions.rJ  1.0.0rJ  unet_configmodelparamsnetwork_configzConfiguring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`is deprecated and will be ignored in future versions.in_channelsfirst_stage_configddconfigc                       g | ]} d  | qS )model_channelsr   r   mult)unet_paramsr   r   
<listcomp>u      z9create_unet_diffusers_config_from_ldm.<locals>.<listcomp>channel_multr   attention_resolutionsCrossAttnDownBlock2DDownBlock2Dr   CrossAttnUpBlock2D	UpBlock2Dtransformer_depthch_mult	num_headsuse_linear_in_transformerFr^  num_head_channelsc                    s   g | ]} | qS r   r   )r   c)head_dim_multr   r   rb    s    context_dimr   num_classes
sequential)r&  i   	text_time   
projectionadm_in_channelsnum_res_blocks)sample_sizerZ  down_block_typesblock_out_channelslayers_per_blockcross_attention_dimattention_head_dimuse_linear_projectionr   r   addition_time_embed_dim%projection_class_embeddings_input_dimtransformer_layers_per_blockzConfiguring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`is deprecated and will be ignored in future versions.upcast_attentiondisable_self_attentionsonly_cross_attentionnum_class_embedsout_channelsup_block_types)r   rK  rangelenappend
isinstanceintrN  )r  r  rJ  r  num_in_channelsdeprecation_messagerZ  
vae_paramsr{  rz  
resolutioni
block_typer  r  vae_scale_factorhead_dimr  r   r   r  r  rq  configr   )rp  ra  r   %create_unet_diffusers_config_from_ldmS  s   









r  c                 K   s   |d urd}t dd| t||d}| d d d d }t| |d}|d |d	 |d
 |d |d |d |d |d |d |d |d |d |d d}|S )NzoConfiguring ControlNetModel with the `image_size` argumentis deprecated and will be ignored in future versions.rJ  rT  rU  rW  rX  control_stage_confighint_channelsrZ  rz  r{  r|  r}  r~  r  r   r   r  r  r  )conditioning_channelsrZ  rz  r{  r|  r}  r~  r  r   r   r  r  r  )r   rK  r  )r  r  rJ  r   r  ra  diffusers_unet_configcontrolnet_configr   r   r   +create_controlnet_diffusers_config_from_ldm  s,   r  c              
      s8  |durd}t dd| t||d}d|v r#d|v r#|d }|d }nd}d}| d d	 d
 d	 d  |du rB|durB|durBt}n|du rWd| d d	 v rW| d d	 d }n|du r]t} fdd d D }dgt| }dgt| }	| d  d ||	| d  d |d	}
|dur|dur|
||d |
S )rS  NzmConfiguring AutoencoderKL with the `image_size` argumentis deprecated and will be ignored in future versions.rJ  rT  rU  r'   edm_stdrW  rX  r[  r\  scale_factorc                    r]  )chr   r_  r  r   r   rb    rc  z8create_vae_diffusers_config_from_ldm.<locals>.<listcomp>rk  DownEncoderBlock2DUpDecoderBlock2DrZ  out_ch
z_channelsrx  )	ry  rZ  r  rz  r  r{  latent_channelsr|  scaling_factor)latents_meanlatents_std)r   rK  PLAYGROUND_VAE_SCALING_FACTORLDM_VAE_DEFAULT_SCALING_FACTORr  update)r  r  rJ  r  r  r  r  r{  rz  r  r  r   r  r   $create_vae_diffusers_config_from_ldm  sB   
r  c                 C   sh   | D ]/}| dd dd dd dd d	d
 dd}|r*| |d |d }||||< qd S )Nzin_layers.0norm1zin_layers.2conv1zout_layers.0norm2zout_layers.3conv2zemb_layers.1time_emb_projskip_connectionconv_shortcutoldnewr   r  ldm_keysnew_checkpointr  mappingldm_keydiffusers_keyr   r   r   #update_unet_resnet_ldm_to_diffusers)  s   
r  c                 C   s0   | D ]}| |d |d }||||< qd S )Nr  r  r  r  r   r   r   &update_unet_attention_ldm_to_diffusers8  s   r  c                 C   s8   | D ]}| |d |d  dd}||||< qd S )Nr  r  nin_shortcutr  r  )r   r  r  r  r  r  r   r   r   "update_vae_resnet_ldm_to_diffusers>  s   r  c                 C   s   | D ]l}| |d |d  dd dd dd d	d
 dd dd dd dd dd dd}||||< || j}t|dkrX|| d d d d df ||< qt|dkrn|| d d d d ddf ||< qd S )Nr  r  znorm.weightzgroup_norm.weightz	norm.biaszgroup_norm.biaszq.weightto_q.weightzq.bias	to_q.biaszk.weightto_k.weightzk.bias	to_k.biaszv.weightto_v.weightzv.bias	to_v.biasproj_out.weightto_out.0.weightproj_out.biasto_out.0.biasr   r      )r   r  r(  r  )r   r  r  r  r  r  r(  r   r   r   &update_vae_attentions_ldm_to_diffusersD  s*   
  r  c                 K   s  d| v }|ri }|   D ]}|dr:| | dd}|d ||dd< |d ||dd< |d	 ||dd
< q|drf| | dd}|d ||dd< |d ||dd< |d	 ||dd< q|drx| | }|||dd< q|dr| | }|||dd< q| | ||< q|S i }|   D ]}|dr| | dd}|d ||dd< |d ||dd< |d	 ||dd
< q|dr| | dd}|d ||dd< |d ||dd< |d	 ||dd< q|dr| | }|||dd< q|dr| | }|||dd< q|dr,| | }|||dd< q|dr?| | }|||dd< q| | ||< q|S )Nr2   in_proj_weightr   r   zattn.in_proj_weightr  r   r  r   r  in_proj_biaszattn.in_proj_biasr  r  r  zout_proj.weightzattn.out_proj.weightr  zout_proj.biaszattn.out_proj.biasr  zclip_mapper.weightzclip_txt_pooled_mapper.weightzclip_mapper.biaszclip_txt_pooled_mapper.bias)r   r0  chunkr   )r  r   
is_stage_cr   r%  weightsr   r   r   4convert_stable_cascade_unet_single_file_to_diffusers^  s`   





r  c              
      s  i t |  }t}tdd |D dkrE|rEtd td |D ] }|drCdd|d	d
d  }| 	||
|d< q#n%tdd |D dkrUtd |D ]}||ri| 	||
|d< qWi }td d }	|	 D ]\}
}|vrqv| ||
< qvd|v r|d dv rtd d }| D ]
\}
}| ||
< qd|v r|d dkrtd d }| D ]
\}
}| ||
< qd|v r|d durdv r؈d |d< tdd D }fddt|D }tdd D }fddt|D }tdd D }fddt|D }td
|D ]  d
 |d d
  } d
 |d d
  } fd d!|  D }t||d"  d#d$| d%| d& d"  d'v ru	d"  d'|d$| d(< 	d"  d)|d$| d*<  fd+d!|  D }|rt||d"  d,d$| d-| d& q| D ]9}t|d
 d.}
|d/ d.krt|| |d0| d1|
 d&d2 qt|| |d0| d3|
 d&d2 qt|D ]  |d d
  } |d d
  } fd4d!|  D }t||d5  d#d6| d%| d&  fd7d!|  D }|r.t||d5  d,d6| d-| d& d5  d8v rSd5  d8 |d6| d9< d5  d: |d6| d;< d5  d<v rxd5  d< |d6| d9< d5  d= |d6| d;< q|S )>zN
    Takes a state dict and a config, and returns a converted checkpoint.
    c                 s   r/  	model_emaN
startswithr   r   r   r   r     r1  z.convert_ldm_unet_checkpoint.<locals>.<genexpr>d   z,Checkpoint has both EMA and non-EMA weights.zIn this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag.zmodel.diffusion_modelz
model_ema.r   rL  r   Nc                 s   r/  r  r  r   r   r   r   r     r1  zIn this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA weights (usually better for inference), please make sure to add the `--extract_ema` flag.r   r   r   )timesteprv  r   rt  r  zlabel_emb.weightzclass_embedding.weightc                 S   ,   h | ]}d |v rd |ddd qS input_blocksrL  Nr   r   rO  r   layerr   r   r   	<setcomp>     , z.convert_ldm_unet_checkpoint.<locals>.<setcomp>c                        i | ]   fd dD qS )c                       g | ]}d   |v r|qS input_blocks.r   r$  layer_idr   r   rb        :convert_ldm_unet_checkpoint.<locals>.<dictcomp>.<listcomp>r   r   unet_state_dictr  r   
<dictcomp>      z/convert_ldm_unet_checkpoint.<locals>.<dictcomp>c                 S   r  middle_blockrL  Nr   r  r  r   r   r   r    r  c                    r  )c                    r  middle_block.r   r$  r  r   r   rb    r  r  r   r  r  r  r   r    r  c                 S   r  )output_blocksrL  Nr   r  r  r   r   r   r    r  c                    r  )c                    r  )output_blocks.r   r$  r  r   r   rb    r  r  r   r  r  r  r   r    r  r|  c                    0   g | ]}d   d|v rd   d|vr|qS r  .0.0.opr   r$  r  r   r   rb        *z/convert_ldm_unet_checkpoint.<locals>.<listcomp>r  r  down_blocks.	.resnets.r  r  .0.op.weight.downsamplers.0.conv.weight
.0.op.bias.downsamplers.0.conv.biasc                        g | ]}d   d|v r|qS r  .1r   r$  r  r   r   rb         r  .attentions.r   r   r  mid_block.resnets.r  mid_block.attentions.c                    r  )r  r  r  r   r$  r  r   r   rb    r  r  
up_blocks.c                    r  )r  r  z.1.convr   r$  r  r   r   rb     r  z.1.conv.weight.upsamplers.0.conv.weightz.1.conv.bias.upsamplers.0.conv.biasz.2.conv.weightz.2.conv.bias)rN  r   LDM_UNET_KEYsumloggerwarningr  r   rO  r  r   DIFFUSERS_TO_LDM_MAPPINGitemsr  r  r  r  max)r  r  extract_emar   r   unet_keyr%  flat_ema_keyr  ldm_unet_keysr  r  class_embed_keysaddition_embed_keysnum_input_blocksr  num_middle_blocksmiddle_blocksnum_output_blocksr  block_idlayer_in_block_idresnets
attentionsr   )r  r  r   convert_ldm_unet_checkpoint  s  
















r"  c              
      sB  d| v r| S d| v r|  ni  t |  }t}|D ]}||r+| | ||d< qi }td d }| D ]\}}	|	 vrAq8 |	 ||< q8tdd  D }
 fdd	t	|
D }t	d
|
D ]|d
 |d d
  }d
 |d d
  }fdd| D }t
|| d dd| d| d d d v r d d|d| d<  d d|d| d< fdd| D }|rt|| d dd| d| d qat	|
D ]  d d|d d<  d d|d d< qtd d  D } fd!d	t	|D }| D ]9}t|d
 d"}|d# d"kr@t
|| | d$| d%| dd& qt|| | d$| d'| dd& q d(|d)<  d*|d+< d,d  D }t|}t	d
|d
 D ])}|d
 }d#| } d-| d|d.| d<  d-| d|d.| d< qu|S )/Nr   r   r   r    r   c                 S   r  r  r  r  r   r   r   r  [  r  z0convert_controlnet_checkpoint.<locals>.<setcomp>c                    r  )c                    r  r  r   r$  r  r   r   rb  ^  r  <convert_controlnet_checkpoint.<locals>.<dictcomp>.<listcomp>r   r  controlnet_state_dictr  r   r  ]  r  z1convert_controlnet_checkpoint.<locals>.<dictcomp>r   r|  c                    r  r  r   r$  r  r   r   rb  g  r  z1convert_controlnet_checkpoint.<locals>.<listcomp>r  r  r  r  r  r  r  r   r  c                    r  r  r   r$  r  r   r   rb  y  r  r  r  zzero_convs.z	.0.weightzcontrolnet_down_blocks..weightz.0.bias.biasc                 S   r  r  r  r  r   r   r   r    r  c                    r  )c                    r  r  r   r$  r  r   r   rb    r  r#  r   r  r$  r  r   r    r  r   r   r  r  r  r	  zmiddle_block_out.0.weightzcontrolnet_mid_block.weightzmiddle_block_out.0.biaszcontrolnet_mid_block.biasc                 S   s<   h | ]}d |v rd|vrd|vrd |ddd qS )input_hint_blockzinput_hint_block.0zinput_hint_block.14rL  Nr   r  r  r   r   r   r    s
    zinput_hint_block.z!controlnet_cond_embedding.blocks.)rN  r   LDM_CONTROLNET_KEYr  r  r   r  r  r  r  r  r  r  )r  r  r   r   controlnet_keyr%  r  ldm_controlnet_keysr  r  r  r  r  r  r   r!  r  r  cond_embedding_blocksnum_cond_embedding_blocksidxdiffusers_idxcond_block_idr   )r%  r  r   convert_controlnet_checkpoint=  s   




 




r1  c              	      s"  i t |  }d}tD ]tfdd|D r}q|D ]}||r0| |||d< qi }td }| D ]\}}|vrDq;| ||< q;t	|d }	fddt
|	D }
t
|	D ]Ffdd	|
 D }t||d
 dd ddd d dv rd d|d d< d d|d d< q`dd	 D }d}t
d|d D ]fdd	|D }t||d dd  dd qdd	 D }t||dddd t	|d }fd dt
|D }t
|D ]K|d    fd!d	|  D }t||d"  dd# ddd d$  d%v rHd$  d% |d& d'< d$  d( |d& d)< qd*d	 D }d}t
d|d D ]fd+d	|D }t||d dd  dd qYd,d	 D }t||dddd t| |S )-Nr   c                 3   s    | ]}|  V  qd S r   r  r   )ldm_vae_keyr   r   r     r1  z-convert_ldm_vae_checkpoint.<locals>.<genexpr>r   rz  c                    r  )c                    r  )down.r   r$  r  r   r   rb    r  9convert_ldm_vae_checkpoint.<locals>.<dictcomp>.<listcomp>r   r  vae_state_dictr  r   r        z.convert_ldm_vae_checkpoint.<locals>.<dictcomp>c                    .   g | ]}d   |v rd   d|vr|qS )r3  z.downsampler   r$  r  r   r   rb    s   . z.convert_ldm_vae_checkpoint.<locals>.<listcomp>r3  z.blockr  z.resnetsr  r  zencoder.down.z.downsample.conv.weightencoder.down_blocks.r  z.downsample.conv.biasr  c                 S      g | ]}d |v r|qS )zencoder.mid.blockr   r$  r   r   r   rb    rc  r   r   c                    r  )zencoder.mid.block_r   r$  r  r   r   rb    r  z
mid.block_r  c                 S   r:  )zencoder.mid.attnr   r$  r   r   r   rb    rc  z
mid.attn_1zmid_block.attentions.0r  c                    r  )c                    r  )up.r   r$  r  r   r   rb    r  r4  r   r  r5  r  r   r    r7  c                    r8  )r;  z	.upsampler   r$  )r  r   r   rb    s    (r;  r
  zdecoder.up.z.upsample.conv.weightdecoder.up_blocks.r  z.upsample.conv.biasr  c                 S   r:  )zdecoder.mid.blockr   r$  r   r   r   rb    rc  c                    r  )zdecoder.mid.block_r   r$  r  r   r   rb    r  c                 S   r:  )zdecoder.mid.attnr   r$  r   r   r   rb    rc  )rN  r   LDM_VAE_KEYSr   r  r  r   r  r  r  r  r  r  rR  )r  r  r   vae_keyr%  r  vae_diffusers_ldm_mapr  r  num_down_blocksdown_blocksr   mid_resnetsnum_mid_res_blocksmid_attentionsnum_up_blocks	up_blocksr   )r  r  r2  r6  r   convert_ldm_vae_checkpoint  s   








rG  c                 C   sh   t |  }i }g }|t |r|| |D ]}|D ]}||r0||d}| |||< qq|S )Nr   )rN  r   extendLDM_CLIP_PREFIX_TO_REMOVEr  r  r   r  )r  remove_prefixr   text_model_dictremove_prefixesr%  r   r  r   r   r   convert_ldm_clip_checkpoint$  s   


rM  cond_stage_model.model.c                 C   s4  i }|d }||v rt || jd }nt| jdr| jj}nt}t| }t}t	d d }|
 D ]'\}	}
||
 }
|
|vr@q3|
|v rEq3|
drT||
 j ||	< q3||
 ||	< q3|D ]}||v rdq]||d slq]||d d}	t	d d }|
 D ]\}}|	||d	dd
d}	q~|d	r||}|d |d d f   ||	d < |||d d d f   ||	d < ||d d d d f   ||	d < q]|d
r||}|d |   ||	d < |||d    ||	d < ||d d    ||	d < q]||||	< q]|S )Nr   r   hidden_sizer   r   ztransformer.r   r   z.in_proj_weightz.in_proj_biasz.q_proj.weightr   z.k_proj.weightz.v_proj.weightz.q_proj.biasz.k_proj.biasz.v_proj.bias)r  r(  hasattrr  rO  !LDM_OPEN_CLIP_TEXT_PROJECTION_DIMrN  r    SD_2_TEXT_ENCODER_KEYS_TO_IGNOREr  r  r0  T
contiguousr  r   r  clonedetach)
text_modelr  r   rK  text_proj_keytext_proj_dimr   keys_to_ignoreopenclip_diffusers_ldm_mapr  r  r%   transformer_diffusers_to_ldm_mapnew_keyold_keyweight_valuer   r   r   convert_open_clip_checkpoint6  sV   



$
*

"r`  r   c                 C   s4  |rd|i}nt |}|r8td t|st|r#d}||d< d}nt|r0d}||d< d}nd}||d< d}| jjdi |||d}t rKt	nt
}	|	  | |}
W d    n1 s_w   Y  |
jjjjjd	 }t|rvt|}n~t|r|td
  jd	 |krt|}njt|r|td  jd	 |krt|d}t||d< nNt|rd}t|
||d}n@t|r|td  jd	 |krd}t|
||d}n't|rd}t|
||d}nt|r|td  jd	 |krt|d}ntdt rt|
||d t  n|
j|dd |d ur|
| |
  |
S )NrX   zDetected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update the local cache directory with the necessary CLIP model config files. Attempting to load CLIP model from legacy cache directory.openai/clip-vit-large-patch14r   stabilityai/stable-diffusion-2text_encoder(laion/CLIP-ViT-bigG-14-laion2B-39B-b160kr]   r   r"  r*   r+   z!text_encoders.clip_l.transformer.r   rN  )r   r-   zconditioner.embedders.1.model.zconditioner.embedders.0.model.r/   z!text_encoders.clip_g.transformer.zDThe provided checkpoint does not seem to contain a valid CLIP model.dtypeF)strictr   )rH  r  r  r  r  r  config_classfrom_pretrainedr   r   r   rW  
embeddingsposition_embeddingweightr(  rM  r  r  r;  eyer`  r  r  r  r   r   r   r   toeval)clsr  r]   r  torch_dtyper   is_legacy_loadingclip_configmodel_configctxrW  position_embedding_dimdiffusers_format_checkpointr   r   r   r   $create_diffusers_clip_model_from_ldmx  sv   	






ry  c                 K   sb  | dd }| dd }|d urd}tdd| |d ur$d}tdd| t}t|d}	d|v r3|d nd }
|rBt|d d	 d
d}nd}||d< |	dkrY|d u rX|
dkrVdnd}n|p\d}||d< |	dv rhd}n4|	dkrod}n-|r|d d	  d}|d d	  d}nd}d}||d< ||d< d|d< d|d< d|d< |d kr| ddd!d"ddd d#d$S |d u r| |S |d%krd"|d&< t|}|S |d'krt|}|S |d(krt|}|S |dkrt	|}|S |d)krt
|}|S |d*krt|}|S |d+kr	t|}|S |dkr)d,d-dd.d"ddd/d0d1d2d3d4d5dd6}td9i |}|S td7| d8):Nr   r   a&  Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`

Example:

from diffusers import StableDiffusionPipeline, DDIMScheduler

scheduler = DDIMScheduler()
pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)
rT  a~  Please configure an instance of a Scheduler with the appropriate `prediction_type` and pass the object directly to the `scheduler` argument in `from_single_file`.

Example:

from diffusers import StableDiffusionPipeline, DDIMScheduler

scheduler = DDIMScheduler(prediction_type="v_prediction")
pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)
r  global_steprW  rX  	timestepsr   r   r   iY r   v_prediction)r   r   euler
playgroundedm_dpm_solver_multisteplinear_start
linear_endg{Gz?g(\µ?r   r   r   r   Fclip_sampler   low_res_schedulerg-C6?Tfixed_small)r   r   r   r  r   r   trained_betasvariance_typepndmr   lmsheunzeuler-ancestraldpmddimzdpmsolver++gףp=
?zerog      @r   r   g      T@gMb`?r   midpoint)algorithm_typedynamic_thresholding_ratioeuler_at_finalfinal_sigmas_typelower_order_finalr   r   rhor   
sigma_data	sigma_max	sigma_minsolver_ordersolver_typethresholdingzScheduler of type z doesn't exist!r   )r  r   SCHEDULER_DEFAULT_CONFIGrC  getattrfrom_configr   r   r   r   r
   r   r   r	   r   )rq  r  component_namer  r   r   r   r  scheduler_configr<  rz  r   r   r   	schedulerr   r   r   _legacy_load_scheduler  s   


+
(
%
"




r  c                 C   s   |rd|i}nt |}t|st|rd}||d< d}nt|r)d}||d< d}nd}||d< d}| jdi |||d}|S )	NrX   ra  r   rb  	tokenizerrd  re  r   )rH  r  r  r  rj  )rq  r  r  r   rt  r]   r  r   r   r   _legacy_load_clip_tokenizerh  s    
r  c                 C   s6   ddl m} tjd| |d}|jd| |d}||dS )Nr   )StableDiffusionSafetyCheckerz'CompVis/stable-diffusion-safety-checker)r   rr  )safety_checkerfeature_extractor))pipelines.stable_diffusion.safety_checkerr  r   rj  )r   rr  r  r  r  r   r   r   _legacy_load_safety_checker  s   
r  c                 C   s(   | j ddd\}}tj||gdd}|S Nr   r   dimr  r;  cat)rm  r  shiftscale
new_weightr   r   r   swap_scale_shift     r  c                 C   (   | j ddd\}}tj||gdd}|S r  r  )rm  projgater  r   r   r   swap_proj_gate  r  r  c                 C   sF   g }|   D ]}d|v rt|dd }|| qttt|S )Nzattn2.rL  r   )r   r  rO  r  tuplesortedset)r   attn2_layersr%  	layer_numr   r   r   get_attn2_layers  s   
r  c                 C   s   | d j d }|S )Ncontext_embedder.weightr   r'  )r   caption_projection_dimr   r   r   get_caption_projection_dim  s   r  c                 K   st  i }t |  }|D ]}d|v r| || |dd< q
t tdd | D d d }t| }t| }tdd |  D }| d|d	< | d
|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< t|D ]}	t	j
| d |	 d!d"d#d$\}
}}t	j
| d |	 d%d"d#d$\}}}t	j
| d |	 d&d"d#d$\}}}t	j
| d |	 d'd"d#d$\}}}t	|
g|d(|	 d)< t	|g|d(|	 d*< t	|g|d(|	 d+< t	|g|d(|	 d,< t	|g|d(|	 d-< t	|g|d(|	 d.< t	|g|d(|	 d/< t	|g|d(|	 d0< t	|g|d(|	 d1< t	|g|d(|	 d2< t	|g|d(|	 d3< t	|g|d(|	 d4< |r| d |	 d5|d(|	 d6< | d |	 d7|d(|	 d8< | d |	 d9|d(|	 d:< | d |	 d;|d(|	 d<< | d |	 d=|d(|	 d>< | d |	 d?|d(|	 d@< |	|d ks| d |	 dA|d(|	 dB< | d |	 dC|d(|	 dD< |	|v rt	j
| d |	 dEd"d#d$\}}}t	j
| d |	 dFd"d#d$\}}}t	|g|d(|	 dG< t	|g|d(|	 dH< t	|g|d(|	 dI< t	|g|d(|	 dJ< t	|g|d(|	 dK< t	|g|d(|	 dL< |r| d |	 dM|d(|	 dN< | d |	 dO|d(|	 dP< | d |	 dQ|d(|	 dR< | d |	 dS|d(|	 dT< | d |	 dU|d(|	 dV< | d |	 dW|d(|	 dX< |	|d ks| d |	 dY|d(|	 dZ< | d |	 d[|d(|	 d\< n&t| d |	 dY|d$|d(|	 dZ< t| d |	 d[|d$|d(|	 d\< | d |	 d]|d(|	 d^< | d |	 d_|d(|	 d`< | d |	 da|d(|	 db< | d |	 dc|d(|	 dd< |	|d ks| d |	 de|d(|	 df< | d |	 dg|d(|	 dh< | d |	 di|d(|	 dj< | d |	 dk|d(|	 dl< q| dm|dn< | do|dp< t| dq|d$|dr< t| ds|d$|dt< |S )uNr   r   c                 s   ,    | ]}d |v rt |ddd V  qdS )joint_blocksrL  r   r   Nr  rO  r   r   r   r   r        * zBconvert_sd3_transformer_checkpoint_to_diffusers.<locals>.<genexpr>r"  r   c                 s       | ]}d |v V  qdS )ln_qNr   r$  r   r   r   r     r   r)  pos_embed.pos_embedx_embedder.proj.weightpos_embed.proj.weightx_embedder.proj.biaspos_embed.proj.biasr<   1time_text_embed.timestep_embedder.linear_1.weightt_embedder.mlp.0.bias/time_text_embed.timestep_embedder.linear_1.biast_embedder.mlp.2.weight1time_text_embed.timestep_embedder.linear_2.weightt_embedder.mlp.2.bias/time_text_embed.timestep_embedder.linear_2.biasr  context_embedder.biaszy_embedder.mlp.0.weight-time_text_embed.text_embedder.linear_1.weightzy_embedder.mlp.0.bias+time_text_embed.text_embedder.linear_1.biaszy_embedder.mlp.2.weight-time_text_embed.text_embedder.linear_2.weightzy_embedder.mlp.2.bias+time_text_embed.text_embedder.linear_2.biaszjoint_blocks.z.x_block.attn.qkv.weightr   r   r  z.context_block.attn.qkv.weightz.x_block.attn.qkv.biasz.context_block.attn.qkv.biastransformer_blocks..attn.to_q.weight.attn.to_q.bias.attn.to_k.weight.attn.to_k.bias.attn.to_v.weight.attn.to_v.biasz.attn.add_q_proj.weightz.attn.add_q_proj.biasz.attn.add_k_proj.weightz.attn.add_k_proj.biasz.attn.add_v_proj.weightz.attn.add_v_proj.biasz.x_block.attn.ln_q.weightz.attn.norm_q.weightz.x_block.attn.ln_k.weightz.attn.norm_k.weightz.context_block.attn.ln_q.weightz.attn.norm_added_q.weightz.context_block.attn.ln_k.weightz.attn.norm_added_k.weightz.x_block.attn.proj.weightz.attn.to_out.0.weightz.x_block.attn.proj.biasz.attn.to_out.0.biasz.context_block.attn.proj.weightz.attn.to_add_out.weightz.context_block.attn.proj.biasz.attn.to_add_out.biasz.x_block.attn2.qkv.weightz.x_block.attn2.qkv.bias.attn2.to_q.weight.attn2.to_q.bias.attn2.to_k.weight.attn2.to_k.bias.attn2.to_v.weight.attn2.to_v.biasz.x_block.attn2.ln_q.weightz.attn2.norm_q.weightz.x_block.attn2.ln_k.weightz.attn2.norm_k.weightz.x_block.attn2.proj.weight.attn2.to_out.0.weightz.x_block.attn2.proj.bias.attn2.to_out.0.biasz".x_block.adaLN_modulation.1.weight.norm1.linear.weightz .x_block.adaLN_modulation.1.biasz.norm1.linear.biasz(.context_block.adaLN_modulation.1.weightz.norm1_context.linear.weightz&.context_block.adaLN_modulation.1.biasz.norm1_context.linear.biasz.x_block.mlp.fc1.weightz.ff.net.0.proj.weightz.x_block.mlp.fc1.biasz.ff.net.0.proj.biasz.x_block.mlp.fc2.weightz.ff.net.2.weightz.x_block.mlp.fc2.biasz.ff.net.2.biasz.context_block.mlp.fc1.weightz.ff_context.net.0.proj.weightz.context_block.mlp.fc1.biasz.ff_context.net.0.proj.biasz.context_block.mlp.fc2.weightz.ff_context.net.2.weightz.context_block.mlp.fc2.biasz.ff_context.net.2.biasfinal_layer.linear.weightr  final_layer.linear.biasr  %final_layer.adaLN_modulation.1.weightnorm_out.linear.weight#final_layer.adaLN_modulation.1.biasnorm_out.linear.bias)rN  r   popr   r  r  r  r   r  r;  r  r  r  )r  r   converted_state_dictr   r   
num_layersdual_attention_layersr  has_qk_normr  sample_qsample_ksample_v	context_q	context_k	context_vsample_q_biassample_k_biassample_v_biascontext_q_biascontext_k_biascontext_v_bias	sample_q2	sample_k2	sample_v2sample_q2_biassample_k2_biassample_v2_biasr   r   r   /convert_sd3_transformer_checkpoint_to_diffusers  s@  




























r  c                 C   s   d| v rdS dS )Nz-text_encoders.t5xxl.transformer.shared.weightTFr   r  r   r   r   is_t5_in_single_filer  s   r  c                 C   sR   t |  }i }dg}|D ]}|D ]}||r%||d}| |||< qq|S )Nz text_encoders.t5xxl.transformer.r   )rN  r   r  r   r  )r  r   rK  rL  r%  r   r  r   r   r   &convert_sd3_t5_checkpoint_to_diffusersy  s   
r  c                    s  |rd|i}nt |}| jjdi |||d}t rtnt}|  | |}W d    n1 s2w   Y  t|}	t rIt||	|d t  n|	|	 | j
d uoW|tjk}
|
r^|j
}ng }|d ur| D ]\ }t fdd|D r|jtj|_qh|S )NrX   re  rf  c                 3   s    | ]
}|  d v V  qdS )rL  N)rO  )r   module_to_keep_in_fp32r   r   r   r     s    z<create_diffusers_t5_model_from_checkpoint.<locals>.<genexpr>r   )rH  ri  rj  r   r   r   r  r   r   r   _keep_in_fp32_modulesr;  float16named_parametersr   dataro  float32)rq  r  r]   r  rr  r   ru  rv  rW  rx  use_keep_in_fp32_moduleskeep_in_fp32_modulesparamr   r  r   )create_diffusers_t5_model_from_checkpoint  s.   


r  c                 K   s\   i }|   D ]%\}}d|v rq|||dddddddd	d
ddd< q|S )Npos_encoderz.norms.0z.norm1z.norms.1z.norm2z.ff_normz.norm3z.attention_blocks.0.attn1z.attention_blocks.1.attn2z.temporal_transformerr   )r  r   )r  r   r  r   vr   r   r   +convert_animatediff_checkpoint_to_diffusers  s   

	r  c           "      K   sz  i }t |  }|D ]}d|v r| || |dd< q
t tdd | D d d }t tdd | D d d }d}d	}d
d }	| d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< tdd | D }
|
r| d|d< | d|d < | d!|d"< | d#|d$< | d%|d&< | d'|d(< | d)|d*< | d+|d,< t|D ]}d-| d.}| d/| d0|| d1< | d/| d2|| d3< | d/| d4|| d5< | d/| d6|| d7< tj| d/| d8d9d:d;\}}}tj| d/| d<d9d:d;\}}}tj| d/| d=d9d:d;\}}}tj| d/| d>d9d:d;\}}}t	|g|| d?< t	|g|| d@< t	|g|| dA< t	|g|| dB< t	|g|| dC< t	|g|| dD< t	|g|| dE< t	|g|| dF< t	|g|| dG< t	|g|| dH< t	|g|| dI< t	|g|| dJ< | d/| dK|| dL< | d/| dM|| dN< | d/| dO|| dP< | d/| dQ|| dR< | d/| dS|| dT< | d/| dU|| dV< | d/| dW|| dX< | d/| dY|| dZ< | d/| d[|| d\< | d/| d]|| d^< | d/| d_|| d`< | d/| da|| db< | d/| dc|| dd< | d/| de|| df< | d/| dg|| dh< | d/| di|| dj< qt|D ]}dk| d.}| dl| dm|| dn< | dl| do|| dp< t
|| }||||f}tj| dl| dq|d:d;\}}}}tj| dl| dr|d:d;\}}} }!t	|g|| d?< t	|g|| d@< t	|g|| dA< t	|g|| dB< t	|g|| dC< t	| g|| dD< t	|g|| ds< t	|!g|| dt< | dl| du|| dL< | dl| dv|| dN< | dl| dw|| dx< | dl| dy|| dz< q| d{|dx< | d||dz< |	| d}|d~< |	| d|d< |S )Nr   r   c                 s   r  double_blocks.rL  r   r   Nr  r   r   r   r   r     r  zCconvert_flux_transformer_checkpoint_to_diffusers.<locals>.<genexpr>r"  r   c                 s   r  single_blocks.rL  r   r   Nr  r   r   r   r   r     r        @   c                 S   r  r  r  rm  r  r  r  r   r   r   r    r  zJconvert_flux_transformer_checkpoint_to_diffusers.<locals>.swap_scale_shiftztime_in.in_layer.weightr  ztime_in.in_layer.biasr  ztime_in.out_layer.weightr  ztime_in.out_layer.biasr  zvector_in.in_layer.weightr  zvector_in.in_layer.biasr  zvector_in.out_layer.weightr  zvector_in.out_layer.biasr  c                 s   r  )guidanceNr   r   r   r   r   r     r   zguidance_in.in_layer.weightz1time_text_embed.guidance_embedder.linear_1.weightr,  z/time_text_embed.guidance_embedder.linear_1.biaszguidance_in.out_layer.weightz1time_text_embed.guidance_embedder.linear_2.weightzguidance_in.out_layer.biasz/time_text_embed.guidance_embedder.linear_2.biastxt_in.weightr  txt_in.biasr  r-  x_embedder.weightimg_in.biasx_embedder.biasr  rL  r  z.img_mod.lin.weightnorm1.linear.weightz.img_mod.lin.biasnorm1.linear.biasz.txt_mod.lin.weightnorm1_context.linear.weightz.txt_mod.lin.biasnorm1_context.linear.bias.img_attn.qkv.weightr   r   r  .txt_attn.qkv.weight.img_attn.qkv.bias.txt_attn.qkv.biasattn.to_q.weightattn.to_q.biasattn.to_k.weightattn.to_k.biasattn.to_v.weightattn.to_v.biasattn.add_q_proj.weightattn.add_q_proj.biasattn.add_k_proj.weightattn.add_k_proj.biasattn.add_v_proj.weightattn.add_v_proj.bias.img_attn.norm.query_norm.scaleattn.norm_q.weight.img_attn.norm.key_norm.scaleattn.norm_k.weight.txt_attn.norm.query_norm.scaleattn.norm_added_q.weight.txt_attn.norm.key_norm.scaleattn.norm_added_k.weight.img_mlp.0.weightff.net.0.proj.weight.img_mlp.0.biasff.net.0.proj.bias.img_mlp.2.weightff.net.2.weight.img_mlp.2.biasff.net.2.bias.txt_mlp.0.weightff_context.net.0.proj.weight.txt_mlp.0.biasff_context.net.0.proj.bias.txt_mlp.2.weightff_context.net.2.weight.txt_mlp.2.biasff_context.net.2.bias.img_attn.proj.weightattn.to_out.0.weight.img_attn.proj.biasattn.to_out.0.bias.txt_attn.proj.weightattn.to_add_out.weight.txt_attn.proj.biasattn.to_add_out.biassingle_transformer_blocks.r  z.modulation.lin.weightznorm.linear.weightz.modulation.lin.biasznorm.linear.bias.linear1.weight.linear1.biasproj_mlp.weightproj_mlp.bias.norm.query_norm.scale.norm.key_norm.scale.linear2.weightr  .linear2.biasr  r  r  r  r  r  r  )rN  r   r  r   r  r   r  r;  r  r  r  rO  )"r  r   r  r   r   r  num_single_layers	mlp_ratio	inner_dimr  has_guidancer  block_prefixr  r  r  r  r  r  r  r  r  r  r  r  mlp_hidden_dim
split_sizeqr  mlpq_biask_biasv_biasmlp_biasr   r   r   0convert_flux_transformer_checkpoint_to_diffusers  s@  



$














&

 rm  c                    s    fddt   D }dddddd}i }t | D ]}|}| D ]
\}}|||}q%||||< qt | D ]}| D ]\}	}
|	|vrMqD|
|| qDq>|S )	Nc                    s    i | ]}d |vr|  |qS )r   r  r$  r  r   r   r  	  r  zCconvert_ltx_transformer_checkpoint_to_diffusers.<locals>.<dictcomp>r   proj_in
time_embednorm_qnorm_k)r   patchify_projadaln_singleq_normk_norm)rN  r   r  r   r  )r  r   r  TRANSFORMER_KEYS_RENAME_DICTTRANSFORMER_SPECIAL_KEYS_REMAPr%  r]  replace_key
rename_keyspecial_keyhandler_fn_inplacer   r  r   /convert_ltx_transformer_checkpoint_to_diffusers	  s(   r}  c                    s   fddt   D }dtfdd}i dddd	d
ddddd
ddddddddddddddddddddddd d!dd"dd	d#d$d%d&d'd(	}d	d)ddd
ddddd*d+d,}i dd	d
d)dddddd
ddddddddddddddddddd d"d-dd.d/d	d*d+d0}|||d1}|d2 jd3 d4kr|| n	d5|v r|| t | D ]}|}	| D ]
\}
}|	|
|}	q||||	< qt | D ]}| D ]\}}||vrq||| qq|S )6Nc                    s    i | ]}d |v r|  |qS )r   rn  r$  r  r   r   r  	  r  z;convert_ltx_vae_checkpoint_to_diffusers.<locals>.<dictcomp>r%  c                 S      | |  d S r   rn  r%  r   r   r   r   remove_keys_	     z=convert_ltx_vae_checkpoint_to_diffusers.<locals>.remove_keys_r   r   up_blocks.0	mid_blockup_blocks.1up_blocks.2up_blocks.1.upsamplers.0up_blocks.3up_blocks.4zup_blocks.2.conv_inup_blocks.5up_blocks.2.upsamplers.0up_blocks.6up_blocks.7zup_blocks.3.conv_inup_blocks.8zup_blocks.3.upsamplers.0zup_blocks.9down_blocks.0down_blocks.1down_blocks.0.downsamplers.0down_blocks.2zdown_blocks.0.conv_outdown_blocks.3down_blocks.4down_blocks.1.downsamplers.0down_blocks.5zdown_blocks.1.conv_outdown_blocks.2.downsamplers.0zconv_shortcut.convr   norm3r  r  )	down_blocks.6down_blocks.7down_blocks.8zdown_blocks.9r  
res_blocksz
norm3.norm$per_channel_statistics.mean-of-means#per_channel_statistics.std-of-meansup_blocks.0.upsamplers.0time_embedderscale_shift_table)r  r  r  r  r  r  r  r  r  last_time_embedderlast_scale_shift_tabler  r  down_blocks.3.downsamplers.0)r  r  r  )per_channel_statistics.channelr  #per_channel_statistics.mean-of-stdsr2  r   r&  r3  )rN  r   strr(  r  r  r   r  )r  r   r  r  VAE_KEYS_RENAME_DICTVAE_091_RENAME_DICTVAE_095_RENAME_DICTVAE_SPECIAL_KEYS_REMAPr%  r]  ry  rz  r{  r|  r   r  r   'convert_ltx_vae_checkpoint_to_diffusers	  s   	
#	

r  c                    s\   fddt   D }dtfdd}dtfdd}i dd	d
d	dddddddddddddddddddddd d!dd"d#d$d%d&d'd(d)d*d+d,d-}d.d/d0}||d1}d2|vrk|| t | D ]}|d d  }	| D ]
\}
}|	|
|}	q}||||	< qqt | D ]}| D ]\}}||vrq||| qq|S )3Nc                       i | ]}|  |qS r   rn  r$  r  r   r   r  
  rc  zBconvert_autoencoder_dc_checkpoint_to_diffusers.<locals>.<dictcomp>r%  c                 S   sj   | | }tj|ddd\}}}| d\}}}| || d< | || d< | || d< d S )Nr   r   r  z.qkv.conv.weightz.to_q.weightz.to_k.weightz.to_v.weight)r  r;  r  
rpartitionsqueeze)r%  r   qkvrg  r   r  parent_module_r   r   r   
remap_qkv_
  s   
zBconvert_autoencoder_dc_checkpoint_to_diffusers.<locals>.remap_qkv_c                 S   s,   |  d\}}}||  || d< d S )Nz.proj.conv.weightz.to_out.weight)r  r  r  )r%  r   r  r  r   r   r   remap_proj_conv_
  s   zHconvert_autoencoder_dc_checkpoint_to_diffusers.<locals>.remap_proj_conv_zmain.r   zop_list.context_moduleattnlocal_moduleconv_outz
aggreg.0.0zto_qkv_multiscale.0.proj_inz
aggreg.0.1zto_qkv_multiscale.0.proj_outzdepth_conv.conv
conv_depthzinverted_conv.convconv_invertedzpoint_conv.conv
conv_pointzpoint_conv.normnormz
conv.conv.zconv.z
conv1.convr  
conv2.convr  z
conv2.normz	proj.normnorm_outencoder.project_in.convzencoder.conv_inzencoder.project_out.0.convzencoder.conv_outzencoder.down_blockszdecoder.conv_inzdecoder.norm_outzdecoder.conv_outzdecoder.up_blocks)zencoder.stageszdecoder.project_in.convzdecoder.project_out.0decoder.project_out.2.convzdecoder.stageszencoder.conv_in.convzdecoder.conv_out.conv)r  r  )zqkv.conv.weightzproj.conv.weightrA   )rN  r   r  r  r  r   r  )r  r   r  r  r  AE_KEYS_RENAME_DICTAE_F32C32_F64C128_F128C512_KEYSAE_SPECIAL_KEYS_REMAPr%  r]  ry  rz  r{  r|  r   r  r   .convert_autoencoder_dc_checkpoint_to_diffusers
  s   	

r  c                 K   s  i }t |  }|D ]}d|v r| || |dd< q
| d|d< | d|d< | d|d< | d	|d
< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< d}t|D ]3}d | d!}d"| d!}| |d# ||d$ < | |d% ||d& < ||d' k r| |d( ||d) < | |d* ||d+ < n| |d( ||d, < | |d* ||d- < | |d. }	|	jd/d0d1\}
}}|
||d2 < |||d3 < |||d4 < | |d5 ||d6 < | |d7 ||d8 < | |d9 ||d: < | |d; ||d< < | |d= }	|	jd/d0d1\}
}}|
||d> < |||d? < |||d@ < | |dA ||dB < | |dC ||dD < ||d' k r| |dE ||dF < | |dG ||dH < t| |dI ||dJ < | |dK ||dL < ||d' k rt| |dM ||dN < | |dO ||dP < qt| dQd0d1|dR< t| dSd0d1|dT< | dU|dV< | dW|dX< | dY|dY< |S )ZNr   r   r  patch_embed.proj.weightr  patch_embed.proj.biasr<   z,time_embed.timestep_embedder.linear_1.weightr  z*time_embed.timestep_embedder.linear_1.biasr  z,time_embed.timestep_embedder.linear_2.weightr  z*time_embed.timestep_embedder.linear_2.biaszt5_y_embedder.to_kv.weightztime_embed.pooler.to_kv.weightzt5_y_embedder.to_kv.biasztime_embed.pooler.to_kv.biaszt5_y_embedder.to_q.weightztime_embed.pooler.to_q.weightzt5_y_embedder.to_q.biasztime_embed.pooler.to_q.biaszt5_y_embedder.to_out.weightztime_embed.pooler.to_out.weightzt5_y_embedder.to_out.biasztime_embed.pooler.to_out.biaszt5_yproj.weightztime_embed.caption_proj.weightzt5_yproj.biasztime_embed.caption_proj.bias0   r  rL  blocks.zmod_x.weightr#  z
mod_x.biasr$  r   zmod_y.weightr%  z
mod_y.biasr&  znorm1_context.linear_1.weightznorm1_context.linear_1.biaszattn.qkv_x.weightr   r   r  zattn1.to_q.weightzattn1.to_k.weightzattn1.to_v.weightzattn.q_norm_x.weightzattn1.norm_q.weightzattn.k_norm_x.weightzattn1.norm_k.weightzattn.proj_x.weightzattn1.to_out.0.weightzattn.proj_x.biaszattn1.to_out.0.biaszattn.qkv_y.weightzattn1.add_q_proj.weightzattn1.add_k_proj.weightzattn1.add_v_proj.weightzattn.q_norm_y.weightzattn1.norm_added_q.weightzattn.k_norm_y.weightzattn1.norm_added_k.weightzattn.proj_y.weightzattn1.to_add_out.weightzattn.proj_y.biaszattn1.to_add_out.biaszmlp_x.w1.weightr@  zmlp_x.w2.weightrD  zmlp_y.w1.weightrH  zmlp_y.w2.weightrL  zfinal_layer.mod.weightr  zfinal_layer.mod.biasr  r  r  r  r  pos_frequencies)rN  r   r  r   r  r  r  r  )r  r   r  r   r   r  r  rd  
old_prefix
qkv_weightrg  r  r   r   r   1convert_mochi_transformer_checkpoint_to_diffusersX
  s   r  c                 K   sN  dd }dd }dd }dd }d	d
 }i dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d*d/d0d d1d2d3d4d5d6d7d8}|||||d9}d:d; }	t |  D ]}
|
d d  }| D ]
\}}|||}qw|	| |
| qkt |  D ]}
| D ]\}}||
vrq||
|  qq| S )<Nc                 S   sB   | | }|jddd\}}tj||gdd}||| dd< d S )Nr   r   r  final_layer.adaLN_modulation.1norm_out.linear)r  r  r;  r  r   )r%  r   rm  r  r  r  r   r   r   remap_norm_scale_shift_
  s   
zOconvert_hunyuan_video_transformer_to_diffusers.<locals>.remap_norm_scale_shift_c                 S   s   dd }d| v r7| | }|jddd\}}}|||| dd< |||| dd< |||| dd	< d S | | ||| < d S )
Nc                 S   sX   |  dd}| dd}| dd}| dd}| d	d
}| dd}| dd}|S )Nzindividual_token_refiner.blocksztoken_refiner.refiner_blocksadaLN_modulation.1r  txt_incontext_embeddert_embedder.mlp.0*time_text_embed.timestep_embedder.linear_1t_embedder.mlp.2*time_text_embed.timestep_embedder.linear_2
c_embedderztime_text_embed.text_embedderrh  ff)r   )r%  r]  r   r   r   rz  
  s   zYconvert_hunyuan_video_transformer_to_diffusers.<locals>.remap_txt_in_.<locals>.rename_keyself_attn_qkvr   r   r  	attn.to_q	attn.to_k	attn.to_vr  r  r   )r%  r   rz  rm  to_qto_kto_vr   r   r   remap_txt_in_
  s   

zEconvert_hunyuan_video_transformer_to_diffusers.<locals>.remap_txt_in_c                 S   R   | | }|jddd\}}}||| dd< ||| dd< ||| dd< d S )Nr   r   r  img_attn_qkvr  r  r  r  r%  r   rm  r  r  r  r   r   r   remap_img_attn_qkv_
  
   
zKconvert_hunyuan_video_transformer_to_diffusers.<locals>.remap_img_attn_qkv_c                 S   r  )Nr   r   r  txt_attn_qkvattn.add_q_projattn.add_k_projattn.add_v_projr  r  r   r   r   remap_txt_attn_qkv_
  r  zKconvert_hunyuan_video_transformer_to_diffusers.<locals>.remap_txt_attn_qkv_c                 S   sj  d}d| v rK| | }||||dd|  f}tj||dd\}}}}| ddd}	|||	 d	< |||	 d
< |||	 d< |||	 d< d S d| v r| | }
||||
dd|  f}tj|
|dd\}}}}| ddd}	|||	 d< |||	 d< |||	 d< |||	 d< d S | dd}	|	dd}	|	dd}	|	dd}	| | ||	< d S )Nr  zlinear1.weightr   r   r  single_blockssingle_transformer_blocksrX  r  r  r  z.proj_mlp.weightzlinear1.biasrY  r  r  r  z.proj_mlp.biaslinear2proj_outru  attn.norm_qrv  attn.norm_k)r  sizer;  rO  r   removesuffix)r%  r   rO  linear1_weightrf  rg  r   r  rh  r]  linear1_biasri  rj  rk  rl  r   r   r    remap_single_transformer_blocks_
  s0   

zXconvert_hunyuan_video_transformer_to_diffusers.<locals>.remap_single_transformer_blocks_img_in
x_embedderztime_in.mlp.0r  ztime_in.mlp.2r  zguidance_in.mlp.0z*time_text_embed.guidance_embedder.linear_1zguidance_in.mlp.2z*time_text_embed.guidance_embedder.linear_2zvector_in.in_layerz&time_text_embed.text_embedder.linear_1zvector_in.out_layerz&time_text_embed.text_embedder.linear_2double_blockstransformer_blocksimg_attn_q_normr  img_attn_k_normr  img_attn_projattn.to_out.0txt_attn_q_normattn.norm_added_qtxt_attn_k_normattn.norm_added_ktxt_attn_projattn.to_add_outzimg_mod.linearnorm1.linear	img_norm1
norm1.norm	img_norm2r  r  znorm1_context.linearnorm2_context
ff_contextznorm.linearz	norm.normznorm_out.normr  
net.0.projnet.2ro  )img_mlpztxt_mod.linear	txt_norm1	txt_norm2txt_mlpself_attn_projzmodulation.linearpre_normzfinal_layer.norm_finalfinal_layer.linearfc1fc2input_embedder)r  r  r  r  r  c                 S      |  || |< d S r   rn  r   r^  r]  r   r   r   update_state_dict_:     zJconvert_hunyuan_video_transformer_to_diffusers.<locals>.update_state_dict_)rN  r   r  r   )r  r   r  r  r  r  r  rw  rx  r  r%  r]  ry  rz  r{  r|  r   r   r   .convert_hunyuan_video_transformer_to_diffusers
  s   	
"r  c                 K   s  i }t |  }| dd |d< | dd |d< | dd |d< | dd |d< | dd |d	< | d
d |d< dd }||dd}||dd}t|D ]}ddd}dddd}	| D ])\}
}|	 D ] \}}| d| d|
 d| dd |d| d| d| d< qeq]ddd}| D ]\}
}| d| d|
 dd |d| d| d < qd!d"d#d$d%}d&d'd(d)d*}||fD ]!}| D ]\}}| d| d+| dd |d| d+| d< qqqLt|D ]^}dddd}| D ]\}}| d,| d-| dd |d.| d/| d< q| d,| d0d |d.| d1< d!d"d#d$d*}| D ]\}}| d,| d+| dd |d.| d+| d< q&q| d2d |d3< | d4d }|d ur_t|d d5|d6< nd |d6< | d7|d8< | d9|d:< | d;|d<< |S )=Nregister_tokensr<   ztime_step_proj.linear_1.weightr  ztime_step_proj.linear_1.biasr  ztime_step_proj.linear_2.weightr  ztime_step_proj.linear_2.biasr;   r  c                 S   s<   t  }| D ]}||v rt|dd }|| qt|S )NrL  r   )r  r  rO  addr  )r   
key_prefixr   r   r  r   r   r   calculate_layers]  s   
zNconvert_auraflow_transformer_checkpoint_to_diffusers.<locals>.calculate_layersdouble_layers)r  single_layersr  r	  )mlpXmlpClinear_1linear_2out_projection)c_fc1c_fc2c_projzdouble_layers.rL  r&  zjoint_transformer_blocks.r  norm1_context)modXmodCz	.1.weightz.linear.weightr  r  r  to_out.0)w2qw2kw2vw2o
add_q_proj
add_k_proj
add_v_proj
to_add_out)w1qw1kw1vw1oz.attn.zsingle_layers.z.mlp.rW  z.ff.z.modCX.1.weightr  zfinal_linear.weightr  zmodF.1.weightr  r  positional_encodingr  zinit_x_linear.weightr  zinit_x_linear.biasr  )rN  r   r  r  r  r  )r  r   r  state_dict_keysr  mmdit_layerssingle_dit_layersr  path_mappingweight_mappingorig_k
diffuser_kr   r  x_attn_mappingcontext_attn_mappingattn_mappingr  norm_weightr   r   r   4convert_auraflow_transformer_checkpoint_to_diffusersL  st   


rE  c                 K   sB  i }|  dd  t|  }|D ]}d|v r!|  || |dd< qdddddd	d
ddddd}ddd}ddddd}ddd}dd }	|D ]W}
|
}| D ]
\}}|||}qO| D ]
\}}|||}q^| D ]
\}}|||}qm| D ]
\}}|||}q|d|v r||	|  |
| qG|  |
||< qG|S )Nnorm_final.weightr   r   z#time_caption_embed.caption_embedderz-time_caption_embed.timestep_embedder.linear_1z-time_caption_embed.timestep_embedder.linear_2r  
.to_out.0.rr  rq  r#  r$  linear_3r  )cap_embedderr  r  	attentionz.out.rv  ru  w1w2w3r  r  r  )attention_norm1attention_norm2zcontext_refiner.0.norm1zcontext_refiner.0.norm2zcontext_refiner.1.norm1zcontext_refiner.1.norm2)z!context_refiner.0.attention_norm1z!context_refiner.0.attention_norm2z!context_refiner.1.attention_norm1z!context_refiner.1.attention_norm2norm_out.linear_1norm_out.linear_2)r  r  c                 S   sP   d}d }}t j| |||gdd\}}}|dd||dd||dd|iS )	Ni 	  r   r   r  r  r  r  r  )r;  rO  r   )tensorr  q_dimk_dimv_dimr  r  r  r   r   r    convert_lumina_attn_to_diffusers  s   zFconvert_lumina2_to_diffusers.<locals>.convert_lumina_attn_to_diffusersr  )r  rN  r   r   r  r  )r  r   r  r   r   LUMINA_KEY_MAPATTENTION_NORM_MAPCONTEXT_REFINER_MAPFINAL_LAYER_MAPrV  r%  r  r  r   r   r   convert_lumina2_to_diffusers  sZ   r[  c                 K   s  i }t |  }|D ]}d|v r| || |dd< q
t tdd | D d d }| d | d|d	< | d
|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d|d< | d | d|d< | d|d< | d|d< | d|d < | d!|d"< t|D ].}| d#| d$|d%| d$< tj| d#| d&d'd(d)\}}}	t|g|d%| d*< t|g|d%| d+< t|	g|d%| d,< | d#| d-|d%| d.< | d#| d/|d%| d0< | d#| d1|d%| d2< | d#| d3|d%| d4< tj| d#| d5d6d(d)\}
}tj| d#| d7d6d(d)\}}|
|d%| d8< ||d%| d9< ||d%| d:< ||d%| d;< | d#| d<|d%| d=< | d#| d>|d%| d?< | d#| d@|d%| dA< | d#| dB|d%| dC< | d#| dD|d%| dE< | d#| dF|d%| dG< | d#| dH|d%| dI< q| dJ|dK< | dL|dM< | dN|dO< |S )PNr   r   c                 s   r  )blocksrL  r   r   Nr  r   r   r   r   r     r  z8convert_sana_transformer_to_diffusers.<locals>.<genexpr>r"  r   r)  r  r  r  r  r<   z0time_embed.emb.timestep_embedder.linear_1.weightr  z.time_embed.emb.timestep_embedder.linear_1.biasr  z0time_embed.emb.timestep_embedder.linear_2.weightr  z.time_embed.emb.timestep_embedder.linear_2.biaszt_block.1.weightztime_embed.linear.weightzt_block.1.biasztime_embed.linear.biaszy_embedder.y_embeddingzy_embedder.y_proj.fc1.weightz"caption_projection.linear_1.weightzy_embedder.y_proj.fc1.biasz caption_projection.linear_1.biaszy_embedder.y_proj.fc2.weightz"caption_projection.linear_2.weightzy_embedder.y_proj.fc2.biasz caption_projection.linear_2.biaszattention_y_norm.weightzcaption_norm.weightr  z.scale_shift_tabler  z.attn.qkv.weightr   r   r  z.attn1.to_q.weightz.attn1.to_k.weightz.attn1.to_v.weightz.attn.proj.weightz.attn1.to_out.0.weightz.attn.proj.biasz.attn1.to_out.0.biasz.cross_attn.q_linear.weightr  z.cross_attn.q_linear.biasr  z.cross_attn.kv_linear.weightr   z.cross_attn.kv_linear.biasr  r  r  r  z.cross_attn.proj.weightr  z.cross_attn.proj.biasr  z.mlp.inverted_conv.conv.weightz.ff.conv_inverted.weightz.mlp.inverted_conv.conv.biasz.ff.conv_inverted.biasz.mlp.depth_conv.conv.weightz.ff.conv_depth.weightz.mlp.depth_conv.conv.biasz.ff.conv_depth.biasz.mlp.point_conv.conv.weightz.ff.conv_point.weightr  r  r  r  zfinal_layer.scale_shift_tabler  )	rN  r   r  r   r  r  r;  r  r  )r  r   r  r   r   r  r  r  r  r  linear_sample_klinear_sample_vlinear_sample_k_biaslinear_sample_v_biasr   r   r   %convert_sana_transformer_to_diffusers  s   


$












ra  c                 K   s  dd }dd }dd }dd }i }t |  }|D ]}d	|v r+| || |d	d
< qi dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d(d,d-d.d/d0d1d2d3d4d5d*d6}	i }
td7d8 |  D r|	|  |dd gf|
d9< td:d8 |  D r|	|  t |  D ]}|||  qt |  D ]}|}|	 D ]
\}}|||}q| |||< qt | D ]}|
 D ]\}\}}||vrq|||||  q|S );Nc                  S   s   dddddd} t dD ]=}|d }| d	| d
d| dd	| dd| dd	| dd| dd	| dd| dd	| dd| di q| S )Nz&motion_encoder.motion_synthesis_weightzmotion_encoder.conv_in.weightz"motion_encoder.conv_in.act_fn.biaszmotion_encoder.conv_out.weightzmotion_encoder.motion_network)rF   z+motion_encoder.enc.net_app.convs.0.0.weightz)motion_encoder.enc.net_app.convs.0.1.biasz)motion_encoder.enc.net_app.convs.8.weightzmotion_encoder.enc.fc   r   !motion_encoder.enc.net_app.convs.z.conv1.0.weightzmotion_encoder.res_blocks..conv1.weightz.conv1.1.biasz.conv1.act_fn.biasz.conv2.1.weight.conv2.weightz.conv2.2.biasz.conv2.act_fn.biasz.skip.1.weightz.conv_skip.weight)r  r  )mappingsr  conv_idxr   r   r    generate_motion_encoder_mappingsZ  s"   
zNconvert_wan_transformer_to_diffusers.<locals>.generate_motion_encoder_mappingsc                	   S   s   ddddddddd	S )
Nface_adapterz.norm_k.z.norm_q..to_q.z.to_out.conv1_localr  conv3)zface_adapter.fuser_blocksz.k_norm.z.q_norm.z.linear1_q.z	.linear2.zconv1_local.convr  z
conv3.convr   r   r   r   r   generate_face_adapter_mappingsq  s   zLconvert_wan_transformer_to_diffusers.<locals>.generate_face_adapter_mappingsc                 S   s\   | | }|jd d }| ||d }| ||d }|d | ||< ||d  ||< d S )Nr   r   r   )r  r(  r   )r%  r   split_patterntarget_keysrR  	split_idx	new_key_1	new_key_2r   r   r   split_tensor_handler}  s   
zBconvert_wan_transformer_to_diffusers.<locals>.split_tensor_handlerc                 S   s8   d| v rd| v r||  dd d ddf || < d S d S d S )Nrc  r'  r   r   r  r   r   r   reshape_bias_handler  s    zBconvert_wan_transformer_to_diffusers.<locals>.reshape_bias_handlerr   r   ztime_embedding.0z)condition_embedder.time_embedder.linear_1ztime_embedding.2z)condition_embedder.time_embedder.linear_2ztext_embedding.0z)condition_embedder.text_embedder.linear_1ztext_embedding.2z)condition_embedder.text_embedder.linear_2ztime_projection.1zcondition_embedder.time_proj
cross_attnattn2	self_attnattn1z.o.rG  z.q.rj  z.k.z.to_k.z.v.z.to_v.z.k_img.z.add_k_proj.z.v_img.z.add_v_proj.z.norm_k_img.z.norm_added_k.rD   r  z	head.headr  
modulationzffn.net.0.projz	ffn.net.2norm__placeholderr  r  z'condition_embedder.image_embedder.norm1z/condition_embedder.image_embedder.ff.net.0.projz*condition_embedder.image_embedder.ff.net.2z'condition_embedder.image_embedder.norm2ro  )zffn.0zffn.2r  r  rz  zimg_emb.proj.0zimg_emb.proj.1zimg_emb.proj.3zimg_emb.proj.4before_proj
after_projc                 s   r  )ri  Nr   r   r   r   r   r     r   z7convert_wan_transformer_to_diffusers.<locals>.<genexpr>z.linear1_kv.c                 s   r  )motion_encoderNr   r   r   r   r   r     r   )rN  r   r  r   r   r  r  )r  r   rh  rm  rs  rt  r  r   r   rw  SPECIAL_KEYS_HANDLERSr%  r]  ry  rz  r   
handler_fnro  r   r   r   $convert_wan_transformer_to_diffusersY  s   
	
"r  c                 K   s  i }i ddddddddd	d
dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*}d+d,d-d.d/d0d1d2d3d4d5
}d6d7d8d9d:d;d<}d=d>d?d@dA}|   D ]\}}||v rt|| }	|||	< qb||v r|| }	|||	< qb||v r|| }	|||	< qb||v r|| }	|||	< qb|dBkr||dC< qb|dDkr||dE< qb|dFkr||dG< qb|dHkr||dI< qb|dJr+|dJdK}	dL|	v r|	dLdM}	nPdN|	v r|	dNdO}	nEdP|	v r|	dPdQ}	n:dR|	v r|	dRdS}	n/dT|	v r|	dTdU}	n#dV|	v r|	dVdW}	ndX|	v r|	dXdY}	ndZ|	v r&|	dZd[}	|||	< qb|d\rJ|d]}
t|
d^ }d_|v r|d`v rKda}|}n)|dbv rWdc}|dd }n|dev rcd^}|df }n|dgv rodh}|di }n|||< qbdL|v rdj| dk| dM}	nMdN|v rdj| dk| dO}	n>dP|v rdj| dk| dQ}	n/dR|v rdj| dk| dS}	n dT|v rdj| dk| dU}	ndV|v rdj| dk| dW}	n|}	|||	< qbdl|v r|ddkr|dldm}	|	dndo}	n|d\dj}	|	dldp}	|||	< qbdq|v sdr|v r?|dhkr|d\| ds}	n$|dtkr%|d\| du}	n|dvkr4|d\| dw}	n|d\dj}	|||	< qb|d\dj}	|||	< qb|||< qb|S )xNz!encoder.middle.0.residual.0.gammaz'encoder.mid_block.resnets.0.norm1.gammaz encoder.middle.0.residual.2.biasz&encoder.mid_block.resnets.0.conv1.biasz"encoder.middle.0.residual.2.weightz(encoder.mid_block.resnets.0.conv1.weightz!encoder.middle.0.residual.3.gammaz'encoder.mid_block.resnets.0.norm2.gammaz encoder.middle.0.residual.6.biasz&encoder.mid_block.resnets.0.conv2.biasz"encoder.middle.0.residual.6.weightz(encoder.mid_block.resnets.0.conv2.weightz!encoder.middle.2.residual.0.gammaz'encoder.mid_block.resnets.1.norm1.gammaz encoder.middle.2.residual.2.biasz&encoder.mid_block.resnets.1.conv1.biasz"encoder.middle.2.residual.2.weightz(encoder.mid_block.resnets.1.conv1.weightz!encoder.middle.2.residual.3.gammaz'encoder.mid_block.resnets.1.norm2.gammaz encoder.middle.2.residual.6.biasz&encoder.mid_block.resnets.1.conv2.biasz"encoder.middle.2.residual.6.weightz(encoder.mid_block.resnets.1.conv2.weightrE   z'decoder.mid_block.resnets.0.norm1.gammaz decoder.middle.0.residual.2.biasz&decoder.mid_block.resnets.0.conv1.biasz"decoder.middle.0.residual.2.weightz(decoder.mid_block.resnets.0.conv1.weightz!decoder.middle.0.residual.3.gammaz'decoder.mid_block.resnets.0.norm2.gammaz decoder.middle.0.residual.6.biasz&decoder.mid_block.resnets.0.conv2.biasz(decoder.mid_block.resnets.0.conv2.weightz'decoder.mid_block.resnets.1.norm1.gammaz&decoder.mid_block.resnets.1.conv1.biasz(decoder.mid_block.resnets.1.conv1.weightz'decoder.mid_block.resnets.1.norm2.gammaz&decoder.mid_block.resnets.1.conv2.biasz(decoder.mid_block.resnets.1.conv2.weight)z"decoder.middle.0.residual.6.weightz!decoder.middle.2.residual.0.gammaz decoder.middle.2.residual.2.biasz"decoder.middle.2.residual.2.weightz!decoder.middle.2.residual.3.gammaz decoder.middle.2.residual.6.biasz"decoder.middle.2.residual.6.weightz)encoder.mid_block.attentions.0.norm.gammaz,encoder.mid_block.attentions.0.to_qkv.weightz*encoder.mid_block.attentions.0.to_qkv.biasz*encoder.mid_block.attentions.0.proj.weightz(encoder.mid_block.attentions.0.proj.biasz)decoder.mid_block.attentions.0.norm.gammaz,decoder.mid_block.attentions.0.to_qkv.weightz*decoder.mid_block.attentions.0.to_qkv.biasz*decoder.mid_block.attentions.0.proj.weightz(decoder.mid_block.attentions.0.proj.bias)
zencoder.middle.1.norm.gammazencoder.middle.1.to_qkv.weightzencoder.middle.1.to_qkv.biaszencoder.middle.1.proj.weightzencoder.middle.1.proj.biaszdecoder.middle.1.norm.gammazdecoder.middle.1.to_qkv.weightzdecoder.middle.1.to_qkv.biaszdecoder.middle.1.proj.weightzdecoder.middle.1.proj.biaszencoder.norm_out.gammar   r   zdecoder.norm_out.gammar   r   )zencoder.head.0.gammazencoder.head.2.biaszencoder.head.2.weightzdecoder.head.0.gammazdecoder.head.2.biaszdecoder.head.2.weightr   r   r   r   )zconv1.weightz
conv1.biaszconv2.weightz
conv2.biaszencoder.conv1.weightr   zencoder.conv1.biasr   zdecoder.conv1.weightr   zdecoder.conv1.biasr   zencoder.downsamples.r9  z.residual.0.gammaz.norm1.gammaz.residual.2.biasz.conv1.biasz.residual.2.weightrd  z.residual.3.gammaz.norm2.gammaz.residual.6.biasz.conv2.biasz.residual.6.weightre  z.shortcut.biasz.conv_shortcut.biasz.shortcut.weightz.conv_shortcut.weightzdecoder.upsamples.rL  r   residual)r   r   r   r   )r        r   r  )r6  r!  
   r6  )         r   r  r<  r  z
.shortcut.z.resnets.0.conv_shortcut.zdecoder.upsamples.4zdecoder.up_blocks.1z.conv_shortcut.z
.resample.z.time_conv.z decoder.up_blocks.0.upsamplers.0rb  z decoder.up_blocks.1.upsamplers.0   z decoder.up_blocks.2.upsamplers.0)r  r  r   rO  r  )r  r   r  middle_key_mappingattention_mappinghead_mappingquant_mappingr%  valuer]  parts	block_idxnew_block_idx
resnet_idxr   r   r   convert_wan_vae_to_diffusers  s>  	
 





































r  c                 K   s8   t |  }|D ]}d|v r| || |dd< q| S )Nr   r   )rN  r   r  r   )r  r   r   r   r   r   r   (convert_hidream_transformer_to_diffusers  s   r  c           "      K   s  i }t |  }|D ]}d|v r| || |dd< q
t tdd | D d d }t tdd | D d d }t tdd | D d d }d	}d
}	dd }
| d|d< | d|d< | d|d< | d|d< t|D ]O}d| d}| d| d|| d< | d| d|| d< | d| d|| d< | d| d|| d< | d| d|d| d< qq| d|d< | d |d!< | d"|d#< | d$|d%< t|D ]}d&| d}tj| d'| d(d)d*d+\}}}tj| d'| d,d)d*d+\}}}tj| d'| d-d)d*d+\}}}tj| d'| d.d)d*d+\}}}t|g|| d/< t|g|| d0< t|g|| d1< t|g|| d2< t|g|| d3< t|g|| d4< t|g|| d5< t|g|| d6< t|g|| d7< t|g|| d8< t|g|| d9< t|g|| d:< | d'| d;|| d<< | d'| d=|| d>< | d'| d?|| d@< | d'| dA|| dB< | d'| dC|| dD< | d'| dE|| dF< | d'| dG|| dH< | d'| dI|| dJ< | d'| dK|| dL< | d'| dM|| dN< | d'| dO|| dP< | d'| dQ|| dR< | d'| dS|| dT< | d'| dU|| dV< | d'| dW|| dX< | d'| dY|| dZ< qt|D ]}d[| d}t	|	| }|	|	|	|f}tj
| d\| d]|d*d+\}}}}tj
| d\| d^|d*d+\}}} }!t|g|| d/< t|g|| d0< t|g|| d1< t|g|| d2< t|g|| d3< t| g|| d4< t|g|| d_< t|!g|| d`< | d\| da|| d<< | d\| db|| d>< | d\| dc|| dd< | d\| de|| df< q| dg|dd< | dh|df< |S )iNr   r   c                 s   r  r  r  r   r   r   r   r     r  zEconvert_chroma_transformer_checkpoint_to_diffusers.<locals>.<genexpr>r"  r   c                 s   r  r  r  r   r   r   r   r     r  c                 s   r  ) distilled_guidance_layer.layers.rL  r   r   Nr  r   r   r   r   r     r  r  r  c                 S   r  r  r  r  r   r   r   r    r  zLconvert_chroma_transformer_checkpoint_to_diffusers.<locals>.swap_scale_shiftz%distilled_guidance_layer.in_proj.biasz'distilled_guidance_layer.in_proj.weightz&distilled_guidance_layer.out_proj.biasz(distilled_guidance_layer.out_proj.weightr  rL  z.in_layer.biaszlinear_1.biasz.in_layer.weightzlinear_1.weightz.out_layer.biaszlinear_2.biasz.out_layer.weightzlinear_2.weightzdistilled_guidance_layer.norms..scaler&  r  r  r  r  r-  r   r!  r"  r  r  r'  r   r   r  r(  r)  r*  r+  r,  r-  r.  r/  r0  r1  r2  r3  r4  r5  r6  r7  r8  r9  r:  r;  r<  r=  r>  r?  r@  rA  rB  rC  rD  rE  rF  rG  rH  rI  rJ  rK  rL  rM  rN  rO  rP  rQ  rR  rS  rT  rU  rV  rW  r  rX  rY  rZ  r[  r\  r]  r^  r  r_  r  r  r  )rN  r   r  r   r  r  r;  r  r  r  rO  )"r  r   r  r   r   r  r`  num_guidance_layersra  rb  r  r  rd  r  r  r  r  r  r  r  r  r  r  r  r  re  rf  rg  r  rh  ri  rj  rk  rl  r   r   r   2convert_chroma_transformer_checkpoint_to_diffusers  s  




$












&

 r  c              	      s   fddt   D }dtfdd}dtfdd}i dd	d
dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1}|||||d2}i dd	d3dd4d5d6d7d8d9d:d;d<d=d>d?d@dAdBdCdDdEdFddGd#dHd'dIdJdKd!dLd%dMdNdOd.d/d0dP}||||||||dQ}dR}	dS v r|}
|}n|}
|}t | }|D ](}|d d  }||	r||	}|
 D ]
\}}|||}q||||< qt | }|D ]}| D ]\}}||vrq||| qq|S )TNc                    r  r   rn  r$  r  r   r   r  J  rc  zFconvert_cosmos_transformer_checkpoint_to_diffusers.<locals>.<dictcomp>r%  c                 S   r~  r   rn  r  r   r   r   r  L  r  zHconvert_cosmos_transformer_checkpoint_to_diffusers.<locals>.remove_keys_c                 S   sP   t | dd d}| }d| }d| }||| }|| ||< d S )NrL  r   blockblocks.blockr  )r  rO  removeprefixr  )r%  r   block_indexr]  r  
new_prefixr   r   r   rename_transformer_blocks_O  s   

zVconvert_cosmos_transformer_checkpoint_to_diffusers.<locals>.rename_transformer_blocks_zt_embedder.1ztime_embed.t_embedderaffline_normztime_embed.normz.blocks.0.block.attnr  z.blocks.1.block.attnr  z.blocks.2.blockz.ffz.blocks.0.adaLN_modulation.1z.norm1.linear_1z.blocks.0.adaLN_modulation.2z.norm1.linear_2z.blocks.1.adaLN_modulation.1z.norm2.linear_1z.blocks.1.adaLN_modulation.2z.norm2.linear_2z.blocks.2.adaLN_modulation.1z.norm3.linear_1z.blocks.2.adaLN_modulation.2z.norm3.linear_2zto_q.0r  zto_q.1rq  zto_k.0r  zto_k.1rr  zto_v.0r  layer1r
  r  r  patch_embedlearnable_pos_embedrP  rQ  r  )layer2zproj.1r  extra_pos_embedderr  zfinal_layer.adaLN_modulation.2r  )r  zlogvar.0.freqszlogvar.0.phaseszlogvar.1.weightpos_embedder.seqt_embedding_normr\  r  zadaln_modulation_self_attn.1znorm1.linear_1zadaln_modulation_self_attn.2znorm1.linear_2zadaln_modulation_cross_attn.1znorm2.linear_1zadaln_modulation_cross_attn.2znorm2.linear_2zadaln_modulation_mlp.1znorm3.linear_1zadaln_modulation_mlp.2znorm3.linear_2rw  rx  ru  rv  q_projk_projv_projoutput_projr,  ru  rv  zff.net.0.projzff.net.2zpatch_embed.proj)z
mlp.layer1z
mlp.layer2zx_embedder.proj.1zfinal_layer.adaln_modulation.1zfinal_layer.adaln_modulation.2r  )accum_video_sample_counteraccum_image_sample_counteraccum_iterationaccum_train_in_hoursr  zpos_embedder.dim_spatial_rangezpos_embedder.dim_temporal_range_extra_stateznet.rH   )rN  r   r  r  r  r  r   r  )r  r   r  r  r  'TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0)TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0'TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0)TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
PREFIX_KEYrw  rx  r:  r%  r]  ry  rz  r{  r|  r   r  r   2convert_cosmos_transformer_checkpoint_to_diffusersI  s   	
	



r  c                    s|  ddddddddd	d
d
}ddi ddddddddddd
ddddddt dtt tf dd ffdd }dt dtt tf dd f fd!d"}dt dtt tf dd ffd#d$}dtt tf d%t d&t dd fd'd(}|||d)}fd*d+t D }t| D ]}	|	d d  }
| D ]
\}}|
||}
q|||	|
 qt| D ]}	| D ]\}}||	vrq||	| qq|S ),Nr  r  z.time_guidance_embed.timestep_embedder.linear_1z.time_guidance_embed.timestep_embedder.linear_2z.time_guidance_embed.guidance_embedder.linear_1z.time_guidance_embed.guidance_embedder.linear_2z#double_stream_modulation_img.linearz#double_stream_modulation_txt.linearzsingle_stream_modulation.linearr  )
r  r  ztime_in.in_layerztime_in.out_layerzguidance_in.in_layerzguidance_in.out_layerz double_stream_modulation_img.linz double_stream_modulation_txt.linzsingle_stream_modulation.linr  r  r  r  r  r  zff.linear_inzff.linear_outr  r  r  zff_context.linear_inzff_context.linear_out)
zimg_attn.norm.query_normzimg_attn.norm.key_normzimg_attn.projz	img_mlp.0z	img_mlp.2ztxt_attn.norm.query_normztxt_attn.norm.key_normztxt_attn.projz	txt_mlp.0z	txt_mlp.2zattn.to_qkv_mlp_projzattn.to_out)linear1r  znorm.query_normznorm.key_normr%  r   returnc           
         s   d| vrd| vrd| vrd S d}d| v rF|  d}|d }d|dd	 }|d	 }|d
kr0d} | }d||||g}|| }	|	||< d S )Nr&  r'  r  r  r  rL  r   r   r"  r  rm  )rO  r   r  )
r%  r   r  r  r  within_block_name
param_typenew_within_block_namer]  r  )&FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAPr   r   "convert_flux2_single_stream_blocks  s   

z]convert_flux2_transformer_checkpoint_to_diffusers.<locals>.convert_flux2_single_stream_blocksc                    sX   d| vrd S d| v r*| j ddd\}} | }d||g}t|| d}|||< d S )Nr&  adaLN_modulationrL  r   )maxsplitr   )rsplitr   r  r  )r%  r   key_without_param_typer  new_key_without_param_typer]  swapped_weight)(FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAPr   r   convert_ada_layer_norm_weights  s   zYconvert_flux2_transformer_checkpoint_to_diffusers.<locals>.convert_ada_layer_norm_weightsc                    sh  d| vrd| vrd| vrd S d}d| v r|  d}|d }|d }d|dd	 }|d	 }|d
kr4d}d|v r|| }tj|ddd\}	}
}d|v r^tj|ddd\}	}
}d}d}d}nd|v rstj|ddd\}	}
}d}d}d}d||||g}d||||g}d||||g}|	||< |
||< |||< d S  | }d||||g}|| }|||< d S )Nr&  r'  r  r  r  rL  r   r   r"  r  rm  r  r   r   r  imgr  r  r  txtr  r  r  )rO  r   r  r;  r  )r%  r   r  r  r  modality_block_namer  r  fused_qkv_weightto_q_weightto_k_weightto_v_weight
new_q_name
new_k_name
new_v_name	new_q_key	new_k_key	new_v_keyr  r]  r  )&FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAPr   r   "convert_flux2_double_stream_blocks  sH   


z]convert_flux2_transformer_checkpoint_to_diffusers.<locals>.convert_flux2_double_stream_blocksr^  r]  c                 S   r  r   rn  r  r   r   r   update_state_dictA  r  zLconvert_flux2_transformer_checkpoint_to_diffusers.<locals>.update_state_dict)r  r  r  c                    r  r   rn  r$  r  r   r   r  J  rc  zEconvert_flux2_transformer_checkpoint_to_diffusers.<locals>.<dictcomp>)r  dictobjectrN  r   r  r   )r  r   "FLUX2_TRANSFORMER_KEYS_RENAME_DICTr  r  r  r  rx  r  r%  r]  ry  rz  r{  r|  r   )r  r  r  r  r   1convert_flux2_transformer_checkpoint_to_diffusers  sd   """"-r  c                    s  dddddddd}d	t d
tt tf dd fdd}d|i}d
tt tf dt dt dd fdd} fddt  D }t| D ]}|d d  }| D ]
\}	}
||	|
}qM|||| qAd| v rj|d}t| D ]}| D ]\}}||vrqv||| qvqp|S )Nzall_final_layer.2-1.zall_x_embedder.2-1.z.attention.to_out.0.biasz.attention.norm_k.weightz.attention.norm_q.weightz.attention.to_out.0.weightr   )zfinal_layer.zx_embedder.z.attention.out.biasz.attention.k_norm.weightz.attention.q_norm.weightz.attention.out.weightr   r%  r   r  c           	      S   sl   d| vrd S | | }tj|ddd\}}}| dd}| dd}| dd}|||< |||< |||< d S )N.attention.qkv.weightr   r   r  z.attention.to_q.weightz.attention.to_k.weightz.attention.to_v.weight)r  r;  r  r   )	r%  r   r  r  r  r  r  r  r  r   r   r   convert_z_image_fused_attentionj  s   
z\convert_z_image_transformer_checkpoint_to_diffusers.<locals>.convert_z_image_fused_attentionr  r^  r]  c                 S   r  r   rn  r  r   r   r   r  }  r  zNconvert_z_image_transformer_checkpoint_to_diffusers.<locals>.update_state_dictc                    r  r   rn  r$  r  r   r   r    rc  zGconvert_z_image_transformer_checkpoint_to_diffusers.<locals>.<dictcomp>rF  )r  r  r  rN  r   r  r   r  )r  r   Z_IMAGE_KEYS_RENAME_DICTr  rx  r  r  r%  r]  ry  rz  r  r{  r|  r   r  r   3convert_z_image_transformer_checkpoint_to_diffusers_  s6   
"
r  c                    sR   |d d u r S |d dkr S |d dkr% fddt   D }|S td)Nadd_control_noise_refinercontrol_noise_refinercontrol_layersc                    s"   i | ]}| d s| |qS )zcontrol_noise_refiner.)r  r  r$  r  r   r   r    s
    

zFconvert_z_image_controlnet_checkpoint_to_diffusers.<locals>.<dictcomp>z&Unknown Z-Image Turbo ControlNet type.)rN  r   r   )r  r  r   r  r   r  r   2convert_z_image_controlnet_checkpoint_to_diffusers  s   

r  c                    s   ddddddddd	d
dd}dt dt dd fdd}dt dd fdd}dt dd fdd}|||d} fddt  D }t| D ]}|d d  }	| D ]
\}
}|	|
|}	qP||||	 qDt| D ]}| D ]\}}||vrwqn||| qnqh|S )Nr   ro  audio_proj_inav_cross_attn_video_scale_shiftav_cross_attn_video_a2v_gateav_cross_attn_audio_scale_shiftav_cross_attn_audio_v2a_gate&video_a2v_cross_attn_scale_shift_table&audio_a2v_cross_attn_scale_shift_tablerq  rr  )r   rs  audio_patchify_proj$av_ca_video_scale_shift_adaln_singleav_ca_a2v_gate_adaln_single$av_ca_audio_scale_shift_adaln_singleav_ca_v2a_gate_adaln_singlescale_shift_table_a2v_ca_videoscale_shift_table_a2v_ca_audioru  rv  r^  r]  r  c                 S   r  r   rn  r  r   r   r   update_state_dict_inplace  r  zHconvert_ltx2_transformer_to_diffusers.<locals>.update_state_dict_inplacer%  c                 S   r~  r   rn  r  r   r   r   remove_keys_inplace  r  zBconvert_ltx2_transformer_to_diffusers.<locals>.remove_keys_inplacec                 S   sh   d| vr
d| vr
d S |  dr| dd}|| }|||< |  dr2| dd}|| }|||< d S )Nr&  r'  zadaln_single.ztime_embed.zaudio_adaln_single.zaudio_time_embed.)r  r   r  )r%  r   r]  r  r   r   r   %convert_ltx2_transformer_adaln_single  s   



zTconvert_ltx2_transformer_to_diffusers.<locals>.convert_ltx2_transformer_adaln_single)video_embeddings_connectoraudio_embeddings_connectorrt  c                    r  r   rn  r$  r  r   r   r    rc  z9convert_ltx2_transformer_to_diffusers.<locals>.<dictcomp>r  rN  r   r  r   )r  r   $LTX_2_0_TRANSFORMER_KEYS_RENAME_DICTr  r  r  &LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAPr  r%  r]  ry  rz  r{  r|  r   r  r   %convert_ltx2_transformer_to_diffusers  s@   r  c                    sF  i ddddddddddd	dd
ddddddddddddddddddddddddd}dt d t d!d fd"d#}d$t d!d fd%d&}||d'} fd(d)t  D }t| D ]}|d d  }| D ]
\}	}
||	|
}qs|||| qgt| D ]}| D ]\}}||vrq||| qq|S )*Nr   r   r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r   r  r  )r  r  r  r^  r]  r  c                 S   r  r   rn  r  r   r   r   r    r  z@convert_ltx2_vae_to_diffusers.<locals>.update_state_dict_inplacer%  c                 S   r~  r   rn  r  r   r   r   r    r  z:convert_ltx2_vae_to_diffusers.<locals>.remove_keys_inplace)r  r  c                    r  r   rn  r$  r  r   r   r    rc  z1convert_ltx2_vae_to_diffusers.<locals>.<dictcomp>r  )r  r   LTX_2_0_VIDEO_VAE_RENAME_DICTr  r  LTX_2_0_VAE_SPECIAL_KEYS_REMAPr  r%  r]  ry  rz  r{  r|  r   r  r   convert_ltx2_vae_to_diffusers  sr   	
r  c           	         s   dddd}dt dt dd fdd	} fd
dt  D }t| D ]}|d d  }| D ]
\}}|||}q0|||| q$|S )Nr   r  r  )z
audio_vae.r  r  r^  r]  r  c                 S   r  r   rn  r  r   r   r   r  4  r  zFconvert_ltx2_audio_vae_to_diffusers.<locals>.update_state_dict_inplacec                    r  r   rn  r$  r  r   r   r  7  rc  z7convert_ltx2_audio_vae_to_diffusers.<locals>.<dictcomp>r  )	r  r   LTX_2_0_AUDIO_VAE_RENAME_DICTr  r  r%  r]  ry  rz  r   r  r   #convert_ltx2_audio_vae_to_diffusers,  s   r  )FNNNNNFN)Fr   )NNN)NN)rN  )r   NNNFr   )r   NNN)__doc__rE  r   r   
contextlibr   ior   urllib.parser   r  r;  r
  models.modeling_utilsr   
schedulersr   r   r	   r
   r   r   r   r   utilsr   r   r   r   r   r   utils.constantsr   utils.hub_utilsr   utils.torch_utilsr   transformersr   
accelerater   models.model_loading_utilsr   
get_loggerr   r  r  rD  rI  r  rR  r  r=  r  r  r  r)  rI  rQ  r   r   	Exceptionr   r   r   r   r   r   r  r  r  r  r  r  r  r  r  r   rC  rH  rK  rR  r  r  r  r  r  r  r  r  r"  r1  rG  rM  r`  ry  r  r  r  r  r  r  r  r  r  r  r  r  rm  r}  r  r  r  r  rE  r[  ra  r  r  r  r  r  r  r  r  r  r  r  r   r   r   r   <module>   s  ( 

!"#$%&,089:;<a	

!$%&'()*+,-./0123456789:;<=>?@ABCU	

]
( n  38 (hEj  C) FkHr \Hix G	 *q &9I=