2025-03-05 00:13:49 -05:00
import json
2026-03-09 20:50:10 -07:00
import comfy . memory_management
2023-09-23 18:47:46 -04:00
import comfy . supported_models
import comfy . supported_models_base
2024-06-19 21:46:37 -04:00
import comfy . utils
2024-06-10 13:26:25 -04:00
import math
2024-03-10 11:37:08 -04:00
import logging
2024-06-19 21:46:37 -04:00
import torch
2023-06-22 13:03:50 -04:00
def count_blocks ( state_dict_keys , prefix_string ) :
count = 0
while True :
c = False
for k in state_dict_keys :
if k . startswith ( prefix_string . format ( count ) ) :
c = True
break
if c == False :
break
count + = 1
return count
2026-02-13 12:35:13 -08:00
def any_suffix_in ( keys , prefix , main , suffix_list = [ ] ) :
for x in suffix_list :
if " {} {} {} " . format ( prefix , main , x ) in keys :
return True
return False
2023-10-27 14:15:45 -04:00
def calculate_transformer_depth ( prefix , state_dict_keys , state_dict ) :
context_dim = None
use_linear_in_transformer = False
transformer_prefix = prefix + " 1.transformer_blocks. "
transformer_keys = sorted ( list ( filter ( lambda a : a . startswith ( transformer_prefix ) , state_dict_keys ) ) )
if len ( transformer_keys ) > 0 :
last_transformer_depth = count_blocks ( state_dict_keys , transformer_prefix + ' {} ' )
context_dim = state_dict [ ' {} 0.attn2.to_k.weight ' . format ( transformer_prefix ) ] . shape [ 1 ]
use_linear_in_transformer = len ( state_dict [ ' {} 1.proj_in.weight ' . format ( prefix ) ] . shape ) == 2
2023-11-23 19:41:33 -05:00
time_stack = ' {} 1.time_stack.0.attn1.to_q.weight ' . format ( prefix ) in state_dict or ' {} 1.time_mix_blocks.0.attn1.to_q.weight ' . format ( prefix ) in state_dict
2024-06-10 13:26:25 -04:00
time_stack_cross = ' {} 1.time_stack.0.attn2.to_q.weight ' . format ( prefix ) in state_dict or ' {} 1.time_mix_blocks.0.attn2.to_q.weight ' . format ( prefix ) in state_dict
return last_transformer_depth , context_dim , use_linear_in_transformer , time_stack , time_stack_cross
2023-10-27 14:15:45 -04:00
return None
2025-03-05 00:13:49 -05:00
def detect_unet_config ( state_dict , key_prefix , metadata = None ) :
2023-06-22 13:03:50 -04:00
state_dict_keys = list ( state_dict . keys ( ) )
2024-06-10 13:26:25 -04:00
if ' {} joint_blocks.0.context_block.attn.qkv.weight ' . format ( key_prefix ) in state_dict_keys : #mmdit model
unet_config = { }
unet_config [ " in_channels " ] = state_dict [ ' {} x_embedder.proj.weight ' . format ( key_prefix ) ] . shape [ 1 ]
patch_size = state_dict [ ' {} x_embedder.proj.weight ' . format ( key_prefix ) ] . shape [ 2 ]
unet_config [ " patch_size " ] = patch_size
2024-06-25 23:40:44 -04:00
final_layer = ' {} final_layer.linear.weight ' . format ( key_prefix )
if final_layer in state_dict :
unet_config [ " out_channels " ] = state_dict [ final_layer ] . shape [ 0 ] / / ( patch_size * patch_size )
2024-06-10 13:26:25 -04:00
unet_config [ " depth " ] = state_dict [ ' {} x_embedder.proj.weight ' . format ( key_prefix ) ] . shape [ 0 ] / / 64
unet_config [ " input_size " ] = None
y_key = ' {} y_embedder.mlp.0.weight ' . format ( key_prefix )
if y_key in state_dict_keys :
unet_config [ " adm_in_channels " ] = state_dict [ y_key ] . shape [ 1 ]
context_key = ' {} context_embedder.weight ' . format ( key_prefix )
if context_key in state_dict_keys :
in_features = state_dict [ context_key ] . shape [ 1 ]
out_features = state_dict [ context_key ] . shape [ 0 ]
unet_config [ " context_embedder_config " ] = { " target " : " torch.nn.Linear " , " params " : { " in_features " : in_features , " out_features " : out_features } }
num_patches_key = ' {} pos_embed ' . format ( key_prefix )
if num_patches_key in state_dict_keys :
num_patches = state_dict [ num_patches_key ] . shape [ 1 ]
unet_config [ " num_patches " ] = num_patches
unet_config [ " pos_embed_max_size " ] = round ( math . sqrt ( num_patches ) )
rms_qk = ' {} joint_blocks.0.context_block.attn.ln_q.weight ' . format ( key_prefix )
if rms_qk in state_dict_keys :
unet_config [ " qk_norm " ] = " rms "
unet_config [ " pos_embed_scaling_factor " ] = None #unused for inference
context_processor = ' {} context_processor.layers.0.attn.qkv.weight ' . format ( key_prefix )
if context_processor in state_dict_keys :
unet_config [ " context_processor_layers " ] = count_blocks ( state_dict_keys , ' {} context_processor.layers. ' . format ( key_prefix ) + ' {} . ' )
2024-10-28 21:58:52 -04:00
unet_config [ " x_block_self_attn_layers " ] = [ ]
for key in state_dict_keys :
if key . startswith ( ' {} joint_blocks. ' . format ( key_prefix ) ) and key . endswith ( ' .x_block.attn2.qkv.weight ' ) :
layer = key [ len ( ' {} joint_blocks. ' . format ( key_prefix ) ) : - len ( ' .x_block.attn2.qkv.weight ' ) ]
unet_config [ " x_block_self_attn_layers " ] . append ( int ( layer ) )
2024-06-10 13:26:25 -04:00
return unet_config
2024-02-16 10:55:08 -05:00
if ' {} clf.1.weight ' . format ( key_prefix ) in state_dict_keys : #stable cascade
unet_config = { }
text_mapper_name = ' {} clip_txt_mapper.weight ' . format ( key_prefix )
if text_mapper_name in state_dict_keys :
unet_config [ ' stable_cascade_stage ' ] = ' c '
w = state_dict [ text_mapper_name ]
if w . shape [ 0 ] == 1536 : #stage c lite
unet_config [ ' c_cond ' ] = 1536
unet_config [ ' c_hidden ' ] = [ 1536 , 1536 ]
unet_config [ ' nhead ' ] = [ 24 , 24 ]
unet_config [ ' blocks ' ] = [ [ 4 , 12 ] , [ 12 , 4 ] ]
elif w . shape [ 0 ] == 2048 : #stage c full
unet_config [ ' c_cond ' ] = 2048
elif ' {} clip_mapper.weight ' . format ( key_prefix ) in state_dict_keys :
unet_config [ ' stable_cascade_stage ' ] = ' b '
2024-02-16 23:41:23 -05:00
w = state_dict [ ' {} down_blocks.1.0.channelwise.0.weight ' . format ( key_prefix ) ]
if w . shape [ - 1 ] == 640 :
unet_config [ ' c_hidden ' ] = [ 320 , 640 , 1280 , 1280 ]
unet_config [ ' nhead ' ] = [ - 1 , - 1 , 20 , 20 ]
unet_config [ ' blocks ' ] = [ [ 2 , 6 , 28 , 6 ] , [ 6 , 28 , 6 , 2 ] ]
unet_config [ ' block_repeat ' ] = [ [ 1 , 1 , 1 , 1 ] , [ 3 , 3 , 2 , 2 ] ]
elif w . shape [ - 1 ] == 576 : #stage b lite
unet_config [ ' c_hidden ' ] = [ 320 , 576 , 1152 , 1152 ]
unet_config [ ' nhead ' ] = [ - 1 , 9 , 18 , 18 ]
unet_config [ ' blocks ' ] = [ [ 2 , 4 , 14 , 4 ] , [ 4 , 14 , 4 , 2 ] ]
unet_config [ ' block_repeat ' ] = [ [ 1 , 1 , 1 , 1 ] , [ 2 , 2 , 2 , 2 ] ]
2024-02-16 10:55:08 -05:00
return unet_config
2024-06-15 12:14:56 -04:00
if ' {} transformer.rotary_pos_emb.inv_freq ' . format ( key_prefix ) in state_dict_keys : #stable audio dit
unet_config = { }
unet_config [ " audio_model " ] = " dit1.0 "
return unet_config
2024-07-11 16:51:06 -04:00
if ' {} double_layers.0.attn.w1q.weight ' . format ( key_prefix ) in state_dict_keys : #aura flow dit
unet_config = { }
unet_config [ " max_seq " ] = state_dict [ ' {} positional_encoding ' . format ( key_prefix ) ] . shape [ 1 ]
unet_config [ " cond_seq_dim " ] = state_dict [ ' {} cond_seq_linear.weight ' . format ( key_prefix ) ] . shape [ 1 ]
2024-07-13 13:51:40 -04:00
double_layers = count_blocks ( state_dict_keys , ' {} double_layers. ' . format ( key_prefix ) + ' {} . ' )
single_layers = count_blocks ( state_dict_keys , ' {} single_layers. ' . format ( key_prefix ) + ' {} . ' )
unet_config [ " n_double_layers " ] = double_layers
unet_config [ " n_layers " ] = double_layers + single_layers
2024-07-11 16:51:06 -04:00
return unet_config
2024-07-25 18:21:08 -04:00
if ' {} mlp_t5.0.weight ' . format ( key_prefix ) in state_dict_keys : #Hunyuan DiT
unet_config = { }
unet_config [ " image_model " ] = " hydit "
unet_config [ " depth " ] = count_blocks ( state_dict_keys , ' {} blocks. ' . format ( key_prefix ) + ' {} . ' )
unet_config [ " hidden_size " ] = state_dict [ ' {} x_embedder.proj.weight ' . format ( key_prefix ) ] . shape [ 0 ]
if unet_config [ " hidden_size " ] == 1408 and unet_config [ " depth " ] == 40 : #DiT-g/2
unet_config [ " mlp_ratio " ] = 4.3637
if state_dict [ ' {} extra_embedder.0.weight ' . format ( key_prefix ) ] . shape [ 1 ] == 3968 :
unet_config [ " size_cond " ] = True
unet_config [ " use_style_cond " ] = True
unet_config [ " image_model " ] = " hydit1 "
return unet_config
2024-12-16 19:35:40 -05:00
if ' {} txt_in.individual_token_refiner.blocks.0.norm1.weight ' . format ( key_prefix ) in state_dict_keys : #Hunyuan Video
dit_config = { }
2025-09-09 23:05:07 -07:00
in_w = state_dict [ ' {} img_in.proj.weight ' . format ( key_prefix ) ]
out_w = state_dict [ ' {} final_layer.linear.weight ' . format ( key_prefix ) ]
2024-12-16 19:35:40 -05:00
dit_config [ " image_model " ] = " hunyuan_video "
2025-09-09 23:05:07 -07:00
dit_config [ " in_channels " ] = in_w . shape [ 1 ] #SkyReels img2video has 32 input channels
dit_config [ " patch_size " ] = list ( in_w . shape [ 2 : ] )
dit_config [ " out_channels " ] = out_w . shape [ 0 ] / / math . prod ( dit_config [ " patch_size " ] )
2025-09-10 20:17:34 -07:00
if any ( s . startswith ( ' {} vector_in. ' . format ( key_prefix ) ) for s in state_dict_keys ) :
2025-09-09 23:05:07 -07:00
dit_config [ " vec_in_dim " ] = 768
else :
dit_config [ " vec_in_dim " ] = None
2025-09-10 20:17:34 -07:00
if len ( dit_config [ " patch_size " ] ) == 2 :
2025-09-09 23:05:07 -07:00
dit_config [ " axes_dim " ] = [ 64 , 64 ]
2025-09-10 20:17:34 -07:00
else :
dit_config [ " axes_dim " ] = [ 16 , 56 , 56 ]
if any ( s . startswith ( ' {} time_r_in. ' . format ( key_prefix ) ) for s in state_dict_keys ) :
dit_config [ " meanflow " ] = True
else :
dit_config [ " meanflow " ] = False
2025-09-09 23:05:07 -07:00
dit_config [ " context_in_dim " ] = state_dict [ ' {} txt_in.input_embedder.weight ' . format ( key_prefix ) ] . shape [ 1 ]
dit_config [ " hidden_size " ] = in_w . shape [ 0 ]
2024-12-16 19:35:40 -05:00
dit_config [ " mlp_ratio " ] = 4.0
2025-09-09 23:05:07 -07:00
dit_config [ " num_heads " ] = in_w . shape [ 0 ] / / 128
2024-12-16 19:35:40 -05:00
dit_config [ " depth " ] = count_blocks ( state_dict_keys , ' {} double_blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " depth_single_blocks " ] = count_blocks ( state_dict_keys , ' {} single_blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " theta " ] = 256
dit_config [ " qkv_bias " ] = True
2025-09-09 23:05:07 -07:00
if ' {} byt5_in.fc1.weight ' . format ( key_prefix ) in state_dict :
dit_config [ " byt5 " ] = True
else :
dit_config [ " byt5 " ] = False
2024-12-16 19:35:40 -05:00
guidance_keys = list ( filter ( lambda a : a . startswith ( " {} guidance_in. " . format ( key_prefix ) ) , state_dict_keys ) )
dit_config [ " guidance_embed " ] = len ( guidance_keys ) > 0
2025-11-20 19:44:43 -08:00
# HunyuanVideo 1.5
if ' {} cond_type_embedding.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " use_cond_type_embedding " ] = True
else :
dit_config [ " use_cond_type_embedding " ] = False
if ' {} vision_in.proj.0.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " vision_in_dim " ] = state_dict [ ' {} vision_in.proj.0.weight ' . format ( key_prefix ) ] . shape [ 0 ]
2025-12-09 23:59:16 +02:00
dit_config [ " meanflow_sum " ] = True
2025-11-20 19:44:43 -08:00
else :
dit_config [ " vision_in_dim " ] = None
2025-12-09 23:59:16 +02:00
dit_config [ " meanflow_sum " ] = False
2024-12-16 19:35:40 -05:00
return dit_config
2026-02-13 12:35:13 -08:00
if any_suffix_in ( state_dict_keys , key_prefix , ' double_blocks.0.img_attn.norm.key_norm. ' , [ " weight " , " scale " ] ) and ( ' {} img_in.weight ' . format ( key_prefix ) in state_dict_keys or any_suffix_in ( state_dict_keys , key_prefix , ' distilled_guidance_layer.norms.0. ' , [ " weight " , " scale " ] ) ) : #Flux, Chroma or Chroma Radiance (has no img_in.weight)
2024-08-01 04:03:59 -04:00
dit_config = { }
2025-11-25 07:50:19 -08:00
if ' {} double_stream_modulation_img.lin.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " image_model " ] = " flux2 "
dit_config [ " axes_dim " ] = [ 32 , 32 , 32 , 32 ]
dit_config [ " num_heads " ] = 48
dit_config [ " mlp_ratio " ] = 3.0
dit_config [ " theta " ] = 2000
dit_config [ " out_channels " ] = 128
dit_config [ " global_modulation " ] = True
dit_config [ " mlp_silu_act " ] = True
dit_config [ " qkv_bias " ] = False
dit_config [ " ops_bias " ] = False
dit_config [ " default_ref_method " ] = " index "
dit_config [ " ref_index_scale " ] = 10.0
2025-12-01 17:56:17 -08:00
dit_config [ " txt_ids_dims " ] = [ 3 ]
2025-11-25 07:50:19 -08:00
patch_size = 1
else :
dit_config [ " image_model " ] = " flux "
dit_config [ " axes_dim " ] = [ 16 , 56 , 56 ]
dit_config [ " num_heads " ] = 24
dit_config [ " mlp_ratio " ] = 4.0
dit_config [ " theta " ] = 10000
dit_config [ " out_channels " ] = 16
dit_config [ " qkv_bias " ] = True
2025-12-01 17:56:17 -08:00
dit_config [ " txt_ids_dims " ] = [ ]
2025-11-25 07:50:19 -08:00
patch_size = 2
2024-08-04 15:45:43 -04:00
dit_config [ " in_channels " ] = 16
2025-11-25 07:50:19 -08:00
dit_config [ " hidden_size " ] = 3072
dit_config [ " context_in_dim " ] = 4096
2024-11-21 08:38:23 -05:00
dit_config [ " patch_size " ] = patch_size
in_key = " {} img_in.weight " . format ( key_prefix )
if in_key in state_dict_keys :
2025-11-25 07:50:19 -08:00
w = state_dict [ in_key ]
dit_config [ " in_channels " ] = w . shape [ 1 ] / / ( patch_size * patch_size )
dit_config [ " hidden_size " ] = w . shape [ 0 ]
txt_in_key = " {} txt_in.weight " . format ( key_prefix )
if txt_in_key in state_dict_keys :
w = state_dict [ txt_in_key ]
dit_config [ " context_in_dim " ] = w . shape [ 1 ]
dit_config [ " hidden_size " ] = w . shape [ 0 ]
2025-04-30 20:57:30 -04:00
vec_in_key = ' {} vector_in.in_layer.weight ' . format ( key_prefix )
if vec_in_key in state_dict_keys :
dit_config [ " vec_in_dim " ] = state_dict [ vec_in_key ] . shape [ 1 ]
2025-12-01 17:56:17 -08:00
else :
dit_config [ " vec_in_dim " ] = None
2025-11-25 07:50:19 -08:00
2026-01-10 14:31:31 -08:00
dit_config [ " num_heads " ] = dit_config [ " hidden_size " ] / / sum ( dit_config [ " axes_dim " ] )
2024-08-07 12:59:28 -04:00
dit_config [ " depth " ] = count_blocks ( state_dict_keys , ' {} double_blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " depth_single_blocks " ] = count_blocks ( state_dict_keys , ' {} single_blocks. ' . format ( key_prefix ) + ' {} . ' )
2026-02-13 12:35:13 -08:00
if any_suffix_in ( state_dict_keys , key_prefix , ' distilled_guidance_layer.0.norms.0. ' , [ " weight " , " scale " ] ) or any_suffix_in ( state_dict_keys , key_prefix , ' distilled_guidance_layer.norms.0. ' , [ " weight " , " scale " ] ) : #Chroma
2025-04-30 20:57:30 -04:00
dit_config [ " image_model " ] = " chroma "
dit_config [ " in_channels " ] = 64
dit_config [ " out_channels " ] = 64
dit_config [ " in_dim " ] = 64
dit_config [ " out_dim " ] = 3072
dit_config [ " hidden_dim " ] = 5120
dit_config [ " n_layers " ] = 5
2026-02-13 12:35:13 -08:00
if any_suffix_in ( state_dict_keys , key_prefix , ' nerf_blocks.0.norm. ' , [ " weight " , " scale " ] ) : #Chroma Radiance
2025-09-13 15:58:43 -06:00
dit_config [ " image_model " ] = " chroma_radiance "
dit_config [ " in_channels " ] = 3
dit_config [ " out_channels " ] = 3
2026-01-21 00:46:11 +01:00
dit_config [ " patch_size " ] = state_dict . get ( ' {} img_in_patch.weight ' . format ( key_prefix ) ) . size ( dim = - 1 )
2025-09-13 15:58:43 -06:00
dit_config [ " nerf_hidden_size " ] = 64
dit_config [ " nerf_mlp_ratio " ] = 4
dit_config [ " nerf_depth " ] = 4
dit_config [ " nerf_max_freqs " ] = 8
2025-10-18 20:19:52 -07:00
dit_config [ " nerf_tile_size " ] = 512
2026-02-13 12:35:13 -08:00
dit_config [ " nerf_final_head_type " ] = " conv " if any_suffix_in ( state_dict_keys , key_prefix , ' nerf_final_layer_conv.norm. ' , [ " weight " , " scale " ] ) else " linear "
2025-09-13 15:58:43 -06:00
dit_config [ " nerf_embedder_dtype " ] = torch . float32
2025-12-16 14:03:17 -08:00
if " {} __x0__ " . format ( key_prefix ) in state_dict_keys : # x0 pred
2025-12-11 14:15:00 -08:00
dit_config [ " use_x0 " ] = True
else :
dit_config [ " use_x0 " ] = False
2025-04-30 20:57:30 -04:00
else :
dit_config [ " guidance_embed " ] = " {} guidance_in.in_layer.weight " . format ( key_prefix ) in state_dict_keys
2025-12-01 17:56:17 -08:00
dit_config [ " yak_mlp " ] = ' {} double_blocks.0.img_mlp.gate_proj.weight ' . format ( key_prefix ) in state_dict_keys
2026-02-13 12:35:13 -08:00
dit_config [ " txt_norm " ] = any_suffix_in ( state_dict_keys , key_prefix , ' txt_norm. ' , [ " weight " , " scale " ] )
2025-12-01 17:56:17 -08:00
if dit_config [ " yak_mlp " ] and dit_config [ " txt_norm " ] : # Ovis model
dit_config [ " txt_ids_dims " ] = [ 1 , 2 ]
2026-02-28 05:04:34 +01:00
if dit_config . get ( " context_in_dim " ) == 3584 and dit_config [ " vec_in_dim " ] is None : # LongCat-Image
dit_config [ " txt_ids_dims " ] = [ 1 , 2 ]
2025-12-01 17:56:17 -08:00
2024-08-01 04:03:59 -04:00
return dit_config
2024-10-26 06:54:00 -04:00
if ' {} t5_yproj.weight ' . format ( key_prefix ) in state_dict_keys : #Genmo mochi preview
dit_config = { }
dit_config [ " image_model " ] = " mochi_preview "
dit_config [ " depth " ] = 48
dit_config [ " patch_size " ] = 2
dit_config [ " num_heads " ] = 24
dit_config [ " hidden_size_x " ] = 3072
dit_config [ " hidden_size_y " ] = 1536
dit_config [ " mlp_ratio_x " ] = 4.0
dit_config [ " mlp_ratio_y " ] = 4.0
dit_config [ " learn_sigma " ] = False
dit_config [ " in_channels " ] = 12
dit_config [ " qk_norm " ] = True
dit_config [ " qkv_bias " ] = False
dit_config [ " out_bias " ] = True
dit_config [ " attn_drop " ] = 0.0
dit_config [ " patch_embed_bias " ] = True
dit_config [ " posenc_preserve_area " ] = True
dit_config [ " timestep_mlp_bias " ] = True
dit_config [ " attend_to_padding " ] = False
dit_config [ " timestep_scale " ] = 1000.0
dit_config [ " use_t5 " ] = True
dit_config [ " t5_feat_dim " ] = 4096
dit_config [ " t5_token_length " ] = 256
dit_config [ " rope_theta " ] = 10000.0
return dit_config
2024-12-20 21:25:00 +01:00
if ' {} adaln_single.emb.timestep_embedder.linear_1.bias ' . format ( key_prefix ) in state_dict_keys and ' {} pos_embed.proj.bias ' . format ( key_prefix ) in state_dict_keys :
# PixArt diffusers
return None
2024-11-22 08:44:42 -05:00
if ' {} adaln_single.emb.timestep_embedder.linear_1.bias ' . format ( key_prefix ) in state_dict_keys : #Lightricks ltxv
dit_config = { }
2026-01-04 22:58:59 -08:00
dit_config [ " image_model " ] = " ltxav " if f ' { key_prefix } audio_adaln_single.linear.weight ' in state_dict_keys else " ltxv "
2025-05-07 18:28:24 -07:00
dit_config [ " num_layers " ] = count_blocks ( state_dict_keys , ' {} transformer_blocks. ' . format ( key_prefix ) + ' {} . ' )
shape = state_dict [ ' {} transformer_blocks.0.attn2.to_k.weight ' . format ( key_prefix ) ] . shape
dit_config [ " attention_head_dim " ] = shape [ 0 ] / / 32
dit_config [ " cross_attention_dim " ] = shape [ 1 ]
2025-03-05 00:13:49 -05:00
if metadata is not None and " config " in metadata :
dit_config . update ( json . loads ( metadata [ " config " ] ) . get ( " transformer " , { } ) )
2024-11-22 08:44:42 -05:00
return dit_config
2024-10-26 06:54:00 -04:00
2025-05-07 05:33:34 -07:00
if ' {} genre_embedder.weight ' . format ( key_prefix ) in state_dict_keys : #ACE-Step model
dit_config = { }
dit_config [ " audio_model " ] = " ace "
dit_config [ " attention_head_dim " ] = 128
dit_config [ " in_channels " ] = 8
dit_config [ " inner_dim " ] = 2560
dit_config [ " max_height " ] = 16
dit_config [ " max_position " ] = 32768
dit_config [ " max_width " ] = 32768
dit_config [ " mlp_ratio " ] = 2.5
dit_config [ " num_attention_heads " ] = 20
dit_config [ " num_layers " ] = 24
dit_config [ " out_channels " ] = 8
dit_config [ " patch_size " ] = [ 16 , 1 ]
dit_config [ " rope_theta " ] = 1000000.0
dit_config [ " speaker_embedding_dim " ] = 512
dit_config [ " text_embedding_dim " ] = 768
dit_config [ " ssl_encoder_depths " ] = [ 8 , 8 ]
dit_config [ " ssl_latent_dims " ] = [ 1024 , 768 ]
dit_config [ " ssl_names " ] = [ " mert " , " m-hubert " ]
dit_config [ " lyric_encoder_vocab_size " ] = 6693
dit_config [ " lyric_hidden_size " ] = 1024
return dit_config
2024-12-20 21:25:00 +01:00
if ' {} t_block.1.weight ' . format ( key_prefix ) in state_dict_keys : # PixArt
patch_size = 2
dit_config = { }
dit_config [ " num_heads " ] = 16
dit_config [ " patch_size " ] = patch_size
dit_config [ " hidden_size " ] = 1152
dit_config [ " in_channels " ] = 4
dit_config [ " depth " ] = count_blocks ( state_dict_keys , ' {} blocks. ' . format ( key_prefix ) + ' {} . ' )
y_key = " {} y_embedder.y_embedding " . format ( key_prefix )
if y_key in state_dict_keys :
dit_config [ " model_max_length " ] = state_dict [ y_key ] . shape [ 0 ]
pe_key = " {} pos_embed " . format ( key_prefix )
if pe_key in state_dict_keys :
dit_config [ " input_size " ] = int ( math . sqrt ( state_dict [ pe_key ] . shape [ 1 ] ) ) * patch_size
dit_config [ " pe_interpolation " ] = dit_config [ " input_size " ] / / ( 512 / / 8 ) # guess
2024-12-27 18:02:21 -05:00
2024-12-20 21:25:00 +01:00
ar_key = " {} ar_embedder.mlp.0.weight " . format ( key_prefix )
if ar_key in state_dict_keys :
dit_config [ " image_model " ] = " pixart_alpha "
dit_config [ " micro_condition " ] = True
else :
dit_config [ " image_model " ] = " pixart_sigma "
dit_config [ " micro_condition " ] = False
return dit_config
2025-02-04 03:56:00 -05:00
if ' {} blocks.block0.blocks.0.block.attn.to_q.0.weight ' . format ( key_prefix ) in state_dict_keys : # Cosmos
2025-01-10 09:11:57 -05:00
dit_config = { }
dit_config [ " image_model " ] = " cosmos "
dit_config [ " max_img_h " ] = 240
dit_config [ " max_img_w " ] = 240
dit_config [ " max_frames " ] = 128
2025-01-14 05:14:10 -05:00
concat_padding_mask = True
dit_config [ " in_channels " ] = ( state_dict [ ' {} x_embedder.proj.1.weight ' . format ( key_prefix ) ] . shape [ 1 ] / / 4 ) - int ( concat_padding_mask )
2025-01-10 09:11:57 -05:00
dit_config [ " out_channels " ] = 16
dit_config [ " patch_spatial " ] = 2
dit_config [ " patch_temporal " ] = 1
dit_config [ " model_channels " ] = state_dict [ ' {} blocks.block0.blocks.0.block.attn.to_q.0.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " block_config " ] = " FA-CA-MLP "
2025-01-14 05:14:10 -05:00
dit_config [ " concat_padding_mask " ] = concat_padding_mask
2025-01-10 09:11:57 -05:00
dit_config [ " pos_emb_cls " ] = " rope3d "
dit_config [ " pos_emb_learnable " ] = False
dit_config [ " pos_emb_interpolation " ] = " crop "
dit_config [ " block_x_format " ] = " THWBD "
dit_config [ " affline_emb_norm " ] = True
dit_config [ " use_adaln_lora " ] = True
dit_config [ " adaln_lora_dim " ] = 256
if dit_config [ " model_channels " ] == 4096 :
# 7B
dit_config [ " num_blocks " ] = 28
dit_config [ " num_heads " ] = 32
dit_config [ " extra_per_block_abs_pos_emb " ] = True
dit_config [ " rope_h_extrapolation_ratio " ] = 1.0
dit_config [ " rope_w_extrapolation_ratio " ] = 1.0
dit_config [ " rope_t_extrapolation_ratio " ] = 2.0
dit_config [ " extra_per_block_abs_pos_emb_type " ] = " learnable "
else : # 5120
# 14B
dit_config [ " num_blocks " ] = 36
dit_config [ " num_heads " ] = 40
dit_config [ " extra_per_block_abs_pos_emb " ] = True
dit_config [ " rope_h_extrapolation_ratio " ] = 2.0
dit_config [ " rope_w_extrapolation_ratio " ] = 2.0
dit_config [ " rope_t_extrapolation_ratio " ] = 2.0
dit_config [ " extra_h_extrapolation_ratio " ] = 2.0
dit_config [ " extra_w_extrapolation_ratio " ] = 2.0
dit_config [ " extra_t_extrapolation_ratio " ] = 2.0
dit_config [ " extra_per_block_abs_pos_emb_type " ] = " learnable "
return dit_config
2026-03-02 15:54:18 -08:00
if ' {} cap_embedder.1.weight ' . format ( key_prefix ) in state_dict_keys and ' {} noise_refiner.0.attention.k_norm.weight ' . format ( key_prefix ) in state_dict_keys : # Lumina 2
2025-02-04 03:56:00 -05:00
dit_config = { }
dit_config [ " image_model " ] = " lumina2 "
dit_config [ " patch_size " ] = 2
dit_config [ " in_channels " ] = 16
2025-11-25 15:41:45 -08:00
w = state_dict [ ' {} cap_embedder.1.weight ' . format ( key_prefix ) ]
dit_config [ " dim " ] = w . shape [ 0 ]
dit_config [ " cap_feat_dim " ] = w . shape [ 1 ]
2025-10-06 19:08:08 -07:00
dit_config [ " n_layers " ] = count_blocks ( state_dict_keys , ' {} layers. ' . format ( key_prefix ) + ' {} . ' )
2025-02-04 03:56:00 -05:00
dit_config [ " qk_norm " ] = True
2025-11-25 15:41:45 -08:00
if dit_config [ " dim " ] == 2304 : # Original Lumina 2
dit_config [ " n_heads " ] = 24
dit_config [ " n_kv_heads " ] = 8
dit_config [ " axes_dims " ] = [ 32 , 32 , 32 ]
dit_config [ " axes_lens " ] = [ 300 , 512 , 512 ]
dit_config [ " rope_theta " ] = 10000.0
dit_config [ " ffn_dim_multiplier " ] = 4.0
2025-12-07 04:44:55 -08:00
ctd_weight = state_dict . get ( ' {} clip_text_pooled_proj.0.weight ' . format ( key_prefix ) , None )
2025-12-20 13:57:22 +08:00
if ctd_weight is not None : # NewBie
2025-12-07 04:44:55 -08:00
dit_config [ " clip_text_dim " ] = ctd_weight . shape [ 0 ]
2025-12-20 13:57:22 +08:00
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
2025-11-25 15:41:45 -08:00
elif dit_config [ " dim " ] == 3840 : # Z image
dit_config [ " n_heads " ] = 30
dit_config [ " n_kv_heads " ] = 30
dit_config [ " axes_dims " ] = [ 32 , 48 , 48 ]
dit_config [ " axes_lens " ] = [ 1536 , 512 , 512 ]
dit_config [ " rope_theta " ] = 256.0
dit_config [ " ffn_dim_multiplier " ] = ( 8.0 / 3.0 )
dit_config [ " z_image_modulation " ] = True
dit_config [ " time_scale " ] = 1000.0
2026-01-24 19:32:28 -08:00
try :
dit_config [ " allow_fp16 " ] = torch . std ( state_dict [ ' {} layers. {} .ffn_norm1.weight ' . format ( key_prefix , dit_config [ " n_layers " ] - 2 ) ] , unbiased = False ) . item ( ) < 0.42
except Exception :
pass
2025-11-25 15:41:45 -08:00
if ' {} cap_pad_token ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " pad_tokens_multiple " ] = 32
2026-01-19 20:17:38 -08:00
sig_weight = state_dict . get ( ' {} siglip_embedder.0.weight ' . format ( key_prefix ) , None )
if sig_weight is not None :
dit_config [ " siglip_feat_dim " ] = sig_weight . shape [ 0 ]
2025-11-25 15:41:45 -08:00
2026-03-03 07:43:47 +07:00
dec_cond_key = ' {} dec_net.cond_embed.weight ' . format ( key_prefix )
if dec_cond_key in state_dict_keys : # pixel-space variant
dit_config [ " image_model " ] = " zimage_pixel "
# patch_size and in_channels are derived from x_embedder:
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
x_emb_in = state_dict [ ' {} x_embedder.weight ' . format ( key_prefix ) ] . shape [ 1 ]
dec_out = state_dict [ ' {} dec_net.final_layer.linear.weight ' . format ( key_prefix ) ] . shape [ 0 ]
# patch_size: infer from decoder final layer output matching x_embedder input
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
embedder_w = state_dict [ ' {} dec_net.input_embedder.embedder.0.weight ' . format ( key_prefix ) ]
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
dit_config [ " patch_size " ] = round ( ( x_emb_in / 3 ) * * 0.5 ) # assume RGB (in_channels=3)
dit_config [ " in_channels " ] = 3
dit_config [ " decoder_in_channels " ] = dec_in_ch
dit_config [ " decoder_hidden_size " ] = state_dict [ dec_cond_key ] . shape [ 0 ]
dit_config [ " decoder_num_res_blocks " ] = count_blocks (
state_dict_keys , ' {} dec_net.res_blocks. ' . format ( key_prefix ) + ' {} . '
)
dit_config [ " decoder_max_freqs " ] = int ( ( embedder_w . shape [ 1 ] - dec_in_ch ) * * 0.5 )
if ' {} __x0__ ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " use_x0 " ] = True
2025-02-04 03:56:00 -05:00
return dit_config
2025-02-25 17:20:35 -05:00
if ' {} head.modulation ' . format ( key_prefix ) in state_dict_keys : # Wan 2.1
dit_config = { }
dit_config [ " image_model " ] = " wan2.1 "
dim = state_dict [ ' {} head.modulation ' . format ( key_prefix ) ] . shape [ - 1 ]
2025-07-28 05:00:23 -07:00
out_dim = state_dict [ ' {} head.head.weight ' . format ( key_prefix ) ] . shape [ 0 ] / / 4
2025-02-25 17:20:35 -05:00
dit_config [ " dim " ] = dim
2025-07-28 05:00:23 -07:00
dit_config [ " out_dim " ] = out_dim
2025-02-25 17:20:35 -05:00
dit_config [ " num_heads " ] = dim / / 128
dit_config [ " ffn_dim " ] = state_dict [ ' {} blocks.0.ffn.0.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " num_layers " ] = count_blocks ( state_dict_keys , ' {} blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " patch_size " ] = ( 1 , 2 , 2 )
dit_config [ " freq_dim " ] = 256
dit_config [ " window_size " ] = ( - 1 , - 1 )
dit_config [ " qk_norm " ] = True
dit_config [ " cross_attn_norm " ] = True
dit_config [ " eps " ] = 1e-6
2025-02-26 01:49:43 -05:00
dit_config [ " in_dim " ] = state_dict [ ' {} patch_embedding.weight ' . format ( key_prefix ) ] . shape [ 1 ]
2025-04-21 11:40:29 -07:00
if ' {} vace_patch_embedding.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " vace "
dit_config [ " vace_in_dim " ] = state_dict [ ' {} vace_patch_embedding.weight ' . format ( key_prefix ) ] . shape [ 1 ]
dit_config [ " vace_layers " ] = count_blocks ( state_dict_keys , ' {} vace_blocks. ' . format ( key_prefix ) + ' {} . ' )
2025-05-15 16:02:19 -07:00
elif ' {} control_adapter.conv.weight ' . format ( key_prefix ) in state_dict_keys :
2025-08-15 14:29:58 -07:00
if ' {} img_emb.proj.0.bias ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " camera "
else :
dit_config [ " model_type " ] = " camera_2.2 "
2025-08-26 22:10:34 -07:00
elif ' {} casual_audio_encoder.encoder.final_linear.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " s2v "
2025-09-16 21:12:48 -07:00
elif ' {} audio_proj.audio_proj_glob_1.layer.bias ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " humo "
2025-09-19 00:07:17 -07:00
elif ' {} face_adapter.fuser_blocks.0.k_norm.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " animate "
2026-02-28 23:49:12 +02:00
elif ' {} patch_embedding_pose.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " scail "
2025-02-25 17:20:35 -05:00
else :
2025-04-21 11:40:29 -07:00
if ' {} img_emb.proj.0.bias ' . format ( key_prefix ) in state_dict_keys :
dit_config [ " model_type " ] = " i2v "
else :
dit_config [ " model_type " ] = " t2v "
2025-04-17 12:04:48 -04:00
flf_weight = state_dict . get ( ' {} img_emb.emb_pos ' . format ( key_prefix ) )
if flf_weight is not None :
dit_config [ " flf_pos_embed_token_number " ] = flf_weight . shape [ 1 ]
2025-08-12 20:26:33 -07:00
ref_conv_weight = state_dict . get ( ' {} ref_conv.weight ' . format ( key_prefix ) )
if ref_conv_weight is not None :
dit_config [ " in_dim_ref_conv " ] = ref_conv_weight . shape [ 1 ]
2026-02-26 06:38:46 +02:00
if metadata is not None and " config " in metadata :
dit_config . update ( json . loads ( metadata [ " config " ] ) . get ( " transformer " , { } ) )
2025-02-25 17:20:35 -05:00
return dit_config
2025-03-19 16:19:50 -04:00
if ' {} latent_in.weight ' . format ( key_prefix ) in state_dict_keys : # Hunyuan 3D
in_shape = state_dict [ ' {} latent_in.weight ' . format ( key_prefix ) ] . shape
dit_config = { }
dit_config [ " image_model " ] = " hunyuan3d2 "
dit_config [ " in_channels " ] = in_shape [ 1 ]
dit_config [ " context_in_dim " ] = state_dict [ ' {} cond_in.weight ' . format ( key_prefix ) ] . shape [ 1 ]
dit_config [ " hidden_size " ] = in_shape [ 0 ]
dit_config [ " mlp_ratio " ] = 4.0
dit_config [ " num_heads " ] = 16
dit_config [ " depth " ] = count_blocks ( state_dict_keys , ' {} double_blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " depth_single_blocks " ] = count_blocks ( state_dict_keys , ' {} single_blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " qkv_bias " ] = True
dit_config [ " guidance_embed " ] = " {} guidance_in.in_layer.weight " . format ( key_prefix ) in state_dict_keys
return dit_config
2026-03-02 15:54:18 -08:00
if f " { key_prefix } t_embedder.mlp.2.weight " in state_dict_keys and f " { key_prefix } blocks.0.attn1.k_norm.weight " in state_dict_keys : # Hunyuan 3D 2.1
2025-09-05 03:36:20 +03:00
dit_config = { }
dit_config [ " image_model " ] = " hunyuan3d2_1 "
dit_config [ " in_channels " ] = state_dict [ f " { key_prefix } x_embedder.weight " ] . shape [ 1 ]
dit_config [ " context_dim " ] = 1024
dit_config [ " hidden_size " ] = state_dict [ f " { key_prefix } x_embedder.weight " ] . shape [ 0 ]
dit_config [ " mlp_ratio " ] = 4.0
dit_config [ " num_heads " ] = 16
dit_config [ " depth " ] = count_blocks ( state_dict_keys , f " { key_prefix } blocks. {{ }} " )
dit_config [ " qkv_bias " ] = False
dit_config [ " guidance_cond_proj_dim " ] = None #f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
2025-04-15 17:35:05 -04:00
if ' {} caption_projection.0.linear.weight ' . format ( key_prefix ) in state_dict_keys : # HiDream
dit_config = { }
dit_config [ " image_model " ] = " hidream "
dit_config [ " attention_head_dim " ] = 128
dit_config [ " axes_dims_rope " ] = [ 64 , 32 , 32 ]
dit_config [ " caption_channels " ] = [ 4096 , 4096 ]
dit_config [ " max_resolution " ] = [ 128 , 128 ]
dit_config [ " in_channels " ] = 16
dit_config [ " llama_layers " ] = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 , 31 ]
dit_config [ " num_attention_heads " ] = 20
dit_config [ " num_routed_experts " ] = 4
dit_config [ " num_activated_experts " ] = 2
dit_config [ " num_layers " ] = 16
dit_config [ " num_single_layers " ] = 32
dit_config [ " out_channels " ] = 16
dit_config [ " patch_size " ] = 2
dit_config [ " text_emb_dim " ] = 2048
return dit_config
2025-06-13 04:05:23 -07:00
if ' {} blocks.0.mlp.layer1.weight ' . format ( key_prefix ) in state_dict_keys : # Cosmos predict2
dit_config = { }
dit_config [ " image_model " ] = " cosmos_predict2 "
2026-01-21 16:44:28 -08:00
if " {} llm_adapter.blocks.0.cross_attn.q_proj.weight " . format ( key_prefix ) in state_dict_keys :
dit_config [ " image_model " ] = " anima "
2025-06-13 04:05:23 -07:00
dit_config [ " max_img_h " ] = 240
dit_config [ " max_img_w " ] = 240
dit_config [ " max_frames " ] = 128
concat_padding_mask = True
dit_config [ " in_channels " ] = ( state_dict [ ' {} x_embedder.proj.1.weight ' . format ( key_prefix ) ] . shape [ 1 ] / / 4 ) - int ( concat_padding_mask )
dit_config [ " out_channels " ] = 16
dit_config [ " patch_spatial " ] = 2
dit_config [ " patch_temporal " ] = 1
dit_config [ " model_channels " ] = state_dict [ ' {} x_embedder.proj.1.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " concat_padding_mask " ] = concat_padding_mask
dit_config [ " crossattn_emb_channels " ] = 1024
dit_config [ " pos_emb_cls " ] = " rope3d "
dit_config [ " pos_emb_learnable " ] = True
dit_config [ " pos_emb_interpolation " ] = " crop "
dit_config [ " min_fps " ] = 1
dit_config [ " max_fps " ] = 30
dit_config [ " use_adaln_lora " ] = True
dit_config [ " adaln_lora_dim " ] = 256
if dit_config [ " model_channels " ] == 2048 :
dit_config [ " num_blocks " ] = 28
dit_config [ " num_heads " ] = 16
elif dit_config [ " model_channels " ] == 5120 :
dit_config [ " num_blocks " ] = 36
dit_config [ " num_heads " ] = 40
if dit_config [ " in_channels " ] == 16 :
dit_config [ " extra_per_block_abs_pos_emb " ] = False
dit_config [ " rope_h_extrapolation_ratio " ] = 4.0
dit_config [ " rope_w_extrapolation_ratio " ] = 4.0
dit_config [ " rope_t_extrapolation_ratio " ] = 1.0
2025-06-14 18:37:07 -07:00
elif dit_config [ " in_channels " ] == 17 : # img to video
if dit_config [ " model_channels " ] == 2048 :
dit_config [ " extra_per_block_abs_pos_emb " ] = False
dit_config [ " rope_h_extrapolation_ratio " ] = 3.0
dit_config [ " rope_w_extrapolation_ratio " ] = 3.0
dit_config [ " rope_t_extrapolation_ratio " ] = 1.0
elif dit_config [ " model_channels " ] == 5120 :
dit_config [ " rope_h_extrapolation_ratio " ] = 2.0
dit_config [ " rope_w_extrapolation_ratio " ] = 2.0
dit_config [ " rope_t_extrapolation_ratio " ] = 0.8333333333333334
2025-06-13 04:05:23 -07:00
dit_config [ " extra_h_extrapolation_ratio " ] = 1.0
dit_config [ " extra_w_extrapolation_ratio " ] = 1.0
dit_config [ " extra_t_extrapolation_ratio " ] = 1.0
dit_config [ " rope_enable_fps_modulation " ] = False
return dit_config
2025-06-25 16:35:57 -07:00
if ' {} time_caption_embed.timestep_embedder.linear_1.bias ' . format ( key_prefix ) in state_dict_keys : # Omnigen2
dit_config = { }
dit_config [ " image_model " ] = " omnigen2 "
dit_config [ " axes_dim_rope " ] = [ 40 , 40 , 40 ]
dit_config [ " axes_lens " ] = [ 1024 , 1664 , 1664 ]
dit_config [ " ffn_dim_multiplier " ] = None
dit_config [ " hidden_size " ] = 2520
dit_config [ " in_channels " ] = 16
dit_config [ " multiple_of " ] = 256
dit_config [ " norm_eps " ] = 1e-05
dit_config [ " num_attention_heads " ] = 21
dit_config [ " num_kv_heads " ] = 7
dit_config [ " num_layers " ] = 32
dit_config [ " num_refiner_layers " ] = 2
dit_config [ " out_channels " ] = None
dit_config [ " patch_size " ] = 2
dit_config [ " text_feat_dim " ] = 2048
dit_config [ " timestep_scale " ] = 1000.0
return dit_config
2025-08-04 19:53:25 -07:00
if ' {} txt_norm.weight ' . format ( key_prefix ) in state_dict_keys : # Qwen Image
dit_config = { }
dit_config [ " image_model " ] = " qwen_image "
2025-08-21 20:18:04 -07:00
dit_config [ " in_channels " ] = state_dict [ ' {} img_in.weight ' . format ( key_prefix ) ] . shape [ 1 ]
dit_config [ " num_layers " ] = count_blocks ( state_dict_keys , ' {} transformer_blocks. ' . format ( key_prefix ) + ' {} . ' )
2025-12-16 14:03:17 -08:00
if " {} __index_timestep_zero__ " . format ( key_prefix ) in state_dict_keys : # 2511
dit_config [ " default_ref_method " ] = " index_timestep_zero "
2025-12-18 17:21:14 -08:00
if " {} time_text_embed.addition_t_embedding.weight " . format ( key_prefix ) in state_dict_keys : # Layered
dit_config [ " use_additional_t_cond " ] = True
dit_config [ " default_ref_method " ] = " negative_index "
2025-08-04 19:53:25 -07:00
return dit_config
2025-12-06 05:20:22 +02:00
if ' {} visual_transformer_blocks.0.cross_attention.key_norm.weight ' . format ( key_prefix ) in state_dict_keys : # Kandinsky 5
dit_config = { }
model_dim = state_dict [ ' {} visual_embeddings.in_layer.bias ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " model_dim " ] = model_dim
if model_dim in [ 4096 , 2560 ] : # pro video and lite image
dit_config [ " axes_dims " ] = ( 32 , 48 , 48 )
if model_dim == 2560 : # lite image
dit_config [ " rope_scale_factor " ] = ( 1.0 , 1.0 , 1.0 )
elif model_dim == 1792 : # lite video
dit_config [ " axes_dims " ] = ( 16 , 24 , 24 )
dit_config [ " time_dim " ] = state_dict [ ' {} time_embeddings.in_layer.bias ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " image_model " ] = " kandinsky5 "
dit_config [ " ff_dim " ] = state_dict [ ' {} visual_transformer_blocks.0.feed_forward.in_layer.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " visual_embed_dim " ] = state_dict [ ' {} visual_embeddings.in_layer.weight ' . format ( key_prefix ) ] . shape [ 1 ]
dit_config [ " num_text_blocks " ] = count_blocks ( state_dict_keys , ' {} text_transformer_blocks. ' . format ( key_prefix ) + ' {} . ' )
dit_config [ " num_visual_blocks " ] = count_blocks ( state_dict_keys , ' {} visual_transformer_blocks. ' . format ( key_prefix ) + ' {} . ' )
return dit_config
2026-02-02 21:06:18 -08:00
if ' {} encoder.lyric_encoder.layers.0.input_layernorm.weight ' . format ( key_prefix ) in state_dict_keys :
dit_config = { }
dit_config [ " audio_model " ] = " ace1.5 "
2026-04-07 00:13:47 -07:00
head_dim = 128
dit_config [ " hidden_size " ] = state_dict [ ' {} decoder.layers.0.self_attn_norm.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " intermediate_size " ] = state_dict [ ' {} decoder.layers.0.mlp.gate_proj.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " num_heads " ] = state_dict [ ' {} decoder.layers.0.self_attn.q_proj.weight ' . format ( key_prefix ) ] . shape [ 0 ] / / head_dim
dit_config [ " encoder_hidden_size " ] = state_dict [ ' {} encoder.lyric_encoder.layers.0.input_layernorm.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " encoder_num_heads " ] = state_dict [ ' {} encoder.lyric_encoder.layers.0.self_attn.q_proj.weight ' . format ( key_prefix ) ] . shape [ 0 ] / / head_dim
dit_config [ " encoder_intermediate_size " ] = state_dict [ ' {} encoder.lyric_encoder.layers.0.mlp.gate_proj.weight ' . format ( key_prefix ) ] . shape [ 0 ]
dit_config [ " num_dit_layers " ] = count_blocks ( state_dict_keys , ' {} decoder.layers. ' . format ( key_prefix ) + ' {} . ' )
2026-02-02 21:06:18 -08:00
return dit_config
2026-03-29 06:34:10 +03:00
if ' {} encoder.pan_blocks.1.cv4.conv.weight ' . format ( key_prefix ) in state_dict_keys : # RT-DETR_v4
dit_config = { }
dit_config [ " image_model " ] = " RT_DETR_v4 "
dit_config [ " enc_h " ] = state_dict [ ' {} encoder.pan_blocks.1.cv4.conv.weight ' . format ( key_prefix ) ] . shape [ 0 ]
return dit_config
2026-04-11 19:29:31 -07:00
if ' {} layers.0.mlp.linear_fc2.weight ' . format ( key_prefix ) in state_dict_keys : # Ernie Image
dit_config = { }
dit_config [ " image_model " ] = " ernie "
return dit_config
2026-04-23 07:07:43 +03:00
if ' detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight ' in state_dict_keys : # SAM3 / SAM3.1
if ' detector.transformer.decoder.query_embed.weight ' in state_dict_keys :
dit_config = { }
dit_config [ " image_model " ] = " SAM3 "
if ' detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight ' in state_dict_keys :
dit_config [ " image_model " ] = " SAM31 "
return dit_config
2024-07-11 11:37:31 -04:00
if ' {} input_blocks.0.0.weight ' . format ( key_prefix ) not in state_dict_keys :
return None
2023-06-22 13:03:50 -04:00
unet_config = {
" use_checkpoint " : False ,
" image_size " : 32 ,
" use_spatial_transformer " : True ,
" legacy " : False
}
y_input = ' {} label_emb.0.0.weight ' . format ( key_prefix )
if y_input in state_dict_keys :
unet_config [ " num_classes " ] = " sequential "
unet_config [ " adm_in_channels " ] = state_dict [ y_input ] . shape [ 1 ]
else :
unet_config [ " adm_in_channels " ] = None
model_channels = state_dict [ ' {} input_blocks.0.0.weight ' . format ( key_prefix ) ] . shape [ 0 ]
in_channels = state_dict [ ' {} input_blocks.0.0.weight ' . format ( key_prefix ) ] . shape [ 1 ]
2024-01-02 14:41:33 -05:00
out_key = ' {} out.2.weight ' . format ( key_prefix )
if out_key in state_dict :
out_channels = state_dict [ out_key ] . shape [ 0 ]
else :
out_channels = 4
2023-06-22 13:03:50 -04:00
num_res_blocks = [ ]
channel_mult = [ ]
transformer_depth = [ ]
2023-10-27 14:15:45 -04:00
transformer_depth_output = [ ]
2023-06-22 13:03:50 -04:00
context_dim = None
use_linear_in_transformer = False
2023-11-23 19:41:33 -05:00
video_model = False
2024-06-10 13:26:25 -04:00
video_model_cross = False
2023-06-22 13:03:50 -04:00
current_res = 1
count = 0
last_res_blocks = 0
last_channel_mult = 0
2023-10-27 14:15:45 -04:00
input_block_count = count_blocks ( state_dict_keys , ' {} input_blocks ' . format ( key_prefix ) + ' . {} . ' )
for count in range ( input_block_count ) :
2023-06-22 13:03:50 -04:00
prefix = ' {} input_blocks. {} . ' . format ( key_prefix , count )
2023-10-27 14:15:45 -04:00
prefix_output = ' {} output_blocks. {} . ' . format ( key_prefix , input_block_count - count - 1 )
2023-06-22 13:03:50 -04:00
block_keys = sorted ( list ( filter ( lambda a : a . startswith ( prefix ) , state_dict_keys ) ) )
if len ( block_keys ) == 0 :
break
2023-10-27 14:15:45 -04:00
block_keys_output = sorted ( list ( filter ( lambda a : a . startswith ( prefix_output ) , state_dict_keys ) ) )
2023-06-22 13:03:50 -04:00
if " {} 0.op.weight " . format ( prefix ) in block_keys : #new layer
num_res_blocks . append ( last_res_blocks )
channel_mult . append ( last_channel_mult )
current_res * = 2
last_res_blocks = 0
last_channel_mult = 0
2023-10-27 14:15:45 -04:00
out = calculate_transformer_depth ( prefix_output , state_dict_keys , state_dict )
if out is not None :
transformer_depth_output . append ( out [ 0 ] )
else :
transformer_depth_output . append ( 0 )
2023-06-22 13:03:50 -04:00
else :
res_block_prefix = " {} 0.in_layers.0.weight " . format ( prefix )
if res_block_prefix in block_keys :
last_res_blocks + = 1
last_channel_mult = state_dict [ " {} 0.out_layers.3.weight " . format ( prefix ) ] . shape [ 0 ] / / model_channels
2023-10-27 14:15:45 -04:00
out = calculate_transformer_depth ( prefix , state_dict_keys , state_dict )
if out is not None :
transformer_depth . append ( out [ 0 ] )
if context_dim is None :
context_dim = out [ 1 ]
use_linear_in_transformer = out [ 2 ]
2023-11-23 19:41:33 -05:00
video_model = out [ 3 ]
2024-06-10 13:26:25 -04:00
video_model_cross = out [ 4 ]
2023-10-27 14:15:45 -04:00
else :
transformer_depth . append ( 0 )
res_block_prefix = " {} 0.in_layers.0.weight " . format ( prefix_output )
if res_block_prefix in block_keys_output :
out = calculate_transformer_depth ( prefix_output , state_dict_keys , state_dict )
if out is not None :
transformer_depth_output . append ( out [ 0 ] )
else :
transformer_depth_output . append ( 0 )
2023-06-22 13:03:50 -04:00
num_res_blocks . append ( last_res_blocks )
channel_mult . append ( last_channel_mult )
2023-10-27 14:15:45 -04:00
if " {} middle_block.1.proj_in.weight " . format ( key_prefix ) in state_dict_keys :
transformer_depth_middle = count_blocks ( state_dict_keys , ' {} middle_block.1.transformer_blocks. ' . format ( key_prefix ) + ' {} ' )
2024-02-28 11:55:06 -05:00
elif " {} middle_block.0.in_layers.0.weight " . format ( key_prefix ) in state_dict_keys :
2023-10-27 14:15:45 -04:00
transformer_depth_middle = - 1
2024-02-28 11:55:06 -05:00
else :
transformer_depth_middle = - 2
2023-06-22 13:03:50 -04:00
unet_config [ " in_channels " ] = in_channels
2024-01-02 01:50:57 -05:00
unet_config [ " out_channels " ] = out_channels
2023-06-22 13:03:50 -04:00
unet_config [ " model_channels " ] = model_channels
unet_config [ " num_res_blocks " ] = num_res_blocks
unet_config [ " transformer_depth " ] = transformer_depth
2023-10-27 14:15:45 -04:00
unet_config [ " transformer_depth_output " ] = transformer_depth_output
2023-06-22 13:03:50 -04:00
unet_config [ " channel_mult " ] = channel_mult
unet_config [ " transformer_depth_middle " ] = transformer_depth_middle
unet_config [ ' use_linear_in_transformer ' ] = use_linear_in_transformer
unet_config [ " context_dim " ] = context_dim
2023-11-23 19:41:33 -05:00
if video_model :
unet_config [ " extra_ff_mix_layer " ] = True
unet_config [ " use_spatial_context " ] = True
unet_config [ " merge_strategy " ] = " learned_with_images "
unet_config [ " merge_factor " ] = 0.0
unet_config [ " video_kernel_size " ] = [ 3 , 1 , 1 ]
unet_config [ " use_temporal_resblock " ] = True
unet_config [ " use_temporal_attention " ] = True
2024-06-10 13:26:25 -04:00
unet_config [ " disable_temporal_crossattention " ] = not video_model_cross
2023-11-23 19:41:33 -05:00
else :
unet_config [ " use_temporal_resblock " ] = False
unet_config [ " use_temporal_attention " ] = False
2026-02-27 02:59:05 +02:00
heatmap_key = ' {} heatmap_head.conv_layers.0.weight ' . format ( key_prefix )
if heatmap_key in state_dict_keys :
unet_config [ " heatmap_head " ] = True
2023-06-22 13:03:50 -04:00
return unet_config
2024-03-31 01:25:16 -04:00
def model_config_from_unet_config ( unet_config , state_dict = None ) :
2023-09-23 18:47:46 -04:00
for model_config in comfy . supported_models . models :
2024-03-31 01:25:16 -04:00
if model_config . matches ( unet_config , state_dict ) :
2023-06-22 13:03:50 -04:00
return model_config ( unet_config )
2024-03-10 11:37:08 -04:00
logging . error ( " no match {} " . format ( unet_config ) )
2023-06-22 13:03:50 -04:00
return None
2023-07-05 17:34:45 -04:00
2025-03-05 00:13:49 -05:00
def model_config_from_unet ( state_dict , unet_key_prefix , use_base_if_no_match = False , metadata = None ) :
unet_config = detect_unet_config ( state_dict , unet_key_prefix , metadata = metadata )
2024-07-11 11:37:31 -04:00
if unet_config is None :
return None
2024-03-31 01:25:16 -04:00
model_config = model_config_from_unet_config ( unet_config , state_dict )
2023-09-23 18:47:46 -04:00
if model_config is None and use_base_if_no_match :
2024-10-19 23:47:42 -04:00
model_config = comfy . supported_models_base . BASE ( unet_config )
2025-10-28 21:20:53 +01:00
# Detect per-layer quantization (mixed precision)
2025-12-05 11:35:42 -08:00
quant_config = comfy . utils . detect_layer_quantization ( state_dict , unet_key_prefix )
if quant_config :
model_config . quant_config = quant_config
logging . info ( " Detected mixed precision quantization " )
2025-10-28 21:20:53 +01:00
2024-10-19 23:47:42 -04:00
return model_config
2023-07-21 22:58:16 -04:00
2024-06-15 12:14:56 -04:00
def unet_prefix_from_state_dict ( state_dict ) :
2026-04-23 07:07:43 +03:00
# SAM3: detector.* and tracker.* at top level, no common prefix
if any ( k . startswith ( " detector. " ) for k in state_dict ) and any ( k . startswith ( " tracker. " ) for k in state_dict ) :
return " "
2024-07-23 14:13:32 -04:00
candidates = [ " model.diffusion_model. " , #ldm/sgm models
" model.model. " , #audio models
2025-01-10 09:11:57 -05:00
" net. " , #cosmos
2024-07-23 14:13:32 -04:00
]
counts = { k : 0 for k in candidates }
for k in state_dict :
for c in candidates :
if k . startswith ( c ) :
counts [ c ] + = 1
break
top = max ( counts , key = counts . get )
if counts [ top ] > 5 :
return top
2024-06-15 12:14:56 -04:00
else :
2024-07-23 14:13:32 -04:00
return " model. " #aura flow and others
2024-06-15 12:14:56 -04:00
2023-10-27 14:15:45 -04:00
def convert_config ( unet_config ) :
new_config = unet_config . copy ( )
num_res_blocks = new_config . get ( " num_res_blocks " , None )
channel_mult = new_config . get ( " channel_mult " , None )
if isinstance ( num_res_blocks , int ) :
num_res_blocks = len ( channel_mult ) * [ num_res_blocks ]
if " attention_resolutions " in new_config :
attention_resolutions = new_config . pop ( " attention_resolutions " )
transformer_depth = new_config . get ( " transformer_depth " , None )
transformer_depth_middle = new_config . get ( " transformer_depth_middle " , None )
if isinstance ( transformer_depth , int ) :
transformer_depth = len ( channel_mult ) * [ transformer_depth ]
if transformer_depth_middle is None :
transformer_depth_middle = transformer_depth [ - 1 ]
t_in = [ ]
t_out = [ ]
s = 1
for i in range ( len ( num_res_blocks ) ) :
res = num_res_blocks [ i ]
d = 0
if s in attention_resolutions :
d = transformer_depth [ i ]
t_in + = [ d ] * res
t_out + = [ d ] * ( res + 1 )
s * = 2
transformer_depth = t_in
new_config [ " transformer_depth " ] = t_in
new_config [ " transformer_depth_output " ] = t_out
new_config [ " transformer_depth_middle " ] = transformer_depth_middle
new_config [ " num_res_blocks " ] = num_res_blocks
return new_config
2024-02-16 10:55:08 -05:00
def unet_config_from_diffusers_unet ( state_dict , dtype = None ) :
2025-05-25 02:28:11 -07:00
if " conv_in.weight " not in state_dict :
return None
2023-07-21 22:58:16 -04:00
match = { }
2023-11-16 23:12:55 -05:00
transformer_depth = [ ]
2023-08-16 12:22:46 -04:00
attn_res = 1
2023-11-16 23:12:55 -05:00
down_blocks = count_blocks ( state_dict , " down_blocks. {} " )
for i in range ( down_blocks ) :
attn_blocks = count_blocks ( state_dict , " down_blocks. {} .attentions. " . format ( i ) + ' {} ' )
2024-02-28 11:55:06 -05:00
res_blocks = count_blocks ( state_dict , " down_blocks. {} .resnets. " . format ( i ) + ' {} ' )
2023-11-16 23:12:55 -05:00
for ab in range ( attn_blocks ) :
transformer_count = count_blocks ( state_dict , " down_blocks. {} .attentions. {} .transformer_blocks. " . format ( i , ab ) + ' {} ' )
transformer_depth . append ( transformer_count )
if transformer_count > 0 :
match [ " context_dim " ] = state_dict [ " down_blocks. {} .attentions. {} .transformer_blocks.0.attn2.to_k.weight " . format ( i , ab ) ] . shape [ 1 ]
2023-08-16 12:22:46 -04:00
attn_res * = 2
2023-11-16 23:12:55 -05:00
if attn_blocks == 0 :
2024-02-28 11:55:06 -05:00
for i in range ( res_blocks ) :
transformer_depth . append ( 0 )
2023-08-16 12:22:46 -04:00
2023-11-16 23:12:55 -05:00
match [ " transformer_depth " ] = transformer_depth
2023-08-16 12:22:46 -04:00
2023-07-21 22:58:16 -04:00
match [ " model_channels " ] = state_dict [ " conv_in.weight " ] . shape [ 0 ]
match [ " in_channels " ] = state_dict [ " conv_in.weight " ] . shape [ 1 ]
match [ " adm_in_channels " ] = None
if " class_embedding.linear_1.weight " in state_dict :
match [ " adm_in_channels " ] = state_dict [ " class_embedding.linear_1.weight " ] . shape [ 1 ]
elif " add_embedding.linear_1.weight " in state_dict :
match [ " adm_in_channels " ] = state_dict [ " add_embedding.linear_1.weight " ] . shape [ 1 ]
SDXL = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-10-13 14:35:21 -04:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
2023-11-16 23:12:55 -05:00
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 2 , 2 , 10 , 10 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 10 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 2 , 2 , 2 , 10 , 10 , 10 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-07-21 22:58:16 -04:00
SDXL_refiner = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-10-13 14:35:21 -04:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2560 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 384 ,
2023-11-16 23:12:55 -05:00
' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 4 , 4 , 4 , 4 , 0 , 0 ] , ' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 4 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' context_dim ' : 1280 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 4 , 4 , 4 , 4 , 4 , 4 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-07-21 22:58:16 -04:00
SD21 = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-11-16 23:12:55 -05:00
' adm_in_channels ' : None , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 , ' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] ,
' transformer_depth ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] , ' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 1 , ' use_linear_in_transformer ' : True ,
2023-11-24 20:35:29 -05:00
' context_dim ' : 1024 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-07-21 22:58:16 -04:00
SD21_uncliph = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-10-13 14:35:21 -04:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2048 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
2023-11-16 23:12:55 -05:00
' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] , ' transformer_depth ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] , ' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 1 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' context_dim ' : 1024 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-07-21 22:58:16 -04:00
SD21_unclipl = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-10-13 14:35:21 -04:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 1536 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
2023-11-16 23:12:55 -05:00
' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] , ' transformer_depth ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] , ' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 1 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' context_dim ' : 1024 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-07-21 22:58:16 -04:00
2023-11-16 23:12:55 -05:00
SD15 = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False , ' adm_in_channels ' : None ,
' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 , ' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] , ' transformer_depth ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] ,
' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 1 , ' use_linear_in_transformer ' : False , ' context_dim ' : 768 , ' num_heads ' : 8 ,
2023-11-24 20:35:29 -05:00
' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-07-21 22:58:16 -04:00
2023-08-16 12:45:13 -04:00
SDXL_mid_cnet = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-11-16 23:12:55 -05:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 0 , 0 , 1 , 1 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 1 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-08-16 12:22:46 -04:00
2023-08-16 12:45:13 -04:00
SDXL_small_cnet = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-11-16 23:12:55 -05:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 0 , 0 , 0 , 0 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 0 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' num_head_channels ' : 64 , ' context_dim ' : 1 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-08-16 12:45:13 -04:00
2023-09-01 15:18:25 -04:00
SDXL_diffusers_inpaint = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
2023-11-16 23:12:55 -05:00
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 9 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 2 , 2 , 10 , 10 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 10 ,
2023-11-24 20:35:29 -05:00
' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 2 , 2 , 2 , 10 , 10 , 10 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-11-16 23:12:55 -05:00
2024-03-31 01:25:16 -04:00
SDXL_diffusers_ip2p = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 8 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 2 , 2 , 10 , 10 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 10 ,
' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 0 , 2 , 2 , 2 , 10 , 10 , 10 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-11-16 23:12:55 -05:00
SSD_1B = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 2 , 2 , 4 , 4 ] , ' transformer_depth_output ' : [ 0 , 0 , 0 , 1 , 1 , 2 , 10 , 4 , 4 ] ,
2023-11-24 20:35:29 -05:00
' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : - 1 , ' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2023-08-16 12:45:13 -04:00
2023-12-12 19:09:53 -05:00
Segmind_Vega = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 2 , 2 , 2 ] , ' transformer_depth ' : [ 0 , 0 , 1 , 1 , 2 , 2 ] , ' transformer_depth_output ' : [ 0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 ] ,
' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : - 1 , ' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2024-02-28 11:55:06 -05:00
KOALA_700M = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 1 , 1 , 1 ] , ' transformer_depth ' : [ 0 , 2 , 5 ] , ' transformer_depth_output ' : [ 0 , 0 , 2 , 2 , 5 , 5 ] ,
' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : - 2 , ' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
KOALA_1B = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' num_classes ' : ' sequential ' , ' adm_in_channels ' : 2816 , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 ,
' num_res_blocks ' : [ 1 , 1 , 1 ] , ' transformer_depth ' : [ 0 , 2 , 6 ] , ' transformer_depth_output ' : [ 0 , 0 , 2 , 2 , 6 , 6 ] ,
' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : 6 , ' use_linear_in_transformer ' : True , ' context_dim ' : 2048 , ' num_head_channels ' : 64 ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2024-03-27 23:51:17 -04:00
SD09_XS = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' adm_in_channels ' : None , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 , ' num_res_blocks ' : [ 1 , 1 , 1 ] ,
' transformer_depth ' : [ 1 , 1 , 1 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : - 2 , ' use_linear_in_transformer ' : True ,
' context_dim ' : 1024 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False , ' disable_self_attentions ' : [ True , False , False ] }
2024-04-12 22:12:35 -04:00
SD_XS = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False ,
' adm_in_channels ' : None , ' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 , ' num_res_blocks ' : [ 1 , 1 , 1 ] ,
' transformer_depth ' : [ 0 , 1 , 1 ] , ' channel_mult ' : [ 1 , 2 , 4 ] , ' transformer_depth_middle ' : - 2 , ' use_linear_in_transformer ' : False ,
' context_dim ' : 768 , ' num_head_channels ' : 64 , ' transformer_depth_output ' : [ 0 , 0 , 1 , 1 , 1 , 1 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2024-12-27 18:02:21 -05:00
2024-08-23 15:57:08 +08:00
SD15_diffusers_inpaint = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False , ' adm_in_channels ' : None ,
' dtype ' : dtype , ' in_channels ' : 9 , ' model_channels ' : 320 , ' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] , ' transformer_depth ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] ,
' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 1 , ' use_linear_in_transformer ' : False , ' context_dim ' : 768 , ' num_heads ' : 8 ,
' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ] ,
2024-12-31 03:16:37 -05:00
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2024-04-12 22:12:35 -04:00
2025-03-21 11:04:15 -07:00
LotusD = { ' use_checkpoint ' : False , ' image_size ' : 32 , ' out_channels ' : 4 , ' use_spatial_transformer ' : True , ' legacy ' : False , ' adm_in_channels ' : 4 ,
' dtype ' : dtype , ' in_channels ' : 4 , ' model_channels ' : 320 , ' num_res_blocks ' : [ 2 , 2 , 2 , 2 ] , ' transformer_depth ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] ,
2026-02-27 02:59:05 +02:00
' channel_mult ' : [ 1 , 2 , 4 , 4 ] , ' transformer_depth_middle ' : 1 , ' use_linear_in_transformer ' : True , ' context_dim ' : 1024 , ' num_head_channels ' : 64 ,
2025-03-21 11:04:15 -07:00
' transformer_depth_output ' : [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 ] ,
' use_temporal_attention ' : False , ' use_temporal_resblock ' : False }
2024-04-12 22:12:35 -04:00
2025-03-21 11:04:15 -07:00
supported_models = [ LotusD , SDXL , SDXL_refiner , SD21 , SD15 , SD21_uncliph , SD21_unclipl , SDXL_mid_cnet , SDXL_small_cnet , SDXL_diffusers_inpaint , SSD_1B , Segmind_Vega , KOALA_700M , KOALA_1B , SD09_XS , SD_XS , SDXL_diffusers_ip2p , SD15_diffusers_inpaint ]
2023-07-21 22:58:16 -04:00
for unet_config in supported_models :
matches = True
for k in match :
if match [ k ] != unet_config [ k ] :
matches = False
break
if matches :
2023-10-27 14:15:45 -04:00
return convert_config ( unet_config )
2023-08-16 12:22:46 -04:00
return None
2024-02-16 10:55:08 -05:00
def model_config_from_diffusers_unet ( state_dict ) :
unet_config = unet_config_from_diffusers_unet ( state_dict )
2023-08-16 12:22:46 -04:00
if unet_config is not None :
return model_config_from_unet_config ( unet_config )
2023-07-21 22:58:16 -04:00
return None
2024-06-19 21:46:37 -04:00
def convert_diffusers_mmdit ( state_dict , output_prefix = " " ) :
2024-07-13 13:51:40 -04:00
out_sd = { }
2024-10-30 04:24:00 -04:00
if ' joint_transformer_blocks.0.attn.add_k_proj.weight ' in state_dict : #AuraFlow
num_joint = count_blocks ( state_dict , ' joint_transformer_blocks. {} . ' )
num_single = count_blocks ( state_dict , ' single_transformer_blocks. {} . ' )
sd_map = comfy . utils . auraflow_to_diffusers ( { " n_double_layers " : num_joint , " n_layers " : num_joint + num_single } , output_prefix = output_prefix )
2024-12-20 21:25:00 +01:00
elif ' adaln_single.emb.timestep_embedder.linear_1.bias ' in state_dict and ' pos_embed.proj.bias ' in state_dict : # PixArt
num_blocks = count_blocks ( state_dict , ' transformer_blocks. {} . ' )
sd_map = comfy . utils . pixart_to_diffusers ( { " depth " : num_blocks } , output_prefix = output_prefix )
2026-03-02 15:54:18 -08:00
elif ' noise_refiner.0.attention.norm_k.weight ' in state_dict :
n_layers = count_blocks ( state_dict , ' layers. {} . ' )
dim = state_dict [ ' noise_refiner.0.attention.to_k.weight ' ] . shape [ 0 ]
sd_map = comfy . utils . z_image_to_diffusers ( { " n_layers " : n_layers , " dim " : dim } , output_prefix = output_prefix )
for k in state_dict : # For zeta chroma
if k not in sd_map :
sd_map [ k ] = k
2024-10-30 13:11:34 -04:00
elif ' x_embedder.weight ' in state_dict : #Flux
2024-08-10 21:28:24 -04:00
depth = count_blocks ( state_dict , ' transformer_blocks. {} . ' )
depth_single_blocks = count_blocks ( state_dict , ' single_transformer_blocks. {} . ' )
hidden_size = state_dict [ " x_embedder.bias " ] . shape [ 0 ]
sd_map = comfy . utils . flux_to_diffusers ( { " depth " : depth , " depth_single_blocks " : depth_single_blocks , " hidden_size " : hidden_size } , output_prefix = output_prefix )
2025-08-04 19:53:25 -07:00
elif ' transformer_blocks.0.attn.add_q_proj.weight ' in state_dict and ' pos_embed.proj.weight ' in state_dict : #SD3
2024-07-13 13:51:40 -04:00
num_blocks = count_blocks ( state_dict , ' transformer_blocks. {} . ' )
2024-06-25 23:40:44 -04:00
depth = state_dict [ " pos_embed.proj.weight " ] . shape [ 0 ] / / 64
sd_map = comfy . utils . mmdit_to_diffusers ( { " depth " : depth , " num_blocks " : num_blocks } , output_prefix = output_prefix )
2024-07-13 13:51:40 -04:00
else :
return None
for k in sd_map :
weight = state_dict . get ( k , None )
if weight is not None :
t = sd_map [ k ]
if not isinstance ( t , str ) :
if len ( t ) > 2 :
fun = t [ 2 ]
else :
fun = lambda a : a
offset = t [ 1 ]
if offset is not None :
old_weight = out_sd . get ( t [ 0 ] , None )
if old_weight is None :
old_weight = torch . empty_like ( weight )
2024-08-10 21:28:24 -04:00
if old_weight . shape [ offset [ 0 ] ] < offset [ 1 ] + offset [ 2 ] :
exp = list ( weight . shape )
exp [ offset [ 0 ] ] = offset [ 1 ] + offset [ 2 ]
new = torch . empty ( exp , device = weight . device , dtype = weight . dtype )
new [ : old_weight . shape [ 0 ] ] = old_weight
old_weight = new
2024-07-13 13:51:40 -04:00
2026-03-09 20:50:10 -07:00
if old_weight is out_sd . get ( t [ 0 ] , None ) and comfy . memory_management . aimdo_enabled :
old_weight = old_weight . clone ( )
2024-07-13 13:51:40 -04:00
w = old_weight . narrow ( offset [ 0 ] , offset [ 1 ] , offset [ 2 ] )
2024-06-19 21:46:37 -04:00
else :
2026-03-09 20:50:10 -07:00
if comfy . memory_management . aimdo_enabled :
weight = weight . clone ( )
2024-07-13 13:51:40 -04:00
old_weight = weight
w = weight
w [ : ] = fun ( weight )
t = t [ 0 ]
out_sd [ t ] = old_weight
else :
out_sd [ t ] = weight
state_dict . pop ( k )
2024-06-19 21:46:37 -04:00
return out_sd