Compare commits

..

31 Commits

Author SHA1 Message Date
comfyanonymous 51696e3fdc ComfyUI version 0.3.65
Python Linting / Run Ruff (push) Failing after 33s
Python Linting / Run Pylint (push) Failing after 35s
Build package / Build Test (3.10) (push) Failing after 29s
Build package / Build Test (3.11) (push) Failing after 29s
Build package / Build Test (3.12) (push) Failing after 34s
Build package / Build Test (3.13) (push) Failing after 33s
Build package / Build Test (3.9) (push) Failing after 35s
2025-10-13 23:39:55 -04:00
comfyanonymous dfff7e5332 Better memory estimation for the SD/Flux VAE on AMD. (#10334) 2025-10-13 22:37:19 -04:00
comfyanonymous e4ea393666 Fix loading old stable diffusion ckpt files on newer numpy. (#10333) 2025-10-13 22:18:58 -04:00
comfyanonymous c8674bc6e9 Enable RDNA4 pytorch attention on ROCm 7.0 and up. (#10332) 2025-10-13 21:19:03 -04:00
Alexander Piskun 3dfdcf66b6 convert nodes_hunyuan.py to V3 schema (#10136) 2025-10-13 12:36:26 -07:00
rattus128 95ca2e56c8 WAN2.2: Fix cache VRAM leak on error (#10308)
Same change pattern as 7e8dd275c243ad460ed5015d2e13611d81d2a569
applied to WAN2.2

If this suffers an exception (such as a VRAM oom) it will leave the
encode() and decode() methods which skips the cleanup of the WAN
feature cache. The comfy node cache then ultimately keeps a reference
this object which is in turn reffing large tensors from the failed
execution.

The feature cache is currently setup at a class variable on the
encoder/decoder however, the encode and decode functions always clear
it on both entry and exit of normal execution.

Its likely the design intent is this is usable as a streaming encoder
where the input comes in batches, however the functions as they are
today don't support that.

So simplify by bringing the cache back to local variable, so that if
it does VRAM OOM the cache itself is properly garbage when the
encode()/decode() functions dissappear from the stack.
2025-10-13 15:23:11 -04:00
Daniel Harte 27ffd12c45 add indent=4 kwarg to json.dumps() (#10307) 2025-10-13 12:14:52 -07:00
comfyanonymous e693e4db6a Always set diffusion model to eval() mode. (#10331) 2025-10-13 14:57:27 -04:00
comfyanonymous d68ece7301 Update the extra_model_paths.yaml.example (#10319) 2025-10-12 23:54:41 -04:00
Christian Byrne 894837de9a update extra models paths example (#10316) 2025-10-12 23:35:33 -04:00
ComfyUI Wiki fdc92863b6 Update node docs to 0.3.0 (#10318) 2025-10-12 23:32:02 -04:00
comfyanonymous a125cd84b0 Improve AMD performance. (#10302)
I honestly have no idea why this improves things but it does.
2025-10-12 00:28:01 -04:00
comfyanonymous 84e9ce32c6 Implement the mmaudio VAE. (#10300) 2025-10-11 22:57:23 -04:00
ComfyUI Wiki f43b8ab2a2 Update template to 0.1.95 (#10294) 2025-10-11 10:27:22 -07:00
Alexander Piskun 14d642acd6 feat(api-nodes): add price extractor feature; small fixes to Kling & Pika nodes (#10284) 2025-10-10 16:21:40 -07:00
Alexander Piskun aa895db7e8 feat(GeminiImage-ApiNode): add aspect_ratio and release version of model (#10255) 2025-10-10 16:17:20 -07:00
comfyanonymous cdfc25a160 Fix save audio nodes saving mono audio as stereo. (#10289) 2025-10-10 17:33:51 -04:00
Alexander Piskun 81e4dac107 convert nodes_upscale_model.py to V3 schema (#10149) 2025-10-09 16:08:40 -07:00
Alexander Piskun 90853fb9cd convert nodes_flux to V3 schema (#10122) 2025-10-09 16:07:17 -07:00
comfyanonymous f1dd6e50f8 Fix bug with applying loras on fp8 scaled without fp8 ops. (#10279) 2025-10-09 19:02:40 -04:00
Alexander Piskun fc0fbf141c convert nodes_sd3.py and nodes_slg.py to V3 schema (#10162) 2025-10-09 15:18:23 -07:00
Alexander Piskun f3d5d328a3 fix(v3,api-nodes): V3 schema typing; corrected Pika API nodes (#10265) 2025-10-09 15:15:03 -07:00
comfyanonymous 139addd53c More surgical fix for #10267 (#10276) 2025-10-09 16:37:35 -04:00
Alexander Piskun cbee7d3390 convert nodes_latent.py to V3 schema (#10160) 2025-10-08 23:14:00 -07:00
Alexander Piskun 6732014a0a convert nodes_compositing.py to V3 schema (#10174) 2025-10-08 23:13:15 -07:00
Alexander Piskun 989f715d92 convert nodes_lora_extract.py to V3 schema (#10182) 2025-10-08 23:11:45 -07:00
Alexander Piskun 2ba8d7cce8 convert nodes_model_downscale.py to V3 schema (#10199) 2025-10-08 23:10:23 -07:00
Alexander Piskun 51fb505ffa feat(api-nodes, pylint): use lazy formatting in logging functions (#10248) 2025-10-08 23:06:56 -07:00
Jedrzej Kosinski 72c2071972 Mvly/node update (#10042)
* updated V2V node to allow for control image input
exposing steps in v2v
fixing guidance_scale as input parameter

TODO: allow for motion_intensity as input param.

* refactor: comment out unsupported resolution and adjust default values in video nodes

* set control_after_generate

* adding new defaults

* fixes

* changed control_after_generate back to True

* changed control_after_generate back to False

---------

Co-authored-by: thorsten <thorsten@tripod-digital.co.nz>
2025-10-08 20:30:41 -04:00
comfyanonymous 6e59934089 Refactor model sampling sigmas code. (#10250) 2025-10-08 17:49:02 -04:00
Alexander Piskun 3e0eb8d33f feat(V3-io): allow Enum classes for Combo options (#10237) 2025-10-08 00:14:04 -07:00
54 changed files with 2811 additions and 1157 deletions
View File
+120
View File
@@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
+157
View File
@@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
+156
View File
@@ -0,0 +1,156 @@
from typing import Literal
import torch
import torch.nn as nn
from .distributions import DiagonalGaussianDistribution
from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.norm_fn = norm_fn
# mel = librosa_mel_fn(sr=self.sampling_rate,
# n_fft=self.n_fft,
# n_mels=self.num_mels,
# fmin=self.fmin,
# fmax=self.fmax)
# mel_basis = torch.from_numpy(mel).float()
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
hann_window = torch.hann_window(self.win_size)
self.register_buffer('mel_basis', mel_basis)
self.register_buffer('hann_window', hann_window)
@property
def device(self):
return self.mel_basis.device
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1),
[int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2)],
mode='reflect')
waveform = waveform.squeeze(1)
spec = torch.stft(waveform,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec, self.norm_fn)
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
bigvgan_config = {
"resblock": "1",
"num_mels": 80,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5],
],
"activation": "snakebeta",
"snake_logscale": True,
}
self.vocoder = BigVGANVocoder(
bigvgan_config
).eval()
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
# x: (B * L)
mel = self.mel_converter(x)
dist = self.vae.encode(mel)
return dist
@torch.no_grad()
def decode(self, z):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)
return dist.mean
+219
View File
@@ -0,0 +1,219 @@
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))
])
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))
])
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
class BigVGANVocoder(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super().__init__()
if isinstance(h, dict):
h = SimpleNamespace(**h)
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
+92
View File
@@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
+358
View File
@@ -0,0 +1,358 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')
+121
View File
@@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
else:
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.optimized_attention = vae_attention()
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=1).unbind(2)
h = self.optimized_attention(q, k, v)
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x
+14 -23
View File
@@ -657,51 +657,51 @@ class WanVAE(nn.Module):
)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.encoder)
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
@@ -715,12 +715,3 @@ class WanVAE(nn.Module):
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
+1 -2
View File
@@ -138,6 +138,7 @@ class BaseModel(torch.nn.Module):
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
@@ -669,7 +670,6 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -698,7 +698,6 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
+5 -8
View File
@@ -332,6 +332,7 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
try:
if is_amd():
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
@@ -344,9 +345,9 @@ try:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
@@ -925,11 +926,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
if d == torch.bfloat16 and should_use_bf16(device):
return d
return torch.float32
+25 -7
View File
@@ -123,16 +123,30 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
def __init__(self, key, patches):
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key):
set_func = None
@@ -657,13 +671,15 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1
cast_weight = True
@@ -825,10 +841,12 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True
+17 -11
View File
@@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@@ -58,15 +64,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma)
class X0(EPS):
@@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma
noise += latent_image
return noise
+3 -1
View File
@@ -416,8 +416,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
+31 -2
View File
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import yaml
import math
@@ -275,8 +276,13 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
else:
VAE_KL_MEM_RATIO = 1.0
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
@@ -291,6 +297,7 @@ class VAE:
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
self.crop_input = True
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -542,6 +549,25 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:
mode = '16k'
else:
mode = '44k'
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
self.latent_channels = 20
self.output_channels = 2
self.upscale_ratio = 512 * (44100 / sample_rate)
self.downscale_ratio = 512 * (44100 / sample_rate)
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -575,6 +601,9 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
+5 -1
View File
@@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar
def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
+7 -2
View File
@@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@@ -114,6 +114,8 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
comfy_io = io # create the new alias for io
__all__ = [
"ComfyAPI",
"ComfyAPISync",
@@ -121,4 +123,7 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"comfy_io",
"ui",
]
+2 -2
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod
def save_to(
self,
path: str,
path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
+92 -78
View File
@@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput):
"""Combo input (dropdown)."""
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
upload: UploadType=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None):
def __init__(
self,
id: str,
options: list[str] | list[int] | type[Enum] = None,
display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False
self.options = options
@@ -1568,78 +1582,78 @@ class _UIOutput(ABC):
...
class _IO:
FolderType = FolderType
UploadType = UploadType
RemoteOptions = RemoteOptions
NumberDisplay = NumberDisplay
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
comfytype = staticmethod(comfytype)
Custom = staticmethod(Custom)
Input = Input
WidgetInput = WidgetInput
Output = Output
ComfyTypeI = ComfyTypeI
ComfyTypeIO = ComfyTypeIO
#---------------------------------
"comfytype",
"Custom",
"Input",
"WidgetInput",
"Output",
"ComfyTypeI",
"ComfyTypeIO",
# Supported Types
Boolean = Boolean
Int = Int
Float = Float
String = String
Combo = Combo
MultiCombo = MultiCombo
Image = Image
WanCameraEmbedding = WanCameraEmbedding
Webcam = Webcam
Mask = Mask
Latent = Latent
Conditioning = Conditioning
Sampler = Sampler
Sigmas = Sigmas
Noise = Noise
Guider = Guider
Clip = Clip
ControlNet = ControlNet
Vae = Vae
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoder = AudioEncoder
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel
Audio = Audio
Video = Video
SVG = SVG
LoraModel = LoraModel
LossMap = LossMap
Voxel = Voxel
Mesh = Mesh
Hooks = Hooks
HookKeyframes = HookKeyframes
TimestepsRange = TimestepsRange
LatentOperation = LatentOperation
FlowControl = FlowControl
Accumulation = Accumulation
Load3DCamera = Load3DCamera
Load3D = Load3D
Load3DAnimation = Load3DAnimation
Photomaker = Photomaker
Point = Point
FaceAnalysis = FaceAnalysis
BBOX = BBOX
SEGS = SEGS
AnyType = AnyType
MultiType = MultiType
#---------------------------------
HiddenHolder = HiddenHolder
Hidden = Hidden
NodeInfoV1 = NodeInfoV1
NodeInfoV3 = NodeInfoV3
Schema = Schema
ComfyNode = ComfyNode
NodeOutput = NodeOutput
add_to_dict_v1 = staticmethod(add_to_dict_v1)
add_to_dict_v3 = staticmethod(add_to_dict_v3)
"Boolean",
"Int",
"Float",
"String",
"Combo",
"MultiCombo",
"Image",
"WanCameraEmbedding",
"Webcam",
"Mask",
"Latent",
"Conditioning",
"Sampler",
"Sigmas",
"Noise",
"Guider",
"Clip",
"ControlNet",
"Vae",
"Model",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",
"Gligen",
"UpscaleModel",
"Audio",
"Video",
"SVG",
"LoraModel",
"LossMap",
"Voxel",
"Mesh",
"Hooks",
"HookKeyframes",
"TimestepsRange",
"LatentOperation",
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3D",
"Load3DAnimation",
"Photomaker",
"Point",
"FaceAnalysis",
"BBOX",
"SEGS",
"AnyType",
"MultiType",
# Other classes
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
]
+13 -12
View File
@@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText
__all__ = [
"SavedResult",
"SavedImages",
"SavedAudios",
"ImageSaveHelper",
"AudioSaveHelper",
"PreviewImage",
"PreviewMask",
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewText",
]
+2 -2
View File
@@ -269,7 +269,7 @@ def tensor_to_bytesio(
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data.
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
@@ -431,7 +431,7 @@ async def upload_video_to_comfyapi(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error(f"Error getting video duration: {e}")
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
+73 -55
View File
@@ -98,7 +98,7 @@ import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
@@ -175,7 +175,7 @@ class ApiClient:
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
@@ -199,9 +199,9 @@ class ApiClient:
@staticmethod
def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
data: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return {
"json": data,
"headers": headers,
@@ -209,11 +209,11 @@ class ApiClient:
def _create_form_data_args(
self,
data: Dict[str, Any] | None,
files: Dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None,
data: dict[str, Any] | None,
files: dict[str, Any] | None,
headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
@@ -254,9 +254,9 @@ class ApiClient:
@staticmethod
def _create_urlencoded_form_data_args(
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
data: dict[str, Any],
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
@@ -264,7 +264,7 @@ class ApiClient:
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
@@ -275,7 +275,7 @@ class ApiClient:
return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
@@ -316,14 +316,14 @@ class ApiClient:
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
@@ -359,10 +359,10 @@ class ApiClient:
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug("[DEBUG] Files: %s", files)
logging.debug("[DEBUG] Params: %s", params)
logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
@@ -485,7 +485,7 @@ class ApiClient:
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: Dict[str, str] = {}
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
@@ -558,7 +558,7 @@ class ApiClient:
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
@@ -592,9 +592,9 @@ class ApiClient:
error_message=f"HTTP Error {exc.status}",
)
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}")
logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
@@ -659,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
@@ -684,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
@@ -729,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
)
try:
request_dict: Optional[Dict[str, Any]]
request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
@@ -738,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
response_json = await client.request(
self.endpoint.method.value,
@@ -757,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response
finally:
if owns_client:
@@ -784,14 +782,16 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] | None = None,
result_url_extractor: Callable[[R], str] | None = None,
*,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
@@ -817,10 +817,12 @@ class PollingOperation(Generic[T, R]):
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
@@ -842,6 +844,8 @@ class PollingOperation(Generic[T, R]):
def _display_text_on_node(self, text: str):
if not self.node_id:
return
if self.extracted_price is not None:
text = f"Price: {self.extracted_price}$\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
@@ -877,18 +881,19 @@ class PollingOperation(Generic[T, R]):
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
"[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
"[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
)
# Query task status
@@ -903,7 +908,7 @@ class PollingOperation(Generic[T, R]):
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress
if self.progress_extractor:
@@ -911,13 +916,18 @@ class PollingOperation(Generic[T, R]):
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}")
logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
@@ -925,7 +935,7 @@ class PollingOperation(Generic[T, R]):
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
logging.error("[DEBUG] %s", message)
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
@@ -939,7 +949,12 @@ class PollingOperation(Generic[T, R]):
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
@@ -949,10 +964,13 @@ class PollingOperation(Generic[T, R]):
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
+10 -7
View File
@@ -1,19 +1,22 @@
from __future__ import annotations
from typing import List, Optional
from typing import Optional
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
responseModalities: Optional[list[str]] = None
imageConfig: Optional[GeminiImageConfig] = None
class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
contents: list[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
safetySettings: Optional[list[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
tools: Optional[list[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None
+100
View File
@@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)
+3 -3
View File
@@ -21,7 +21,7 @@ def get_log_directory():
try:
os.makedirs(log_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}")
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
# Fallback to base temp directory if sub-directory creation fails
return base_temp_dir
return log_dir
@@ -122,9 +122,9 @@ def log_request_response(
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}")
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
logger.error("Error writing API log to %s: %s", filepath, str(e))
if __name__ == '__main__':
+8 -8
View File
@@ -249,8 +249,8 @@ class ByteDanceImageNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2ImageModelName],
default=Text2ImageModelName.seedream_3.value,
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -382,8 +382,8 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2ImageModelName],
default=Image2ImageModelName.seededit_3.value,
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
comfy_io.Image.Input(
@@ -676,8 +676,8 @@ class ByteDanceTextToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2VideoModelName],
default=Text2VideoModelName.seedance_1_pro.value,
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -793,8 +793,8 @@ class ByteDanceImageToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_pro.value,
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
comfy_io.String.Input(
+18 -6
View File
@@ -26,7 +26,7 @@ from comfy_api_nodes.apis import (
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
@@ -63,6 +63,7 @@ class GeminiImageModel(str, Enum):
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
gemini_2_5_flash_image = "gemini-2.5-flash-image"
def get_gemini_endpoint(
@@ -538,7 +539,7 @@ class GeminiImage(ComfyNodeABC):
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
"default": GeminiImageModel.gemini_2_5_flash_image.value,
},
),
"seed": (
@@ -579,6 +580,14 @@ class GeminiImage(ComfyNodeABC):
# "tooltip": "How many images to generate",
# },
# ),
"aspect_ratio": (
IO.COMBO,
{
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
"default": "auto",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
@@ -600,15 +609,17 @@ class GeminiImage(ComfyNodeABC):
images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
aspect_ratio: str = "auto",
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
@@ -625,7 +636,8 @@ class GeminiImage(ComfyNodeABC):
),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
responseModalities=["TEXT","IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
)
),
auth_kwargs=kwargs,
+31 -26
View File
@@ -73,6 +73,7 @@ from comfy_api_nodes.util.validation_utils import (
validate_video_dimensions,
validate_video_duration,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.input.basic_types import AudioInput
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import ComfyExtension, io as comfy_io
@@ -296,7 +297,7 @@ def validate_video_result_response(response) -> None:
"""Validates that the Kling task result contains a video."""
if not is_valid_video_response(response):
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
logging.error(f"Error: {error_msg}.\nResponse: {response}")
logging.error("Error: %s.\nResponse: %s", error_msg, response)
raise Exception(error_msg)
@@ -304,7 +305,7 @@ def validate_image_result_response(response) -> None:
"""Validates that the Kling task result contains an image."""
if not is_valid_image_response(response):
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
logging.error(f"Error: {error_msg}.\nResponse: {response}")
logging.error("Error: %s.\nResponse: %s", error_msg, response)
raise Exception(error_msg)
@@ -511,7 +512,7 @@ async def execute_video_effect(
image_1: torch.Tensor,
image_2: Optional[torch.Tensor] = None,
model_mode: Optional[KlingVideoGenMode] = None,
) -> comfy_io.NodeOutput:
) -> tuple[VideoFromFile, str, str]:
if dual_character:
request_input_field = KlingDualCharacterEffectInput(
model_name=model_name,
@@ -562,7 +563,7 @@ async def execute_video_effect(
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
return comfy_io.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration))
return await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)
async def execute_lipsync(
@@ -647,7 +648,7 @@ class KlingCameraControls(comfy_io.ComfyNode):
category="api node/video/Kling",
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
inputs=[
comfy_io.Combo.Input("camera_control_type", options=[i.value for i in KlingCameraControlType]),
comfy_io.Combo.Input("camera_control_type", options=KlingCameraControlType),
comfy_io.Float.Input(
"horizontal_movement",
default=0.0,
@@ -772,7 +773,7 @@ class KlingTextToVideoNode(comfy_io.ComfyNode):
comfy_io.Float.Input("cfg_scale", default=1.0, min=0.0, max=1.0),
comfy_io.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingVideoGenAspectRatio],
options=KlingVideoGenAspectRatio,
default="16:9",
),
comfy_io.Combo.Input(
@@ -840,7 +841,7 @@ class KlingCameraControlT2VNode(comfy_io.ComfyNode):
comfy_io.Float.Input("cfg_scale", default=0.75, min=0.0, max=1.0),
comfy_io.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingVideoGenAspectRatio],
options=KlingVideoGenAspectRatio,
default="16:9",
),
comfy_io.Custom("CAMERA_CONTROL").Input(
@@ -903,17 +904,17 @@ class KlingImage2VideoNode(comfy_io.ComfyNode):
comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"),
comfy_io.Combo.Input(
"model_name",
options=[i.value for i in KlingVideoGenModelName],
options=KlingVideoGenModelName,
default="kling-v2-master",
),
comfy_io.Float.Input("cfg_scale", default=0.8, min=0.0, max=1.0),
comfy_io.Combo.Input("mode", options=[i.value for i in KlingVideoGenMode], default="std"),
comfy_io.Combo.Input("mode", options=KlingVideoGenMode, default=KlingVideoGenMode.std),
comfy_io.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingVideoGenAspectRatio],
default="16:9",
options=KlingVideoGenAspectRatio,
default=KlingVideoGenAspectRatio.field_16_9,
),
comfy_io.Combo.Input("duration", options=[i.value for i in KlingVideoGenDuration], default="5"),
comfy_io.Combo.Input("duration", options=KlingVideoGenDuration, default=KlingVideoGenDuration.field_5),
],
outputs=[
comfy_io.Video.Output(),
@@ -984,8 +985,8 @@ class KlingCameraControlI2VNode(comfy_io.ComfyNode):
comfy_io.Float.Input("cfg_scale", default=0.75, min=0.0, max=1.0),
comfy_io.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingVideoGenAspectRatio],
default="16:9",
options=KlingVideoGenAspectRatio,
default=KlingVideoGenAspectRatio.field_16_9,
),
comfy_io.Custom("CAMERA_CONTROL").Input(
"camera_control",
@@ -1271,7 +1272,7 @@ class KlingDualCharacterVideoEffectNode(comfy_io.ComfyNode):
image_1=image_left,
image_2=image_right,
)
return video, duration
return comfy_io.NodeOutput(video, duration)
class KlingSingleImageVideoEffectNode(comfy_io.ComfyNode):
@@ -1320,17 +1321,21 @@ class KlingSingleImageVideoEffectNode(comfy_io.ComfyNode):
model_name: KlingSingleImageEffectModelName,
duration: KlingVideoGenDuration,
) -> comfy_io.NodeOutput:
return await execute_video_effect(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
dual_character=False,
effect_scene=effect_scene,
model_name=model_name,
duration=duration,
image_1=image,
return comfy_io.NodeOutput(
*(
await execute_video_effect(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
dual_character=False,
effect_scene=effect_scene,
model_name=model_name,
duration=duration,
image_1=image,
)
)
)
+9 -9
View File
@@ -181,11 +181,11 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
options=LumaImageModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
@@ -366,7 +366,7 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
options=LumaImageModel,
),
comfy_io.Int.Input(
"seed",
@@ -466,21 +466,21 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
options=LumaVideoModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
options=LumaVideoModelOutputDuration,
),
comfy_io.Boolean.Input(
"loop",
@@ -595,7 +595,7 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
options=LumaVideoModel,
),
# comfy_io.Combo.Input(
# "aspect_ratio",
@@ -604,7 +604,7 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
# ),
comfy_io.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
+1 -1
View File
@@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
+42 -32
View File
@@ -237,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
@@ -247,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
@@ -256,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
@@ -288,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
@@ -308,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
@@ -335,7 +331,7 @@ def parse_width_height_from_res(resolution: str):
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
@@ -388,11 +384,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
@@ -403,14 +399,14 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
default=4.5,
min=1.0,
max=20.0,
step=1.0,
@@ -424,10 +420,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
comfy_io.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
@@ -468,7 +465,6 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
@@ -526,11 +522,11 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Int.Input(
@@ -546,7 +542,7 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
comfy_io.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
"control_type",
@@ -563,6 +559,15 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=33,
min=1,
max=100,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
@@ -582,6 +587,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
steps=33,
prompt_adherence=4.5,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
@@ -602,6 +609,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
control = parse_control_parameter(control_type)
@@ -653,11 +662,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
@@ -675,7 +684,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
default=4.0,
min=1.0,
max=20.0,
step=1.0,
@@ -688,11 +697,12 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
comfy_io.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
+86 -193
View File
@@ -8,30 +8,18 @@ from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional, TypeVar
from enum import Enum
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, comfy_io
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
validate_string,
)
from comfy_api_nodes.apis import (
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaGenerateResponse,
PikaVideoResponse,
)
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
ApiEndpoint,
EmptyRequest,
@@ -55,116 +43,36 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
class PikaDurationEnum(int, Enum):
integer_5 = 5
integer_10 = 10
class PikaResolutionEnum(str, Enum):
field_1080p = "1080p"
field_720p = "720p"
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaApiError(Exception):
"""Exception for Pika API errors."""
pass
def is_valid_video_response(response: PikaVideoResponse) -> bool:
"""Check if the video response is valid."""
return hasattr(response, "url") and response.url is not None
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
"""Check if the initial response is valid."""
return hasattr(response, "video_id") and response.video_id is not None
async def poll_for_task_status(
task_id: str,
async def execute_task(
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> PikaGenerateResponse:
polling_operation = PollingOperation(
) -> comfy_io.NodeOutput:
task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PikaVideoResponse,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=[
"finished",
],
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: (
response.status.value if response.status else None
),
progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None
),
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (
response.url if hasattr(response, "url") else None
),
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60
)
return await polling_operation.execute()
async def execute_task(
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
Args:
initial_operation: The initial operation to execute.
auth_kwargs: The authentication token(s) to use for the API call.
Returns:
A tuple containing the video file as a VIDEO output.
"""
initial_response = await initial_operation.execute()
if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
estimated_duration=60,
max_poll_attempts=240,
).execute()
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise PikaApiError(error_msg)
task_id = initial_response.video_id
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
)
logging.error(error_msg)
raise PikaApiError(error_msg)
video_url = str(final_response.url)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return (await download_url_to_video_output(video_url),)
return comfy_io.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[comfy_io.Input]:
@@ -173,16 +81,12 @@ def get_base_inputs_types() -> list[comfy_io.Input]:
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
comfy_io.Combo.Input(
"resolution", options=[resolution.value for resolution in PikaResolutionEnum], default="1080p"
),
comfy_io.Combo.Input(
"duration", options=[duration.value for duration in PikaDurationEnum], default=5
),
comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
comfy_io.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideoV2_2(comfy_io.ComfyNode):
class PikaImageToVideo(comfy_io.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
@@ -215,14 +119,9 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
resolution: str,
duration: int,
) -> comfy_io.NodeOutput:
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
# Prepare non-file data
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -237,8 +136,8 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
@@ -248,7 +147,7 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
class PikaTextToVideoNode(comfy_io.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
@@ -296,10 +195,10 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=PikaBodyGenerate22T2vGenerate22T2vPost(
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -313,7 +212,7 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaScenesV2_2(comfy_io.ComfyNode):
class PikaScenes(comfy_io.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
@@ -389,7 +288,6 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> comfy_io.NodeOutput:
# Convert all passed images to BytesIO
all_image_bytes_io = []
for image in [
image_ingredient_1,
@@ -399,16 +297,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_5,
]:
if image is not None:
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
all_image_bytes_io.append(image_bytes_io)
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -425,8 +321,8 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
@@ -477,22 +373,16 @@ class PikAdditionsNode(comfy_io.ComfyNode):
negative_prompt: str,
seed: int,
) -> comfy_io.NodeOutput:
# Convert video to BytesIO
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -505,8 +395,8 @@ class PikAdditionsNode(comfy_io.ComfyNode):
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
@@ -529,11 +419,25 @@ class PikaSwapsNode(comfy_io.ComfyNode):
category="api node/video/Pika",
inputs=[
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."),
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"),
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
comfy_io.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
comfy_io.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
comfy_io.String.Input("prompt_text", multiline=True, optional=True),
comfy_io.String.Input("negative_prompt", multiline=True, optional=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
comfy_io.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
@@ -548,41 +452,29 @@ class PikaSwapsNode(comfy_io.ComfyNode):
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
mask: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> comfy_io.NodeOutput:
# Convert video to BytesIO
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
# Convert mask to binary mask with three channels
mask = torch.round(mask)
mask = mask.repeat(1, 3, 1, 1)
# Convert 3-channel binary mask to BytesIO
mask_bytes_io = BytesIO()
mask_bytes_io.write(mask.numpy().astype(np.uint8))
mask_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
@@ -590,10 +482,10 @@ class PikaSwapsNode(comfy_io.ComfyNode):
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
@@ -616,7 +508,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
comfy_io.Combo.Input(
"pikaffect", options=[pikaffect.value for pikaffect in Pikaffect], default="Cake-ify"
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True),
@@ -648,10 +540,10 @@ class PikaffectsNode(comfy_io.ComfyNode):
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -664,7 +556,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
class PikaStartEndFrameNode(comfy_io.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
@@ -699,6 +591,7 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
resolution: str,
duration: int,
) -> comfy_io.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
@@ -711,10 +604,10 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=PikaGenerateResponse,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -732,13 +625,13 @@ class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
PikaImageToVideoV2_2,
PikaTextToVideoNodeV2_2,
PikaScenesV2_2,
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode2_2,
PikaStartEndFrameNode,
]
+11 -11
View File
@@ -85,7 +85,7 @@ class PixverseTemplateNode(comfy_io.ComfyNode):
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
comfy_io.Combo.Input("template", options=[list(pixverse_templates.keys())]),
comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())),
],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@@ -120,20 +120,20 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in PixverseAspectRatio],
options=PixverseAspectRatio,
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
@@ -262,16 +262,16 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
@@ -403,16 +403,16 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
+5 -5
View File
@@ -172,16 +172,16 @@ async def create_generate_task(
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
return task_uuid, subscription_key
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
if any(job.status == JobStatus.Failed for job in response.jobs):
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
if all_done:
return "DONE"
@@ -235,7 +235,7 @@ async def download_files(url_list, task_uuid):
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5
for attempt in range(max_retries):
try:
@@ -246,7 +246,7 @@ async def download_files(url_list, task_uuid):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
+6 -6
View File
@@ -200,11 +200,11 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed",
@@ -300,11 +300,11 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio],
options=RunwayGen4TurboAspectRatio,
),
comfy_io.Int.Input(
"seed",
@@ -408,11 +408,11 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed",
+5 -5
View File
@@ -82,8 +82,8 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
@@ -217,12 +217,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[x.value for x in Stability_SD3_5_Model],
options=Stability_SD3_5_Model,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
+1 -1
View File
@@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}")
logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function
def status_extractor(response):
+17 -17
View File
@@ -173,8 +173,8 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.String.Input(
@@ -205,22 +205,22 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
options=MovementAmplitude,
default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@@ -278,8 +278,8 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
@@ -316,14 +316,14 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
options=MovementAmplitude,
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
@@ -388,8 +388,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
@@ -424,8 +424,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),
+5 -4
View File
@@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
for key, value in metadata.items():
output_container.metadata[key] = value
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
+66 -57
View File
@@ -1,6 +1,9 @@
import torch
import comfy.utils
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha
class PorterDuffImageComposite:
class PorterDuffImageComposite(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source": ("IMAGE",),
"source_alpha": ("MASK",),
"destination": ("IMAGE",),
"destination_alpha": ("MASK",),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
},
}
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
io.Image.Input("source"),
io.Mask.Input("source_alpha"),
io.Image.Input("destination"),
io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "composite"
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
@classmethod
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = []
out_alphas = []
@@ -150,45 +157,48 @@ class PorterDuffImageComposite:
out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
class SplitImageWithAlpha:
class SplitImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
@classmethod
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
class JoinImageWithAlpha:
class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []
@@ -196,19 +206,18 @@ class JoinImageWithAlpha:
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result
return io.NodeOutput(torch.stack(out_images))
NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite,
"SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha,
}
class CompositingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}
async def comfy_entrypoint() -> CompositingExtension:
return CompositingExtension()
+115 -74
View File
@@ -1,60 +1,80 @@
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeFlux:
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeFlux",
category="advanced/conditioning/flux",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning/flux"
def encode(self, clip, clip_l, t5xxl, guidance):
@classmethod
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
class FluxGuidance:
encode = execute # TODO: remove
class FluxGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxGuidance",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
@classmethod
def execute(cls, conditioning, guidance) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
class FluxDisableGuidance:
class FluxDisableGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxDisableGuidance",
category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[
io.Conditioning.Input("conditioning"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
@classmethod
def execute(cls, conditioning) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
@@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
]
class FluxKontextImageScale:
class FluxKontextImageScale(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE", ),
},
}
def define_schema(cls):
return io.Schema(
node_id="FluxKontextImageScale",
category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "scale"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
@classmethod
def execute(cls, image) -> io.NodeOutput:
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, )
return io.NodeOutput(image)
scale = execute # TODO: remove
class FluxKontextMultiReferenceLatentMethod:
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
@classmethod
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
return io.NodeOutput(c)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
}
append = execute # TODO: remove
class FluxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeFlux,
FluxGuidance,
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
]
async def comfy_entrypoint() -> FluxExtension:
return FluxExtension()
+150 -91
View File
@@ -2,42 +2,60 @@ import nodes
import node_helpers
import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeHunyuanDiT:
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHunyuanDiT",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("bert", multiline=True, dynamic_prompts=True),
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, bert, mt5xl):
@classmethod
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class EmptyHunyuanLatentVideo:
encode = execute # TODO: remove
class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanLatentVideo",
category="latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
return io.NodeOutput({"samples":latent})
generate = execute # TODO: remove
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
@@ -50,45 +68,61 @@ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
class TextEncodeHunyuanVideo_ImageToVideo:
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="TextEncodeHunyuanVideo_ImageToVideo",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.ClipVisionOutput.Input("clip_vision_output"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Int.Input(
"image_interleave",
default=2,
min=1,
max=512,
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt, image_interleave):
@classmethod
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class HunyuanImageToVideo:
encode = execute # TODO: remove
class HunyuanImageToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="HunyuanImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
@classmethod
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
@@ -111,51 +145,76 @@ class HunyuanImageToVideo:
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent
return (positive, out_latent)
return io.NodeOutput(positive, out_latent)
class EmptyHunyuanImageLatent:
encode = execute # TODO: remove
class EmptyHunyuanImageLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanImageLatent",
category="latent",
inputs=[
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
return io.NodeOutput({"samples":latent})
class HunyuanRefinerLatent:
generate = execute # TODO: remove
class HunyuanRefinerLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT", ),
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="HunyuanRefinerLatent",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "execute"
def execute(self, positive, negative, latent, noise_augmentation):
@classmethod
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
latent = latent["samples"]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
out_latent = {}
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
"HunyuanRefinerLatent": HunyuanRefinerLatent,
}
class HunyuanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeHunyuanDiT,
TextEncodeHunyuanVideo_ImageToVideo,
EmptyHunyuanLatentVideo,
HunyuanImageToVideo,
EmptyHunyuanImageLatent,
HunyuanRefinerLatent,
]
async def comfy_entrypoint() -> HunyuanExtension:
return HunyuanExtension()
+224 -170
View File
@@ -2,6 +2,8 @@ import comfy.utils
import comfy_extras.nodes_post_processing
import torch
import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def reshape_latent_to(target_shape, latent, repeat_batch=True):
@@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
return latent
class LatentAdd:
class LatentAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentAdd",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@@ -31,19 +39,25 @@ class LatentAdd:
s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 + s2
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentSubtract:
class LatentSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentSubtract",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@@ -51,41 +65,49 @@ class LatentSubtract:
s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 - s2
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentMultiply:
class LatentMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentMultiply",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, multiplier):
@classmethod
def execute(cls, samples, multiplier) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = s1 * multiplier
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentInterpolate:
class LatentInterpolate(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentInterpolate",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
@classmethod
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@@ -104,19 +126,26 @@ class LatentInterpolate:
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentConcat:
class LatentConcat(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
def define_schema(cls):
return io.Schema(
node_id="LatentConcat",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
@classmethod
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@@ -136,22 +165,27 @@ class LatentConcat:
dim = -3
samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentCut:
class LatentCut(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT",),
"dim": (["x", "y", "t"], ),
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
def define_schema(cls):
return io.Schema(
node_id="LatentCut",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["x", "y", "t"]),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
@classmethod
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
@@ -171,19 +205,25 @@ class LatentCut:
amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentBatch:
class LatentBatch(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentBatch",
category="latent/batch",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
@@ -192,20 +232,25 @@ class LatentBatch:
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentBatchSeedBehavior:
class LatentBatchSeedBehavior(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
def define_schema(cls):
return io.Schema(
node_id="LatentBatchSeedBehavior",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
@classmethod
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
@@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentApplyOperation:
class LatentApplyOperation(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"operation": ("LATENT_OPERATION",),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperation",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
@classmethod
def execute(cls, samples, operation) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentApplyOperationCFG:
class LatentApplyOperationCFG(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperationCFG",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def patch(self, model, operation):
@classmethod
def execute(cls, model, operation) -> io.NodeOutput:
m = model.clone()
def pre_cfg_function(args):
@@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, )
return io.NodeOutput(m)
class LatentOperationTonemapReinhard:
class LatentOperationTonemapReinhard(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
@classmethod
def execute(cls, multiplier) -> io.NodeOutput:
def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
@@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
new_magnitude *= top
return normalized_latent * new_magnitude
return (tonemap_reinhard,)
return io.NodeOutput(tonemap_reinhard)
class LatentOperationSharpen:
class LatentOperationSharpen(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"sharpen_radius": ("INT", {
"default": 9,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentOperationSharpen",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
@classmethod
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance
@@ -340,19 +386,27 @@ class LatentOperationSharpen:
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened
return (sharpen,)
return io.NodeOutput(sharpen)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentCut": LatentCut,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
"LatentApplyOperationCFG": LatentApplyOperationCFG,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
"LatentOperationSharpen": LatentOperationSharpen,
}
class LatentExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LatentAdd,
LatentSubtract,
LatentMultiply,
LatentInterpolate,
LatentConcat,
LatentCut,
LatentBatch,
LatentBatchSeedBehavior,
LatentApplyOperation,
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
]
async def comfy_entrypoint() -> LatentExtension:
return LatentExtension()
+42 -28
View File
@@ -5,6 +5,8 @@ import folder_paths
import os
import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
CLAMP_QUANTILE = 0.99
@@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class LoraSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraSave",
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Model.Input(
"model_diff",
tooltip="The ModelSubtract output to be converted to a lora.",
optional=True,
),
io.Clip.Input(
"text_encoder_diff",
tooltip="The CLIPSubtract output to be converted to a lora.",
optional=True,
),
],
is_experimental=True,
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
if model_diff is None and text_encoder_diff is None:
return {}
return io.NodeOutput()
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
output_sd = {}
if model_diff is not None:
@@ -108,12 +118,16 @@ class LoraSave:
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
return io.NodeOutput()
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraSave": "Extract and Save Lora"
}
class LoraSaveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoraSave,
]
async def comfy_entrypoint() -> LoraSaveExtension:
return LoraSaveExtension()
+39 -22
View File
@@ -1,24 +1,33 @@
from typing_extensions import override
import comfy.utils
from comfy_api.latest import ComfyExtension, io
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
class PatchModelAddDownscale(io.ComfyNode):
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="PatchModelAddDownscale",
display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
io.Boolean.Input("downscale_after_skip", default=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
@classmethod
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
@@ -41,13 +50,21 @@ class PatchModelAddDownscale:
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
"PatchModelAddDownscale": "",
}
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PatchModelAddDownscale,
]
async def comfy_entrypoint() -> ModelDownscaleExtension:
return ModelDownscaleExtension()
+1 -1
View File
@@ -25,7 +25,7 @@ class PreviewAny():
value = str(source)
elif source is not None:
try:
value = json.dumps(source)
value = json.dumps(source, indent=4)
except Exception:
try:
value = str(source)
+153 -79
View File
@@ -3,64 +3,83 @@ import comfy.sd
import comfy.model_management
import nodes
import torch
import comfy_extras.nodes_slg
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
class TripleCLIPLoader:
class TripleCLIPLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
def define_schema(cls):
return io.Schema(
node_id="TripleCLIPLoader",
category="advanced/loaders",
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
@classmethod
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
return io.NodeOutput(clip)
load_clip = execute # TODO: remove
class EmptySD3LatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
class EmptySD3LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptySD3LatentImage",
category="latent/sd3",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
generate = execute # TODO: remove
class CLIPTextEncodeSD3:
class CLIPTextEncodeSD3(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"empty_padding": (["none", "empty_prompt"], )
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeSD3",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
@classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g)
@@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
encode = execute # TODO: remove
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
class ControlNetApplySD3(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "conditioning/controlnet"
DEPRECATED = True
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ControlNetApplySD3",
display_name="Apply Controlnet with VAE",
category="conditioning/controlnet",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.ControlNet.Input("control_net"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
is_deprecated=True,
)
@classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
if strength == 0:
return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1, 1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
vae=vae, extra_concat=[])
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1])
apply_controlnet = execute # TODO: remove
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
class SkipLayerGuidanceSD3(io.ComfyNode):
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance_sd3"
def define_schema(cls):
return io.Schema(
node_id="SkipLayerGuidanceSD3",
category="advanced/guidance",
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
inputs=[
io.Model.Input("model"),
io.String.Input("layers", default="7, 8, 9", multiline=False),
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
CATEGORY = "advanced/guidance"
@classmethod
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
skip_guidance_sd3 = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
"EmptySD3LatentImage": EmptySD3LatentImage,
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
}
class SD3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TripleCLIPLoader,
EmptySD3LatentImage,
CLIPTextEncodeSD3,
ControlNetApplySD3,
SkipLayerGuidanceSD3,
]
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "Apply Controlnet with VAE",
}
async def comfy_entrypoint() -> SD3Extension:
return SD3Extension()
+66 -42
View File
@@ -1,33 +1,40 @@
import comfy.model_patcher
import comfy.samplers
import re
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SkipLayerGuidanceDiT:
class SkipLayerGuidanceDiT(io.ComfyNode):
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True
def define_schema(cls):
return io.Schema(
node_id="SkipLayerGuidanceDiT",
category="advanced/guidance",
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("double_layers", default="7, 8, 9"),
io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
@classmethod
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
# check if layer is comma separated integers
def skip(args, extra_args):
return args
@@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )
return io.NodeOutput(model)
def post_cfg_function(args):
model = args["model"]
@@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, )
return io.NodeOutput(m)
class SkipLayerGuidanceDiTSimple:
skip_guidance = execute # TODO: remove
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
'''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True
def define_schema(cls):
return io.Schema(
node_id="SkipLayerGuidanceDiTSimple",
category="advanced/guidance",
description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("double_layers", default="7, 8, 9"),
io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
@classmethod
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
def skip(args, extra_args):
return args
@@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )
return io.NodeOutput(model)
def calc_cond_batch_function(args):
x = args["input"]
@@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
}
skip_guidance = execute # TODO: remove
class SkipLayerGuidanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SkipLayerGuidanceDiT,
SkipLayerGuidanceDiTSimple,
]
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
return SkipLayerGuidanceExtension()
+51 -25
View File
@@ -4,6 +4,8 @@ from comfy import model_management
import torch
import comfy.utils
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
try:
from spandrel_extra_arches import EXTRA_REGISTRY
@@ -13,17 +15,23 @@ try:
except:
pass
class UpscaleModelLoader:
class UpscaleModelLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
}}
RETURN_TYPES = ("UPSCALE_MODEL",)
FUNCTION = "load_model"
def define_schema(cls):
return io.Schema(
node_id="UpscaleModelLoader",
display_name="Load Upscale Model",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
],
outputs=[
io.UpscaleModel.Output(),
],
)
CATEGORY = "loaders"
def load_model(self, model_name):
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
@@ -33,21 +41,29 @@ class UpscaleModelLoader:
if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.")
return (out, )
return io.NodeOutput(out)
load_model = execute # TODO: remove
class ImageUpscaleWithModel:
class ImageUpscaleWithModel(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
"image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
def define_schema(cls):
return io.Schema(
node_id="ImageUpscaleWithModel",
display_name="Upscale Image (using Model)",
category="image/upscaling",
inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
CATEGORY = "image/upscaling"
def upscale(self, upscale_model, image):
@classmethod
def execute(cls, upscale_model, image) -> io.NodeOutput:
device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model.model)
@@ -75,9 +91,19 @@ class ImageUpscaleWithModel:
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,)
return io.NodeOutput(s)
NODE_CLASS_MAPPINGS = {
"UpscaleModelLoader": UpscaleModelLoader,
"ImageUpscaleWithModel": ImageUpscaleWithModel
}
upscale = execute # TODO: remove
class UpscaleModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
UpscaleModelLoader,
ImageUpscaleWithModel,
]
async def comfy_entrypoint() -> UpscaleModelExtension:
return UpscaleModelExtension()
+1 -1
View File
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.64"
__version__ = "0.3.65"
+29 -21
View File
@@ -1,25 +1,5 @@
#Rename this to extra_model_paths.yaml and ComfyUI will load it
#config for a1111 ui
#all you have to do is change the base_path to where yours is installed
a111:
base_path: path/to/stable-diffusion-webui/
checkpoints: models/Stable-diffusion
configs: models/Stable-diffusion
vae: models/VAE
loras: |
models/Lora
models/LyCORIS
upscale_models: |
models/ESRGAN
models/RealESRGAN
models/SwinIR
embeddings: embeddings
hypernetworks: models/hypernetworks
controlnet: models/ControlNet
#config for comfyui
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
@@ -28,7 +8,9 @@ a111:
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/
# clip: models/clip/
# text_encoders: |
# models/text_encoders/
# models/clip/ # legacy location still supported
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
@@ -39,6 +21,32 @@ a111:
# loras: models/loras/
# upscale_models: models/upscale_models/
# vae: models/vae/
# audio_encoders: models/audio_encoders/
# model_patches: models/model_patches/
#config for a1111 ui
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
#a111:
# base_path: path/to/stable-diffusion-webui/
# checkpoints: models/Stable-diffusion
# configs: models/Stable-diffusion
# vae: models/VAE
# loras: |
# models/Lora
# models/LyCORIS
# upscale_models: |
# models/ESRGAN
# models/RealESRGAN
# models/SwinIR
# embeddings: embeddings
# hypernetworks: models/hypernetworks
# controlnet: models/ControlNet
# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker,
# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py.
#other_ui:
# base_path: path/to/ui
-2
View File
@@ -2027,7 +2027,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DiffControlNetLoader": "Load ControlNet Model (diff)",
"StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision",
"UpscaleModelLoader": "Load Upscale Model",
"UNETLoader": "Load Diffusion Model",
# Conditioning
"CLIPVisionEncode": "CLIP Vision Encode",
@@ -2065,7 +2064,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
+1 -2
View File
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.64"
version = "0.3.65"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
@@ -61,7 +61,6 @@ messages_control.disable = [
# next warnings should be fixed in future
"bad-classmethod-argument", # Class method should have 'cls' as first argument
"wrong-import-order", # Standard imports should be placed before third party imports
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
"ungrouped-imports",
"unnecessary-pass",
"unnecessary-lambda-assignment",
+2 -2
View File
@@ -1,6 +1,6 @@
comfyui-frontend-package==1.27.10
comfyui-workflow-templates==0.1.94
comfyui-embedded-docs==0.2.6
comfyui-workflow-templates==0.1.95
comfyui-embedded-docs==0.3.0
torch
torchsde
torchvision