Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cb459573c8 | |||
| 35183543e0 | |||
| a246cc02b2 | |||
| a50c32d63f | |||
| 6125b80979 | |||
| c8fcbd66ee | |||
| 26dd7eb421 | |||
| e77b34dfea | |||
| ef73070ea4 | |||
| d30c609f5a | |||
| 5087f1d497 | |||
| a31681564d | |||
| 855849c658 | |||
| fe2511468d | |||
| 3be0175166 | |||
| b8315e66cb | |||
| ab1050bec3 | |||
| fb23935c11 | |||
| 85fc35e8fa | |||
| 223364743c | |||
| affe881354 | |||
| f5030e26fd |
@@ -183,7 +183,7 @@ class AceStepAttention(nn.Module):
|
||||
else:
|
||||
attn_bias = window_bias
|
||||
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
@@ -1035,8 +1035,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
|
||||
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
|
||||
else:
|
||||
assert False
|
||||
# TODO ?
|
||||
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
|
||||
|
||||
lm_hints = self.detokenizer(lm_hints_5Hz)
|
||||
|
||||
|
||||
@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep15):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
+12
-7
@@ -1548,6 +1548,7 @@ class ACEStep15(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
noise = kwargs["noise"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
@@ -1571,15 +1572,19 @@ class ACEStep15(BaseModel):
|
||||
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
|
||||
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
|
||||
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1]
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
|
||||
@@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
@@ -1400,7 +1400,7 @@ class ModelPatcher:
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
return self.model.state_dict_for_saving(unet_state_dict)
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
|
||||
@@ -54,6 +54,8 @@ try:
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
|
||||
+12
-5
@@ -554,6 +554,8 @@ class VAE:
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
config = {}
|
||||
param_key = None
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
if "decoder.layers.2.layers.1.weight_v" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.weight_v"
|
||||
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
|
||||
@@ -562,6 +564,8 @@ class VAE:
|
||||
if sd[param_key].shape[-1] == 12:
|
||||
config["strides"] = [2, 4, 4, 6, 10]
|
||||
self.audio_sample_rate = 48000
|
||||
self.upscale_ratio = 1920
|
||||
self.downscale_ratio = 1920
|
||||
|
||||
self.first_stage_model = AudioOobleckVAE(**config)
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
@@ -569,8 +573,6 @@ class VAE:
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
@@ -870,7 +872,7 @@ class VAE:
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
@@ -974,7 +976,7 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@@ -1442,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
elif clip_type == CLIPType.ACE:
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data))
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
if TEModel.QWEN3_4B in te_models:
|
||||
model_type = "qwen3_4b"
|
||||
else:
|
||||
model_type = "qwen3_2b"
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
|
||||
@@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE):
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect))
|
||||
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
if "dtype_llama" in detect_2b:
|
||||
detect = detect_2b
|
||||
detect["lm_model"] = "qwen3_2b"
|
||||
elif "dtype_llama" in detect_4b:
|
||||
detect = detect_4b
|
||||
detect["lm_model"] = "qwen3_4b"
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
+110
-30
@@ -3,6 +3,8 @@ import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
import yaml
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def sample_manual_loop_no_classes(
|
||||
@@ -18,6 +20,7 @@ def sample_manual_loop_no_classes(
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
||||
audio_end_id: int = 215669,
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
device = model.execution_device
|
||||
@@ -42,6 +45,8 @@ def sample_manual_loop_no_classes(
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in range(max_new_tokens):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
@@ -54,8 +59,10 @@ def sample_manual_loop_no_classes(
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
# Only generate audio tokens
|
||||
cfg_logits[:, :audio_start_id] = float('-inf')
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
@@ -63,7 +70,7 @@ def sample_manual_loop_no_classes(
|
||||
if top_k is not None and top_k > 0:
|
||||
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = float('-inf')
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
@@ -72,7 +79,7 @@ def sample_manual_loop_no_classes(
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
cfg_logits[indices_to_remove] = float('-inf')
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if temperature > 0:
|
||||
cfg_logits = cfg_logits / temperature
|
||||
@@ -90,13 +97,12 @@ def sample_manual_loop_no_classes(
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
|
||||
cfg_scale = 2.0
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
positive = positive[0]
|
||||
@@ -113,34 +119,80 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
|
||||
|
||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||
user_metas = {
|
||||
k: kwargs.pop(k)
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
|
||||
if k in kwargs
|
||||
}
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
|
||||
user_metas = {
|
||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||
for k, v in user_metas.items()
|
||||
if v not in {"unspecified", None}
|
||||
}
|
||||
if len(user_metas):
|
||||
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
|
||||
else:
|
||||
meta_yaml = ""
|
||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||
|
||||
def _metas_to_cap(self, **kwargs) -> str:
|
||||
use_keys = ("bpm", "duration", "keyscale", "timesignature")
|
||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||
duration = user_metas["duration"]
|
||||
if duration == "N/A":
|
||||
user_metas["duration"] = "30 seconds"
|
||||
elif isinstance(duration, (str, int, float)):
|
||||
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
|
||||
else:
|
||||
raise TypeError("Unexpected type for duration key, must be str, int or float")
|
||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
bpm = kwargs.get("bpm", 120)
|
||||
duration = kwargs.get("duration", 120)
|
||||
keyscale = kwargs.get("keyscale", "C major")
|
||||
timesignature = kwargs.get("timesignature", 2)
|
||||
language = kwargs.get("language", "en")
|
||||
language = kwargs.get("language")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
generate_audio_codes = kwargs.get("generate_audio_codes", True)
|
||||
cfg_scale = kwargs.get("cfg_scale", 2.0)
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
|
||||
|
||||
duration = math.ceil(duration)
|
||||
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
|
||||
kwargs["duration"] = duration
|
||||
|
||||
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
|
||||
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
||||
meta_cap = self._metas_to_cap(**kwargs)
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
|
||||
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
@@ -157,14 +209,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class ACE15TEModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
|
||||
super().__init__()
|
||||
if dtype_llama is None:
|
||||
dtype_llama = dtype
|
||||
|
||||
model = None
|
||||
self.constant = 0.4375
|
||||
if lm_model == "qwen3_4b":
|
||||
model = Qwen3_4B_ACE15
|
||||
self.constant = 0.5625
|
||||
elif lm_model == "qwen3_2b":
|
||||
model = Qwen3_2B_ACE15
|
||||
|
||||
self.lm_model = lm_model
|
||||
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
if model is not None:
|
||||
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
|
||||
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
@@ -176,18 +248,26 @@ class ACE15TEModel(torch.nn.Module):
|
||||
self.qwen3_06b.set_clip_options({"layer": [0]})
|
||||
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
|
||||
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
|
||||
|
||||
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.qwen3_06b.set_clip_options(options)
|
||||
self.qwen3_2b.set_clip_options(options)
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.qwen3_06b.reset_clip_options()
|
||||
self.qwen3_2b.reset_clip_options()
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
@@ -195,11 +275,11 @@ class ACE15TEModel(torch.nn.Module):
|
||||
if shape[0] == 1024:
|
||||
return self.qwen3_06b.load_sd(sd)
|
||||
else:
|
||||
return self.qwen3_2b.load_sd(sd)
|
||||
return getattr(self, self.lm_model).load_sd(sd)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
constant = 0.4375
|
||||
constant = self.constant
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant *= 0.5
|
||||
|
||||
@@ -208,11 +288,11 @@ class ACE15TEModel(torch.nn.Module):
|
||||
num_tokens += lm_metadata['min_tokens']
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||
class ACE15TEModel_(ACE15TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options)
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
|
||||
return ACE15TEModel_
|
||||
|
||||
@@ -6,6 +6,7 @@ import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.clip_model
|
||||
|
||||
@@ -149,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
vocab_size: int = 151936
|
||||
@@ -627,10 +651,10 @@ class Llama2_(nn.Module):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
@@ -738,6 +762,21 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
@@ -766,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06BConfig(**config_dict)
|
||||
@@ -775,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06B_ACE15_Config(**config_dict)
|
||||
@@ -784,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
|
||||
@@ -793,10 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def logits(self, x):
|
||||
return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None)
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4BConfig(**config_dict)
|
||||
@@ -805,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_8BConfig(**config_dict)
|
||||
|
||||
+9
-5
@@ -82,14 +82,12 @@ _TYPES = {
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
|
||||
mv = mv[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
@@ -97,7 +95,13 @@ def load_safetensors(ckpt):
|
||||
continue
|
||||
|
||||
start, end = info["data_offsets"]
|
||||
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
|
||||
if start == end:
|
||||
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@@ -105,6 +105,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
|
||||
+51
-1
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO):
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
"""General 3D file type - accepts any supported 3D format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_GLB")
|
||||
class File3DGLB(ComfyTypeIO):
|
||||
"""GLB format 3D file - binary glTF, best for web and cross-platform."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_GLTF")
|
||||
class File3DGLTF(ComfyTypeIO):
|
||||
"""GLTF format 3D file - JSON-based glTF with external resources."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_FBX")
|
||||
class File3DFBX(ComfyTypeIO):
|
||||
"""FBX format 3D file - best for game engines and animation."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_OBJ")
|
||||
class File3DOBJ(ComfyTypeIO):
|
||||
"""OBJ format 3D file - simple geometry format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_STL")
|
||||
class File3DSTL(ComfyTypeIO):
|
||||
"""STL format 3D file - best for 3D printing."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_USDZ")
|
||||
class File3DUSDZ(ComfyTypeIO):
|
||||
"""USDZ format 3D file - Apple AR format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@@ -2037,6 +2080,13 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
"File3DFBX",
|
||||
"File3DOBJ",
|
||||
"File3DSTL",
|
||||
"File3DUSDZ",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
@@ -9,5 +9,6 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"File3D",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
import shutil
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -10,3 +15,75 @@ class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class File3D:
|
||||
"""Class representing a 3D file from a file path or binary stream.
|
||||
|
||||
Supports both disk-backed (file path) and memory-backed (BytesIO) storage.
|
||||
"""
|
||||
|
||||
def __init__(self, source: str | IO[bytes], file_format: str = ""):
|
||||
self._source = source
|
||||
self._format = file_format or self._infer_format()
|
||||
|
||||
def _infer_format(self) -> str:
|
||||
if isinstance(self._source, str):
|
||||
return Path(self._source).suffix.lstrip(".").lower()
|
||||
return ""
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
return self._format
|
||||
|
||||
@format.setter
|
||||
def format(self, value: str) -> None:
|
||||
self._format = value.lstrip(".").lower() if value else ""
|
||||
|
||||
@property
|
||||
def is_disk_backed(self) -> bool:
|
||||
return isinstance(self._source, str)
|
||||
|
||||
def get_source(self) -> str | IO[bytes]:
|
||||
if isinstance(self._source, str):
|
||||
return self._source
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
return self._source
|
||||
|
||||
def get_data(self) -> BytesIO:
|
||||
if isinstance(self._source, str):
|
||||
with open(self._source, "rb") as f:
|
||||
result = BytesIO(f.read())
|
||||
return result
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
if isinstance(self._source, BytesIO):
|
||||
return self._source
|
||||
return BytesIO(self._source.read())
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
dest = Path(path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if isinstance(self._source, str):
|
||||
if Path(self._source).resolve() != dest.resolve():
|
||||
shutil.copy2(self._source, dest)
|
||||
else:
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
with open(dest, "wb") as f:
|
||||
f.write(self._source.read())
|
||||
return str(dest)
|
||||
|
||||
def get_bytes(self) -> bytes:
|
||||
if isinstance(self._source, str):
|
||||
return Path(self._source).read_bytes()
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
return self._source.read()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self._source, str):
|
||||
return f"File3D(source={self._source!r}, format={self._format!r})"
|
||||
return f"File3D(<stream>, format={self._format!r})"
|
||||
|
||||
@@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel):
|
||||
|
||||
class MeshyModelsUrls(BaseModel):
|
||||
glb: str = Field("")
|
||||
fbx: str = Field("")
|
||||
usdz: str = Field("")
|
||||
obj: str = Field("")
|
||||
|
||||
|
||||
class MeshyRiggedModelsUrls(BaseModel):
|
||||
rigged_character_glb_url: str = Field("")
|
||||
rigged_character_fbx_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyAnimatedModelsUrls(BaseModel):
|
||||
animation_glb_url: str = Field("")
|
||||
animation_fbx_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyResultTextureUrls(BaseModel):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@@ -14,7 +12,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
@@ -22,14 +20,13 @@ from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
|
||||
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
|
||||
for i in response_objs:
|
||||
if i.Type.lower() == "glb":
|
||||
if i.Type.lower() == file_type.lower():
|
||||
return i
|
||||
raise ValueError("No GLB file found in response. Please report this to the developers.")
|
||||
return None
|
||||
|
||||
|
||||
class TencentTextToModelNode(IO.ComfyNode):
|
||||
@@ -74,7 +71,9 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -124,19 +123,20 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
task_id = response.JobId
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
data=To3DProTaskQueryRequest(JobId=task_id),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentImageToModelNode(IO.ComfyNode):
|
||||
@@ -184,7 +184,9 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -269,19 +271,20 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
task_id = response.JobId
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
data=To3DProTaskQueryRequest(JobId=task_id),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentHunyuan3DExtension(ComfyExtension):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
class MeshyTextToModelNode(IO.ComfyNode):
|
||||
@@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyRefineNode(IO.ComfyNode):
|
||||
@@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
ai_model=model,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyImageToModelNode(IO.ComfyNode):
|
||||
@@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
@@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyRigModelNode(IO.ComfyNode):
|
||||
@@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
texture_image_url=texture_image_url,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"),
|
||||
response_model=MeshyRiggedResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(
|
||||
result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id),
|
||||
)
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
@@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
action_id=action_id,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"),
|
||||
response_model=MeshyAnimationResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyTextureNode(IO.ComfyNode):
|
||||
@@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
image_style_url=image_style_url,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyExtension(ComfyExtension):
|
||||
|
||||
@@ -10,7 +10,6 @@ import folder_paths as comfy_paths
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
@@ -28,8 +27,9 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
@@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
|
||||
if not response.jobs:
|
||||
return None
|
||||
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||
@@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D
|
||||
)
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid: str):
|
||||
async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]:
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
file_3d = None
|
||||
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
if i.name.lower().endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
return model_file_path
|
||||
file_3d = await download_url_to_file_3d(i.url, "glb")
|
||||
# Save to disk for backward compatibility
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_3d.get_bytes())
|
||||
else:
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
|
||||
return model_file_path, file_3d
|
||||
|
||||
|
||||
class Rodin3D_Regular(IO.ComfyNode):
|
||||
@@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Detail(IO.ComfyNode):
|
||||
@@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(IO.ComfyNode):
|
||||
@@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(IO.ComfyNode):
|
||||
@@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Gen2(IO.ComfyNode):
|
||||
@@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input("TAPose", default=False),
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_as_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
@@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
async def poll_until_finished(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
response: TripoTaskResponse,
|
||||
average_duration: Optional[int] = None,
|
||||
average_duration: int | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
if response.code != 0:
|
||||
@@ -69,12 +64,8 @@ async def poll_until_finished(
|
||||
)
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = await download_url_as_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
f.write(bytesio.getvalue())
|
||||
return IO.NodeOutput(model_file, task_id)
|
||||
file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id)
|
||||
return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb)
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
|
||||
@@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt: str | None = None,
|
||||
model_version=None,
|
||||
style: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
image_seed: Optional[int] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
style: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
image_seed: int | None = None,
|
||||
model_seed: int | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
@@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model_version: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
image: Input.Image,
|
||||
model_version: str | None = None,
|
||||
style: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
model_seed: int | None = None,
|
||||
orientation=None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
@@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image_left: Optional[torch.Tensor] = None,
|
||||
image_back: Optional[torch.Tensor] = None,
|
||||
image_right: Optional[torch.Tensor] = None,
|
||||
model_version: Optional[str] = None,
|
||||
orientation: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
image: Input.Image,
|
||||
image_left: Input.Image | None = None,
|
||||
image_back: Input.Image | None = None,
|
||||
image_right: Input.Image | None = None,
|
||||
model_version: str | None = None,
|
||||
orientation: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
model_seed: int | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
@@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model_task_id,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode):
|
||||
category="api node/3d/Tripo",
|
||||
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
|
||||
@@ -28,6 +28,7 @@ from .conversions import (
|
||||
from .download_helpers import (
|
||||
download_url_as_bytesio,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
@@ -69,6 +70,7 @@ __all__ = [
|
||||
# Download helpers
|
||||
"download_url_as_bytesio",
|
||||
"download_url_to_bytesio",
|
||||
"download_url_to_file_3d",
|
||||
"download_url_to_image_tensor",
|
||||
"download_url_to_video_output",
|
||||
# Conversions
|
||||
|
||||
@@ -11,7 +11,8 @@ import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api.latest import InputImpl
|
||||
from comfy_api.latest import InputImpl, Types
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
@@ -261,3 +262,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
async def download_url_to_file_3d(
|
||||
url: str,
|
||||
file_format: str,
|
||||
*,
|
||||
task_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> Types.File3D:
|
||||
"""Downloads a 3D model file from a URL into memory as BytesIO.
|
||||
|
||||
If task_id is provided, also writes the file to disk in the output directory
|
||||
for backward compatibility with the old save-to-disk behavior.
|
||||
"""
|
||||
file_format = file_format.lstrip(".").lower()
|
||||
data = BytesIO()
|
||||
await download_url_to_bytesio(
|
||||
url,
|
||||
data,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
cls=cls,
|
||||
)
|
||||
|
||||
if task_id is not None:
|
||||
# This is only for backward compatability with current behavior when every 3D node is output node
|
||||
# All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results
|
||||
output_dir = Path(get_output_directory())
|
||||
output_path = output_dir / f"{task_id}.{file_format}"
|
||||
output_path.write_bytes(data.getvalue())
|
||||
data.seek(0)
|
||||
|
||||
return Types.File3D(source=data, file_format=file_format)
|
||||
|
||||
@@ -44,13 +44,18 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
@@ -100,14 +105,15 @@ class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceTimbreAudio(io.ComfyNode):
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for timbre (for ace step 1.5)",
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
@@ -131,7 +137,7 @@ class AceExtension(ComfyExtension):
|
||||
EmptyAceStepLatentAudio,
|
||||
TextEncodeAceStepAudio15,
|
||||
EmptyAceStep15LatentAudio,
|
||||
ReferenceTimbreAudio,
|
||||
ReferenceAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
|
||||
@@ -94,6 +94,19 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
if tile is not None:
|
||||
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||
else:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
class VAEDecodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -111,16 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudioTiled",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio (Tiled)",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
|
||||
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
|
||||
|
||||
|
||||
class SaveAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -675,6 +705,7 @@ class AudioExtension(ComfyExtension):
|
||||
EmptyLatentAudio,
|
||||
VAEEncodeAudio,
|
||||
VAEDecodeAudio,
|
||||
VAEDecodeAudioTiled,
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
|
||||
@@ -618,18 +618,31 @@ class SaveGLB(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.MultiType.Input(
|
||||
IO.Mesh.Input("mesh"),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
|
||||
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
@@ -641,15 +654,27 @@ class SaveGLB(IO.ComfyNode):
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
if isinstance(mesh, Types.File3D):
|
||||
# Handle File3D input - save BytesIO data to output folder
|
||||
ext = mesh.format or "glb"
|
||||
f = f"{filename}_{counter:05}_.{ext}"
|
||||
mesh.save_to(os.path.join(full_output_folder, f))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
else:
|
||||
# Handle Mesh input - save vertices and faces as GLB
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
|
||||
from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -44,6 +45,7 @@ class Load3D(IO.ComfyNode):
|
||||
IO.Image.Output(display_name="normal"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Video.Output(display_name="recording_video"),
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -65,7 +67,8 @@ class Load3D(IO.ComfyNode):
|
||||
|
||||
video = InputImpl.VideoFromFile(recording_video_path)
|
||||
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
@@ -81,7 +84,19 @@ class Preview3D(IO.ComfyNode):
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
IO.MultiType.Input(
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DFBX,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="3D model file or path string",
|
||||
),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True),
|
||||
IO.Image.Input("bg_image", optional=True),
|
||||
],
|
||||
@@ -89,10 +104,15 @@ class Preview3D(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
|
||||
def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
|
||||
if isinstance(model_file, Types.File3D):
|
||||
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
|
||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
||||
else:
|
||||
filename = model_file
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.12.0"
|
||||
__version__ = "0.12.3"
|
||||
|
||||
@@ -192,7 +192,10 @@ import comfy_aimdo.control
|
||||
import comfy_aimdo.torch
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
@@ -208,7 +211,7 @@ if enables_dynamic_vram():
|
||||
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
else:
|
||||
logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.12.0"
|
||||
version = "0.12.3"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user