Compare commits
97 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72212fef66 | |||
| df34f1549a | |||
| 9b0553809c | |||
| 8d7c930246 | |||
| de44b95db6 | |||
| 543888d3d8 | |||
| 70fc0425b3 | |||
| 85e34643f8 | |||
| 5c33872e2f | |||
| 206595f854 | |||
| b288fb0db8 | |||
| f73b176abd | |||
| 103a12cb66 | |||
| 97652d26b8 | |||
| bd1d9bcd5f | |||
| fb763d4333 | |||
| bcbd7884e3 | |||
| 27a0fcccc3 | |||
| ea6cdd2631 | |||
| 2ee7879a0b | |||
| 3493b9cb1f | |||
| c9ebe70072 | |||
| 261421e218 | |||
| a9f1bb10a5 | |||
| b0338e930b | |||
| b71f9bcb71 | |||
| 72855db715 | |||
| f48d05a2d1 | |||
| 4368d8f87f | |||
| 22da0a83e9 | |||
| 50333f1715 | |||
| 26d5b86da8 | |||
| 4f5812b937 | |||
| 1bcb469089 | |||
| 464ba1d614 | |||
| e3018c2a5a | |||
| 3412d53b1d | |||
| e2d1e5dad9 | |||
| 27e067ce50 | |||
| 9b15155972 | |||
| 32a627bf1f | |||
| fe442fac2e | |||
| d2c502e629 | |||
| fea9ea8268 | |||
| f949094b3c | |||
| 4449e14769 | |||
| 885015eecf | |||
| a86aaa4301 | |||
| 2efb2cbc38 | |||
| 15aa9222c4 | |||
| c7bb3e2bce | |||
| e80a14ad50 | |||
| d28b39d93d | |||
| 1c184c29eb | |||
| edde0b5043 | |||
| 0063610177 | |||
| ce0052c087 | |||
| 0eb821a7b6 | |||
| 4aa79dbf2c | |||
| 38f697d953 | |||
| 3aad339b63 | |||
| 491755325c | |||
| 496888fd68 | |||
| b5ac6ed7ce | |||
| b20ba1f27c | |||
| 31a37686d0 | |||
| 88aee596a3 | |||
| 6a193ac557 | |||
| 47f4db3e84 | |||
| 5352abc6d3 | |||
| 39aa06bd5d | |||
| 914c2a2973 | |||
| e633a47ad1 | |||
| f6b93d41a0 | |||
| 95ac7794b7 | |||
| 71ed4a399e | |||
| 3e316c6338 | |||
| 8be0d22ab7 | |||
| 59eddda900 | |||
| 41048c69b4 | |||
| fc247150fe | |||
| fe31ad0276 | |||
| ca4e96a8ae | |||
| 050c67323c | |||
| 497d41fb50 | |||
| ff57793659 | |||
| f7bd5e58dd | |||
| 7ed73d12d1 | |||
| eb39019daa | |||
| bab08f40d1 | |||
| bc49106837 | |||
| 1b2de2642d | |||
| 9fa1036f60 | |||
| 0737b7e0d2 | |||
| 0963493a9c | |||
| e73a9dbe30 | |||
| fe01885acf |
@@ -0,0 +1,30 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
@@ -65,18 +65,17 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- Audio Models
|
||||
@@ -191,7 +190,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
|
||||
+10
-3
@@ -363,10 +363,17 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import logging
|
||||
import torchaudio
|
||||
|
||||
|
||||
class AudioEncoderModel():
|
||||
def __init__(self, config):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_audio(self, audio, sample_rate):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||
out, all_layers = self.model(audio.to(self.load_device))
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
return outputs
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
audio_encoder = AudioEncoderModel(None)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
|
||||
return audio_encoder
|
||||
@@ -0,0 +1,207 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class LayerNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class FeatureProjection(nn.Module):
|
||||
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer_norm(x)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
)
|
||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = self.conv(x)[:, :, :-1]
|
||||
x = self.activation(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
assert (mask is None) # TODO?
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.intermediate_dense(x)
|
||||
x = torch.nn.functional.gelu(x)
|
||||
x = self.output_dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
|
||||
x = x + self.feed_forward(self.final_layer_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
"""Complete Wav2Vec 2.0 model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=1024,
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
|
||||
return x, all_x
|
||||
+2
-1
@@ -143,8 +143,9 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
+11
-1
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
|
||||
+20
-4
@@ -50,7 +50,13 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -68,12 +74,18 @@ class ClipVisionModel():
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
outputs["all_hidden_states"] = all_hs
|
||||
else:
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
|
||||
# Dinov2
|
||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
+26
-3
@@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = None
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
compression_ratio *= self.vae.spacial_compression_encode()
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@@ -252,7 +253,10 @@ class ControlNet(ControlBase):
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
||||
if c.ndim < self.cond_hint.ndim:
|
||||
c = c.unsqueeze(2)
|
||||
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
@@ -582,6 +586,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
||||
|
||||
extra_condition_channels = 0
|
||||
concat_mask = False
|
||||
if control_latent_channels == 68: #inpaint controlnet
|
||||
extra_condition_channels = control_latent_channels - 64
|
||||
concat_mask = True
|
||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -655,8 +675,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
|
||||
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
class Dinov2MLP(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
mlp_ratio = 4
|
||||
hidden_features = int(hidden_size * mlp_ratio)
|
||||
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||
hidden_state = self.fc2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
else:
|
||||
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
for i, layer in enumerate(self.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"hidden_size": 1024,
|
||||
"use_mask_token": true,
|
||||
"patch_size": 14,
|
||||
"image_size": 518,
|
||||
"num_channels": 3,
|
||||
"num_attention_heads": 16,
|
||||
"initializer_range": 0.02,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_hidden_layers": 24,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": false,
|
||||
"layerscale_value": 1.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
return sigmas
|
||||
|
||||
|
||||
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||
return torch.expm1(h)
|
||||
|
||||
|
||||
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||
return (torch.expm1(h) - h) / h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@@ -853,6 +863,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@@ -925,6 +940,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -1535,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@@ -1549,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r < 1
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@@ -1609,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r_1 * h
|
||||
lambda_s_2 = lambda_s + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r_1 < r_2 < 1
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
# Step 2
|
||||
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||
if inject_noise:
|
||||
segment_factor = (r_1 - r_2) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
# Step 3
|
||||
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||
b1 = ei_h_phi_1(-h_eta) - b3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||
if inject_noise:
|
||||
segment_factor = (r_2 - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -533,11 +533,21 @@ class Wan22(Wan21):
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
scale_factor = 0.75289
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2_1(LatentFormat):
|
||||
scale_factor = 1.0039506158752403
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
+23
-1
@@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
@@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||
controlnet_scale, lyrics_strength, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
|
||||
@@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module):
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
@@ -253,6 +254,13 @@ class Chroma(nn.Module):
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask,
|
||||
fps,
|
||||
image_size,
|
||||
padding_mask,
|
||||
scalar_feature,
|
||||
data_type,
|
||||
latent_condition,
|
||||
latent_condition_sigma,
|
||||
condition_video_augment_sigma,
|
||||
**kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -11,6 +11,7 @@ import math
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
@@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
|
||||
+28
-5
@@ -6,6 +6,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -105,6 +106,7 @@ class Flux(nn.Module):
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -116,9 +118,17 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@@ -157,7 +167,7 @@ class Flux(nn.Module):
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
img[:, :add.shape[1]] += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -188,7 +198,7 @@ class Flux(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
@@ -214,6 +224,13 @@ class Flux(nn.Module):
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
@@ -225,12 +242,18 @@ class Flux(nn.Module):
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
if ref_latents_method == "index":
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
index = 0
|
||||
h_offset = h_len * patch_size + h
|
||||
w_offset = w_len * patch_size + w
|
||||
h += ref.shape[-2]
|
||||
w += ref.shape[-1]
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
|
||||
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
|
||||
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
|
||||
+485
-84
@@ -4,81 +4,458 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
return xyz, grid_size, length
|
||||
batch_size = int(batch.max()) + 1
|
||||
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||
|
||||
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||
|
||||
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||
|
||||
#return fps_sampling(src, ptr_vec, ratio)
|
||||
sampled_indicies = []
|
||||
|
||||
for b in range(batch_size):
|
||||
# start and the end of each batch
|
||||
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||
# points from the point cloud
|
||||
points = src[start:end]
|
||||
|
||||
num_points = points.size(0)
|
||||
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||
|
||||
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||
|
||||
# select a random start point
|
||||
if start_random:
|
||||
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||
else:
|
||||
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||
|
||||
for i in range(num_samples):
|
||||
selected[i] = farthest
|
||||
centroid = points[farthest].squeeze(0)
|
||||
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||
distances = torch.minimum(distances, dist)
|
||||
farthest = torch.argmax(distances)
|
||||
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim = 0)
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
downsample_ratio: float,
|
||||
pc_size: int,
|
||||
pc_sharpedge_size: int,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
fourier_embedder,
|
||||
normal_pe: bool = False,
|
||||
qkv_bias: bool = False,
|
||||
use_ln_post: bool = True,
|
||||
qk_norm: bool = True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.normal_pe = normal_pe
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.num_latents = num_latents
|
||||
self.point_feats = point_feats
|
||||
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm,
|
||||
layers = layers
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
"""
|
||||
Subsample points randomly from the point cloud (input_pc)
|
||||
Further sample the subsampled points to get query_pc
|
||||
take the fourier embeddings for both input and query pc
|
||||
|
||||
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||
More computationally efficient.
|
||||
|
||||
Features are additional information for each point in the cloud
|
||||
"""
|
||||
|
||||
B, _, D = point_cloud.shape
|
||||
|
||||
num_latents = int(self.num_latents)
|
||||
|
||||
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||
num_sharpedge_query = num_latents - num_random_query
|
||||
|
||||
# Split random and sharpedge surface points
|
||||
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
|
||||
# assert statements
|
||||
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||
|
||||
else:
|
||||
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
# concat the random and sharpedges
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
if self.point_feats > 0:
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||
|
||||
input_random_surface_features, query_random_features = \
|
||||
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||
|
||||
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||
else:
|
||||
|
||||
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||
|
||||
if self.normal_pe:
|
||||
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||
# replace the first 3 dims with the new PE ones
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||
|
||||
# concat at the channels dim
|
||||
query = torch.cat([query, query_features], dim = -1)
|
||||
data = torch.cat([data, input_features], dim = -1)
|
||||
|
||||
# don't return pc_info to avoid unnecessary memory usuage
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||
|
||||
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||
|
||||
# apply projections
|
||||
query = self.input_proj(query)
|
||||
data = self.input_proj(data)
|
||||
|
||||
# apply cross attention between query and data
|
||||
latents = self.cross_attn(query, data)
|
||||
|
||||
if self.self_attn is not None:
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
num_query: number of points to keep after FPS
|
||||
input_pc_size: number of points to select before FPS
|
||||
"""
|
||||
|
||||
B, _, D = pc.shape
|
||||
query_ratio = num_query / input_pc_size
|
||||
|
||||
# random subsampling of points inside the point cloud
|
||||
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||
input_pc = pc[:, idx_pc, :]
|
||||
|
||||
# flatten to allow applying fps across the whole batch
|
||||
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||
|
||||
# construct a batch_down tensor to tell fps
|
||||
# which points belong to which batch
|
||||
N_down = int(flattent_input_pc.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||
|
||||
return query_pc, input_pc, idx_pc, idx_query
|
||||
|
||||
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||
|
||||
B = batch_size
|
||||
|
||||
input_surface_features = features[:, idx_pc, :]
|
||||
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||
flattent_input_features.shape[-1])
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
def normalize_mesh(mesh, scale = 0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
bbox = mesh.bounds
|
||||
center = (bbox[1] + bbox[0]) / 2
|
||||
|
||||
max_extent = (bbox[1] - bbox[0]).max()
|
||||
mesh.apply_translation(-center)
|
||||
mesh.apply_scale((2 * scale) / max_extent)
|
||||
|
||||
return mesh
|
||||
|
||||
def sample_pointcloud(mesh, num = 200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
points, face_idx = mesh.sample(num, return_index = True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
V, F = mesh.vertices, mesh.faces
|
||||
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||
|
||||
sharp_mask = np.ones(V.shape[0])
|
||||
for i in range(3):
|
||||
indices = F[:, i]
|
||||
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||
|
||||
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||
|
||||
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
""" Sample points preferentially from sharp edges in the mesh. """
|
||||
|
||||
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||
V, VN = mesh.vertices, mesh.vertex_normals
|
||||
|
||||
va, vb = V[edge_a], V[edge_b]
|
||||
na, nb = VN[edge_a], VN[edge_b]
|
||||
|
||||
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||
weights = edge_lengths / edge_lengths.sum()
|
||||
|
||||
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||
t = np.random.rand(num, 1)
|
||||
|
||||
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
except Exception:
|
||||
mesh_full = trimesh.util.concatenate(mesh)
|
||||
|
||||
mesh_full = normalize_mesh(mesh_full)
|
||||
|
||||
faces = mesh_full.faces
|
||||
vertices = mesh_full.vertices
|
||||
origin_face_count = faces.shape[0]
|
||||
|
||||
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||
|
||||
area_surface = mesh_surface.area
|
||||
area_fill = mesh_fill.area
|
||||
total_area = area_surface + area_fill
|
||||
|
||||
sample_num = 499712 // 2
|
||||
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||
|
||||
num_fill = int(sample_num * fill_ratio)
|
||||
num_surface = sample_num - num_fill
|
||||
|
||||
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||
|
||||
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||
|
||||
def assemble_tensor(points, normals, label=None):
|
||||
|
||||
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||
|
||||
if label is not None:
|
||||
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||
data = torch.cat([data, label_tensor], dim=1)
|
||||
|
||||
return data
|
||||
|
||||
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||
label = 0 if sharpedge_flag else None)
|
||||
|
||||
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||
label = 1 if sharpedge_flag else None)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||
|
||||
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||
|
||||
return full
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.total_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_input, device = "cuda"):
|
||||
mesh = self._load_mesh(mesh_input)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||
|
||||
@staticmethod
|
||||
def _load_mesh(mesh_input):
|
||||
import trimesh
|
||||
|
||||
if isinstance(mesh_input, str):
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||
else:
|
||||
mesh = mesh_input
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
combined = None
|
||||
for obj in mesh.geometry.values():
|
||||
combined = obj if combined is None else combined + obj
|
||||
return combined
|
||||
|
||||
return mesh
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
|
||||
class VanillaVolumeDecoder():
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@@ -232,38 +607,41 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
|
||||
return out
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
if not self.enable_ln_post:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
self,
|
||||
*,
|
||||
num_latents: int = 4096,
|
||||
embed_dim: int = 64,
|
||||
width: int = 1024,
|
||||
heads: int = 16,
|
||||
num_decoder_layers: int = 16,
|
||||
num_encoder_layers: int = 8,
|
||||
pc_size: int = 81920,
|
||||
pc_sharpedge_size: int = 0,
|
||||
point_feats: int = 4,
|
||||
downsample_ratio: int = 20,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
drop_path_rate: float = 0.0,
|
||||
include_pi: bool = False,
|
||||
scale_factor: float = 1.0039506158752403,
|
||||
label_type: str = "binary",
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||
num_latents = num_latents,
|
||||
downsample_ratio = downsample_ratio,
|
||||
heads = heads,
|
||||
pc_size = pc_size,
|
||||
width = width,
|
||||
point_feats = point_feats,
|
||||
fourier_embedder = self.fourier_embedder,
|
||||
pc_sharpedge_size = pc_sharpedge_size)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||
|
||||
latents = posterior.sample()
|
||||
|
||||
return latents
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if gate.device.type == "mps":
|
||||
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||
|
||||
return F.gelu(gate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(act_fn)
|
||||
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class AddAuxLoss(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
# do nothing in forward (no computation)
|
||||
ctx.requires_aux_loss = loss.requires_grad
|
||||
ctx.dtype = loss.dtype
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# add the aux loss gradients
|
||||
grad_loss = None
|
||||
# put the aux grad the same as the main grad loss
|
||||
# aux grad contributes equally
|
||||
if ctx.requires_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||
|
||||
return grad_output, grad_loss
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
self.top_k = num_experts_per_tok
|
||||
self.n_routed_experts = num_experts
|
||||
|
||||
self.alpha = aux_loss_alpha
|
||||
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# flatten hidden states
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
|
||||
# get logits and pass it to softmax
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||
scores = logits.softmax(dim = -1)
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
|
||||
# used bincount instead of one hot encoding
|
||||
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||
ce = counts / topk_idx.numel() # normalized expert usage
|
||||
|
||||
# mean expert score
|
||||
Pi = scores_for_aux.mean(0)
|
||||
|
||||
# expert balance loss
|
||||
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
class MoEBlock(nn.Module):
|
||||
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.moe_top_k = moe_top_k
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
|
||||
if self.training:
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||
|
||||
for i, expert in enumerate(self.experts):
|
||||
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||
y = y.view(*orig_shape)
|
||||
|
||||
y = AddAuxLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
|
||||
y = y + self.shared_experts(identity)
|
||||
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
# no need for .numpy().cpu() here
|
||||
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||
token_idxs = idxs // self.moe_top_k
|
||||
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||
# + avoid dtype conversion
|
||||
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||
scale: float = 1.0, max_period: int = 10000):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels = num_channels
|
||||
half_dim = num_channels // 2
|
||||
|
||||
# precompute the “inv_freq” vector once
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
half_dim, dtype=torch.float32
|
||||
) / (half_dim - downscale_freq_shift)
|
||||
|
||||
inv_freq = torch.exp(exponent)
|
||||
|
||||
# pad
|
||||
if num_channels % 2 == 1:
|
||||
# we’ll pad a zero at the end of the cos-half
|
||||
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||
|
||||
# register to buffer so it moves with the device
|
||||
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps: torch.Tensor):
|
||||
|
||||
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||
|
||||
|
||||
# fused CUDA kernels for sin and cos
|
||||
sin_emb = x.sin()
|
||||
cos_emb = x.cos()
|
||||
|
||||
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||
|
||||
# scale factor
|
||||
if self.scale != 1.0:
|
||||
emb = emb * self.scale
|
||||
|
||||
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||
if emb.shape[1] > self.num_channels:
|
||||
emb = emb[:, :self.num_channels]
|
||||
|
||||
return emb
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||
|
||||
self.time_embed = Timesteps(hidden_size)
|
||||
|
||||
def forward(self, timesteps, condition):
|
||||
|
||||
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||
|
||||
if condition is not None:
|
||||
cond_embed = self.cond_proj(condition)
|
||||
timestep_embed = timestep_embed + cond_embed
|
||||
|
||||
time_conditioned = self.mlp(timestep_embed)
|
||||
|
||||
# for broadcasting with image tokens
|
||||
return time_conditioned.unsqueeze(1)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.gelu(self.fc1(x)))
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.qdim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, y):
|
||||
|
||||
b, s1, _ = x.shape
|
||||
_, s2, _ = y.shape
|
||||
|
||||
y = y.to(next(self.to_k.parameters()).dtype)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(y)
|
||||
v = self.to_v(y)
|
||||
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
x = optimized_attention(
|
||||
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
|
||||
return out
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias = True,
|
||||
qk_norm = False,
|
||||
norm_layer = nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None,
|
||||
dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
query = self.q_norm(query)
|
||||
key = self.k_norm(key)
|
||||
|
||||
x = optimized_attention(
|
||||
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm_layer=True,
|
||||
qkv_bias=True,
|
||||
skip_connection=True,
|
||||
timested_modulate=False,
|
||||
use_moe: bool = False,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None, dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.timested_modulate = timested_modulate
|
||||
if self.timested_modulate:
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
if skip_connection:
|
||||
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.use_moe = use_moe
|
||||
|
||||
if self.use_moe:
|
||||
self.moe = MoEBlock(
|
||||
hidden_size,
|
||||
num_experts = num_experts,
|
||||
moe_top_k = moe_top_k,
|
||||
dropout = 0.0,
|
||||
ff_inner_dim = int(hidden_size * 4.0),
|
||||
device = device, dtype = dtype,
|
||||
operations = operations
|
||||
)
|
||||
else:
|
||||
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||
|
||||
if self.skip_linear is not None:
|
||||
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||
hidden_states = self.skip_linear(combined)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
|
||||
# self attention
|
||||
if self.timested_modulate:
|
||||
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||
hidden_states = hidden_states + modulation_shift
|
||||
|
||||
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = hidden_states + self_attn_out
|
||||
|
||||
# cross attention
|
||||
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||
|
||||
# MLP Layer
|
||||
mlp_input = self.norm3(hidden_states)
|
||||
|
||||
if self.use_moe:
|
||||
hidden_states = hidden_states + self.moe(mlp_input)
|
||||
else:
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
|
||||
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = x[:, 1:]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
hidden_size: int = 2048,
|
||||
context_dim: int = 1024,
|
||||
depth: int = 21,
|
||||
num_heads: int = 16,
|
||||
qk_norm: bool = True,
|
||||
qkv_bias: bool = False,
|
||||
num_moe_layers: int = 6,
|
||||
guidance_cond_proj_dim = 2048,
|
||||
norm_type = 'layer',
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||
qk_norm = operations.RMSNorm
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||
|
||||
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
text_states_dim=context_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer = norm,
|
||||
qk_norm_layer = qk_norm,
|
||||
skip_connection=layer > depth // 2,
|
||||
qkv_bias=qkv_bias,
|
||||
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k,
|
||||
use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
|
||||
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||
|
||||
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||
x_embedded = self.x_embedder(x)
|
||||
|
||||
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||
|
||||
def block_wrap(args):
|
||||
return block(
|
||||
args["x"],
|
||||
args["t"],
|
||||
args["cond"],
|
||||
skip_tensor=args.get("skip"),)
|
||||
|
||||
skip_stack = []
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for idx, block in enumerate(self.blocks):
|
||||
if idx <= self.depth // 2:
|
||||
skip_input = None
|
||||
else:
|
||||
skip_input = skip_stack.pop()
|
||||
|
||||
if ("block", idx) in blocks_replace:
|
||||
|
||||
combined = blocks_replace[("block", idx)](
|
||||
{
|
||||
"x": combined,
|
||||
"t": time_embedded,
|
||||
"cond": main_condition,
|
||||
"skip": skip_input,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
else:
|
||||
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||
|
||||
if idx < self.depth // 2:
|
||||
skip_stack.append(combined)
|
||||
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
@@ -1,6 +1,7 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@@ -39,6 +40,7 @@ class HunyuanVideoParams:
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@@ -160,6 +162,30 @@ class TokenRefiner(nn.Module):
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
|
||||
|
||||
class ByT5Mapper(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
|
||||
self.use_res = use_res
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res:
|
||||
res = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x2 = self.act_fn(x)
|
||||
x2 = self.fc3(x2)
|
||||
if self.use_res:
|
||||
x2 = x2 + res
|
||||
return x2
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@@ -184,9 +210,13 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
@@ -214,6 +244,18 @@ class HunyuanVideo(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
if params.byt5:
|
||||
self.byt5_in = ByT5Mapper(
|
||||
in_dim=1472,
|
||||
out_dim=2048,
|
||||
hidden_dim=2048,
|
||||
out_dim1=self.hidden_size,
|
||||
use_res=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.byt5_in = None
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@@ -225,7 +267,8 @@ class HunyuanVideo(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@@ -249,13 +292,17 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
if self.vector_in is not None:
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
else:
|
||||
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
@@ -268,6 +315,12 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@@ -327,12 +380,16 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
shape = initial_shape[-len(self.patch_size):]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
if img.ndim == 8:
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
else:
|
||||
img = img.permute(0, 3, 1, 4, 2, 5)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
|
||||
return img
|
||||
|
||||
def img_ids(self, x):
|
||||
@@ -347,9 +404,30 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
def img_ids_2d(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
|
||||
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
|
||||
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
||||
self.ratio = (in_dim << 2) // out_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h >> 1, w >> 1
|
||||
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
||||
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
||||
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
||||
|
||||
|
||||
class PixelUnshuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
||||
self.scale = (out_dim << 2) // in_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h << 1, w << 1
|
||||
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
return y + r
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, z):
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
@@ -590,8 +591,15 @@ class NextDiT(nn.Module):
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
|
||||
@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
|
||||
|
||||
|
||||
#################################################################################
|
||||
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
|
||||
return x
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
if use_checkpoint:
|
||||
|
||||
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
def forward(self, x, temb=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
def __init__(
|
||||
self,
|
||||
extra_condition_channels=0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
self.main_model_double = 60
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
||||
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
hint, _, _ = self.process_img(hint)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
||||
|
||||
controlnet_block_samples = ()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
||||
|
||||
return {"input": controlnet_block_samples[:self.main_model_double]}
|
||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@@ -214,9 +215,9 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -248,11 +249,11 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
@@ -275,7 +276,7 @@ class LastLayer(nn.Module):
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||
return x
|
||||
|
||||
|
||||
@@ -293,6 +294,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -300,6 +302,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
@@ -329,9 +332,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
@@ -353,7 +356,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
@@ -362,6 +372,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
@@ -416,6 +427,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
@@ -436,6 +448,19 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
+509
-21
@@ -4,13 +4,14 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
@@ -148,11 +149,14 @@ WAN_CROSSATTENTION_CLASSES = {
|
||||
|
||||
def repeat_e(e, x):
|
||||
repeats = 1
|
||||
if e.shape[1] > 1:
|
||||
repeats = x.shape[1] // e.shape[1]
|
||||
if e.size(1) > 1:
|
||||
repeats = x.size(1) // e.size(1)
|
||||
if repeats == 1:
|
||||
return e
|
||||
return torch.repeat_interleave(e, repeats, dim=1)
|
||||
if repeats * e.size(1) == x.size(1):
|
||||
return torch.repeat_interleave(e, repeats, dim=1)
|
||||
else:
|
||||
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
@@ -219,15 +223,15 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs)
|
||||
|
||||
x = x + y * repeat_e(e[2], x)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||
x = x + y * repeat_e(e[5], x)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
@@ -342,7 +346,7 @@ class Head(nn.Module):
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||
|
||||
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
|
||||
return x
|
||||
|
||||
|
||||
@@ -572,30 +576,49 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
t_len = x.shape[2]
|
||||
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@@ -831,3 +854,468 @@ class CameraWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_mode='replicate',
|
||||
operations=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = operations.Conv1d(
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MotionEncoder_tc(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim: int,
|
||||
hidden_dim: int,
|
||||
num_heads=int,
|
||||
need_global=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.need_global = need_global
|
||||
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
|
||||
if need_global:
|
||||
self.conv1_global = CausalConv1d(
|
||||
in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
|
||||
self.norm1 = operations.LayerNorm(
|
||||
hidden_dim // 4,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
**factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
|
||||
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
|
||||
|
||||
if need_global:
|
||||
self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||
|
||||
self.norm1 = operations.LayerNorm(
|
||||
hidden_dim // 4,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
**factory_kwargs)
|
||||
|
||||
self.norm2 = operations.LayerNorm(
|
||||
hidden_dim // 2,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
**factory_kwargs)
|
||||
|
||||
self.norm3 = operations.LayerNorm(
|
||||
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x_ori = x.clone()
|
||||
b, c, t = x.shape
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
if not self.need_global:
|
||||
return x_local
|
||||
|
||||
x = self.conv1_global(x_ori)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.final_linear(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
|
||||
return x, x_local
|
||||
|
||||
|
||||
class CausalAudioEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=5120,
|
||||
num_layers=25,
|
||||
out_dim=2048,
|
||||
video_rate=8,
|
||||
num_token=4,
|
||||
need_global=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.encoder = MotionEncoder_tc(
|
||||
in_dim=dim,
|
||||
hidden_dim=out_dim,
|
||||
num_heads=num_token,
|
||||
need_global=need_global, dtype=dtype, device=device, operations=operations)
|
||||
weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
|
||||
|
||||
self.weights = torch.nn.Parameter(weight)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, features):
|
||||
# features B * num_layers * dim * video_length
|
||||
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
|
||||
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||
weighted_feat = ((features * weights) / weights_sum).sum(
|
||||
dim=1) # b dim f
|
||||
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||
res = self.encoder(weighted_feat) # b f n dim
|
||||
return res # b f n dim
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(self.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AudioInjector_WAN(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=2048,
|
||||
num_heads=32,
|
||||
inject_layer=[0, 27],
|
||||
root_net=None,
|
||||
enable_adain=False,
|
||||
adain_dim=2048,
|
||||
adain_mode=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.enable_adain = enable_adain
|
||||
self.adain_mode = adain_mode
|
||||
self.injected_block_id = {}
|
||||
audio_injector_id = 0
|
||||
for inject_id in inject_layer:
|
||||
self.injected_block_id[inject_id] = audio_injector_id
|
||||
audio_injector_id += 1
|
||||
|
||||
self.injector = nn.ModuleList([
|
||||
WanT2VCrossAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
|
||||
) for _ in range(audio_injector_id)
|
||||
])
|
||||
self.injector_pre_norm_feat = nn.ModuleList([
|
||||
operations.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6, dtype=dtype, device=device
|
||||
) for _ in range(audio_injector_id)
|
||||
])
|
||||
self.injector_pre_norm_vec = nn.ModuleList([
|
||||
operations.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6, dtype=dtype, device=device
|
||||
) for _ in range(audio_injector_id)
|
||||
])
|
||||
if enable_adain:
|
||||
self.injector_adain_layers = nn.ModuleList([
|
||||
AdaLayerNorm(
|
||||
output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(audio_injector_id)
|
||||
])
|
||||
if adain_mode != "attn_norm":
|
||||
self.injector_adain_output_layers = nn.ModuleList(
|
||||
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
|
||||
|
||||
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
|
||||
audio_attn_id = self.injected_block_id.get(block_id, None)
|
||||
if audio_attn_id is None:
|
||||
return x
|
||||
|
||||
num_frames = audio_emb.shape[1]
|
||||
input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
|
||||
if self.enable_adain and self.adain_mode == "attn_norm":
|
||||
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||
adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||
attn_hidden_states = adain_hidden_states
|
||||
else:
|
||||
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
|
||||
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
attn_audio_emb = audio_emb
|
||||
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
|
||||
residual_out = rearrange(
|
||||
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
x[:, :seq_len] = x[:, :seq_len] + residual_out
|
||||
return x
|
||||
|
||||
|
||||
class FramePackMotioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inner_dim=1024,
|
||||
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
|
||||
zip_frame_buckets=[
|
||||
1, 2, 16
|
||||
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
|
||||
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
|
||||
self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
|
||||
self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
|
||||
self.zip_frame_buckets = zip_frame_buckets
|
||||
|
||||
self.inner_dim = inner_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.drop_mode = drop_mode
|
||||
|
||||
def forward(self, motion_latents, rope_embedder, add_last_motion=2):
|
||||
lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
|
||||
padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
|
||||
if overlap_frame > 0:
|
||||
padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||
zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
|
||||
padd_lat[:, :, -zero_end_frame:] = 0
|
||||
|
||||
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
|
||||
|
||||
# patchfy
|
||||
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||
clean_latents_2x = self.proj_2x(clean_latents_2x)
|
||||
l_2x_shape = clean_latents_2x.shape
|
||||
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
|
||||
clean_latents_4x = self.proj_4x(clean_latents_4x)
|
||||
l_4x_shape = clean_latents_4x.shape
|
||||
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||
clean_latents_post = clean_latents_post[:, :
|
||||
0] if add_last_motion < 2 else clean_latents_post
|
||||
clean_latents_2x = clean_latents_2x[:, :
|
||||
0] if add_last_motion < 1 else clean_latents_2x
|
||||
|
||||
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||
|
||||
rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
|
||||
rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
|
||||
return motion_lat, rope
|
||||
|
||||
|
||||
class WanModel_S2V(WanModel):
|
||||
def __init__(self,
|
||||
model_type='s2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
audio_dim=1024,
|
||||
num_audio_token=4,
|
||||
enable_adain=True,
|
||||
cond_dim=16,
|
||||
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||
adain_mode="attn_norm",
|
||||
framepack_drop_mode="padd",
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
||||
|
||||
self.casual_audio_encoder = CausalAudioEncoder(
|
||||
dim=audio_dim,
|
||||
out_dim=self.dim,
|
||||
num_token=num_audio_token,
|
||||
need_global=enable_adain, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if cond_dim > 0:
|
||||
self.cond_encoder = operations.Conv3d(
|
||||
cond_dim,
|
||||
self.dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size, device=device, dtype=dtype)
|
||||
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
dim=self.dim,
|
||||
num_heads=self.num_heads,
|
||||
inject_layer=audio_inject_layers,
|
||||
root_net=self,
|
||||
enable_adain=enable_adain,
|
||||
adain_dim=self.dim,
|
||||
adain_mode=adain_mode,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.frame_packer = FramePackMotioner(
|
||||
inner_dim=self.dim,
|
||||
num_heads=self.num_heads,
|
||||
zip_frame_buckets=[1, 2, 16],
|
||||
drop_mode=framepack_drop_mode,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
audio_embed=None,
|
||||
reference_latent=None,
|
||||
control_video=None,
|
||||
reference_motion=None,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
if audio_embed is not None:
|
||||
num_embeds = x.shape[-3] * 4
|
||||
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
|
||||
else:
|
||||
audio_emb = None
|
||||
|
||||
# embeddings
|
||||
bs, _, time, height, width = x.shape
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if control_video is not None:
|
||||
x = x + self.cond_encoder(control_video)
|
||||
|
||||
if t.ndim == 1:
|
||||
t = t.unsqueeze(1).repeat(1, x.shape[2])
|
||||
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
seq_len = x.size(1)
|
||||
|
||||
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
|
||||
x = x + cond_mask_weight[0]
|
||||
|
||||
if reference_latent is not None:
|
||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||
ref = ref.flatten(2).transpose(1, 2)
|
||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||
ref = ref + cond_mask_weight[1]
|
||||
x = torch.cat([x, ref], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||
del ref, freqs_ref
|
||||
|
||||
if reference_motion is not None:
|
||||
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||
motion_encoded = motion_encoded + cond_mask_weight[2]
|
||||
x = torch.cat([x, motion_encoded], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_motion], dim=1)
|
||||
|
||||
t = torch.repeat_interleave(t, 2, dim=1)
|
||||
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||
del motion_encoded, freqs_motion
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
if audio_emb is not None:
|
||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
|
||||
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||
|
||||
def convert_uso_lora(sd):
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
tensor = sd[k]
|
||||
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
|
||||
.replace(".up.weight", ".lora_up.weight")
|
||||
.replace(".qkv_lora2.", ".txt_attn.qkv.")
|
||||
.replace(".qkv_lora1.", ".img_attn.qkv.")
|
||||
.replace(".proj_lora1.", ".img_attn.proj.")
|
||||
.replace(".proj_lora2.", ".txt_attn.proj.")
|
||||
.replace(".qkv_lora.", ".linear1_qkv.")
|
||||
.replace(".proj_lora.", ".linear2.")
|
||||
.replace(".processor.", ".")
|
||||
)
|
||||
sd_out[k_to] = tensor
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||
return convert_lora_wan_fun(sd)
|
||||
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
||||
return convert_uso_lora(sd)
|
||||
return sd
|
||||
|
||||
+93
-11
@@ -16,6 +16,8 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@@ -150,6 +152,7 @@ class BaseModel(torch.nn.Module):
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
self.memory_usage_factor_conds = ()
|
||||
self.memory_usage_shape_process = {}
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -350,8 +353,15 @@ class BaseModel(torch.nn.Module):
|
||||
input_shapes = [input_shape]
|
||||
for c in self.memory_usage_factor_conds:
|
||||
shape = cond_shapes.get(c, None)
|
||||
if shape is not None and len(shape) > 0:
|
||||
input_shapes += shape
|
||||
if shape is not None:
|
||||
if c in self.memory_usage_shape_process:
|
||||
out = []
|
||||
for s in shape:
|
||||
out.append(self.memory_usage_shape_process[c](s))
|
||||
shape = out
|
||||
|
||||
if len(shape) > 0:
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
@@ -1102,9 +1112,10 @@ class WAN21(BaseModel):
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], 16):
|
||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if extra_channels != image.shape[1] + 4:
|
||||
@@ -1201,18 +1212,50 @@ class WAN21_Camera(WAN21):
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class WAN22(BaseModel):
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
|
||||
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
audio_embed = kwargs.get("audio_embed", None)
|
||||
if audio_embed is not None:
|
||||
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||
|
||||
reference_motion = kwargs.get("reference_motion", None)
|
||||
if reference_motion is not None:
|
||||
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
|
||||
|
||||
control_video = kwargs.get("control_video", None)
|
||||
if control_video is not None:
|
||||
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
reference_motion = kwargs.get("reference_motion", None)
|
||||
if reference_motion is not None:
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
denoise_mask = kwargs.get("denoise_mask", None)
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
return out
|
||||
@@ -1241,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2_1(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guidance = kwargs.get("guidance", 5.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
@@ -1350,3 +1408,27 @@ class QwenImage(BaseModel):
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class HunyuanImage21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||
if conditioning_byt5small is not None:
|
||||
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
return out
|
||||
|
||||
@@ -136,20 +136,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
|
||||
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = [1, 2, 2]
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = list(in_w.shape[2:])
|
||||
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
||||
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
dit_config["axes_dim"] = [64, 64]
|
||||
|
||||
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_w.shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["num_heads"] = in_w.shape[0] // 128
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 256
|
||||
dit_config["qkv_bias"] = True
|
||||
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
|
||||
dit_config["byt5"] = True
|
||||
else:
|
||||
dit_config["byt5"] = False
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
@@ -368,6 +380,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
dit_config["model_type"] = "camera_2.2"
|
||||
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "s2v"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -398,6 +412,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
dit_config["context_dim"] = 1024
|
||||
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hidream"
|
||||
@@ -492,6 +520,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
|
||||
@@ -22,6 +22,7 @@ from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
@@ -289,6 +290,24 @@ def is_amd():
|
||||
return True
|
||||
return False
|
||||
|
||||
def amd_min_version(device=None, min_rdna_version=0):
|
||||
if not is_amd():
|
||||
return False
|
||||
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if arch.startswith('gfx') and len(arch) == 7:
|
||||
try:
|
||||
cmp_rdna_version = int(arch[4]) + 2
|
||||
except:
|
||||
cmp_rdna_version = 0
|
||||
if cmp_rdna_version >= min_rdna_version:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||
@@ -321,12 +340,13 @@ try:
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
@@ -593,7 +613,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models = set(models)
|
||||
models_temp = set()
|
||||
for m in models:
|
||||
models_temp.add(m)
|
||||
for mm in m.model_patches_models():
|
||||
models_temp.add(mm)
|
||||
|
||||
models = models_temp
|
||||
|
||||
models_to_load = []
|
||||
|
||||
@@ -899,7 +925,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||
# also a problem on RDNA4 except fp32 is also slow there.
|
||||
# This is due to large bf16 convolutions being extremely slow.
|
||||
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
return torch.float32
|
||||
|
||||
@@ -430,6 +430,12 @@ class ModelPatcher:
|
||||
def set_model_forward_timestep_embed_patch(self, patch):
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@@ -486,6 +492,30 @@ class ModelPatcher:
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_patches_models(self):
|
||||
to = self.model_options["transformer_options"]
|
||||
models = []
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "models"):
|
||||
models += patch_list[i].models()
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "models"):
|
||||
models += patch_list[k].models()
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "models"):
|
||||
models += wrap_func.models()
|
||||
|
||||
return models
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
@@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
PREPARE_SAMPLING = "prepare_sampling"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
PREDICT_NOISE = "predict_noise"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
Regular → Executable
+16
-5
@@ -17,6 +17,7 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert (mask.shape[1:] == x_in.shape[2:])
|
||||
# assert (mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
@@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||
|
||||
mask = mask * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
if len(mask.shape) == len(dims):
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1:] != dims:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||
if mask.ndim < 4:
|
||||
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
|
||||
else:
|
||||
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
|
||||
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
@@ -725,7 +729,7 @@ class Sampler:
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||
|
||||
@@ -953,7 +957,14 @@ class CFGGuider:
|
||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict_noise(*args, **kwargs)
|
||||
return self.outer_predict_noise(*args, **kwargs)
|
||||
|
||||
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.predict_noise,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
|
||||
).execute(x, timestep, model_options, seed)
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
+49
-10
@@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@@ -48,6 +49,7 @@ import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -328,6 +330,19 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.downscale_ratio = 32
|
||||
self.upscale_ratio = 32
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@@ -446,17 +461,29 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
# Hunyuan 3d v2 2.0 & 2.1
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
|
||||
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
|
||||
batch, num_tokens, hidden_dim = shape
|
||||
dtype_size = model_management.dtype_size(dtype)
|
||||
|
||||
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
|
||||
return total_mem
|
||||
|
||||
# better memory estimations
|
||||
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
|
||||
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||
|
||||
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
|
||||
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
@@ -773,6 +800,7 @@ class CLIPType(Enum):
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -794,6 +822,7 @@ class TEModel(Enum):
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
QWEN25_7B = 11
|
||||
BYT5_SMALL_GLYPH = 12
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -811,6 +840,9 @@ def detect_te_model(sd):
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
|
||||
if weight.shape[0] == 384:
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
@@ -925,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
elif te_model == TEModel.QWEN25_7B:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@@ -970,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -700,7 +701,7 @@ class Flux(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 2.8
|
||||
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
@@ -1072,6 +1073,19 @@ class WAN21_Vace(WAN21_T2V):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "s2v",
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN22_S2V(self, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1115,6 +1129,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class Hunyuan3Dv2_1(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2_1",
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2_1
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Hunyuan3Dv2_1(self, device = device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@@ -1271,7 +1296,31 @@ class QwenImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
class HunyuanImage21(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vec_in_dim": None,
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.HunyuanImage21
|
||||
|
||||
memory_usage_factor = 7.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanImage21(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 3584,
|
||||
"d_kv": 64,
|
||||
"d_model": 1472,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 4,
|
||||
"num_heads": 6,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 1510
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
{
|
||||
"<extra_id_0>": 259,
|
||||
"<extra_id_100>": 359,
|
||||
"<extra_id_101>": 360,
|
||||
"<extra_id_102>": 361,
|
||||
"<extra_id_103>": 362,
|
||||
"<extra_id_104>": 363,
|
||||
"<extra_id_105>": 364,
|
||||
"<extra_id_106>": 365,
|
||||
"<extra_id_107>": 366,
|
||||
"<extra_id_108>": 367,
|
||||
"<extra_id_109>": 368,
|
||||
"<extra_id_10>": 269,
|
||||
"<extra_id_110>": 369,
|
||||
"<extra_id_111>": 370,
|
||||
"<extra_id_112>": 371,
|
||||
"<extra_id_113>": 372,
|
||||
"<extra_id_114>": 373,
|
||||
"<extra_id_115>": 374,
|
||||
"<extra_id_116>": 375,
|
||||
"<extra_id_117>": 376,
|
||||
"<extra_id_118>": 377,
|
||||
"<extra_id_119>": 378,
|
||||
"<extra_id_11>": 270,
|
||||
"<extra_id_120>": 379,
|
||||
"<extra_id_121>": 380,
|
||||
"<extra_id_122>": 381,
|
||||
"<extra_id_123>": 382,
|
||||
"<extra_id_124>": 383,
|
||||
"<extra_id_12>": 271,
|
||||
"<extra_id_13>": 272,
|
||||
"<extra_id_14>": 273,
|
||||
"<extra_id_15>": 274,
|
||||
"<extra_id_16>": 275,
|
||||
"<extra_id_17>": 276,
|
||||
"<extra_id_18>": 277,
|
||||
"<extra_id_19>": 278,
|
||||
"<extra_id_1>": 260,
|
||||
"<extra_id_20>": 279,
|
||||
"<extra_id_21>": 280,
|
||||
"<extra_id_22>": 281,
|
||||
"<extra_id_23>": 282,
|
||||
"<extra_id_24>": 283,
|
||||
"<extra_id_25>": 284,
|
||||
"<extra_id_26>": 285,
|
||||
"<extra_id_27>": 286,
|
||||
"<extra_id_28>": 287,
|
||||
"<extra_id_29>": 288,
|
||||
"<extra_id_2>": 261,
|
||||
"<extra_id_30>": 289,
|
||||
"<extra_id_31>": 290,
|
||||
"<extra_id_32>": 291,
|
||||
"<extra_id_33>": 292,
|
||||
"<extra_id_34>": 293,
|
||||
"<extra_id_35>": 294,
|
||||
"<extra_id_36>": 295,
|
||||
"<extra_id_37>": 296,
|
||||
"<extra_id_38>": 297,
|
||||
"<extra_id_39>": 298,
|
||||
"<extra_id_3>": 262,
|
||||
"<extra_id_40>": 299,
|
||||
"<extra_id_41>": 300,
|
||||
"<extra_id_42>": 301,
|
||||
"<extra_id_43>": 302,
|
||||
"<extra_id_44>": 303,
|
||||
"<extra_id_45>": 304,
|
||||
"<extra_id_46>": 305,
|
||||
"<extra_id_47>": 306,
|
||||
"<extra_id_48>": 307,
|
||||
"<extra_id_49>": 308,
|
||||
"<extra_id_4>": 263,
|
||||
"<extra_id_50>": 309,
|
||||
"<extra_id_51>": 310,
|
||||
"<extra_id_52>": 311,
|
||||
"<extra_id_53>": 312,
|
||||
"<extra_id_54>": 313,
|
||||
"<extra_id_55>": 314,
|
||||
"<extra_id_56>": 315,
|
||||
"<extra_id_57>": 316,
|
||||
"<extra_id_58>": 317,
|
||||
"<extra_id_59>": 318,
|
||||
"<extra_id_5>": 264,
|
||||
"<extra_id_60>": 319,
|
||||
"<extra_id_61>": 320,
|
||||
"<extra_id_62>": 321,
|
||||
"<extra_id_63>": 322,
|
||||
"<extra_id_64>": 323,
|
||||
"<extra_id_65>": 324,
|
||||
"<extra_id_66>": 325,
|
||||
"<extra_id_67>": 326,
|
||||
"<extra_id_68>": 327,
|
||||
"<extra_id_69>": 328,
|
||||
"<extra_id_6>": 265,
|
||||
"<extra_id_70>": 329,
|
||||
"<extra_id_71>": 330,
|
||||
"<extra_id_72>": 331,
|
||||
"<extra_id_73>": 332,
|
||||
"<extra_id_74>": 333,
|
||||
"<extra_id_75>": 334,
|
||||
"<extra_id_76>": 335,
|
||||
"<extra_id_77>": 336,
|
||||
"<extra_id_78>": 337,
|
||||
"<extra_id_79>": 338,
|
||||
"<extra_id_7>": 266,
|
||||
"<extra_id_80>": 339,
|
||||
"<extra_id_81>": 340,
|
||||
"<extra_id_82>": 341,
|
||||
"<extra_id_83>": 342,
|
||||
"<extra_id_84>": 343,
|
||||
"<extra_id_85>": 344,
|
||||
"<extra_id_86>": 345,
|
||||
"<extra_id_87>": 346,
|
||||
"<extra_id_88>": 347,
|
||||
"<extra_id_89>": 348,
|
||||
"<extra_id_8>": 267,
|
||||
"<extra_id_90>": 349,
|
||||
"<extra_id_91>": 350,
|
||||
"<extra_id_92>": 351,
|
||||
"<extra_id_93>": 352,
|
||||
"<extra_id_94>": 353,
|
||||
"<extra_id_95>": 354,
|
||||
"<extra_id_96>": 355,
|
||||
"<extra_id_97>": 356,
|
||||
"<extra_id_98>": 357,
|
||||
"<extra_id_99>": 358,
|
||||
"<extra_id_9>": 268
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>",
|
||||
"<extra_id_100>",
|
||||
"<extra_id_101>",
|
||||
"<extra_id_102>",
|
||||
"<extra_id_103>",
|
||||
"<extra_id_104>",
|
||||
"<extra_id_105>",
|
||||
"<extra_id_106>",
|
||||
"<extra_id_107>",
|
||||
"<extra_id_108>",
|
||||
"<extra_id_109>",
|
||||
"<extra_id_110>",
|
||||
"<extra_id_111>",
|
||||
"<extra_id_112>",
|
||||
"<extra_id_113>",
|
||||
"<extra_id_114>",
|
||||
"<extra_id_115>",
|
||||
"<extra_id_116>",
|
||||
"<extra_id_117>",
|
||||
"<extra_id_118>",
|
||||
"<extra_id_119>",
|
||||
"<extra_id_120>",
|
||||
"<extra_id_121>",
|
||||
"<extra_id_122>",
|
||||
"<extra_id_123>",
|
||||
"<extra_id_124>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,100 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from transformers import ByT5Tokenizer
|
||||
import os
|
||||
import re
|
||||
|
||||
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
# self.llama_template_images = "{}"
|
||||
self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
# ByT5 processing for HunyuanImage
|
||||
text_prompt_texts = []
|
||||
pattern_quote_single = r'\'(.*?)\''
|
||||
pattern_quote_double = r'\"(.*?)\"'
|
||||
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||
pattern_quote_chinese_double = r'“(.*?)”'
|
||||
|
||||
matches_quote_single = re.findall(pattern_quote_single, text)
|
||||
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_single)
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
if len(text_prompt_texts) > 0:
|
||||
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class ByT5SmallModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class HunyuanImageTEModel(QwenImageTEModel):
|
||||
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
|
||||
if byt5:
|
||||
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
|
||||
else:
|
||||
self.byt5_small = None
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs)
|
||||
if self.byt5_small is not None and "byt5" in token_weight_pairs:
|
||||
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
||||
extra["conditioning_byt5small"] = out[0]
|
||||
return cond, p, extra
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
if self.byt5_small is not None:
|
||||
self.byt5_small.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
super().reset_clip_options()
|
||||
if self.byt5_small is not None:
|
||||
self.byt5_small.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd:
|
||||
return self.byt5_small.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||
return QwenImageTEModel_
|
||||
@@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
org_dtype = xq.dtype
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
return q_embed, k_embed
|
||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
@@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LokrDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
||||
@@ -8,6 +8,7 @@ import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
@@ -282,8 +283,6 @@ class VideoFromComponents(VideoInput):
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
@@ -298,27 +297,12 @@ class VideoFromComponents(VideoInput):
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
+21
-3
@@ -726,6 +726,18 @@ class SEGS(ComfyTypeIO):
|
||||
class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER")
|
||||
class AudioEncoder(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
|
||||
class AudioEncoderOutput(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@@ -1178,13 +1190,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, **kwargs) -> bool:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||
def validate_inputs(cls, **kwargs) -> bool | str:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.
|
||||
|
||||
If the function returns a string, it will be used as the validation error message for the node.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, **kwargs) -> Any:
|
||||
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
|
||||
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.
|
||||
|
||||
If this function returns the same value as last run, the node will not be executed."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -1580,6 +1597,7 @@ class _IO:
|
||||
Model = Model
|
||||
ClipVision = ClipVision
|
||||
ClipVisionOutput = ClipVisionOutput
|
||||
AudioEncoderOutput = AudioEncoderOutput
|
||||
StyleModel = StyleModel
|
||||
Gligen = Gligen
|
||||
UpscaleModel = UpscaleModel
|
||||
|
||||
@@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi(
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
|
||||
Generated
+21
-1
@@ -951,7 +951,11 @@ class MagicPrompt2(str, Enum):
|
||||
|
||||
|
||||
class StyleType1(str, Enum):
|
||||
AUTO = 'AUTO'
|
||||
GENERAL = 'GENERAL'
|
||||
REALISTIC = 'REALISTIC'
|
||||
DESIGN = 'DESIGN'
|
||||
FICTION = 'FICTION'
|
||||
|
||||
|
||||
class ImagenImageGenerationInstance(BaseModel):
|
||||
@@ -2676,7 +2680,7 @@ class ReleaseNote(BaseModel):
|
||||
|
||||
|
||||
class RenderingSpeed(str, Enum):
|
||||
BALANCED = 'BALANCED'
|
||||
DEFAULT = 'DEFAULT'
|
||||
TURBO = 'TURBO'
|
||||
QUALITY = 'QUALITY'
|
||||
|
||||
@@ -4918,6 +4922,14 @@ class IdeogramV3EditRequest(BaseModel):
|
||||
None,
|
||||
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
|
||||
)
|
||||
character_reference_images: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
character_reference_images_mask: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
|
||||
|
||||
class IdeogramV3Request(BaseModel):
|
||||
@@ -4951,6 +4963,14 @@ class IdeogramV3Request(BaseModel):
|
||||
style_type: Optional[StyleType1] = Field(
|
||||
None, description='The type of style to apply'
|
||||
)
|
||||
character_reference_images: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
character_reference_images_mask: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
|
||||
|
||||
class ImagenGenerateImageResponse(BaseModel):
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GeminiImageGenerateContentRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
||||
@@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityTextToAudioRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: int = Field(190, ge=1, le=190)
|
||||
seed: int = Field(0, ge=0, le=4294967294)
|
||||
steps: int = Field(8, ge=4, le=8)
|
||||
output_format: str = Field("wav")
|
||||
|
||||
|
||||
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||
|
||||
|
||||
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||
mask_start: int = Field(30, ge=0, le=190)
|
||||
mask_end: int = Field(190, ge=0, le=190)
|
||||
|
||||
|
||||
class StabilityAudioResponse(BaseModel):
|
||||
audio: Optional[str] = Field(None)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+318
-95
@@ -4,8 +4,12 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
@@ -22,6 +26,7 @@ from comfy_api_nodes.apis import (
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
)
|
||||
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
@@ -32,6 +37,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,6 +56,14 @@ class GeminiModel(str, Enum):
|
||||
gemini_2_5_flash = "gemini-2.5-flash"
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
"""
|
||||
Gemini Image Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
model: GeminiModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
@@ -72,6 +86,135 @@ def get_gemini_endpoint(
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_image_endpoint(
|
||||
model: GeminiImageModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiImageModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiImageGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
|
||||
def create_text_part(text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
|
||||
def get_parts_from_response(
|
||||
response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
|
||||
def get_parts_by_type(
|
||||
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
|
||||
def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1,1024,1024,4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@@ -156,59 +299,6 @@ class GeminiNode(ComfyNodeABC):
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def get_parts_from_response(
|
||||
self, response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
def get_parts_by_type(
|
||||
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in self.get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = self.get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
@@ -268,43 +358,6 @@ class GeminiNode(ComfyNodeABC):
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
def create_text_part(self, text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -320,11 +373,11 @@ class GeminiNode(ComfyNodeABC):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [self.create_text_part(prompt)]
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = self.create_image_parts(images)
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
@@ -348,9 +401,29 @@ class GeminiNode(ComfyNodeABC):
|
||||
).execute()
|
||||
|
||||
# Get result output
|
||||
output_text = self.get_text_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
|
||||
@@ -439,12 +512,162 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
return (files,)
|
||||
|
||||
|
||||
class GeminiImage(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text and image responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, files) to generate coherent
|
||||
text and image responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiImageModel],
|
||||
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
# TODO: later we can add this parameter later
|
||||
# "n": (
|
||||
# IO.INT,
|
||||
# {
|
||||
# "default": 1,
|
||||
# "min": 1,
|
||||
# "max": 8,
|
||||
# "step": 1,
|
||||
# "display": "number",
|
||||
# "tooltip": "How many images to generate",
|
||||
# },
|
||||
# ),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE, IO.STRING)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Gemini"
|
||||
DESCRIPTION = "Edit images synchronously via Google API."
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiImageModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
n=1,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_image_endpoint(model),
|
||||
request=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT","IMAGE"]
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return (output_image, output_text,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiImageNode": GeminiImage,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiImageNode": "Google Gemini Image",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
|
||||
+325
-284
@@ -1,8 +1,8 @@
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
import torch
|
||||
from comfy_api_nodes.apis import (
|
||||
IdeogramGenerateRequest,
|
||||
@@ -246,90 +246,82 @@ def display_image_urls_on_node(image_urls, node_id):
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV1(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
@@ -337,13 +329,15 @@ class IdeogramV1(ComfyNodeABC):
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
@@ -364,7 +358,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
@@ -377,93 +371,86 @@ class IdeogramV1(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV2(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V1_RES_MAP.keys()),
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"style_type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
"default": "NONE",
|
||||
"tooltip": "Style type for generation (V2 only)",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
@@ -473,22 +460,20 @@ class IdeogramV2(ComfyNodeABC):
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
@@ -499,8 +484,6 @@ class IdeogramV2(ComfyNodeABC):
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
@@ -517,6 +500,10 @@ class IdeogramV2(ComfyNodeABC):
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
@@ -540,7 +527,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
@@ -553,108 +540,110 @@ class IdeogramV2(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV3(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation or editing",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or editing",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image for image editing.",
|
||||
optional=True,
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
comfy_io.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V3_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V3_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
optional=True,
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": V3_RESOLUTIONS,
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=V3_RESOLUTIONS,
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"rendering_speed": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["BALANCED", "TURBO", "QUALITY"],
|
||||
"default": "BALANCED",
|
||||
"tooltip": "Controls the trade-off between generation speed and quality",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"rendering_speed",
|
||||
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||
default="DEFAULT",
|
||||
tooltip="Controls the trade-off between generation speed and quality",
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
comfy_io.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Image to use as character reference.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
"character_mask",
|
||||
tooltip="Optional mask for character reference image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
image=None,
|
||||
mask=None,
|
||||
@@ -663,10 +652,46 @@ class IdeogramV3(ComfyNodeABC):
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
rendering_speed="DEFAULT",
|
||||
character_image=None,
|
||||
character_mask=None,
|
||||
):
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
if rendering_speed == "BALANCED": # for backward compatibility
|
||||
rendering_speed = "DEFAULT"
|
||||
|
||||
character_img_binary = None
|
||||
character_mask_binary = None
|
||||
|
||||
if character_image is not None:
|
||||
input_tensor = character_image.squeeze().cpu()
|
||||
if character_mask is not None:
|
||||
character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False)
|
||||
character_mask = 1.0 - character_mask
|
||||
if character_mask.shape[1:] != character_image.shape[1:-1]:
|
||||
raise Exception("Character mask and image must be the same size")
|
||||
|
||||
mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
character_mask_binary = mask_byte_arr
|
||||
character_mask_binary.name = "mask.png"
|
||||
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
character_img_binary = img_byte_arr
|
||||
character_img_binary.name = "image.png"
|
||||
elif character_mask is not None:
|
||||
raise Exception("Character mask requires character image to be present")
|
||||
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
@@ -686,7 +711,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
# Process image
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
@@ -695,7 +720,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
# Process mask - white areas will be replaced
|
||||
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = io.BytesIO()
|
||||
mask_byte_arr = BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
mask_binary = mask_byte_arr
|
||||
@@ -715,6 +740,15 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if num_images > 1:
|
||||
edit_request.num_images = num_images
|
||||
|
||||
files = {
|
||||
"image": img_binary,
|
||||
"mask": mask_binary,
|
||||
}
|
||||
if character_img_binary:
|
||||
files["character_reference_images"] = character_img_binary
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -724,12 +758,9 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
files={
|
||||
"image": img_binary,
|
||||
"mask": mask_binary,
|
||||
},
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
@@ -761,6 +792,14 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if num_images > 1:
|
||||
gen_request.num_images = num_images
|
||||
|
||||
files = {}
|
||||
if character_img_binary:
|
||||
files["character_reference_images"] = character_img_binary
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
if files:
|
||||
gen_request.style_type = "AUTO"
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -770,7 +809,9 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_kwargs=kwargs,
|
||||
files=files if files else None,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
@@ -784,18 +825,18 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramV1": IdeogramV1,
|
||||
"IdeogramV2": IdeogramV2,
|
||||
"IdeogramV3": IdeogramV3,
|
||||
}
|
||||
class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV1": "Ideogram V1",
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
async def comfy_entrypoint() -> IdeogramExtension:
|
||||
return IdeogramExtension()
|
||||
|
||||
@@ -998,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
"OpenAIChatNode": "OpenAI Chat",
|
||||
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
||||
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
||||
"OpenAIChatNode": "OpenAI ChatGPT",
|
||||
"OpenAIInputFiles": "OpenAI ChatGPT Input Files",
|
||||
"OpenAIChatConfig": "OpenAI ChatGPT Advanced Options",
|
||||
}
|
||||
|
||||
+357
-387
@@ -12,6 +12,7 @@ User Guides:
|
||||
"""
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
@@ -46,9 +47,9 @@ from comfy_api_nodes.apinode_utils import (
|
||||
validate_string,
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@@ -85,20 +86,11 @@ class RunwayGen3aAspectRatio(str, Enum):
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
# TODO: replace with updated image validation utils (upstream)
|
||||
def validate_input_image(image: torch.Tensor) -> bool:
|
||||
"""
|
||||
Validate the input image is within the size limits for the Runway API.
|
||||
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
|
||||
"""
|
||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
@@ -134,458 +126,438 @@ def extract_progress_from_task_status(
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
class RunwayVideoGenNode(ComfyNodeABC):
|
||||
"""Runway Video Node Base."""
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/video/Runway"
|
||||
API_NODE = True
|
||||
|
||||
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
async def generate_video(
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no video data found in response."
|
||||
)
|
||||
return True
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen3a",
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[model.value for model in Duration],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
options=[model.value for model in RunwayGen3aAspectRatio],
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
final_response = await self.get_response(task_id, auth_kwargs, node_id)
|
||||
self.validate_response(final_response)
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
return comfy_io.NodeOutput(
|
||||
await generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen3a Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
|
||||
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen4",
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
comfy_io.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[model.value for model in Duration],
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
options=[model.value for model in RunwayGen4TurboAspectRatio],
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Upload image
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
return comfy_io.NodeOutput(
|
||||
await generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen4_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen4 Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
|
||||
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayFirstLastFrameNode",
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
comfy_io.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
comfy_io.Image.Input(
|
||||
"end_frame",
|
||||
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[model.value for model in Duration],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
enum_type=RunwayGen4TurboAspectRatio,
|
||||
options=[model.value for model in RunwayGen3aAspectRatio],
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen4_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
"""Runway First-Last Frame Node."""
|
||||
|
||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"end_frame": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
|
||||
},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
validate_input_image(end_frame)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Upload images
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
return comfy_io.NodeOutput(
|
||||
await generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageNode(ComfyNodeABC):
|
||||
"""Runway Text to Image Node."""
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Runway"
|
||||
API_NODE = True
|
||||
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
|
||||
class RunwayTextToImageNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayTextToImageNode",
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayTextToImageRequest,
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
enum_type=RunwayTextToImageAspectRatioEnum,
|
||||
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"reference_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Optional reference image to guide the generation"},
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
|
||||
def validate_response(self, response: TaskStatusResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no image data found in response."
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
node_id=node_id,
|
||||
comfy_io.Image.Input(
|
||||
"reference_image",
|
||||
tooltip="Optional reference image to guide the generation",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# Validate inputs
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_input_image(reference_image)
|
||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload reference image to comfy api.")
|
||||
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
# Create request
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
@@ -593,7 +565,6 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
# Execute initial request
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
@@ -602,34 +573,33 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
response_model=RunwayTextToImageResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
# Poll for completion
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
final_response = await get_response(
|
||||
initial_response.id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
self.validate_response(final_response)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
# Download and return image
|
||||
image_url = get_image_url_from_task_status(final_response)
|
||||
return (await download_url_to_image_tensor(image_url),)
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
|
||||
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
|
||||
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
|
||||
"RunwayTextToImageNode": RunwayTextToImageNode,
|
||||
}
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
RunwayFirstLastFrameNode,
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
|
||||
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
|
||||
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
|
||||
"RunwayTextToImageNode": "Runway Text to Image",
|
||||
}
|
||||
async def comfy_entrypoint() -> RunwayExtension:
|
||||
return RunwayExtension()
|
||||
|
||||
+675
-313
File diff suppressed because it is too large
Load Diff
+190
-142
@@ -1,17 +1,18 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse
|
||||
VeoGenVidPollResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import (
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
tensor_to_base64_string
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
@@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
|
||||
@@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text description of the video",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["16:9", "9:16"],
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of the output video",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
"duration_seconds": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 5,
|
||||
"min": 5,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"duration_seconds",
|
||||
default=5,
|
||||
min=5,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
"enhance_prompt": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Whether to enhance the prompt with AI assistance",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
"person_generation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["ALLOW", "BLOCK"],
|
||||
"default": "ALLOW",
|
||||
"tooltip": "Whether to allow generating people in the video",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFF,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation (0 for random)",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-2.0-generate-001"],
|
||||
"default": "veo-2.0-generate-001",
|
||||
"tooltip": "Veo 2 model to use for video generation",
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["veo-2.0-generate-001"],
|
||||
default="veo-2.0-generate-001",
|
||||
tooltip="Veo 2 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
||||
API_NODE = True
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
negative_prompt="",
|
||||
@@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
image=None,
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
@@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -214,7 +209,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
@@ -248,10 +243,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=unique_id,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
@@ -304,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = io.BytesIO(video_data)
|
||||
video_io = BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return (VideoFromFile(video_io),)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
@@ -323,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
parent_input = super().INPUT_TYPES()
|
||||
|
||||
# Update model options for Veo 3
|
||||
parent_input["optional"]["model"] = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
"default": "veo-3.0-generate-001",
|
||||
"tooltip": "Veo 3 model to use for video generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration_seconds",
|
||||
default=8,
|
||||
min=8,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
# Add generateAudio parameter
|
||||
parent_input["optional"]["generate_audio"] = (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
||||
}
|
||||
)
|
||||
|
||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
||||
parent_input["optional"]["duration_seconds"] = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 8,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
},
|
||||
)
|
||||
class VeoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
VeoVideoGenerationNode,
|
||||
Veo3VideoGenerationNode,
|
||||
]
|
||||
|
||||
return parent_input
|
||||
|
||||
|
||||
# Register the nodes
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
||||
}
|
||||
async def comfy_entrypoint() -> VeoExtension:
|
||||
return VeoExtension()
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
@@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness(
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
@@ -126,7 +126,7 @@ def validate_video_dimensions(
|
||||
|
||||
|
||||
def validate_video_duration(
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
@@ -151,3 +151,17 @@ def get_number_of_images(images):
|
||||
if isinstance(images, torch.Tensor):
|
||||
return images.shape[0] if images.ndim >= 4 else 1
|
||||
return len(images)
|
||||
|
||||
|
||||
def validate_audio_duration(
|
||||
audio: Input.Audio,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> None:
|
||||
sr = int(audio["sample_rate"])
|
||||
dur = int(audio["waveform"].shape[-1]) / sr
|
||||
eps = 1.0 / sr
|
||||
if min_duration is not None and dur + eps < min_duration:
|
||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||
if max_duration is not None and dur - eps > max_duration:
|
||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||
|
||||
@@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
# Include client_id to ensure message is only sent to the initiating client
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
+47
-33
@@ -1,49 +1,63 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class TextEncodeAceStepAudio:
|
||||
|
||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return (conditioning, )
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||
return ({"samples": latent, "type": "audio"}, )
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||
}
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
return AceExtension()
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import trange
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm.auto import trange
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
|
||||
return x
|
||||
|
||||
|
||||
class SamplerLCMUpscale:
|
||||
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
class SamplerLCMUpscale(io.ComfyNode):
|
||||
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
|
||||
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCMUpscale",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
|
||||
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
|
||||
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
|
||||
@classmethod
|
||||
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
||||
if scale_steps < 0:
|
||||
scale_steps = None
|
||||
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||
return (sampler, )
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
import comfy.model_patcher
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
@@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
return x
|
||||
|
||||
|
||||
class SamplerEulerCFGpp:
|
||||
class SamplerEulerCFGpp(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"version": (["regular", "alternative"],),}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
# CATEGORY = "sampling/custom_sampling/samplers"
|
||||
CATEGORY = "_for_testing"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerCFGpp",
|
||||
display_name="SamplerEulerCFG++",
|
||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||
inputs=[
|
||||
io.Combo.Input("version", options=["regular", "alternative"]),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, version):
|
||||
@classmethod
|
||||
def execute(cls, version) -> io.NodeOutput:
|
||||
if version == "alternative":
|
||||
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
||||
else:
|
||||
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
||||
return (sampler, )
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SamplerLCMUpscale": SamplerLCMUpscale,
|
||||
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SamplerEulerCFGpp": "SamplerEulerCFG++",
|
||||
}
|
||||
class AdvancedSamplersExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SamplerLCMUpscale,
|
||||
SamplerEulerCFGpp,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AdvancedSamplersExtension:
|
||||
return AdvancedSamplersExtension()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""
|
||||
@@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152
|
||||
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
|
||||
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
|
||||
|
||||
class AlignYourStepsScheduler:
|
||||
class AlignYourStepsScheduler(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model_type": (["SD1", "SDXL", "SVD"], ),
|
||||
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
io.Int.Input("steps", default=10, min=1, max=10000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Sigmas.Output()],
|
||||
)
|
||||
|
||||
def get_sigmas(self, model_type, steps, denoise):
|
||||
# Deprecated: use the V3 schema's `execute` method instead of this.
|
||||
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return (torch.FloatTensor([]),)
|
||||
return io.NodeOutput(torch.FloatTensor([]))
|
||||
total_steps = round(steps * denoise)
|
||||
|
||||
sigmas = NOISE_LEVELS[model_type][:]
|
||||
@@ -46,8 +55,15 @@ class AlignYourStepsScheduler:
|
||||
|
||||
sigmas = sigmas[-(total_steps + 1):]
|
||||
sigmas[-1] = 0
|
||||
return (torch.FloatTensor(sigmas), )
|
||||
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AlignYourStepsScheduler": AlignYourStepsScheduler,
|
||||
}
|
||||
|
||||
class AlignYourStepsExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
AlignYourStepsScheduler,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AlignYourStepsExtension:
|
||||
return AlignYourStepsExtension()
|
||||
|
||||
+51
-21
@@ -1,4 +1,8 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
@@ -6,22 +10,45 @@ def project(v0, v1):
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
class APG:
|
||||
class APG(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="APG",
|
||||
display_name="Adaptive Projected Guidance",
|
||||
category="sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
"eta",
|
||||
default=1.0,
|
||||
min=-10.0,
|
||||
max=10.0,
|
||||
step=0.01,
|
||||
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"norm_threshold",
|
||||
default=5.0,
|
||||
min=0.0,
|
||||
max=50.0,
|
||||
step=0.1,
|
||||
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"momentum",
|
||||
default=0.0,
|
||||
min=-5.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
def patch(self, model, eta, norm_threshold, momentum):
|
||||
@classmethod
|
||||
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
@@ -65,12 +92,15 @@ class APG:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"APG": APG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"APG": "Adaptive Projected Guidance",
|
||||
}
|
||||
class ApgExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
APG,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> ApgExtension:
|
||||
return ApgExtension()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def attention_multiply(attn, model, q, k, v, out):
|
||||
m = model.clone()
|
||||
@@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out):
|
||||
return m
|
||||
|
||||
|
||||
class UNetSelfAttentionMultiply:
|
||||
class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetSelfAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
m = attention_multiply("attn1", model, q, k, v, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class UNetCrossAttentionMultiply:
|
||||
|
||||
class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetCrossAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
m = attention_multiply("attn2", model, q, k, v, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class CLIPAttentionMultiply:
|
||||
|
||||
class CLIPAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip": ("CLIP",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, clip, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
||||
m = clip.clone()
|
||||
sd = m.patcher.model_state_dict()
|
||||
|
||||
@@ -79,23 +97,28 @@ class CLIPAttentionMultiply:
|
||||
m.add_patches({key: (None,)}, 0.0, v)
|
||||
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class UNetTemporalAttentionMultiply:
|
||||
|
||||
class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetTemporalAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
|
||||
@classmethod
|
||||
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
sd = model.model_state_dict()
|
||||
|
||||
@@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||
else:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
|
||||
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
|
||||
"CLIPAttentionMultiply": CLIPAttentionMultiply,
|
||||
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
|
||||
}
|
||||
|
||||
class AttentionMultiplyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
UNetSelfAttentionMultiply,
|
||||
UNetCrossAttentionMultiply,
|
||||
CLIPAttentionMultiply,
|
||||
UNetTemporalAttentionMultiply,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AttentionMultiplyExtension:
|
||||
return AttentionMultiplyExtension()
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import folder_paths
|
||||
import comfy.audio_encoders.audio_encoders
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class AudioEncoderLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("AUDIO_ENCODER",)
|
||||
FUNCTION = "load_model"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_model(self, audio_encoder_name):
|
||||
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
|
||||
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
|
||||
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
|
||||
if audio_encoder is None:
|
||||
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
|
||||
return (audio_encoder,)
|
||||
|
||||
|
||||
class AudioEncoderEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
|
||||
"audio": ("AUDIO",),
|
||||
}}
|
||||
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, audio_encoder, audio):
|
||||
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
|
||||
return (output,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AudioEncoderLoader": AudioEncoderLoader,
|
||||
"AudioEncoderEncode": AudioEncoderEncode,
|
||||
}
|
||||
@@ -0,0 +1,498 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
import comfy.patcher_extension
|
||||
import logging
|
||||
import torch
|
||||
import comfy.model_patcher
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
transformer_options: dict[str] = args[-1]
|
||||
if not isinstance(transformer_options, dict):
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||
sigmas = transformer_options["sigmas"]
|
||||
uuids = transformer_options["uuids"]
|
||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(sigmas)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||
if easycache.skip_current_step:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
if easycache.initial_step:
|
||||
easycache.first_cond_uuid = uuids[0]
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
easycache.initial_step = False
|
||||
if has_first_cond_uuid:
|
||||
if easycache.has_x_prev_subsampled():
|
||||
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
output_change_rate = output_change / easycache.output_prev_norm
|
||||
easycache.output_change_rates.append(output_change_rate.item())
|
||||
if easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||
if input_change is not None:
|
||||
easycache.relative_transformation_rate = output_change / input_change
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||
if has_first_cond_uuid:
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
|
||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
timestep: float = args[1]
|
||||
model_options: dict[str] = args[2]
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(timestep)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
if easycache.has_x_prev_subsampled():
|
||||
if easycache.has_x_prev_subsampled():
|
||||
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x)
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
if easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
output_change_rate = output_change / easycache.output_prev_norm
|
||||
easycache.output_change_rates.append(output_change_rate.item())
|
||||
if easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||
if input_change is not None:
|
||||
easycache.relative_transformation_rate = output_change / input_change
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev)
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
|
||||
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
||||
model_options = args[-1]
|
||||
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
easycache.skip_current_step = False
|
||||
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def easycache_sample_wrapper(executor, *args, **kwargs):
|
||||
"""
|
||||
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
|
||||
"""
|
||||
try:
|
||||
guider = executor.class_obj
|
||||
orig_model_options = guider.model_options
|
||||
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
||||
# clone and prepare timesteps
|
||||
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
||||
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
easycache = guider.model_options['transformer_options']['easycache']
|
||||
output_change_rates = easycache.output_change_rates
|
||||
approx_output_change_rates = easycache.approx_output_change_rates
|
||||
if easycache.verbose:
|
||||
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
||||
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
||||
total_steps = len(args[3])-1
|
||||
# catch division by zero for log statement; sucks to crash after all sampling is done
|
||||
try:
|
||||
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
|
||||
except ZeroDivisionError:
|
||||
speedup = 1.0
|
||||
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
|
||||
easycache.reset()
|
||||
guider.model_options = orig_model_options
|
||||
|
||||
|
||||
class EasyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
self.name = "EasyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
self.offload_cache_diff = offload_cache_diff
|
||||
self.verbose = verbose
|
||||
# timestep values
|
||||
self.start_t = 0.0
|
||||
self.end_t = 0.0
|
||||
# control values
|
||||
self.relative_transformation_rate: float = None
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.skip_current_step = False
|
||||
# cache values
|
||||
self.first_cond_uuid = None
|
||||
self.x_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
# how to deal with mismatched dims
|
||||
self.allow_mismatch = True
|
||||
self.cut_from_start = True
|
||||
self.state_metadata = None
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
|
||||
def should_do_easycache(self, timestep: float) -> bool:
|
||||
return (timestep[0] <= self.start_t).item()
|
||||
|
||||
def has_x_prev_subsampled(self) -> bool:
|
||||
return self.x_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_subsampled(self) -> bool:
|
||||
return self.output_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_norm(self) -> bool:
|
||||
return self.output_prev_norm is not None
|
||||
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def prepare_timesteps(self, model_sampling):
|
||||
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
uuid_idx = uuids.index(self.first_cond_uuid)
|
||||
if self.subsample_factor > 1:
|
||||
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
self.total_steps_skipped += 1
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_this_dim = True
|
||||
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
||||
if skip_this_dim:
|
||||
skip_this_dim = False
|
||||
continue
|
||||
if dim_u != dim_x:
|
||||
if self.cut_from_start:
|
||||
slicing.append(slice(dim_x-dim_u, None))
|
||||
else:
|
||||
slicing.append(slice(None, dim_u))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
||||
x = x[slicing]
|
||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if output.shape[1:] != x.shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_dim = True
|
||||
for dim_o, dim_x in zip(output.shape, x.shape):
|
||||
if not skip_dim and dim_o != dim_x:
|
||||
if self.cut_from_start:
|
||||
slicing.append(slice(dim_x-dim_o, None))
|
||||
else:
|
||||
slicing.append(slice(None, dim_o))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
skip_dim = False
|
||||
x = x[slicing]
|
||||
diff = output - x
|
||||
batch_offset = diff.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
|
||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||
return self.first_cond_uuid in uuids
|
||||
|
||||
def check_metadata(self, x: torch.Tensor) -> bool:
|
||||
metadata = (x.device, x.dtype, x.shape[1:])
|
||||
if self.state_metadata is None:
|
||||
self.state_metadata = metadata
|
||||
return True
|
||||
if metadata == self.state_metadata:
|
||||
return True
|
||||
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
self.relative_transformation_rate = 0.0
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.skip_current_step = False
|
||||
self.output_change_rates = []
|
||||
self.first_cond_uuid = None
|
||||
del self.x_prev_subsampled
|
||||
self.x_prev_subsampled = None
|
||||
del self.output_prev_subsampled
|
||||
self.output_prev_subsampled = None
|
||||
del self.output_prev_norm
|
||||
self.output_prev_norm = None
|
||||
del self.uuid_cache_diffs
|
||||
self.uuid_cache_diffs = {}
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
|
||||
|
||||
class EasyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="EasyCache",
|
||||
display_name="EasyCache",
|
||||
description="Native EasyCache implementation.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add EasyCache to."),
|
||||
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with EasyCache."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class LazyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
self.name = "LazyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
self.offload_cache_diff = offload_cache_diff
|
||||
self.verbose = verbose
|
||||
# timestep values
|
||||
self.start_t = 0.0
|
||||
self.end_t = 0.0
|
||||
# control values
|
||||
self.relative_transformation_rate: float = None
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
# cache values
|
||||
self.x_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.cache_diff: torch.Tensor = None
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
|
||||
def has_cache_diff(self) -> bool:
|
||||
return self.cache_diff is not None
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
|
||||
def should_do_easycache(self, timestep: float) -> bool:
|
||||
return (timestep[0] <= self.start_t).item()
|
||||
|
||||
def has_x_prev_subsampled(self) -> bool:
|
||||
return self.x_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_subsampled(self) -> bool:
|
||||
return self.output_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_norm(self) -> bool:
|
||||
return self.output_prev_norm is not None
|
||||
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def prepare_timesteps(self, model_sampling):
|
||||
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
|
||||
if self.subsample_factor > 1:
|
||||
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
if clone:
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor):
|
||||
self.total_steps_skipped += 1
|
||||
return x + self.cache_diff.to(x.device)
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
||||
self.cache_diff = output - x
|
||||
|
||||
def check_metadata(self, x: torch.Tensor) -> bool:
|
||||
metadata = (x.device, x.dtype, x.shape)
|
||||
if self.state_metadata is None:
|
||||
self.state_metadata = metadata
|
||||
return True
|
||||
if metadata == self.state_metadata:
|
||||
return True
|
||||
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
self.relative_transformation_rate = 0.0
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
del self.cache_diff
|
||||
self.cache_diff = None
|
||||
del self.x_prev_subsampled
|
||||
self.x_prev_subsampled = None
|
||||
del self.output_prev_subsampled
|
||||
self.output_prev_subsampled = None
|
||||
del self.output_prev_norm
|
||||
self.output_prev_norm = None
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
|
||||
class LazyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LazyCache",
|
||||
display_name="LazyCache",
|
||||
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add LazyCache to."),
|
||||
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with LazyCache."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class EasyCacheExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EasyCacheNode,
|
||||
LazyCacheNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return EasyCacheExtension()
|
||||
@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
"reference_latents_method": (("offset", "index"), ),
|
||||
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
@@ -115,6 +115,8 @@ class FluxKontextMultiReferenceLatentMethod:
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
|
||||
def append(self, conditioning, reference_latents_method):
|
||||
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
||||
reference_latents_method = "uxo"
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||
return (c, )
|
||||
|
||||
|
||||
@@ -113,6 +113,20 @@ class HunyuanImageToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, out_latent)
|
||||
|
||||
class EmptyHunyuanImageLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
||||
}
|
||||
|
||||
@@ -8,13 +8,16 @@ import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class EmptyLatentHunyuan3Dv2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
return {
|
||||
"required": {
|
||||
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
@@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2:
|
||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||
|
||||
|
||||
class Hunyuan3Dv2Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -81,7 +83,6 @@ class VOXEL:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
||||
class VAEDecodeHunyuan3D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D:
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
|
||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
@@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
||||
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
||||
], device=device)
|
||||
|
||||
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
|
||||
for c, (dz, dy, dx) in enumerate(corner_offsets):
|
||||
corner_values[:, c] = padded[
|
||||
cell_positions[:, 0] + dz,
|
||||
cell_positions[:, 1] + dy,
|
||||
cell_positions[:, 2] + dx
|
||||
]
|
||||
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
|
||||
z_idx, y_idx, x_idx = pos.unbind(-1)
|
||||
corner_values = padded[z_idx, y_idx, x_idx]
|
||||
|
||||
corner_signs = corner_values > threshold
|
||||
has_inside = torch.any(corner_signs, dim=1)
|
||||
|
||||
@@ -625,6 +625,37 @@ class ImageFlip:
|
||||
|
||||
return (image,)
|
||||
|
||||
class ImageScaleToMaxDimension:
|
||||
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"image": ("IMAGE",),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
|
||||
def upscale(self, image, upscale_method, largest_size):
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
|
||||
if height > width:
|
||||
width = round((width / height) * largest_size)
|
||||
height = largest_size
|
||||
elif width > height:
|
||||
height = round((height / width) * largest_size)
|
||||
width = largest_size
|
||||
else:
|
||||
height = largest_size
|
||||
width = largest_size
|
||||
|
||||
samples = image.movedim(-1, 1)
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return (s,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
@@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"GetImageSize": GetImageSize,
|
||||
"ImageRotate": ImageRotate,
|
||||
"ImageFlip": ImageFlip,
|
||||
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import comfy.utils
|
||||
import comfy_extras.nodes_post_processing
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
@@ -105,6 +106,73 @@ class LatentInterpolate:
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
|
||||
class LatentConcat:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2, dim):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
|
||||
|
||||
if "-" in dim:
|
||||
c = (s2, s1)
|
||||
else:
|
||||
c = (s1, s2)
|
||||
|
||||
if "x" in dim:
|
||||
dim = -1
|
||||
elif "y" in dim:
|
||||
dim = -2
|
||||
elif "t" in dim:
|
||||
dim = -3
|
||||
|
||||
samples_out["samples"] = torch.cat(c, dim=dim)
|
||||
return (samples_out,)
|
||||
|
||||
class LatentCut:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT",),
|
||||
"dim": (["x", "y", "t"], ),
|
||||
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples, dim, index, amount):
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
|
||||
if "x" in dim:
|
||||
dim = s1.ndim - 1
|
||||
elif "y" in dim:
|
||||
dim = s1.ndim - 2
|
||||
elif "t" in dim:
|
||||
dim = s1.ndim - 3
|
||||
|
||||
if index >= 0:
|
||||
index = min(index, s1.shape[dim] - 1)
|
||||
amount = min(s1.shape[dim] - index, amount)
|
||||
else:
|
||||
index = max(index, -s1.shape[dim])
|
||||
amount = min(-index, amount)
|
||||
|
||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -279,6 +347,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
"LatentConcat": LatentConcat,
|
||||
"LatentCut": LatentCut,
|
||||
"LatentBatch": LatentBatch,
|
||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||
"LatentApplyOperation": LatentApplyOperation,
|
||||
|
||||
@@ -166,7 +166,7 @@ class LTXVAddGuide:
|
||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
|
||||
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
# [linear, gelu, linear]
|
||||
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
|
||||
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
|
||||
self.input_proj = operations.Linear(dim, dim)
|
||||
self.act = torch.nn.GELU()
|
||||
self.output_proj = operations.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, y):
|
||||
x, y = self.x_rms(x), self.y_rms(y)
|
||||
x = self.input_proj(x + y)
|
||||
x = self.act(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
|
||||
self.controlnet_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def process_input_latent_image(self, latent_image):
|
||||
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
|
||||
patch_size = 2
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
return self.img_in(hidden_states)
|
||||
|
||||
def control_block(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
|
||||
|
||||
class SigLIPMultiFeatProjModel(torch.nn.Module):
|
||||
"""
|
||||
SigLIP Multi-Feature Projection Model for processing style features from different layers
|
||||
and projecting them into a unified hidden space.
|
||||
|
||||
Args:
|
||||
siglip_token_nums (int): Number of SigLIP tokens, default 257
|
||||
style_token_nums (int): Number of style tokens, default 256
|
||||
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
|
||||
hidden_size (int): Hidden layer size, default 3072
|
||||
context_layer_norm (bool): Whether to use context layer normalization, default False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
siglip_token_nums: int = 729,
|
||||
style_token_nums: int = 64,
|
||||
siglip_token_dims: int = 1152,
|
||||
hidden_size: int = 3072,
|
||||
context_layer_norm: bool = True,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# High-level feature processing (layer -2)
|
||||
self.high_embedding_linear = nn.Sequential(
|
||||
operations.Linear(siglip_token_nums, style_token_nums),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.high_layer_norm = (
|
||||
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||
)
|
||||
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||
|
||||
# Mid-level feature processing (layer -11)
|
||||
self.mid_embedding_linear = nn.Sequential(
|
||||
operations.Linear(siglip_token_nums, style_token_nums),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.mid_layer_norm = (
|
||||
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||
)
|
||||
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||
|
||||
# Low-level feature processing (layer -20)
|
||||
self.low_embedding_linear = nn.Sequential(
|
||||
operations.Linear(siglip_token_nums, style_token_nums),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.low_layer_norm = (
|
||||
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||
)
|
||||
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||
|
||||
def forward(self, siglip_outputs):
|
||||
"""
|
||||
Forward pass function
|
||||
|
||||
Args:
|
||||
siglip_outputs: Output from SigLIP model, containing hidden_states
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
|
||||
"""
|
||||
dtype = next(self.high_embedding_linear.parameters()).dtype
|
||||
|
||||
# Process high-level features (layer -2)
|
||||
high_embedding = self._process_layer_features(
|
||||
siglip_outputs[2],
|
||||
self.high_embedding_linear,
|
||||
self.high_layer_norm,
|
||||
self.high_projection,
|
||||
dtype
|
||||
)
|
||||
|
||||
# Process mid-level features (layer -11)
|
||||
mid_embedding = self._process_layer_features(
|
||||
siglip_outputs[1],
|
||||
self.mid_embedding_linear,
|
||||
self.mid_layer_norm,
|
||||
self.mid_projection,
|
||||
dtype
|
||||
)
|
||||
|
||||
# Process low-level features (layer -20)
|
||||
low_embedding = self._process_layer_features(
|
||||
siglip_outputs[0],
|
||||
self.low_embedding_linear,
|
||||
self.low_layer_norm,
|
||||
self.low_projection,
|
||||
dtype
|
||||
)
|
||||
|
||||
# Concatenate features from all layersmodel_patch
|
||||
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
|
||||
|
||||
def _process_layer_features(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_linear: nn.Module,
|
||||
layer_norm: nn.Module,
|
||||
projection: nn.Module,
|
||||
dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function to process features from a single layer
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [bs, seq_len, dim]
|
||||
embedding_linear: Embedding linear layer
|
||||
layer_norm: Layer normalization
|
||||
projection: Projection layer
|
||||
dtype: Target data type
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
|
||||
"""
|
||||
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
|
||||
embedding = embedding_linear(
|
||||
hidden_states.to(dtype).transpose(1, 2)
|
||||
).transpose(1, 2)
|
||||
|
||||
# Apply layer normalization
|
||||
embedding = layer_norm(embedding)
|
||||
|
||||
# Project to target hidden space
|
||||
embedding = projection(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
class ModelPatchLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL_PATCH",)
|
||||
FUNCTION = "load_model_patch"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_model_patch(self, name):
|
||||
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
|
||||
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
||||
dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
if 'controlnet_blocks.0.y_rms.weight' in sd:
|
||||
additional_in_dim = sd["img_in.weight"].shape[1] - 64
|
||||
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
return (model,)
|
||||
|
||||
|
||||
class DiffSynthCnetPatch:
|
||||
def __init__(self, model_patch, vae, image, strength, mask=None):
|
||||
self.model_patch = model_patch
|
||||
self.vae = vae
|
||||
self.image = image
|
||||
self.strength = strength
|
||||
self.mask = mask
|
||||
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||
|
||||
def encode_latent_cond(self, image):
|
||||
latent_image = self.vae.encode(image)
|
||||
if self.model_patch.model.additional_in_dim > 0:
|
||||
if self.mask is None:
|
||||
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
|
||||
else:
|
||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||
|
||||
return torch.cat([latent_image, mask_], dim=1)
|
||||
else:
|
||||
return latent_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs.get("x")
|
||||
img = kwargs.get("img")
|
||||
block_index = kwargs.get("block_index")
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
|
||||
img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength)
|
||||
kwargs['img'] = img
|
||||
return kwargs
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class QwenImageDiffsynthControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"model_patch": ("MODEL_PATCH",),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"mask": ("MASK",)}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "diffsynth_controlnet"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/loaders/qwen"
|
||||
|
||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
||||
model_patched = model.clone()
|
||||
image = image[:, :, :, :3]
|
||||
if mask is not None:
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
if mask.ndim == 4:
|
||||
mask = mask.unsqueeze(2)
|
||||
mask = 1.0 - mask
|
||||
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
class UsoStyleProjectorPatch:
|
||||
def __init__(self, model_patch, encoded_image):
|
||||
self.model_patch = model_patch
|
||||
self.encoded_image = encoded_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
txt_ids = kwargs.get("txt_ids")
|
||||
txt = kwargs.get("txt")
|
||||
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
|
||||
txt = torch.cat([siglip_embedding, txt], dim=1)
|
||||
kwargs['txt'] = txt
|
||||
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
|
||||
return kwargs
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
|
||||
class USOStyleReference:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
"model_patch": ("MODEL_PATCH",),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "apply_patch"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/model_patches/flux"
|
||||
|
||||
def apply_patch(self, model, model_patch, clip_vision_output):
|
||||
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
|
||||
model_patched = model.clone()
|
||||
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
"USOStyleReference": USOStyleReference,
|
||||
}
|
||||
@@ -1,98 +1,109 @@
|
||||
# Primitive nodes that are evaluated at backend.
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class String(ComfyNodeABC):
|
||||
class String(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.STRING, {})},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveString",
|
||||
display_name="String",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.String.Input("value"),
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: str) -> tuple[str]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class StringMultiline(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.STRING, {"multiline": True,},)},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: str) -> tuple[str]:
|
||||
return (value,)
|
||||
def execute(cls, value: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Int(ComfyNodeABC):
|
||||
class StringMultiline(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveStringMultiline",
|
||||
display_name="String (Multiline)",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.String.Input("value", multiline=True),
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: int) -> tuple[int]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class Float(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.FLOAT,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: float) -> tuple[float]:
|
||||
return (value,)
|
||||
def execute(cls, value: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Boolean(ComfyNodeABC):
|
||||
class Int(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.BOOLEAN, {})},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveInt",
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: bool) -> tuple[bool]:
|
||||
return (value,)
|
||||
@classmethod
|
||||
def execute(cls, value: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PrimitiveString": String,
|
||||
"PrimitiveStringMultiline": StringMultiline,
|
||||
"PrimitiveInt": Int,
|
||||
"PrimitiveFloat": Float,
|
||||
"PrimitiveBoolean": Boolean,
|
||||
}
|
||||
class Float(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveFloat",
|
||||
display_name="Float",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
||||
],
|
||||
outputs=[io.Float.Output()],
|
||||
)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PrimitiveString": "String",
|
||||
"PrimitiveStringMultiline": "String (Multiline)",
|
||||
"PrimitiveInt": "Int",
|
||||
"PrimitiveFloat": "Float",
|
||||
"PrimitiveBoolean": "Boolean",
|
||||
}
|
||||
@classmethod
|
||||
def execute(cls, value: float) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Boolean(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveBoolean",
|
||||
display_name="Boolean",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Boolean.Input("value"),
|
||||
],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value: bool) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class PrimitivesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
String,
|
||||
StringMultiline,
|
||||
Int,
|
||||
Float,
|
||||
Boolean,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PrimitivesExtension:
|
||||
return PrimitivesExtension()
|
||||
|
||||
@@ -17,55 +17,61 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.utils
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class StableCascade_EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
class StableCascade_EmptyLatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_EmptyLatentImage",
|
||||
category="latent/stable_cascade",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/stable_cascade"
|
||||
|
||||
def generate(self, width, height, compression, batch_size=1):
|
||||
def execute(cls, width, height, compression, batch_size=1):
|
||||
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
||||
return ({
|
||||
return io.NodeOutput({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageC_VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class StableCascade_StageC_VAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_StageC_VAEEncode",
|
||||
category="latent/stable_cascade",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/stable_cascade"
|
||||
|
||||
def generate(self, image, vae, compression):
|
||||
def execute(cls, image, vae, compression):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
out_width = (width // compression) * vae.downscale_ratio
|
||||
@@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
|
||||
|
||||
c_latent = vae.encode(s[:,:,:,:3])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
|
||||
return ({
|
||||
return io.NodeOutput({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageB_Conditioning:
|
||||
|
||||
class StableCascade_StageB_Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "conditioning": ("CONDITIONING",),
|
||||
"stage_c": ("LATENT",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_StageB_Conditioning",
|
||||
category="conditioning/stable_cascade",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("stage_c"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "set_prior"
|
||||
|
||||
CATEGORY = "conditioning/stable_cascade"
|
||||
|
||||
def set_prior(self, conditioning, stage_c):
|
||||
@classmethod
|
||||
def execute(cls, conditioning, stage_c):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['stable_cascade_prior'] = stage_c['samples']
|
||||
d["stable_cascade_prior"] = stage_c["samples"]
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
return io.NodeOutput(c)
|
||||
|
||||
class StableCascade_SuperResolutionControlnet:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="_for_testing/stable_cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="controlnet_input"),
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
|
||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, image, vae):
|
||||
def execute(cls, image, vae):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
batch_size = image.shape[0]
|
||||
@@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
|
||||
|
||||
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
|
||||
return (controlnet_input, {
|
||||
return io.NodeOutput(controlnet_input, {
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
|
||||
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
|
||||
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
|
||||
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
|
||||
}
|
||||
|
||||
class StableCascadeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
StableCascade_EmptyLatentImage,
|
||||
StableCascade_StageB_Conditioning,
|
||||
StableCascade_StageC_VAEEncode,
|
||||
StableCascade_SuperResolutionControlnet,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StableCascadeExtension:
|
||||
return StableCascadeExtension()
|
||||
|
||||
+237
-212
@@ -1,77 +1,91 @@
|
||||
import re
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class StringConcatenate():
|
||||
|
||||
class StringConcatenate(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringConcatenate",
|
||||
display_name="Concatenate",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
io.String.Input("delimiter", multiline=False, default=""),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, delimiter, **kwargs):
|
||||
return delimiter.join((string_a, string_b)),
|
||||
|
||||
class StringSubstring():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"start": (IO.INT, {}),
|
||||
"end": (IO.INT, {}),
|
||||
}
|
||||
}
|
||||
def execute(cls, string_a, string_b, delimiter):
|
||||
return io.NodeOutput(delimiter.join((string_a, string_b)))
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, start, end, **kwargs):
|
||||
return string[start:end],
|
||||
|
||||
class StringLength():
|
||||
class StringSubstring(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
display_name="Substring",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Int.Input("start"),
|
||||
io.Int.Input("end"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("length",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, **kwargs):
|
||||
length = len(string)
|
||||
|
||||
return length,
|
||||
|
||||
class CaseConverter():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
|
||||
}
|
||||
}
|
||||
def execute(cls, string, start, end):
|
||||
return io.NodeOutput(string[start:end])
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
class StringLength(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringLength",
|
||||
display_name="Length",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="length"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string):
|
||||
return io.NodeOutput(len(string))
|
||||
|
||||
|
||||
class CaseConverter(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
display_name="Case Converter",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string, mode):
|
||||
if mode == "UPPERCASE":
|
||||
result = string.upper()
|
||||
elif mode == "lowercase":
|
||||
@@ -83,24 +97,27 @@ class CaseConverter():
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class StringTrim():
|
||||
class StringTrim(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
display_name="Trim",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["Both", "Left", "Right"]),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, mode):
|
||||
if mode == "Both":
|
||||
result = string.strip()
|
||||
elif mode == "Left":
|
||||
@@ -110,70 +127,78 @@ class StringTrim():
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
class StringReplace():
|
||||
|
||||
class StringReplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"find": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
display_name="Replace",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("find", multiline=True),
|
||||
io.String.Input("replace", multiline=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, find, replace, **kwargs):
|
||||
result = string.replace(find, replace)
|
||||
return result,
|
||||
|
||||
|
||||
class StringContains():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"substring": (IO.STRING, {"multiline": True}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
def execute(cls, string, find, replace):
|
||||
return io.NodeOutput(string.replace(find, replace))
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("contains",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, substring, case_sensitive, **kwargs):
|
||||
class StringContains(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
display_name="Contains",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("substring", multiline=True),
|
||||
io.Boolean.Input("case_sensitive", default=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(display_name="contains"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string, substring, case_sensitive):
|
||||
if case_sensitive:
|
||||
contains = substring in string
|
||||
else:
|
||||
contains = substring.lower() in string.lower()
|
||||
|
||||
return contains,
|
||||
return io.NodeOutput(contains)
|
||||
|
||||
|
||||
class StringCompare():
|
||||
class StringCompare(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
display_name="Compare",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]),
|
||||
io.Boolean.Input("case_sensitive", default=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string_a, string_b, mode, case_sensitive):
|
||||
if case_sensitive:
|
||||
a = string_a
|
||||
b = string_b
|
||||
@@ -182,31 +207,34 @@ class StringCompare():
|
||||
b = string_b.lower()
|
||||
|
||||
if mode == "Equal":
|
||||
return a == b,
|
||||
return io.NodeOutput(a == b)
|
||||
elif mode == "Starts With":
|
||||
return a.startswith(b),
|
||||
return io.NodeOutput(a.startswith(b))
|
||||
elif mode == "Ends With":
|
||||
return a.endswith(b),
|
||||
return io.NodeOutput(a.endswith(b))
|
||||
|
||||
class RegexMatch():
|
||||
|
||||
class RegexMatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
display_name="Regex Match",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
io.Boolean.Input("case_insensitive", default=True),
|
||||
io.Boolean.Input("multiline", default=False),
|
||||
io.Boolean.Input("dotall", default=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(display_name="matches"),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("matches",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
@@ -223,29 +251,32 @@ class RegexMatch():
|
||||
except re.error:
|
||||
result = False
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class RegexExtract():
|
||||
class RegexExtract(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False}),
|
||||
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
display_name="Regex Extract",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]),
|
||||
io.Boolean.Input("case_insensitive", default=True),
|
||||
io.Boolean.Input("multiline", default=False),
|
||||
io.Boolean.Input("dotall", default=False),
|
||||
io.Int.Input("group_index", default=1, min=0, max=100),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index):
|
||||
join_delimiter = "\n"
|
||||
|
||||
flags = 0
|
||||
@@ -294,32 +325,33 @@ class RegexExtract():
|
||||
except re.error:
|
||||
result = ""
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class RegexReplace():
|
||||
DESCRIPTION = "Find and replace text using regex patterns."
|
||||
class RegexReplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
display_name="Regex Replace",
|
||||
category="utils/string",
|
||||
description="Find and replace text using regex patterns.",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
io.String.Input("replace", multiline=True),
|
||||
io.Boolean.Input("case_insensitive", default=True, optional=True),
|
||||
io.Boolean.Input("multiline", default=False, optional=True),
|
||||
io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
|
||||
io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
@@ -329,32 +361,25 @@ class RegexReplace():
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
"StringLength": StringLength,
|
||||
"CaseConverter": CaseConverter,
|
||||
"StringTrim": StringTrim,
|
||||
"StringReplace": StringReplace,
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract,
|
||||
"RegexReplace": RegexReplace,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringConcatenate": "Concatenate",
|
||||
"StringSubstring": "Substring",
|
||||
"StringLength": "Length",
|
||||
"CaseConverter": "Case Converter",
|
||||
"StringTrim": "Trim",
|
||||
"StringReplace": "Replace",
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract",
|
||||
"RegexReplace": "Regex Replace",
|
||||
}
|
||||
class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
StringConcatenate,
|
||||
StringSubstring,
|
||||
StringLength,
|
||||
CaseConverter,
|
||||
StringTrim,
|
||||
StringReplace,
|
||||
StringContains,
|
||||
StringCompare,
|
||||
RegexMatch,
|
||||
RegexExtract,
|
||||
RegexReplace,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StringExtension:
|
||||
return StringExtension()
|
||||
|
||||
+132
-154
@@ -5,52 +5,49 @@ import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
from typing import Optional, Literal
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveWEBM",
|
||||
category="image/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.String.Input("filename_prefix", default="ComfyUI"),
|
||||
io.Combo.Input("codec", options=["vp9", "av1"]),
|
||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"codec": (["vp9", "av1"],),
|
||||
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/video"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
|
||||
file = f"{filename}_{counter:05}_.webm"
|
||||
container = av.open(os.path.join(full_output_folder, file), mode="w")
|
||||
|
||||
if prompt is not None:
|
||||
container.metadata["prompt"] = json.dumps(prompt)
|
||||
if cls.hidden.prompt is not None:
|
||||
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||
@@ -69,63 +66,46 @@ class SaveWEBM:
|
||||
container.mux(stream.encode())
|
||||
container.close()
|
||||
|
||||
results: list[FileLocator] = [{
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
}]
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
|
||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
||||
|
||||
class SaveVideo(ComfyNodeABC):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type: Literal["output"] = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideo",
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_video"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
self.output_dir,
|
||||
folder_paths.get_output_directory(),
|
||||
width,
|
||||
height
|
||||
)
|
||||
results: list[FileLocator] = list()
|
||||
saved_metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = {}
|
||||
if extra_pnginfo is not None:
|
||||
metadata.update(extra_pnginfo)
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = prompt
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
metadata.update(cls.hidden.extra_pnginfo)
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = cls.hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
@@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC):
|
||||
metadata=saved_metadata
|
||||
)
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,) } }
|
||||
|
||||
class CreateVideo(ComfyNodeABC):
|
||||
class CreateVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
|
||||
},
|
||||
"optional": {
|
||||
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CreateVideo",
|
||||
display_name="Create Video",
|
||||
category="image/video",
|
||||
description="Create a video from images.",
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "create_video"
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Create a video from images."
|
||||
|
||||
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
|
||||
return (InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
images=images,
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
)
|
||||
),)
|
||||
|
||||
class GetVideoComponents(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
|
||||
RETURN_NAMES = ("images", "audio", "fps")
|
||||
FUNCTION = "get_components"
|
||||
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetVideoComponents",
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="images"),
|
||||
io.Audio.Output(display_name="audio"),
|
||||
io.Float.Output(display_name="fps"),
|
||||
],
|
||||
)
|
||||
|
||||
def get_components(self, video: Input.Video):
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
|
||||
return (components.images, components.audio, float(components.frame_rate))
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
class LoadVideo(ComfyNodeABC):
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
def define_schema(cls):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||
return {"required":
|
||||
{"file": (sorted(files), {"video_upload": True})},
|
||||
}
|
||||
|
||||
CATEGORY = "image/video"
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "load_video"
|
||||
def load_video(self, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return (InputImpl.VideoFromFile(video_path),)
|
||||
return io.Schema(
|
||||
node_id="LoadVideo",
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, file):
|
||||
def execute(cls, file) -> io.NodeOutput:
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return io.NodeOutput(VideoFromFile(video_path))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(s, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
mod_time = os.path.getmtime(video_path)
|
||||
# Instead of hashing the file, we can just use the modification time to avoid
|
||||
@@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC):
|
||||
return mod_time
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, file):
|
||||
def validate_inputs(s, file):
|
||||
if not folder_paths.exists_annotated_filepath(file):
|
||||
return "Invalid video file: {}".format(file)
|
||||
|
||||
return True
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveWEBM": SaveWEBM,
|
||||
"SaveVideo": SaveVideo,
|
||||
"CreateVideo": CreateVideo,
|
||||
"GetVideoComponents": GetVideoComponents,
|
||||
"LoadVideo": LoadVideo,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveVideo": "Save Video",
|
||||
"CreateVideo": "Create Video",
|
||||
"GetVideoComponents": "Get Video Components",
|
||||
"LoadVideo": "Load Video",
|
||||
}
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SaveWEBM,
|
||||
SaveVideo,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
return VideoExtension()
|
||||
|
||||
+239
-8
@@ -139,16 +139,21 @@ class Wan22FunControlToVideo(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
spacial_scale = vae.spacial_compression_encode()
|
||||
latent_channels = vae.latent_channels
|
||||
latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
|
||||
if latent_channels == 48:
|
||||
concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent)
|
||||
else:
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
ref_latent = None
|
||||
@@ -159,11 +164,11 @@ class Wan22FunControlToVideo(io.ComfyNode):
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
@@ -201,7 +206,8 @@ class WanFirstLastFrameToVideo(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
spacial_scale = vae.spacial_compression_encode()
|
||||
latent = torch.zeros([batch_size, vae.latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
if end_image is not None:
|
||||
@@ -786,6 +792,229 @@ class WanTrackToVideo(io.ComfyNode):
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||
"""
|
||||
features: shape=[1, T, 512]
|
||||
input_fps: fps for audio, f_a
|
||||
output_fps: fps for video, f_m
|
||||
output_len: video length
|
||||
"""
|
||||
features = features.transpose(1, 2) # [1, 512, T]
|
||||
seq_len = features.shape[2] / float(input_fps) # T/f_a
|
||||
if output_len is None:
|
||||
output_len = int(seq_len * output_fps) # f_m*T/f_a
|
||||
output_features = torch.nn.functional.interpolate(
|
||||
features, size=output_len, align_corners=True,
|
||||
mode='linear') # [1, 512, output_len]
|
||||
return output_features.transpose(1, 2) # [1, output_len, 512]
|
||||
|
||||
|
||||
def get_sample_indices(original_fps,
|
||||
total_frames,
|
||||
target_fps,
|
||||
num_sample,
|
||||
fixed_start=None):
|
||||
required_duration = num_sample / target_fps
|
||||
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
if max_start < 0:
|
||||
raise ValueError("video length is too short")
|
||||
start_frame = np.random.randint(0, max_start + 1)
|
||||
start_time = start_frame / original_fps
|
||||
|
||||
end_time = start_time + required_duration
|
||||
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
||||
|
||||
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
||||
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
||||
return frame_indices
|
||||
|
||||
|
||||
def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
scale = video_rate / fps
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num
|
||||
batch_idx = get_sample_indices(
|
||||
original_fps=video_rate,
|
||||
total_frames=audio_frame_num + padd_audio_num,
|
||||
target_fps=fps,
|
||||
num_sample=bucket_num,
|
||||
fixed_start=0)
|
||||
batch_audio_eb = []
|
||||
audio_sample_stride = int(video_rate / fps)
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
|
||||
chosen_idx = list(
|
||||
range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [
|
||||
audio_frame_num - 1 if c >= audio_frame_num else c
|
||||
for c in chosen_idx
|
||||
]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
|
||||
start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
|
||||
def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None):
|
||||
latent_t = ((length - 1) // 4) + 1
|
||||
if audio_encoder_output is not None:
|
||||
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
|
||||
video_rate = 30
|
||||
fps = 16
|
||||
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
|
||||
batch_frames = latent_t * 4
|
||||
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
|
||||
if len(audio_embed_bucket.shape) == 3:
|
||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
||||
elif len(audio_embed_bucket.shape) == 4:
|
||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||
|
||||
audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames]
|
||||
if audio_embed_bucket.shape[3] > 0:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
|
||||
frame_offset += batch_frames
|
||||
|
||||
if ref_image is not None:
|
||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
if ref_motion is not None:
|
||||
if ref_motion.shape[0] > 73:
|
||||
ref_motion = ref_motion[-73:]
|
||||
|
||||
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
if ref_motion.shape[0] < 73:
|
||||
r = torch.ones([73, height, width, 3]) * 0.5
|
||||
r[-ref_motion.shape[0]:] = ref_motion
|
||||
ref_motion = r
|
||||
|
||||
ref_motion_latent = vae.encode(ref_motion[:, :, :, :3])
|
||||
|
||||
if ref_motion_latent is not None:
|
||||
ref_motion_latent = ref_motion_latent[:, :, -19:]
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent})
|
||||
|
||||
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
control_video = vae.encode(control_video[:, :, :, :3])
|
||||
control_video_out[:, :, :control_video.shape[2]] = control_video
|
||||
|
||||
# TODO: check if zero is better than none if none provided
|
||||
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return positive, negative, out_latent, frame_offset
|
||||
|
||||
|
||||
class WanSoundImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSoundImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||
io.Image.Input("ref_image", optional=True),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
io.Image.Input("ref_motion", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
|
||||
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
|
||||
control_video=control_video, ref_motion=ref_motion)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanSoundImageToVideoExtend(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSoundImageToVideoExtend",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Latent.Input("video_latent"),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||
io.Image.Input("ref_image", optional=True),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput:
|
||||
video_latent = video_latent["samples"]
|
||||
width = video_latent.shape[-1] * 8
|
||||
height = video_latent.shape[-2] * 8
|
||||
batch_size = video_latent.shape[0]
|
||||
frame_offset = video_latent.shape[-3] * 4
|
||||
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
|
||||
control_video=control_video, ref_motion=None, ref_motion_latent=video_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -844,6 +1073,8 @@ class WanExtension(ComfyExtension):
|
||||
TrimVideoLatent,
|
||||
WanCameraImageToVideo,
|
||||
WanPhantomSubjectToVideo,
|
||||
WanSoundImageToVideo,
|
||||
WanSoundImageToVideoExtend,
|
||||
Wan22ImageToVideoLatent,
|
||||
]
|
||||
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.51"
|
||||
__version__ = "0.3.59"
|
||||
|
||||
@@ -46,6 +46,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
@@ -112,7 +112,7 @@ import gc
|
||||
|
||||
|
||||
if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.default_device is not None:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Server middleware modules"""
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Cache control middleware for ComfyUI server"""
|
||||
|
||||
from aiohttp import web
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
# Time in seconds
|
||||
ONE_HOUR: int = 3600
|
||||
ONE_DAY: int = 86400
|
||||
IMG_EXTENSIONS = (
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".ppm",
|
||||
".bmp",
|
||||
".pgm",
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(
|
||||
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> web.Response:
|
||||
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
|
||||
response: web.Response = await handler(request)
|
||||
|
||||
if (
|
||||
request.path.endswith(".js")
|
||||
or request.path.endswith(".css")
|
||||
or request.path.endswith("index.json")
|
||||
):
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
if not request.path.lower().endswith(IMG_EXTENSIONS):
|
||||
return response
|
||||
|
||||
# Handle image files
|
||||
if response.status == 404:
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
|
||||
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
|
||||
# Success responses and permanent redirects - cache for 1 day
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
|
||||
elif response.status in (302, 303, 307):
|
||||
# Temporary redirects - no cache
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
# Note: 304 Not Modified falls through - no cache headers set
|
||||
|
||||
return response
|
||||
@@ -925,7 +925,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -953,7 +953,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -963,7 +963,7 @@ class DualCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@@ -2322,6 +2322,9 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_qwen.py",
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2341,6 +2344,7 @@ async def init_builtin_api_nodes():
|
||||
"nodes_veo2.py",
|
||||
"nodes_kling.py",
|
||||
"nodes_bfl.py",
|
||||
"nodes_bytedance.py",
|
||||
"nodes_luma.py",
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.51"
|
||||
version = "0.3.59"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.25.9
|
||||
comfyui-workflow-templates==0.1.62
|
||||
comfyui-frontend-package==1.25.11
|
||||
comfyui-workflow-templates==0.1.78
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@@ -3,11 +3,7 @@ from urllib import request
|
||||
|
||||
#This is the ComfyUI api prompt format.
|
||||
|
||||
#If you want it for a specific workflow you can "enable dev mode options"
|
||||
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
|
||||
#a button on the UI to save workflows in api format.
|
||||
|
||||
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
|
||||
#If you want it for a specific workflow you can "File -> Export (API)" in the interface.
|
||||
|
||||
#this is the one for the default workflow
|
||||
prompt_text = """
|
||||
|
||||
@@ -39,20 +39,15 @@ from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
# Import cache control middleware
|
||||
from middleware.cache_middleware import cache_control
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def compress_body(request: web.Request, handler):
|
||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||
@@ -729,7 +724,34 @@ class PromptServer():
|
||||
|
||||
@routes.post("/interrupt")
|
||||
async def post_interrupt(request):
|
||||
nodes.interrupt_processing()
|
||||
try:
|
||||
json_data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
json_data = {}
|
||||
|
||||
# Check if a specific prompt_id was provided for targeted interruption
|
||||
prompt_id = json_data.get('prompt_id')
|
||||
if prompt_id:
|
||||
currently_running, _ = self.prompt_queue.get_current_queue()
|
||||
|
||||
# Check if the prompt_id matches any currently running prompt
|
||||
should_interrupt = False
|
||||
for item in currently_running:
|
||||
# item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute)
|
||||
if item[1] == prompt_id:
|
||||
logging.info(f"Interrupting prompt {prompt_id}")
|
||||
should_interrupt = True
|
||||
break
|
||||
|
||||
if should_interrupt:
|
||||
nodes.interrupt_processing()
|
||||
else:
|
||||
logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt")
|
||||
else:
|
||||
# No prompt_id provided, do a global interrupt
|
||||
logging.info("Global interrupt (no prompt_id specified)")
|
||||
nodes.interrupt_processing()
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/free")
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Tests for server cache control middleware"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
from typing import Dict, Any
|
||||
|
||||
from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS
|
||||
|
||||
pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests
|
||||
|
||||
# Test configuration data
|
||||
CACHE_SCENARIOS = [
|
||||
# Image file scenarios
|
||||
{
|
||||
"name": "image_200_status",
|
||||
"path": "/test.jpg",
|
||||
"status": 200,
|
||||
"expected_cache": f"public, max-age={ONE_DAY}",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "image_404_status",
|
||||
"path": "/missing.jpg",
|
||||
"status": 404,
|
||||
"expected_cache": f"public, max-age={ONE_HOUR}",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# JavaScript/CSS scenarios
|
||||
{
|
||||
"name": "js_no_cache",
|
||||
"path": "/script.js",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "css_no_cache",
|
||||
"path": "/styles.css",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "index_json_no_cache",
|
||||
"path": "/api/index.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# Non-matching files
|
||||
{
|
||||
"name": "html_no_header",
|
||||
"path": "/index.html",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
{
|
||||
"name": "txt_no_header",
|
||||
"path": "/data.txt",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
{
|
||||
"name": "api_endpoint_no_header",
|
||||
"path": "/api/endpoint",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
{
|
||||
"name": "pdf_no_header",
|
||||
"path": "/file.pdf",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Status code scenarios for images
|
||||
IMAGE_STATUS_SCENARIOS = [
|
||||
# Success statuses get long cache
|
||||
{"status": 200, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 201, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 202, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 204, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 206, "expected": f"public, max-age={ONE_DAY}"},
|
||||
# Permanent redirects get long cache
|
||||
{"status": 301, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 308, "expected": f"public, max-age={ONE_DAY}"},
|
||||
# Temporary redirects get no cache
|
||||
{"status": 302, "expected": "no-cache"},
|
||||
{"status": 303, "expected": "no-cache"},
|
||||
{"status": 307, "expected": "no-cache"},
|
||||
# 404 gets short cache
|
||||
{"status": 404, "expected": f"public, max-age={ONE_HOUR}"},
|
||||
]
|
||||
|
||||
# Case sensitivity test paths
|
||||
CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"]
|
||||
|
||||
# Edge case test paths
|
||||
EDGE_CASE_PATHS = [
|
||||
{
|
||||
"name": "query_strings_ignored",
|
||||
"path": "/image.jpg?v=123&size=large",
|
||||
"expected": f"public, max-age={ONE_DAY}",
|
||||
},
|
||||
{
|
||||
"name": "multiple_dots_in_path",
|
||||
"path": "/image.min.jpg",
|
||||
"expected": f"public, max-age={ONE_DAY}",
|
||||
},
|
||||
{
|
||||
"name": "nested_paths_with_images",
|
||||
"path": "/static/images/photo.jpg",
|
||||
"expected": f"public, max-age={ONE_DAY}",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestCacheControl:
|
||||
"""Test cache control middleware functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def status_handler_factory(self):
|
||||
"""Create a factory for handlers that return specific status codes"""
|
||||
|
||||
def factory(status: int, headers: Dict[str, str] = None):
|
||||
async def handler(request):
|
||||
return web.Response(status=status, headers=headers or {})
|
||||
|
||||
return handler
|
||||
|
||||
return factory
|
||||
|
||||
@pytest.fixture
|
||||
def mock_handler(self, status_handler_factory):
|
||||
"""Create a mock handler that returns a response with 200 status"""
|
||||
return status_handler_factory(200)
|
||||
|
||||
@pytest.fixture
|
||||
def handler_with_existing_cache(self, status_handler_factory):
|
||||
"""Create a handler that returns response with existing Cache-Control header"""
|
||||
return status_handler_factory(200, {"Cache-Control": "max-age=3600"})
|
||||
|
||||
async def assert_cache_header(
|
||||
self,
|
||||
response: web.Response,
|
||||
expected_cache: str = None,
|
||||
should_have_header: bool = True,
|
||||
):
|
||||
"""Helper to assert cache control headers"""
|
||||
if should_have_header:
|
||||
assert "Cache-Control" in response.headers
|
||||
if expected_cache:
|
||||
assert response.headers["Cache-Control"] == expected_cache
|
||||
else:
|
||||
assert "Cache-Control" not in response.headers
|
||||
|
||||
# Parameterized tests
|
||||
@pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"])
|
||||
async def test_cache_control_scenarios(
|
||||
self, scenario: Dict[str, Any], status_handler_factory
|
||||
):
|
||||
"""Test various cache control scenarios"""
|
||||
handler = status_handler_factory(scenario["status"])
|
||||
request = make_mocked_request("GET", scenario["path"])
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
assert response.status == scenario["status"]
|
||||
await self.assert_cache_header(
|
||||
response, scenario["expected_cache"], scenario["should_have_header"]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("ext", IMG_EXTENSIONS)
|
||||
async def test_all_image_extensions(self, ext: str, mock_handler):
|
||||
"""Test all defined image extensions are handled correctly"""
|
||||
request = make_mocked_request("GET", f"/image{ext}")
|
||||
response = await cache_control(request, mock_handler)
|
||||
|
||||
assert response.status == 200
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}"
|
||||
)
|
||||
async def test_image_status_codes(
|
||||
self, status_scenario: Dict[str, Any], status_handler_factory
|
||||
):
|
||||
"""Test different status codes for image requests"""
|
||||
handler = status_handler_factory(status_scenario["status"])
|
||||
request = make_mocked_request("GET", "/image.jpg")
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
assert response.status == status_scenario["status"]
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == status_scenario["expected"]
|
||||
|
||||
@pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS)
|
||||
async def test_case_insensitive_image_extension(self, path: str, mock_handler):
|
||||
"""Test that image extensions are matched case-insensitively"""
|
||||
request = make_mocked_request("GET", path)
|
||||
response = await cache_control(request, mock_handler)
|
||||
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
|
||||
|
||||
@pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"])
|
||||
async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler):
|
||||
"""Test edge cases like query strings, nested paths, etc."""
|
||||
request = make_mocked_request("GET", edge_case["path"])
|
||||
response = await cache_control(request, mock_handler)
|
||||
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == edge_case["expected"]
|
||||
|
||||
# Header preservation tests (special cases not covered by parameterization)
|
||||
async def test_js_preserves_existing_headers(self, handler_with_existing_cache):
|
||||
"""Test that .js files preserve existing Cache-Control headers"""
|
||||
request = make_mocked_request("GET", "/script.js")
|
||||
response = await cache_control(request, handler_with_existing_cache)
|
||||
|
||||
# setdefault should preserve existing header
|
||||
assert response.headers["Cache-Control"] == "max-age=3600"
|
||||
|
||||
async def test_css_preserves_existing_headers(self, handler_with_existing_cache):
|
||||
"""Test that .css files preserve existing Cache-Control headers"""
|
||||
request = make_mocked_request("GET", "/styles.css")
|
||||
response = await cache_control(request, handler_with_existing_cache)
|
||||
|
||||
# setdefault should preserve existing header
|
||||
assert response.headers["Cache-Control"] == "max-age=3600"
|
||||
|
||||
async def test_image_preserves_existing_headers(self, status_handler_factory):
|
||||
"""Test that image cache headers preserve existing Cache-Control"""
|
||||
handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"})
|
||||
request = make_mocked_request("GET", "/image.jpg")
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
# setdefault should preserve existing header
|
||||
assert response.headers["Cache-Control"] == "private, no-cache"
|
||||
|
||||
async def test_304_not_modified_inherits_cache(self, status_handler_factory):
|
||||
"""Test that 304 Not Modified doesn't set cache headers for images"""
|
||||
handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"})
|
||||
request = make_mocked_request("GET", "/not-modified.jpg")
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
assert response.status == 304
|
||||
# Should preserve existing cache header, not override
|
||||
assert response.headers["Cache-Control"] == "max-age=7200"
|
||||
@@ -6,6 +6,7 @@ def pytest_addoption(parser):
|
||||
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
|
||||
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)")
|
||||
|
||||
# This initializes args at the beginning of the test session
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -19,6 +20,11 @@ def args_pytest(pytestconfig):
|
||||
|
||||
return args
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def skip_timing_checks(pytestconfig):
|
||||
"""Fixture that returns whether timing checks should be skipped."""
|
||||
return pytestconfig.getoption("--skip-timing-checks")
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
# Modifies items so tests run in the correct order
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user