Compare commits

...

74 Commits

Author SHA1 Message Date
comfyanonymous 50549aa252 ComfyUI v0.16.2
Python Linting / Run Ruff (push) Failing after 34s
Python Linting / Run Pylint (push) Failing after 25s
Build package / Build Test (3.10) (push) Failing after 35s
Build package / Build Test (3.11) (push) Failing after 32s
Build package / Build Test (3.12) (push) Failing after 30s
Build package / Build Test (3.13) (push) Failing after 35s
Build package / Build Test (3.14) (push) Failing after 38s
2026-03-05 13:41:06 -05:00
comfyanonymous 1c3b651c0a Refactor. (#12794) 2026-03-05 13:35:56 -05:00
ComfyUI Wiki 5073da57ad chore: update workflow templates to v0.9.10 (#12793) 2026-03-05 10:22:38 -08:00
rattus 42e0e023ee ops: Handle CPU weight in VBAR caster (#12792)
This shouldn't happen but custom nodes gets there. Handle it as best
we can.
2026-03-05 10:22:17 -08:00
rattus 6481569ad4 comfy-aimdo 0.2.7 (#12791)
Comfy-aimdo 0.2.7 fixes a crash when a spurious cudaAsyncFree comes in
and would cause an infinite stack overflow (via detours hooks).

A lock is also introduced on the link list holding the free sections
to avoid any possibility of threaded miscellaneous cuda allocations
being the root cause.
2026-03-05 09:04:24 -08:00
comfyanonymous 6ef82a89b8 ComfyUI v0.16.1
Python Linting / Run Ruff (push) Failing after 43s
Python Linting / Run Pylint (push) Failing after 26s
Build package / Build Test (3.10) (push) Failing after 36s
Build package / Build Test (3.11) (push) Failing after 31s
Build package / Build Test (3.12) (push) Failing after 37s
Build package / Build Test (3.13) (push) Failing after 39s
Build package / Build Test (3.14) (push) Failing after 38s
2026-03-05 10:38:33 -05:00
ComfyUI Wiki da29b797ce Update workflow templates to v0.9.8 (#12788) 2026-03-05 07:23:23 -08:00
Alexander Piskun 9cdfd7403b feat(api-nodes): enable Kling 3.0 Motion Control (#12785) 2026-03-05 07:12:38 -08:00
Alexander Piskun bd21363563 feat(api-nodes-xAI): updated models, pricing, added features (#12756) 2026-03-05 04:29:39 -08:00
comfyanonymous e04d0dbeb8 ComfyUI v0.16.0
Python Linting / Run Ruff (push) Failing after 36s
Python Linting / Run Pylint (push) Failing after 40s
Build package / Build Test (3.10) (push) Failing after 38s
Build package / Build Test (3.11) (push) Failing after 33s
Build package / Build Test (3.12) (push) Failing after 32s
Build package / Build Test (3.13) (push) Failing after 34s
Build package / Build Test (3.14) (push) Failing after 30s
2026-03-05 04:06:29 -05:00
ComfyUI Wiki c8428541a6 chore: update workflow templates to v0.9.7 (#12780) 2026-03-05 03:58:25 -05:00
comfyanonymous 4941671b5a Fix cuda getting initialized in cpu mode. (#12779) 2026-03-05 02:39:51 -05:00
ComfyUI Wiki c5fe8ace68 chore: update workflow templates to v0.9.6 (#12778) 2026-03-05 02:37:35 -05:00
comfyanonymous f2ee7f2d36 Fix cublas ops on dynamic vram. (#12776) 2026-03-05 01:21:55 -05:00
comfyanonymous 43c64b6308 Support the LTXAV 2.3 model. (#12773) 2026-03-04 20:06:20 -05:00
comfyanonymous ac4a943ff3 Initial load device should be cpu when using dynamic vram. (#12766) 2026-03-04 16:33:14 -05:00
rattus 8811db52db comfy-aimdo 0.2.6 (#12764)
Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest
as an error after a number of workflow runs.
2026-03-04 12:12:37 -08:00
Jukka Seppänen 0a7446ade4 Pass tokens when loading text gen model for text generation (#12755)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-04 08:59:56 -08:00
rattus 9b85cf9558 Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754)
* ops: dont unpin nothing

This was calling into aimdo in the none case (offloaded weight). Whats worse,
is aimdo syncs for unpinning an offloaded weight, as that is the corner case of
a weight getting evicted by its own use which does require a sync. But this
was heppening every offloaded weight causing slowdown.

* mp: fix get_free_memory policy

The ModelPatcherDynamic get_free_memory was deducting the model from
to try and estimate the conceptual free memory with doing any
offloading. This is kind of what the old memory_memory_required
was estimating in ModelPatcher load logic, however in practical
reality, between over-estimates and padding, the loader usually
underloaded models enough such that sampling could send CFG +/-
through together even when partially loaded.

So don't regress from the status quo and instead go all in on the
idea that offloading is less of an issue than debatching. Tell the
sampler it can use everything.
2026-03-04 07:49:13 -08:00
rattus d531e3fb2a model_patcher: Improve dynamic offload heuristic (#12759)
Define a threshold below which a weight loading takes priority. This
actually makes the offload consistent with non-dynamic, because what
happens, is when non-dynamic fills ints to_load list, it will fill-up
any left-over pieces that could fix large weights with small weights
and load them, even though they were lower priority. This actually
improves performance because the timy weights dont cost any VRAM and
arent worth the control overhead of the DMA etc.
2026-03-04 07:47:44 -08:00
Arthur R Longbottom eb011733b6 Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683)
* Fix VideoFromComponents.save_to crash when writing to BytesIO

When `get_container_format()` or `get_stream_source()` is called on a
tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`.
Since BytesIO has no file extension, `av.open` can't infer the output
format and throws `ValueError: Could not determine output format`.

The sibling class `VideoFromFile` already handles this correctly via
`get_open_write_kwargs()`, which detects BytesIO and sets the format
explicitly. `VideoFromComponents` just never got the same treatment.

This surfaces when any downstream node validates the container format
of a tensor-based video, like TopazVideoEnhance or any node that calls
`validate_container_format_is_mp4()`.

Three-line fix in `comfy_api/latest/_input_impl/video_types.py`.

* Add docstring to save_to to satisfy CI coverage check
2026-03-04 00:29:00 -05:00
rattus ac6513e142 DynamicVram: Add casting / fix torch Buffer weights (#12749)
* respect model dtype in non-comfy caster

* utils: factor out parent and name functionality of set_attr

* utils: implement set_attr_buffer for torch buffers

* ModelPatcherDynamic: Implement torch Buffer loading

If there is a buffer in dynamic - force load it.
2026-03-03 18:19:40 -08:00
Terry Jia b6ddc590ed CURVE type (#12581)
* CURVE type

* fix: update typed wrapper unwrap keys to __type__ and __value__

* code improve

* code improve
2026-03-03 16:58:53 -08:00
comfyanonymous f719a9d928 Adjust memory usage factor of zeta model. (#12746) 2026-03-03 17:35:22 -05:00
rattus 174fd6759d main: Load aimdo after logger is setup (#12743)
This was too early. Aimdo can use the logger in error paths and this
causes a rogue default init if aimdo has something to log.
2026-03-03 08:51:15 -08:00
rattus 09bcbddfcf ModelPatcherDynamic: Force load all non-comfy weights (#12739)
* model_management: Remove non-comfy dynamic _v caster

* Force pre-load non-comfy weights to GPU in ModelPatcherDynamic

Non-comfy weights may expect to be pre-cast to the target
device without in-model casting. Previously they were allocated in
the vbar with _v which required the _v fault path in cast_to.
Instead, back up the original CPU weight and move it directly to GPU
at load time.
2026-03-03 08:50:33 -08:00
xeinherjer dff0a4a158 Fix VAEDecodeAudioTiled ignoring tile_size input (#12735) (#12738) 2026-03-02 20:17:51 -05:00
Lodestone 9ebee0a217 Feat: z-image pixel space (model still training atm) (#12709)
* draft zeta (z-image pixel space)

* revert gitignore

* model loaded and able to run however vector direction still wrong tho

* flip the vector direction to original again this time

* Move wrongly positioned Z image pixel space class

* inherit Radiance LatentFormat class

* Fix parameters in classes for Zeta x0 dino

* remove arbitrary nn.init instances

* Remove unused import of lru_cache

---------

Co-authored-by: silveroxides <ishimarukaito@gmail.com>
2026-03-02 19:43:47 -05:00
comfyanonymous 57dd6c1aad Support loading zeta chroma weights properly. (#12734) 2026-03-02 18:54:18 -05:00
ComfyUI Wiki f1f8996e15 chore: update workflow templates to v0.9.5 (#12732) 2026-03-02 09:13:42 -08:00
Alexander Piskun afb54219fa feat(api-nodes): allow to use "IMAGE+TEXT" in NanoBanana2 (#12729) 2026-03-01 23:24:33 -08:00
rattus 7175c11a4e comfy aimdo 0.2.4 (#12727)
Comfy Aimdo 0.2.4 fixes a VRAM buffer alignment issue that happens in
someworkflows where action is able to bypass the pytorch allocator
and go straight to the cuda hook.
2026-03-01 22:21:41 -08:00
rattus dfbf99a061 model_mangament: make dynamic --disable-smart-memory work (#12724)
This was previously considering the pool of dynamic models as one giant
entity for the sake of smart memory, but that isnt really the useful
or what a user would reasonably expect. Make Dynamic VRAM properly purge
its models just like the old --disable-smart-memory but conditioning
the dynamic-for-dynamic bypass on smart memory.

Re-enable dynamic smart memory.
2026-03-01 19:18:56 -08:00
comfyanonymous 602f6bd82c Make --disable-smart-memory disable dynamic vram. (#12722) 2026-03-01 15:28:39 -05:00
rattus c0d472e5b9 comfy-aimdo 0.2.3 (#12720) 2026-03-01 11:14:56 -08:00
drozbay 4d79f4f028 fix: handle substep sigmas in context window set_step (#12719)
Multi-step samplers (eg. dpmpp_2s_ancestral) call the model at intermediate sigma values not present in the schedule. This caused set_step to crash with "No sample_sigmas matched current timestep" when context windows were enabled.

The fix is to keep self._step from the last exact match when a substep sigma is encountered, since substeps are still logically part of their parent step and should use the same context windows.

Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2026-03-01 09:38:30 -08:00
Christian Byrne 850e8b42ff feat: add text preview support to jobs API (#12169)
* feat: add text preview support to jobs API

Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375
Co-authored-by: Amp <amp@ampcode.com>

* test: update tests to expect text as previewable media type

Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375

---------
2026-02-28 21:38:19 -08:00
Christian Byrne d159142615 refactor: rename Mahiro CFG to Similarity-Adaptive Guidance (#12172)
* refactor: rename Mahiro CFG to Similarity-Adaptive Guidance

Rename the display name to better describe what the node does:
adaptively blends guidance based on cosine similarity between
positive and negative conditions.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1
Co-authored-by: Amp <amp@ampcode.com>

* feat: add search aliases for old mahiro name

Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1

* rename: Similarity-Adaptive Guidance → Positive-Biased Guidance (per reviewer)

- display_name changed to 'Positive-Biased Guidance' to avoid SAG acronym collision
- search_aliases expanded: mahiro, mahiro cfg, similarity-adaptive guidance, positive-biased cfg
- ruff format applied

---------

Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-28 20:59:24 -08:00
comfyanonymous 1080bd442a Disable dynamic vram on wsl. (#12706) 2026-02-28 22:23:28 -05:00
comfyanonymous 17106cb124 Move parsing of requirements logic to function. (#12701) 2026-02-28 22:21:32 -05:00
rattus 48bb0bd18a cli_args: Default comfy to DynamicVram mode (#12658) 2026-02-28 16:52:30 -05:00
rattus 5f41584e96 Disable dynamic_vram when weight hooks applied (#12653)
* sd: add support for clip model reconstruction

* nodes: SetClipHooks: Demote the dynamic model patcher

* mp: Make dynamic_disable more robust

The backup need to not be cloned. In addition add a delegate object
to ModelPatcherDynamic so that non-cloning code can do
ModelPatcherDynamic demotion

* sampler_helpers: Demote to non-dynamic model patcher when hooking

* code rabbit review comments
2026-02-28 16:50:18 -05:00
Jukka Seppänen 1f6744162f feat: Support SCAIL WanVideo model (#12614) 2026-02-28 16:49:12 -05:00
fappaz 95e1059661 fix(ace15): handle missing lm_metadata in memory estimation during checkpoint export #12669 (#12686) 2026-02-28 01:18:40 -05:00
Christian Byrne 80d49441e5 refactor: use AspectRatio enum members as ASPECT_RATIOS dict keys (#12689)
Amp-Thread-ID: https://ampcode.com/threads/T-019ca1cb-0150-7549-8b1b-6713060d3408

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-27 20:53:46 -08:00
comfyanonymous 9d0e114ee3 PyOpenGL-accelerate is not necessary. (#12692) 2026-02-27 23:34:58 -05:00
Talmaj ac4412d0fa Native LongCat-Image implementation (#12597) 2026-02-27 23:04:34 -05:00
comfyanonymous 94f1a1cc9d Limit overlap in image tile and combine nodes to prevent issues. (#12688) 2026-02-27 20:16:24 -05:00
rattus e721e24136 ops: implement lora requanting for non QuantizedTensor fp8 (#12668)
Allow non QuantizedTensor layer to set want_requant to get the post lora
calculation stochastic cast down to the original input dtype.

This is then used by the legacy fp8 Linear implementation to set the
compute_dtype to the preferred lora dtype but then want_requant it back
down to fp8.

This fixes the issue with --fast fp8_matrix_mult is combined with
--fast dynamic_vram which doing a lora on an fp8_ non QT model.
2026-02-27 19:05:51 -05:00
Reiner "Tiles" Prokein 25ec3d96a3 Class WanVAE, def encode, feat_map is using self.decoder instead of self.encoder (#12682) 2026-02-27 19:03:45 -05:00
Christian Byrne 1f1ec377ce feat: add ResolutionSelector node for aspect ratio and megapixel-based resolution calculation (#12199)
Amp-Thread-ID: https://ampcode.com/threads/T-019c179e-cd8c-768f-ae66-207c7a53c01d

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-27 09:13:57 -08:00
pythongosssss 0a7f8e11b6 fix torch.cat requiring inputs to all be same dimensions (#12673) 2026-02-27 08:13:24 -08:00
vickytsang 35e9fce775 Enable Pytorch Attention for gfx950 (#12641) 2026-02-26 20:16:12 -05:00
Jukka Seppänen c7f7d52b68 feat: Support SDPose-OOD (#12661) 2026-02-26 19:59:05 -05:00
rattus 08b26ed7c2 bug_report template: Push harder for logs (#12657)
We get a lot od bug reports without logs, especially for performance
issues.
2026-02-26 18:59:24 -05:00
fappaz b233dbe0bc feat(ace-step): add ACE-Step 1.5 lycoris key alias mapping for LoKR #12638 (#12665) 2026-02-26 18:19:19 -05:00
comfyanonymous 3811780e4f Portable with cu128 isn't useful anymore. (#12666)
Users should either use the cu126 one or the regular one (cu130 at the moment)

The cu128 portable is still included in the latest github release but I will stop including it as soon as it becomes slightly annoying to deal with. This might happen as soon as next week.
2026-02-26 17:12:29 -05:00
comfyanonymous 3dd10a59c0 ComfyUI v0.15.1
Python Linting / Run Ruff (push) Failing after 33s
Python Linting / Run Pylint (push) Failing after 32s
Build package / Build Test (3.10) (push) Failing after 36s
Build package / Build Test (3.11) (push) Failing after 38s
Build package / Build Test (3.12) (push) Failing after 36s
Build package / Build Test (3.13) (push) Failing after 37s
Build package / Build Test (3.14) (push) Failing after 33s
2026-02-26 15:59:22 -05:00
ComfyUI Wiki 88d05fe483 chore: update workflow templates to v0.9.4 (#12664) 2026-02-26 15:52:45 -05:00
Alexander Piskun fd41ec97cc feat(api-nodes): add NanoBanana2 (#12660) 2026-02-26 15:52:10 -05:00
rattus 420e900f69 main: load aimdo earlier (#12655)
Some custom node packs are naughty, and violate the
dont-load-torch-on-load rule. This causes aimdo to lose preference on
its allocator hook on linux.

Go super early on the aimdo first-stage init before custom nodes
are mentioned at all.
2026-02-26 15:19:38 -05:00
pythongosssss 38ca94599f pyopengl-accelerate can cause object to be numpy ints instead of bare ints, which the glDeleteTextures function does not accept, explicitly cast to int (#12650) 2026-02-26 03:07:35 -08:00
Christian Byrne 74b5a337dc fix: move essentials_category to correct replacement nodes (#12568)
Move essentials_category from deprecated/incorrect nodes to their replacements:
- ImageBatch → BatchImagesNode (ImageBatch is deprecated)
- Blur → removed (should use subgraph blueprint)
- GetVideoComponents → Video Slice

Amp-Thread-ID: https://ampcode.com/threads/T-019c8340-4da2-723b-a09f-83895c5bbda5
2026-02-26 01:00:32 -08:00
comfyanonymous 8a4d85c708 Cleanups to the last PR. (#12646) 2026-02-26 01:30:31 -05:00
Tavi Halperin a4522017c5 feat: per-guide attention strength control in self-attention (#12518)
Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
2026-02-26 01:25:23 -05:00
Jukka Seppänen 907e5dcbbf initial FlowRVS support (#12637) 2026-02-25 23:38:46 -05:00
comfyanonymous 7253531670 Fix ltxav te mem estimation. (#12643) 2026-02-25 23:13:47 -05:00
comfyanonymous e14b04478c Fix LTXAV text enc min length. (#12640)
Should have been 1024 instead of 512
2026-02-25 22:36:02 -05:00
Christian Byrne eb8737d675 Update requirements.txt (#12642) 2026-02-25 18:30:48 -08:00
rattus 0467f690a8 comfy aimdo 0.2.2 (#12635)
Comfy Aimdo 0.2.2 moves the cuda allocator hook from the cudart API to
the cuda driver API on windows. This is needed to handle Windows+cu13
where cudart is statically linked.
2026-02-25 16:50:05 -05:00
rattus 4f5b7dbf1f Fix Aimdo fallback on probe to not use zero-copy sft (#12634)
* utils: dont use comfy sft loader in aimdo fallback

This was going to the raw command line switch and should respect main.py
probe of whether aimdo actually loaded successfully.

* ops: dont use deferred linear load in Aimdo fallback

Avoid changes of behaviour on --fast dynamic_vram when aimdo doesnt work.
2026-02-25 16:49:48 -05:00
rattus 3ebe1ac22e Disable dynamic_vram when using torch compiler (#12612)
* mp: attach re-construction arguments to model patcher

When making a model-patcher from a unet or ckpt, attach a callable
function that can be called to replay the model construction. This
can be used to deep clone model patcher WRT the actual model.

Originally written by Kosinkadink
https://github.com/Comfy-Org/ComfyUI/commit/f4b99bc62389af315013dda85f24f2bbd262b686

* mp: Add disable_dynamic clone argument

Add a clone argument that lets a caller clone a ModelPatcher but disable
dynamic to demote the clone to regular MP. This is useful for legacy
features where dynamic_vram support is missing or TBD.

* torch_compile: disable dynamic_vram

This is a bigger feature. Disable for the interim to preserve
functionality.
2026-02-24 19:13:46 -05:00
rattus befa83d434 comfy aimdo 0.2.1 (#12620)
Changes:

throttle VRAM threshold checks to restore performance in high-layer-rate
conditions.
2026-02-24 16:02:26 -05:00
Jedrzej Kosinski 33f83d53ae Fix KeyError when prompt entries lack class_type key (#12595)
Skip entries in the prompt dict that don't contain a class_type key
in apply_replacements(), preventing crashes on metadata or non-node
entries.

Fixes Comfy-Org/ComfyUI#12517
2026-02-24 16:02:05 -05:00
69 changed files with 3937 additions and 467 deletions
+1 -1
View File
@@ -16,7 +16,7 @@ body:
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:
-2
View File
@@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
+3 -39
View File
@@ -17,7 +17,7 @@ from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
@@ -45,25 +45,7 @@ def get_installed_frontend_version():
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
return get_required_packages_versions().get("comfyui-frontend-package", None)
def check_frontend_version():
@@ -217,25 +199,7 @@ class FrontendManager:
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
return get_required_packages_versions().get("comfyui-workflow-templates", None)
@classmethod
def default_frontend_path(cls) -> str:
+2
View File
@@ -46,6 +46,8 @@ class NodeReplaceManager:
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
if "class_type" not in node_struct or "inputs" not in node_struct:
continue
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
+2 -2
View File
@@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
@@ -260,4 +260,4 @@ else:
args.fast = set(args.fast)
def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
+20 -1
View File
@@ -4,6 +4,25 @@ import comfy.utils
import logging
def is_equal(x, y):
if torch.is_tensor(x) and torch.is_tensor(y):
return torch.equal(x, y)
elif isinstance(x, dict) and isinstance(y, dict):
if x.keys() != y.keys():
return False
return all(is_equal(x[k], y[k]) for k in x)
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
if type(x) is not type(y) or len(x) != len(y):
return False
return all(is_equal(a, b) for a, b in zip(x, y))
else:
try:
return x == y
except Exception:
logging.warning("comparison issue with COND")
return False
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
if not is_equal(self.cond, other.cond):
return False
return True
+1 -1
View File
@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
return # substep from multi-step sampler: keep self._step from the last full step
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
+7
View File
@@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
def process_out(self, latent):
return latent
class ZImagePixelSpace(ChromaRadiance):
"""Pixel-space latent format for ZImage DCT variant.
No VAE encoding/decoding — the model operates directly on RGB pixels.
"""
pass
+164 -30
View File
@@ -2,11 +2,16 @@ from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
ADALN_BASE_PARAMS_COUNT,
ADALN_CROSS_ATTN_PARAMS_COUNT,
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
NormSingleLinearTextProjection,
LTXVModel,
apply_cross_attention_adaln,
compute_prompt_timestep,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
apply_gated_attention=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=v_dim,
@@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
@@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def _apply_text_cross_attention(
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
timestep, prompt_timestep, attention_mask, transformer_options,
):
"""Apply text cross-attention, with optional ADaLN modulation."""
if self.cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(
scale_shift_table, x.shape[0], timestep, slice(6, 9)
)
return apply_cross_attention_adaln(
x, context, attn, shift_q, scale_q, gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask, transformer_options,
)
return attn(
comfy.ldm.common_dit.rms_norm(x), context=context,
mask=attention_mask, transformer_options=transformer_options,
)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
v_prompt_timestep=None, a_prompt_timestep=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@@ -234,13 +273,17 @@ class BasicAVTransformerBlock(nn.Module):
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
vx.add_(self._apply_text_cross_attention(
vx, v_context, self.attn2, self.scale_shift_table,
getattr(self, 'prompt_scale_shift_table', None),
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
)
# audio
if run_ax:
@@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
ax.add_(self._apply_text_cross_attention(
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
getattr(self, 'audio_prompt_scale_shift_table', None),
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
)
# video - audio cross attention.
if run_a2v or run_v2a:
@@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
apply_gated_attention=False,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
self.apply_gated_attention = apply_gated_attention
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
@@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
)
# Audio-specific AdaLN
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=audio_embedding_coefficient,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=2,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_prompt_adaln_single = None
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
@@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.audio_caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_caption_projection = lambda a: a
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
connector_split_rope = kwargs.get("rope_type", "split") == "split"
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
num_layers = kwargs.get("connector_num_layers", 2)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
def preprocess_text_embeds(self, context, unprocessed=False):
# LTXv2 fully processed context has dimension of self.caption_channels * 2
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
if not unprocessed:
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
return context
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
context_vid = context[:, :, :self.cross_attention_dim]
context_audio = context[:, :, self.cross_attention_dim:]
else:
context_vid = context
context_audio = context
if self.caption_proj_before_connector:
context_vid = self.caption_projection(context_vid)
context_audio = self.audio_caption_projection(context_audio)
out_vid = self.video_embeddings_connector(context_vid)[0]
out_audio = self.audio_embeddings_connector(context_audio)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
apply_gated_attention=self.apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
@@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
timestep.max().expand_as(a_timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
a_timestep.max().expand_as(timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
a_prompt_timestep = compute_prompt_timestep(
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
)
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
a_prompt_timestep = None
return [v_timestep, a_timestep, cross_av_timestep_ss], [
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
v_embedded_timestep,
a_embedded_timestep,
]
], None
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
video_dim = vx.shape[-1]
audio_dim = ax.shape[-1]
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
context, [v_context_dim, a_context_dim], len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
if self.caption_proj_before_connector is False:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
a_context = a_context.view(batch_size, -1, audio_dim)
return [v_context, a_context], attention_mask
@@ -726,7 +848,7 @@ class LTXAVModel(LTXVModel):
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
):
vx = x[0]
ax = x[1]
@@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
v_prompt_timestep = timestep[3]
a_prompt_timestep = timestep[4]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -770,6 +895,9 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
v_prompt_timestep=args.get("v_prompt_timestep"),
a_prompt_timestep=args.get("a_prompt_timestep"),
)
return out
@@ -790,6 +918,9 @@ class LTXAVModel(LTXVModel):
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
"v_prompt_timestep": v_prompt_timestep,
"a_prompt_timestep": a_prompt_timestep,
},
{"original_block": block_wrap},
)
@@ -811,6 +942,9 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
v_prompt_timestep=v_prompt_timestep,
a_prompt_timestep=a_prompt_timestep,
)
return [vx, ax]
@@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
d_head,
context_dim=None,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
heads=n_heads,
dim_head=d_head,
context_dim=None,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
+411 -31
View File
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import logging
import math
from typing import Dict, Optional, Tuple
@@ -14,6 +15,8 @@ import comfy.ldm.common_dit
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
logger = logging.getLogger(__name__)
def _log_base(x, base):
return np.log(x) / np.log(base)
@@ -272,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class NormSingleLinearTextProjection(nn.Module):
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
def __init__(
self, in_features, hidden_size, dtype=None, device=None, operations=None
):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.in_norm = operations.RMSNorm(
in_features, eps=1e-6, elementwise_affine=False
)
self.linear_1 = operations.Linear(
in_features, hidden_size, bias=True, dtype=dtype, device=device
)
self.hidden_size = hidden_size
self.in_features = in_features
def forward(self, caption):
caption = self.in_norm(caption)
caption = caption * (self.hidden_size / self.in_features) ** 0.5
return self.linear_1(caption)
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
@@ -340,6 +367,7 @@ class CrossAttention(nn.Module):
dim_head=64,
dropout=0.0,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -359,6 +387,12 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
else:
self.to_gate_logits = None
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
@@ -380,16 +414,30 @@ class CrossAttention(nn.Module):
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
out = out.view(b, t, self.heads, self.dim_head)
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
out = out * gates.unsqueeze(-1)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
ADALN_BASE_PARAMS_COUNT = 6
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
class BasicTransformerBlock(nn.Module):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
@@ -413,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
if self.cross_attention_adaln:
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
x += apply_cross_attention_adaln(
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
)
else:
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
@@ -432,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
return x
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
"""Compute a single global prompt timestep for cross-attention ADaLN.
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
over text tokens. Returns None when *adaln_module* is None.
"""
if adaln_module is None:
return None
ts_input = (
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
if timestep_scaled.dim() > 1
else timestep_scaled.flatten()
)
prompt_ts, _ = adaln_module(
ts_input,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
def apply_cross_attention_adaln(
x, context, attn, q_shift, q_scale, q_gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask=None, transformer_options={},
):
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
that both regular tensors and CompressedTimestep are supported.
"""
batch_size = x.shape[0]
shift_kv, scale_kv = (
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
).unbind(dim=2)
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
@@ -553,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
caption_projection_first_linear=True,
dtype=None,
device=None,
operations=None,
@@ -579,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.caption_proj_before_connector = caption_proj_before_connector
self.cross_attention_adaln = cross_attention_adaln
self.caption_projection_first_linear = caption_projection_first_linear
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
@@ -606,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
else:
self.prompt_adaln_single = None
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.caption_projection = lambda a: a
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
@@ -638,8 +760,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
"""Process input data. Must be implemented by subclasses."""
pass
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Base implementation returns None (no attenuation). Subclasses that
support guide-based attention control should override this.
"""
return None
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@@ -654,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep_scaled = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -666,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
return timestep, embedded_timestep, prompt_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
if self.caption_proj_before_connector is False:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
@@ -781,16 +915,25 @@ class LTXBaseModel(torch.nn.Module, ABC):
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
merged_args["prompt_timestep"] = prompt_timestep
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Build self-attention mask for per-guide attenuation
self_attention_mask = self._build_guide_self_attention_mask(
x, transformer_options, merged_args
)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
x, context, attention_mask, timestep, pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
**merged_args,
)
# Process output
@@ -814,7 +957,9 @@ class LTXVModel(LTXBaseModel):
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
timestep_scale_multiplier=1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -833,6 +978,8 @@ class LTXVModel(LTXBaseModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -841,7 +988,6 @@ class LTXVModel(LTXBaseModel):
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -853,6 +999,7 @@ class LTXVModel(LTXBaseModel):
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -890,26 +1037,257 @@ class LTXVModel(LTXBaseModel):
pixel_coords = pixel_coords[:, :, grid_mask, ...]
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
# Compute per-guide surviving token counts from guide_attention_entries.
# Each entry tracks one guide reference; they are appended in order and
# their pre_filter_counts partition the kf_grid_mask.
guide_entries = kwargs.get("guide_attention_entries", None)
if guide_entries:
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
if total_pfc != len(kf_grid_mask):
raise ValueError(
f"guide pre_filter_counts ({total_pfc}) != "
f"keyframe grid mask length ({len(kf_grid_mask)})"
)
resolved_entries = []
offset = 0
for entry in guide_entries:
pfc = entry["pre_filter_count"]
entry_mask = kf_grid_mask[offset:offset + pfc]
surviving = int(entry_mask.sum().item())
resolved_entries.append({
**entry,
"surviving_count": surviving,
})
offset += pfc
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
x = self.patchify_proj(x)
return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Reads resolved_guide_entries from merged_args (computed in _process_input)
to build a log-space additive bias mask that attenuates noisy ↔ guide
attention for each guide reference independently.
Returns None if no attenuation is needed (all strengths == 1.0 and no
spatial masks, or no guide tokens).
"""
if isinstance(x, list):
# AV model: x = [vx, ax]; use vx for token count and device
total_tokens = x[0].shape[1]
device = x[0].device
dtype = x[0].dtype
else:
total_tokens = x.shape[1]
device = x.device
dtype = x.dtype
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
if num_guide_tokens == 0:
return None
resolved_entries = merged_args.get("resolved_guide_entries", None)
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
# Guides are appended in order at the end of the sequence.
guide_start = total_tokens - num_guide_tokens
all_weights = []
total_tracked = 0
for entry in resolved_entries:
surviving = entry["surviving_count"]
if surviving == 0:
continue
strength = entry["strength"]
pixel_mask = entry.get("pixel_mask")
latent_shape = entry.get("latent_shape")
if pixel_mask is not None and latent_shape is not None:
f_lat, h_lat, w_lat = latent_shape
per_token = self._downsample_mask_to_latent(
pixel_mask.to(device=device, dtype=dtype),
f_lat, h_lat, w_lat,
)
# per_token shape: (B, f_lat*h_lat*w_lat).
# Collapse batch dim — the mask is assumed identical across the
# batch; validate and take the first element to get (1, tokens).
if per_token.shape[0] > 1:
ref = per_token[0]
for bi in range(1, per_token.shape[0]):
if not torch.equal(ref, per_token[bi]):
logger.warning(
"pixel_mask differs across batch elements; "
"using first element only."
)
break
per_token = per_token[:1]
# `surviving` is the post-grid_mask token count.
# Clamp to surviving to handle any mismatch safely.
n_weights = min(per_token.shape[1], surviving)
weights = per_token[:, :n_weights] * strength # (1, n_weights)
else:
weights = torch.full(
(1, surviving), strength, device=device, dtype=dtype
)
all_weights.append(weights)
total_tracked += weights.shape[1]
if not all_weights:
return None
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
"""Downsample a pixel-space mask to per-token latent weights.
Args:
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
f_lat: Number of latent frames (pre-dilation original count).
h_lat: Latent height (pre-dilation original height).
w_lat: Latent width (pre-dilation original width).
Returns:
(B, F_lat * H_lat * W_lat) flattened per-token weights.
"""
b = mask.shape[0]
f_pix = mask.shape[2]
# Spatial downsampling: area interpolation per frame
spatial_down = torch.nn.functional.interpolate(
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
size=(h_lat, w_lat),
mode="area",
)
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
# Temporal downsampling: first pixel frame maps to first latent frame,
# remaining pixel frames are averaged in groups for causal temporal structure.
first_frame = spatial_down[:, :, :1, :, :]
if f_pix > 1 and f_lat > 1:
remaining_pix = f_pix - 1
remaining_lat = f_lat - 1
t = remaining_pix // remaining_lat
if t < 1:
# Fewer pixel frames than latent frames — upsample by repeating
# the available pixel frames via nearest interpolation.
rest_flat = rearrange(
spatial_down[:, :, 1:, :, :],
"b 1 f h w -> (b h w) 1 f",
)
rest_up = torch.nn.functional.interpolate(
rest_flat, size=remaining_lat, mode="nearest",
)
rest = rearrange(
rest_up, "(b h w) 1 f -> b 1 f h w",
b=b, h=h_lat, w=w_lat,
)
else:
# Trim trailing pixel frames that don't fill a complete group
usable = remaining_lat * t
rest = rearrange(
spatial_down[:, :, 1:1 + usable, :, :],
"b 1 (f t) h w -> b 1 f t h w",
t=t,
)
rest = rest.mean(dim=3)
latent_mask = torch.cat([first_frame, rest], dim=2)
elif f_lat > 1:
# Single pixel frame but multiple latent frames — repeat the
# single frame across all latent frames.
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
else:
latent_mask = first_frame
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prompt_timestep = kwargs.get("prompt_timestep", None)
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -919,6 +1297,8 @@ class LTXVModel(LTXBaseModel):
timestep=timestep,
pe=pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
prompt_timestep=prompt_timestep,
)
return x
+5 -2
View File
@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
LATENT_DOWNSAMPLE_FACTOR = 4
@@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
super().__init__()
if config is None:
config = self._guess_config()
config = self.get_default_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
self.sampling_rate = model_config.get(
"sampling_rate", config.get("sampling_rate", 16000)
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
def get_default_config(self):
ddconfig = {
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"mel_bins": 64,
"z_channels": 8,
"resolution": 256,
"downsample_time": False,
"in_channels": 2,
"out_ch": 2,
"ch": 128,
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
"mid_block_add_attention": False,
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
"causality_axis": "height",
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
"ddconfig": ddconfig,
"sampling_rate": 16000,
}
},
"preprocessing": {
"stft": {
"filter_length": 1024,
"hop_length": 160,
},
},
}
return config
@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)
self.temporal_cache_state={}
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
super().__init__()
if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)
self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)
self.decoder = Decoder(
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
self.per_channel_statistics = processor()
def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
+511 -12
View File
@@ -3,6 +3,7 @@ import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import numpy as np
import math
ops = comfy.ops.disable_weight_init
@@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
# Adopted from https://github.com/NVIDIA/BigVGAN
# ---------------------------------------------------------------------------
def _sinc(x: torch.Tensor):
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride=1,
padding=True,
padding_mode="replicate",
kernel_size=12,
):
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
# Uses replicate boundary padding, matching the reference resampler exactly.
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser-windowed sinc filter (BigVGAN default).
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter, persistent=persistent)
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
return self.lowpass(x)
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio=2,
down_ratio=2,
up_kernel_size=12,
down_kernel_size=12,
):
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# ---------------------------------------------------------------------------
# BigVGAN v2 activations (Snake / SnakeBeta)
# ---------------------------------------------------------------------------
class Snake(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
class SnakeBeta(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.beta = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.beta.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1)
b = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
a = torch.exp(a)
b = torch.exp(b)
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
# ---------------------------------------------------------------------------
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
# ---------------------------------------------------------------------------
class AMPBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
self.acts1 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
)
self.acts2 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
)
def forward(self, x):
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# HiFi-GAN residual blocks
# ---------------------------------------------------------------------------
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
@@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
"""
def __init__(self, config=None):
@@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
activation = config.get("activation", "snake")
use_bias_at_final = config.get("use_bias_at_final", True)
# "output_sample_rate" is not present in recent checkpoint configs.
# When absent (None), AudioVAE.output_sample_rate computes it as:
# sample_rate * vocoder.upsample_factor / mel_hop_length
# where upsample_factor = product of all upsample stride lengths,
# and mel_hop_length is loaded from the autoencoder config at
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
self.output_sample_rate = config.get("output_sample_rate")
self.resblock = config.get("resblock", "1")
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
self.apply_final_activation = config.get("apply_final_activation", True)
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
if self.resblock == "1":
resblock_cls = ResBlock1
elif self.resblock == "2":
resblock_cls = ResBlock2
elif self.resblock == "AMP1":
resblock_cls = AMPBlock1
else:
raise ValueError(f"Unknown resblock type: {self.resblock}")
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
@@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
if self.resblock == "AMP1":
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
else:
self.resblocks.append(resblock_cls(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
if self.resblock == "AMP1":
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(ch))
else:
self.act_post = nn.LeakyReLU()
self.conv_post = ops.Conv1d(
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_rates": [5, 4, 2, 2, 2],
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
"activation": "snake",
"use_bias_at_final": True,
"use_tanh_at_final": True,
}
return config
@@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
if self.resblock != "AMP1":
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
@@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.act_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
if self.apply_final_activation:
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, -1, 1)
return x
class _STFTFn(nn.Module):
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
bfloat16 bases from training ensures the mel values fed to the BWE generator are
bit-identical to what it was trained on.
"""
def __init__(self, filter_length: int, hop_length: int, win_length: int):
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute magnitude and phase spectrogram from a batch of waveforms.
Applies causal (left-only) padding of win_length - hop_length samples so that
each output frame depends only on past and present input — no lookahead.
The STFT is computed by convolving the padded signal with forward_basis.
Args:
y: Waveform tensor of shape (B, T).
Returns:
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
Computed in float32 for numerical stability, then cast back to
the input dtype.
"""
if y.dim() == 2:
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
"""
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
sampling_rate: int,
mel_fmin: float,
mel_fmax: float,
):
super().__init__()
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
def mel_spectrogram(
self, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute log-mel spectrogram and auxiliary spectral quantities.
Args:
y: Waveform tensor of shape (B, T).
Returns:
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class VocoderWithBWE(torch.nn.Module):
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
"""
def __init__(self, config):
super().__init__()
vocoder_config = config["vocoder"]
bwe_config = config["bwe"]
self.vocoder = Vocoder(config=vocoder_config)
self.bwe_generator = Vocoder(
config={**bwe_config, "apply_final_activation": False}
)
self.input_sample_rate = bwe_config["input_sampling_rate"]
self.output_sample_rate = bwe_config["output_sampling_rate"]
self.hop_length = bwe_config["hop_length"]
self.mel_stft = MelSTFT(
filter_length=bwe_config["n_fft"],
hop_length=bwe_config["hop_length"],
win_length=bwe_config["n_fft"],
n_mel_channels=bwe_config["num_mels"],
sampling_rate=bwe_config["input_sampling_rate"],
mel_fmin=0.0,
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
)
self.resampler = UpSample1d(
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
persistent=False,
window_type="hann",
)
def _compute_mel(self, audio):
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
B, C, T = audio.shape
flat = audio.reshape(B * C, -1) # (B*C, T)
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
def forward(self, mel_spec):
x = self.vocoder(mel_spec)
_, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate
remainder = T_low % self.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
mel = self._compute_mel(x)
residual = self.bwe_generator(mel)
skip = self.resampler(x)
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
+265
View File
@@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
def invert_slices(slices, length):
@@ -858,3 +859,267 @@ class NextDiT(nn.Module):
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img
#############################################################################
# Pixel Space Decoder Components #
#############################################################################
def _modulate_shift_scale(x, shift, scale):
return x * (1 + scale) + shift
class PixelResBlock(nn.Module):
"""
Residual block with AdaLN modulation, zero-initialised so it starts as
an identity at the beginning of training.
"""
def __init__(self, channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
h = self.mlp(h)
return x + gate * h
class DCTFinalLayer(nn.Module):
"""Zero-initialised output projection (adopted from DiT)."""
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
"""
Small MLP decoder head for the pixel-space variant.
Takes per-patch pixel values and a per-patch conditioning vector from the
transformer backbone and predicts the denoised pixel values.
x : [B*N, P^2, C] noisy pixel values per patch position
c : [B*N, dim] backbone hidden state per patch (conditioning)
→ [B*N, P^2, C]
"""
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
max_freqs: int = 8,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
# Project backbone hidden state → per-patch conditioning
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
# Input projection with DCT positional encoding
self.input_embedder = NerfEmbedder(
in_channels=in_channels,
hidden_size_input=model_channels,
max_freqs=max_freqs,
dtype=dtype,
device=device,
operations=operations,
)
# Residual blocks
self.res_blocks = nn.ModuleList([
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
])
# Output projection
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
#############################################################################
# NextDiT Pixel Space #
#############################################################################
class NextDiTPixelSpace(NextDiT):
"""
Pixel-space variant of NextDiT.
Identical transformer backbone to NextDiT, but the output head is replaced
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
per patch rather than a single affine projection.
Key differences vs NextDiT:
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
• ``_forward`` stores the raw patchified pixel values before the backbone
embedding and feeds them to ``dec_net`` together with the per-patch
backbone hidden states.
• Supports optional x0 prediction via ``use_x0``.
"""
def __init__(
self,
# decoder-specific
decoder_hidden_size: int = 3840,
decoder_num_res_blocks: int = 4,
decoder_max_freqs: int = 8,
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
use_x0: bool = False,
# all NextDiT args forwarded unchanged
**kwargs,
):
super().__init__(**kwargs)
# Remove the latent-space final layer not used in pixel space
del self.final_layer
patch_size = kwargs.get("patch_size", 2)
in_channels = kwargs.get("in_channels", 4)
dim = kwargs.get("dim", 4096)
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
self.dec_net = SimpleMLPAdaLN(
in_channels=dec_in_ch,
model_channels=decoder_hidden_size,
out_channels=dec_in_ch,
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
max_freqs=decoder_max_freqs,
dtype=kwargs.get("dtype"),
device=kwargs.get("device"),
operations=kwargs.get("operations"),
)
if use_x0:
self.register_buffer("__x0__", torch.tensor([]))
# ------------------------------------------------------------------
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
# with the pixel-space dec_net decoder.
# ------------------------------------------------------------------
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
adaln_input = t
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
pH = pW = self.patch_size
B, C, H, W = x.shape
pixel_patches = (
x.view(B, C, H // pH, pH, W // pW, pW)
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
.flatten(3) # [B, Ht, Wt, pH*pW*C]
.flatten(1, 2) # [B, N, pH*pW*C]
)
N = pixel_patches.shape[1]
# decoder sees one token per patch: [B*N, 1, P^2*C]
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens,
ref_latents=ref_latents, ref_contexts=ref_contexts,
siglip_feats=siglip_feats, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
# img may have padding tokens beyond N; only the first N are real image patches
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
output = output.reshape(B, N, -1) # [B, N, P^2*C]
# prepend zero cap placeholder so unpatchify indexing works unchanged
cap_placeholder = torch.zeros(
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
)
img_out = self.unpatchify(
torch.cat([cap_placeholder, output], dim=1),
img_size, cap_size, return_tensor=x_is_tensor
)[:, :, :h, :w]
return -img_out
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
# _forward returns neg_x0 = -x0 (negated decoder output).
#
# Reference inference (working_inference_reference.py):
# out = _forward(img, t) # = -x0
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
# img += (t_prev - t_curr) * pred # Euler step
#
# ComfyUI's Euler sampler does the same:
# x_next = x + (sigma_next - sigma) * model_output
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
@@ -18,6 +18,8 @@ import comfy.patcher_extension
import comfy.ops
ops = comfy.ops.disable_weight_init
from ..sdpose import HeatmapHead
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
@@ -441,6 +443,7 @@ class UNetModel(nn.Module):
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
heatmap_head=False,
device=None,
operations=ops,
):
@@ -827,6 +830,9 @@ class UNetModel(nn.Module):
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
if heatmap_head:
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
+130
View File
@@ -0,0 +1,130 @@
import torch
import numpy as np
from scipy.ndimage import gaussian_filter
class HeatmapHead(torch.nn.Module):
def __init__(
self,
in_channels=640,
out_channels=133,
input_size=(768, 1024),
heatmap_scale=4,
deconv_out_channels=(640,),
deconv_kernel_sizes=(4,),
conv_out_channels=(640,),
conv_kernel_sizes=(1,),
final_layer_kernel_size=1,
device=None, dtype=None, operations=None
):
super().__init__()
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
# Deconv layers
if deconv_out_channels:
deconv_layers = []
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
if kernel_size == 4:
padding, output_padding = 1, 0
elif kernel_size == 3:
padding, output_padding = 1, 1
elif kernel_size == 2:
padding, output_padding = 0, 0
else:
raise ValueError(f'Unsupported kernel size {kernel_size}')
deconv_layers.extend([
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
torch.nn.SiLU(inplace=True)
])
in_channels = out_ch
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
else:
self.deconv_layers = torch.nn.Identity()
# Conv layers
if conv_out_channels:
conv_layers = []
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
padding = (kernel_size - 1) // 2
conv_layers.extend([
operations.Conv2d(in_channels, out_ch, kernel_size,
stride=1, padding=padding, device=device, dtype=dtype),
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
torch.nn.SiLU(inplace=True)
])
in_channels = out_ch
self.conv_layers = torch.nn.Sequential(*conv_layers)
else:
self.conv_layers = torch.nn.Identity()
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
def forward(self, x): # Decode heatmaps to keypoints
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
B, K, H, W = heatmaps_np.shape
batch_keypoints = []
batch_scores = []
for b in range(B):
hm = heatmaps_np[b].copy() # (K, H, W)
# --- vectorised argmax ---
flat = hm.reshape(K, -1)
idx = np.argmax(flat, axis=1)
scores = flat[np.arange(K), idx].copy()
y_locs, x_locs = np.unravel_index(idx, (H, W))
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
invalid = scores <= 0.
keypoints[invalid] = -1
# --- DARK sub-pixel refinement (UDP) ---
# 1. Gaussian blur with max-preserving normalisation
border = 5 # (kernel-1)//2 for kernel=11
for k in range(K):
origin_max = np.max(hm[k])
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
dr[border:-border, border:-border] = hm[k].copy()
dr = gaussian_filter(dr, sigma=2.0)
hm[k] = dr[border:-border, border:-border].copy()
cur_max = np.max(hm[k])
if cur_max > 0:
hm[k] *= origin_max / cur_max
# 2. Log-space for Taylor expansion
np.clip(hm, 1e-3, 50., hm)
np.log(hm, hm)
# 3. Hessian-based Newton step
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
index += (W + 2) * (H + 2) * np.arange(0, K)
index = index.astype(int).reshape(-1, 1)
i_ = hm_pad[index]
ix1 = hm_pad[index + 1]
iy1 = hm_pad[index + W + 2]
ix1y1 = hm_pad[index + W + 3]
ix1_y1_ = hm_pad[index - W - 3]
ix1_ = hm_pad[index - 1]
iy1_ = hm_pad[index - 2 - W]
dx = 0.5 * (ix1 - ix1_)
dy = 0.5 * (iy1 - iy1_)
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
dxx = ix1 - 2 * i_ + ix1_
dyy = iy1 - 2 * i_ + iy1_
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
# --- restore to input image space ---
keypoints = keypoints * self.scale_factor
keypoints[invalid] = -1
batch_keypoints.append(keypoints)
batch_scores.append(scores)
return batch_keypoints, batch_scores
+115
View File
@@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class SCAILWanModel(WanModel):
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
del scail_x
# time embeddings
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.cat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
if scail_pose_seq_len > 0:
x = x[:, :-scail_pose_seq_len]
# unpatchify
x = self.unpatchify(x, grid_sizes)
if reference_latent is not None:
x = x[:, :, reference_latent.shape[2]:]
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
if pose_latents is None:
return main_freqs
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
h_scale = h / H_pose
w_scale = w / W_pose
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
return torch.cat([main_freqs, pose_freqs], dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
reference_latent = None
if "reference_latent" in kwargs:
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
+3 -2
View File
@@ -459,6 +459,7 @@ class WanVAE(nn.Module):
attn_scales=[],
temperal_downsample=[True, True, False],
image_channels=3,
conv_out_channels=3,
dropout=0.0):
super().__init__()
self.dim = dim
@@ -474,7 +475,7 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
@@ -484,7 +485,7 @@ class WanVAE(nn.Module):
iter_ = 1 + (t - 1) // 4
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x,按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
+1
View File
@@ -337,6 +337,7 @@ def model_lora_keys_unet(model, key_map={}):
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
return key_map
+80 -1
View File
@@ -76,6 +76,7 @@ class ModelType(Enum):
FLUX = 8
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type):
@@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c):
pass
@@ -922,6 +925,25 @@ class Flux(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class LongCatImage(Flux):
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts)
rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", 512.0)
rope_opts.setdefault("shift_x", 512.0)
transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def encode_adm(self, **kwargs):
return None
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out.pop('guidance', None)
return out
class Flux2(Flux):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -971,6 +993,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@@ -995,7 +1021,7 @@ class LTXAV(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@@ -1023,6 +1049,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@@ -1233,6 +1263,11 @@ class Lumina2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class ZImagePixelSpace(Lumina2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
@@ -1466,6 +1501,50 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class WAN21_SCAIL(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_mask = torch.ones_like(ref_latent[:, :4])
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
pose_latents = self.process_latent_in(pose_latents)
pose_mask = torch.ones_like(pose_latents[:, :4])
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
+44 -4
View File
@@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
dit_config["txt_ids_dims"] = [1, 2]
return dit_config
@@ -421,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
@@ -462,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
if dec_cond_key in state_dict_keys: # pixel-space variant
dit_config["image_model"] = "zimage_pixel"
# patch_size and in_channels are derived from x_embedder:
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
# patch_size: infer from decoder final layer output matching x_embedder input
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
dit_config["in_channels"] = 3
dit_config["decoder_in_channels"] = dec_in_ch
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
dit_config["decoder_num_res_blocks"] = count_blocks(
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
)
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
if '{}__x0__'.format(key_prefix) in state_dict_keys:
dit_config["use_x0"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@@ -496,6 +521,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@@ -509,6 +536,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if ref_conv_weight is not None:
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
@@ -526,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
@@ -792,6 +821,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix)
if heatmap_key in state_dict_keys:
unet_config["heatmap_head"] = True
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None):
@@ -1012,7 +1045,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
@@ -1044,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
n_layers = count_blocks(state_dict, 'layers.{}.')
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
for k in state_dict: # For zeta chroma
if k not in sd_map:
sd_map[k] = k
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
+29 -50
View File
@@ -32,9 +32,6 @@ import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.torch
import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -180,6 +177,14 @@ def is_ixuca():
return True
return False
def is_wsl():
version = platform.uname().release
if version.endswith("-Microsoft"):
return True
elif version.endswith("microsoft-standard-WSL2"):
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -350,7 +355,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
@@ -378,7 +383,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@@ -631,12 +636,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
@@ -792,6 +796,8 @@ def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
for buf_name, buf in module.named_buffers(recurse=False):
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
def cleanup_models():
@@ -824,11 +830,14 @@ def unet_offload_device():
return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
cpu_dev = torch.device("cpu")
if comfy.memory_management.aimdo_enabled:
return cpu_dev
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
@@ -836,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
@@ -939,6 +948,9 @@ def text_encoder_device():
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if comfy.memory_management.aimdo_enabled:
return offload_device
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
@@ -1199,43 +1211,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
if dtype is None:
dtype = weight._model_dtype
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)
return r
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@@ -1691,12 +1666,16 @@ def lora_compute_dtype(device):
return dtype
def synchronize():
if cpu_mode():
return
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
if cpu_mode():
return
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
+83 -49
View File
@@ -241,6 +241,7 @@ class ModelPatcher:
self.patches = {}
self.backup = {}
self.backup_buffers = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
@@ -271,6 +272,7 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -305,10 +307,30 @@ class ModelPatcher:
return self.model.lowvram_patch_counter
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
#the vast majority of setups a little bit of offloading on the giant model more
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
return comfy.model_management.get_free_memory(device) + aimdo_mem
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
def get_clone_model_override(self):
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
if model_override is None:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -317,13 +339,12 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
# attachments
n.attachments = {}
for k in self.attachments:
@@ -362,6 +383,8 @@ class ModelPatcher:
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
@@ -682,7 +705,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, for_dynamic=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
default = False
@@ -710,8 +733,13 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
# Non-dynamic: prioritize by module offload cost.
if for_dynamic:
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
else:
sort_criteria = (module_offload_mem,)
loading.append(sort_criteria + (module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -1419,12 +1447,9 @@ class ModelPatcherDynamic(ModelPatcher):
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
assert load_device is not None
def is_dynamic(self):
@@ -1444,15 +1469,7 @@ class ModelPatcherDynamic(ModelPatcher):
def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
@@ -1487,6 +1504,7 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0
allocated_size = 0
self.model.model_loaded_weight_memory = 0
with self.use_ejected():
self.unpatch_hooks()
@@ -1495,15 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
loading = self._load_list(for_dynamic=True, default_device=device_to)
loading.sort()
for x in loading:
_, _, _, n, m, params = x
*_, module_mem, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
@@ -1541,6 +1555,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
weight, _, _ = get_key_weight(self.model, key)
if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1566,21 +1583,26 @@ class ModelPatcherDynamic(ModelPatcher):
for param in params:
key = key_param_name_to_key(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
casted_weight = weight.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_param(self.model, key, casted_weight)
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
for key, buf in self.model.named_buffers(recurse=True):
if key not in self.backup_buffers:
self.backup_buffers[key] = buf
module, buf_name = comfy.utils.resolve_attr(self.model, key)
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
casted_buf = buf.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
@@ -1596,12 +1618,23 @@ class ModelPatcherDynamic(ModelPatcher):
assert self.load_device != torch.device("cpu")
vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free:
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
return freed
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
@@ -1623,11 +1656,6 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
@@ -1659,4 +1687,10 @@ class ModelPatcherDynamic(ModelPatcher):
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass
def get_non_dynamic_delegate(self):
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
return model_patcher
CoreModelPatcher = ModelPatcher
+10
View File
@@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class IMG_TO_IMG_FLOW(CONST):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return latent_image
def inverse_noise_scaling(self, sigma, latent):
return 1.0 - latent
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
+41 -21
View File
@@ -19,7 +19,7 @@
import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import json
import comfy.memory_management
@@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None
xfer_dest = None
@@ -167,17 +182,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(want_requant and len(fns) == 0 or update_weight)):
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
x = y
if update_weight:
orig.copy_(y)
for f in fns:
@@ -271,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
#FIXME: This is really bad RTTI
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
@@ -296,7 +309,7 @@ class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(in_features, out_features, bias, device, dtype)
return
@@ -317,7 +330,7 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
@@ -617,7 +630,8 @@ def fp8_linear(self, input):
if input.ndim != 2:
return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
@@ -661,23 +675,29 @@ class fp8_ops(manual_cast):
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
from cublas_ops import CublasLinear, cublas_half_matmul
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
class cublas_ops(manual_cast):
class Linear(CublasLinear, manual_cast.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# ==============================================================================
# Mixed Precision Operations
+12
View File
@@ -66,6 +66,18 @@ def convert_cond(cond):
out.append(temp)
return out
def cond_has_hooks(cond):
for c in cond:
temp = c[1]
if "hooks" in temp:
return True
if "control" in temp:
control = temp["control"]
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
return True
return False
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets: list[ControlBase] = []
+2
View File
@@ -946,6 +946,8 @@ class CFGGuider:
def inner_set_conds(self, conds):
for k in conds:
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
+57 -20
View File
@@ -60,6 +60,7 @@ import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.model_patcher
import comfy.lora
@@ -203,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
if no_init:
return
params = target.params.copy()
@@ -232,7 +233,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@@ -266,9 +268,9 @@ class CLIP:
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {}
def clone(self):
def clone(self, disable_dynamic=False):
n = CLIP(no_init=True)
n.patcher = self.patcher.clone()
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
@@ -426,7 +428,7 @@ class CLIP:
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
@@ -694,8 +696,9 @@ class VAE:
self.latent_dim = 3
self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@@ -1159,16 +1162,24 @@ class CLIPType(Enum):
KANDINSKY5_IMAGE = 23
NEWBIE = 24
FLUX2 = 25
LONGCAT_IMAGE = 26
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
return clip.patcher
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
return clip
class TEModel(Enum):
@@ -1273,7 +1284,7 @@ def llama_detect(clip_data):
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = state_dicts
class EmptyClass:
@@ -1371,6 +1382,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
elif clip_type == CLIPType.LONGCAT_IMAGE:
clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer
else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
@@ -1453,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:
@@ -1490,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
return clip
def load_gligen(ckpt_path):
@@ -1530,14 +1544,34 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return model
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
embedding_directory=embedding_directory, output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return clip.patcher
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
clip = None
clipvision = None
vae = None
@@ -1586,7 +1620,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
@@ -1621,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
@@ -1637,7 +1672,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1721,7 +1756,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device):
model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
@@ -1730,12 +1766,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
logging.info("left over keys in diffusion model: {}".format(left_over))
return model_patcher
def load_diffusion_model(unet_path, model_options={}):
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
def load_unet(unet_path, dtype=None):
+69 -2
View File
@@ -25,6 +25,7 @@ import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
from . import supported_models_base
from . import latent_formats
@@ -525,7 +526,8 @@ class LotusD(SD20):
}
unet_extra_config = {
"num_classes": 'sequential'
"num_classes": 'sequential',
"num_head_channels": 64,
}
def get_model(self, state_dict, prefix="", device=None):
@@ -1116,6 +1118,20 @@ class ZImage(Lumina2):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class ZImagePixelSpace(ZImage):
unet_config = {
"image_model": "zimage_pixel",
}
# Pixel-space model: no spatial compression, operates on raw RGB patches.
latent_format = latent_formats.ZImagePixelSpace
# Much lower memory than latent-space models (no VAE, small patches).
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@@ -1256,6 +1272,26 @@ class WAN22_T2V(WAN21_T2V):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class WAN21_FlowRVS(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "flow_rvs",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
return out
class WAN21_SCAIL(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1667,6 +1703,37 @@ class ACEStep15(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
class LongCatImage(supported_models_base.BASE):
unet_config = {
"image_model": "flux",
"guidance_embed": False,
"vec_in_dim": None,
"context_in_dim": 3584,
"txt_ids_dims": [1, 2],
}
sampling_settings = {
}
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.5
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LongCatImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]
+2 -2
View File
@@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module):
return getattr(self, self.lm_model).load_sd(sd)
def memory_estimation_function(self, token_weight_pairs, device=None):
lm_metadata = token_weight_pairs["lm_metadata"]
lm_metadata = token_weight_pairs.get("lm_metadata", {})
constant = self.constant
if comfy.model_management.should_use_bf16(device):
constant *= 0.5
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens += lm_metadata['min_tokens']
num_tokens += lm_metadata.get("min_tokens", 0)
return num_tokens * constant * 1024 * 1024
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
+184
View File
@@ -0,0 +1,184 @@
import re
import numbers
import torch
from comfy import sd1_clip
from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel
import logging
logger = logging.getLogger(__name__)
QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")]
QUOTE_PATTERN = "|".join(
[
re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2)
for q1, q2 in QUOTE_PAIRS
]
)
WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
def split_quotation(prompt):
matches = WORD_INTERNAL_QUOTE_RE.findall(prompt)
mapping = []
for i, word_src in enumerate(set(matches)):
word_tgt = "longcat_$##$_longcat" * (i + 1)
prompt = prompt.replace(word_src, word_tgt)
mapping.append((word_src, word_tgt))
parts = re.split(f"({QUOTE_PATTERN})", prompt)
result = []
for part in parts:
for word_src, word_tgt in mapping:
part = part.replace(word_tgt, word_src)
if not part:
continue
is_quoted = bool(re.match(QUOTE_PATTERN, part))
result.append((part, is_quoted))
return result
class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = 512
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
parts = split_quotation(text)
all_tokens = []
for part_text, is_quoted in parts:
if is_quoted:
for char in part_text:
ids = self.tokenizer(char, add_special_tokens=False)["input_ids"]
all_tokens.extend(ids)
else:
ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"]
all_tokens.extend(ids)
if len(all_tokens) > self.max_length:
all_tokens = all_tokens[: self.max_length]
logger.warning(f"Truncated prompt to {self.max_length} tokens")
output = [(t, 1.0) for t in all_tokens]
# Pad to max length
self.pad_tokens(output, self.max_length - len(output))
return [output]
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer,
)
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
skip_template = False
if text.startswith("<|im_start|>"):
skip_template = True
if text.startswith("<|start_header_id|>"):
skip_template = True
if text == "":
text = " "
base_tok = getattr(self, "qwen25_7b")
if skip_template:
tokens = super().tokenize_with_weights(
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
else:
prefix_ids = base_tok.tokenizer(
self.longcat_template_prefix, add_special_tokens=False
)["input_ids"]
suffix_ids = base_tok.tokenizer(
self.longcat_template_suffix, add_special_tokens=False
)["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights(
text, return_word_ids=return_word_ids, **kwargs
)
prompt_pairs = prompt_tokens[0]
prefix_pairs = [(t, 1.0) for t in prefix_ids]
suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs
tokens = {"qwen25_7b": [combined]}
return tokens
class LongCatImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(
device=device,
dtype=dtype,
name="qwen25_7b",
clip_model=Qwen25_7BVLIModel,
model_options=model_options,
)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
if template_end == -1:
template_end = 0
suffix_start = None
for i in range(len(tok_pairs) - 1, -1, -1):
elem = tok_pairs[i][0]
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
if elem == 151645:
suffix_start = i
break
out = out[:, template_end:]
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask")
if suffix_start is not None:
suffix_len = len(tok_pairs) - suffix_start
if suffix_len > 0 and out.shape[1] > suffix_len:
out = out[:, :-suffix_len]
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
if extra["attention_mask"].sum() == torch.numel(
extra["attention_mask"]
):
extra.pop("attention_mask")
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class LongCatImageTEModel_(LongCatImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return LongCatImageTEModel_
+59 -18
View File
@@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
import torch
import comfy.utils
import math
import itertools
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
@@ -96,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
super().__init__()
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
def forward(self, x):
source_dim = x.shape[-1]
x = x.movedim(1, -1)
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
return torch.cat((video, audio), dim=-1)
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.compat_mode = False
self.text_projection_type = text_projection_type
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
if self.text_projection_type == "single_linear":
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
elif self.text_projection_type == "dual_linear":
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
def enable_compat_mode(self): # TODO: remove
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -147,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
if self.text_projection_type == "single_linear":
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
return out.to(out_device), pooled
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
extra = {}
else:
extra = {"unprocessed_ltxav_embeds": True}
elif self.text_projection_type == "dual_linear":
out = self.text_embedding_projection(out)
extra = {"unprocessed_ltxav_embeds": True}
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
@@ -167,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
@@ -199,11 +228,13 @@ class LTXAVTEModel(torch.nn.Module):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens = max(num_tokens, 64)
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
num_tokens = max(num_tokens, 642)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
@@ -211,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
return LTXAVTEModel_
def sd_detect(state_dict_list, prefix=""):
for sd in state_dict_list:
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
return {"text_projection_type": "dual_linear"}
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
return {"text_projection_type": "single_linear"}
return {}
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
+17 -6
View File
@@ -29,7 +29,7 @@ import itertools
from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
from comfy.cli_args import args
import json
import time
import mmap
@@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
if enables_dynamic_vram():
if comfy.memory_management.aimdo_enabled:
sd, metadata = load_safetensors(ckpt)
if not return_metadata:
metadata = None
@@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
ATTR_UNSET={}
def set_attr(obj, attr, value):
def resolve_attr(obj, attr):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
return obj, attrs[-1]
def set_attr(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
delattr(obj, name)
else:
setattr(obj, attrs[-1], value)
setattr(obj, name, value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
obj.register_buffer(name, value, persistent=persistent)
return prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
@@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput):
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
elif isinstance(path, io.BytesIO):
# BytesIO has no file extension, so av.open can't infer the format.
# Default to mp4 since that's the only supported format anyway.
extra_kwargs["format"] = "mp4"
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:
+18 -1
View File
@@ -1224,9 +1224,10 @@ class BoundingBox(ComfyTypeIO):
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None):
socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
self.force_input = force_input
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
@@ -1234,9 +1235,24 @@ class BoundingBox(ComfyTypeIO):
d = super().as_dict()
if self.component:
d["component"] = self.component
if self.force_input is not None:
d["forceInput"] = self.force_input
return d
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -2223,5 +2239,6 @@ __all__ = [
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"Curve",
"NodeReplace",
]
+6
View File
@@ -127,9 +127,15 @@ class GeminiImageConfig(BaseModel):
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: list[str] | None = Field(None)
imageConfig: GeminiImageConfig | None = Field(None)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
class GeminiImageGenerateContentRequest(BaseModel):
+11 -3
View File
@@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
aspect_ratio: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
resolution: str = Field(...)
class InputUrlObject(BaseModel):
@@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
class ImageEditRequest(BaseModel):
model: str = Field(...)
image: InputUrlObject = Field(...)
images: list[InputUrlObject] = Field(...)
prompt: str = Field(...)
resolution: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
aspect_ratio: str | None = Field(...)
class VideoGenerationRequest(BaseModel):
@@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
revised_prompt: str | None = Field(None)
class UsageObject(BaseModel):
cost_in_usd_ticks: int | None = Field(None)
class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...)
usage: UsageObject | None = Field(None)
class VideoGenerationResponse(BaseModel):
@@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
status: str | None = Field(None)
video: VideoResponseObject | None = Field(None)
model: str | None = Field(None)
usage: UsageObject | None = Field(None)
+1
View File
@@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel):
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)
+1 -1
View File
@@ -186,7 +186,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 5.0",
display_name="ByteDance Seedream 4.5 & 5.0",
category="api node/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
+188 -14
View File
@@ -29,6 +29,7 @@ from comfy_api_nodes.apis.gemini import (
GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
GeminiThinkingConfig,
Modality,
)
from comfy_api_nodes.util import (
@@ -55,6 +56,21 @@ GEMINI_IMAGE_SYS_PROMPT = (
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]),
expr="""
(
$m := widgets.model;
$r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
)
""",
)
class GeminiModel(str, Enum):
"""
@@ -229,6 +245,10 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 120.0
elif response.modelVersion == "gemini-3.1-flash-image-preview":
input_tokens_price = 0.5
output_text_tokens_price = 3.0
output_image_tokens_price = 60.0
else:
return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
@@ -686,7 +706,7 @@ class GeminiImage2(IO.ComfyNode):
),
IO.Combo.Input(
"model",
options=["gemini-3-pro-image-preview"],
options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"],
),
IO.Int.Input(
"seed",
@@ -750,19 +770,7 @@ class GeminiImage2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$r := widgets.resolution;
($contains($r,"1k") or $contains($r,"2k"))
? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}}
: $contains($r,"4k")
? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}}
: {"type":"text","text":"Token-based"}
)
""",
),
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
)
@classmethod
@@ -779,6 +787,8 @@ class GeminiImage2(IO.ComfyNode):
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model = "gemini-3.1-flash-image-preview"
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images is not None:
@@ -815,6 +825,169 @@ class GeminiImage2(IO.ComfyNode):
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiNanoBanana2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNanoBanana2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt describing the image to generate or the edits to apply. "
"Include any constraints, styles, or details the model should follow.",
default="",
),
IO.Combo.Input(
"model",
options=["Nano Banana 2 (Gemini 3.1 Flash Image)"],
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
# "1:4",
# "4:1",
# "8:1",
# "1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, a 16:9 square is usually generated.",
),
IO.Combo.Input(
"resolution",
options=[
# "512px",
"1K",
"2K",
"4K",
],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE", "IMAGE+TEXT"],
advanced=True,
),
IO.Combo.Input(
"thinking_level",
options=["MINIMAL", "HIGH"],
),
IO.Image.Input(
"images",
optional=True,
tooltip="Optional reference image(s). "
"To include multiple images, use the Batch Images node (up to 14).",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
)
@classmethod
async def execute(
cls,
prompt: str,
model: str,
seed: int,
aspect_ratio: str,
resolution: str,
response_modalities: str,
thinking_level: str,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model = "gemini-3.1-flash-image-preview"
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images is not None:
if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(await create_image_parts(cls, images))
if files is not None:
parts.extend(files)
image_config = GeminiImageConfig(imageSize=resolution)
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -822,6 +995,7 @@ class GeminiExtension(ComfyExtension):
GeminiNode,
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,
GeminiInputFiles,
]
+76 -16
View File
@@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
)
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
category="api node/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.String.Input(
"prompt",
multiline=True,
@@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
],
outputs=[
IO.Image.Output(),
@@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": $rate * widgets.number_of_images}
)
""",
),
)
@@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio: str,
number_of_images: int,
seed: int,
resolution: str = "1K",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op(
@@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
n=number_of_images,
seed=seed,
resolution=resolution.lower(),
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
category="api node/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Image.Input("image"),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.Image.Input("image", display_name="images"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
IO.Combo.Input("resolution", options=["1K"]),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Int.Input(
"number_of_images",
default=1,
@@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
],
optional=True,
tooltip="Only allowed when multiple images are connected to the image input.",
),
],
outputs=[
IO.Image.Output(),
@@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
)
""",
),
)
@@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
resolution: str,
number_of_images: int,
seed: int,
aspect_ratio: str = "auto",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
if model == "grok-imagine-image-pro":
if get_number_of_images(image) > 1:
raise ValueError("The pro model supports only 1 input image.")
elif get_number_of_images(image) > 3:
raise ValueError("A maximum of 3 input images is supported.")
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest(
model=model,
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
prompt=prompt,
resolution=resolution.lower(),
n=number_of_images,
seed=seed,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
category="api node/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
expr="""
(
$base := 0.181 * widgets.duration;
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
)
""",
@@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
category="api node/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
),
)
@@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
+3
View File
@@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode):
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True),
],
outputs=[
IO.Video.Output(),
@@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound: bool,
character_orientation: str,
mode: str,
model: str = "kling-v2-6",
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
@@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
model_name=model,
),
)
if response.code:
+45 -8
View File
@@ -20,7 +20,7 @@ class JobStatus:
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
@@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict:
normalized[node_id] = normalized_node
return normalized
# Text preview truncation limit (1024 characters) to prevent preview_output bloat
TEXT_PREVIEW_MAX_LENGTH = 1024
def _create_text_preview(value: str) -> dict:
"""Create a text preview dict with optional truncation.
Returns:
dict with 'content' and optionally 'truncated' flag
"""
if len(value) <= TEXT_PREVIEW_MAX_LENGTH:
return {'content': value}
return {
'content': value[:TEXT_PREVIEW_MAX_LENGTH],
'truncated': True
}
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
@@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
normalized = normalize_output_item(item)
if normalized is None:
continue
if not isinstance(item, dict):
# Handle text outputs (non-dict items like strings or tuples)
normalized = normalize_output_item(item)
if normalized is None:
# Not a 3D file string — check for text preview
if media_type == 'text':
count += 1
if preview_output is None:
if isinstance(item, tuple):
text_value = item[0] if item else ''
else:
text_value = str(item)
text_preview = _create_text_preview(text_value)
enriched = {
**text_preview,
'nodeId': node_id,
'mediaType': media_type
}
if fallback_preview is None:
fallback_preview = enriched
continue
# normalize_output_item returned a dict (e.g. 3D file)
item = normalized
count += 1
if preview_output is not None:
continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
if is_previewable(media_type, item):
enriched = {
**normalized,
**item,
'nodeId': node_id,
}
if 'mediaType' not in normalized:
if 'mediaType' not in item:
enriched['mediaType'] = media_type
if normalized.get('type') == 'output':
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
+1 -1
View File
@@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode):
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
+12 -11
View File
@@ -717,11 +717,11 @@ def _render_shader_batch(
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(tex)
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(tex)
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(tex)
gl.glDeleteTextures(int(tex))
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
@@ -865,14 +865,15 @@ class GLSLShader(io.ComfyNode):
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
) -> dict[str, list]:
"""Build UI output with input and output images for client-side shader execution."""
combined_inputs = torch.cat(image_list, dim=0)
input_images_ui = ui.ImageSaveHelper.save_images(
combined_inputs,
filename_prefix="GLSLShader_input",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
)
input_images_ui = []
for img in image_list:
input_images_ui.extend(ui.ImageSaveHelper.save_images(
img,
filename_prefix="GLSLShader_input",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
))
output_images_ui = ui.ImageSaveHelper.save_images(
output_batch,
+1 -1
View File
@@ -248,7 +248,7 @@ class SetClipHooks:
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
clip = clip.clone(disable_dynamic=True)
if apply_to_conds:
clip.apply_hooks_to_conds = hooks
clip.patcher.forced_hooks = hooks.clone()
+3 -31
View File
@@ -706,8 +706,8 @@ class SplitImageToTileList(IO.ComfyNode):
@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)
stride_x = round(max(tile_width * 0.25, tile_width - overlap))
stride_y = round(max(tile_width * 0.25, tile_height - overlap))
y = 0
while y < height:
@@ -764,34 +764,6 @@ class ImageMergeTileList(IO.ComfyNode):
],
)
@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)
y = 0
while y < height:
x = 0
y_end = min(y + tile_height, height)
y_start = max(0, y_end - tile_height)
while x < width:
x_end = min(x + tile_width, width)
x_start = max(0, x_end - tile_width)
coords.append((x_start, y_start, x_end, y_end))
if x_end >= width:
break
x += stride_x
if y_end >= height:
break
y += stride_y
return coords
@classmethod
def execute(cls, image_list, final_width, final_height, overlap):
w = final_width[0]
@@ -804,7 +776,7 @@ class ImageMergeTileList(IO.ComfyNode):
device = first_tile.device
dtype = first_tile.dtype
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)
coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp)
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
+45 -2
View File
@@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
Entries are derived independently from each conditioning to avoid cross-contamination.
"""
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
existing = []
for t in cond:
found = t[1].get("guide_attention_entries", None)
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
))
return results[0], results[1]
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
@@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
scale_factors,
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
@@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
positive = node_helpers.conditioning_set_values(positive, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
negative = node_helpers.conditioning_set_values(negative, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
+16 -8
View File
@@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro CFG",
display_name="Positive-Biased Guidance",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
@@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode):
io.Model.Output(display_name="patched_model"),
],
is_experimental=True,
search_aliases=[
"mahiro",
"mahiro cfg",
"similarity-adaptive guidance",
"positive-biased cfg",
],
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
cond_p: torch.Tensor = args['cond_denoised']
uncond_p: torch.Tensor = args['uncond_denoised']
#naive leap
scale: float = args["cond_scale"]
cond_p: torch.Tensor = args["cond_denoised"]
uncond_p: torch.Tensor = args["uncond_denoised"]
# naive leap
leap = cond_p * scale
#sim with uncond leap
# sim with uncond leap
u_leap = uncond_p * scale
cfg = args["denoised"]
merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
simsc = 2 * (sim + 1)
wm = (simsc * cfg + (4 - simsc) * leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return io.NodeOutput(m)
+3 -1
View File
@@ -52,7 +52,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
}}
@@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG
elif sampling == "img_to_img_flow":
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
+1 -1
View File
@@ -79,7 +79,6 @@ class Blur(io.ComfyNode):
node_id="ImageBlur",
display_name="Image Blur",
category="image/postprocessing",
essentials_category="Image Tools",
inputs=[
io.Image.Input("image"),
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
@@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
essentials_category="Image Tools",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
+86
View File
@@ -0,0 +1,86 @@
from __future__ import annotations
import math
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AspectRatio(str, Enum):
SQUARE = "1:1 (Square)"
PHOTO_H = "3:2 (Photo)"
STANDARD_H = "4:3 (Standard)"
WIDESCREEN_H = "16:9 (Widescreen)"
ULTRAWIDE_H = "21:9 (Ultrawide)"
PHOTO_V = "2:3 (Portrait Photo)"
STANDARD_V = "3:4 (Portrait Standard)"
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = {
AspectRatio.SQUARE: (1, 1),
AspectRatio.PHOTO_H: (3, 2),
AspectRatio.STANDARD_H: (4, 3),
AspectRatio.WIDESCREEN_H: (16, 9),
AspectRatio.ULTRAWIDE_H: (21, 9),
AspectRatio.PHOTO_V: (2, 3),
AspectRatio.STANDARD_V: (3, 4),
AspectRatio.WIDESCREEN_V: (9, 16),
}
class ResolutionSelector(io.ComfyNode):
"""Calculate width and height from aspect ratio and megapixel target."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionSelector",
display_name="Resolution Selector",
category="utils",
description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.",
inputs=[
io.Combo.Input(
"aspect_ratio",
options=AspectRatio,
default=AspectRatio.SQUARE,
tooltip="The aspect ratio for the output dimensions.",
),
io.Float.Input(
"megapixels",
default=1.0,
min=0.1,
max=16.0,
step=0.1,
tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.",
),
],
outputs=[
io.Int.Output(
"width", tooltip="Calculated width in pixels (multiple of 8)."
),
io.Int.Output(
"height", tooltip="Calculated height in pixels (multiple of 8)."
),
],
)
@classmethod
def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput:
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
total_pixels = megapixels * 1024 * 1024
scale = math.sqrt(total_pixels / (w_ratio * h_ratio))
width = round(w_ratio * scale / 8) * 8
height = round(h_ratio * scale / 8) * 8
return io.NodeOutput(width, height)
class ResolutionExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ResolutionSelector,
]
async def comfy_entrypoint() -> ResolutionExtension:
return ResolutionExtension()
+740
View File
@@ -0,0 +1,740 @@
import torch
import comfy.utils
import numpy as np
import math
import colorsys
from tqdm import tqdm
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_lotus import LotusConditioning
def _preprocess_keypoints(kp_raw, sc_raw):
"""Insert neck keypoint and remap from MMPose to OpenPose ordering.
Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,).
Layout:
0-17 body (18 kp, OpenPose order)
18-23 feet (6 kp)
24-91 face (68 kp)
92-112 right hand (21 kp)
113-133 left hand (21 kp)
"""
kp = np.array(kp_raw, dtype=np.float32)
sc = np.array(sc_raw, dtype=np.float32)
if len(kp) >= 17:
neck = (kp[5] + kp[6]) / 2
neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0
kp = np.insert(kp, 17, neck, axis=0)
sc = np.insert(sc, 17, neck_score)
mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3])
openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17])
tmp_kp, tmp_sc = kp.copy(), sc.copy()
tmp_kp[openpose_idx] = kp[mmpose_idx]
tmp_sc[openpose_idx] = sc[mmpose_idx]
kp, sc = tmp_kp, tmp_sc
return kp, sc
def _to_openpose_frames(all_keypoints, all_scores, height, width):
"""Convert raw keypoint lists to a list of OpenPose-style frame dicts.
Each frame dict contains:
canvas_width, canvas_height, people: list of person dicts with keys:
pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels)
foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels)
face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels)
indices 0-67: 68 face landmarks
index 68: right eye (body[14])
index 69: left eye (body[15])
hand_right_keypoints_2d - 21 right-hand kp (absolute pixels)
hand_left_keypoints_2d - 21 left-hand kp (absolute pixels)
"""
def _flatten(kp_slice, sc_slice):
return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist()
frames = []
for img_idx in range(len(all_keypoints)):
people = []
for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]):
kp, sc = _preprocess_keypoints(kp_raw, sc_raw)
# 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15])
face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0)
face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0)
people.append({
"pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]),
"foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]),
"face_keypoints_2d": _flatten(face_kp, face_sc),
"hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]),
"hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]),
})
frames.append({"canvas_width": width, "canvas_height": height, "people": people})
return frames
class KeypointDraw:
"""
Pose keypoint drawing class that supports both numpy and cv2 backends.
"""
def __init__(self):
try:
import cv2
self.draw = cv2
except ImportError:
self.draw = self
# Hand connections (same for both hands)
self.hand_edges = [
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
[0, 5], [5, 6], [6, 7], [7, 8], # index
[0, 9], [9, 10], [10, 11], [11, 12], # middle
[0, 13], [13, 14], [14, 15], [15, 16], # ring
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
]
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
self.body_limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
[1, 16], [16, 18]
]
# Colors matching DWPose
self.colors = [
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
]
@staticmethod
def circle(canvas_np, center, radius, color, **kwargs):
"""Draw a filled circle using NumPy vectorized operations."""
cx, cy = center
h, w = canvas_np.shape[:2]
radius_int = int(np.ceil(radius))
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
if y_max <= y_min or x_max <= x_min:
return
y, x = np.ogrid[y_min:y_max, x_min:x_max]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
"""Draw line using Bresenham's algorithm with NumPy operations."""
x0, y0, x1, y1 = *pt1, *pt2
h, w = canvas_np.shape[:2]
dx, dy = abs(x1 - x0), abs(y1 - y0)
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
err, x, y, line_points = dx - dy, x0, y0, []
while True:
line_points.append((x, y))
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 > -dy:
err, x = err - dy, x + sx
if e2 < dx:
err, y = err + dx, y + sy
if thickness > 1:
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
for px, py in line_points:
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
if y_max > y_min and x_max > x_min:
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
else:
line_points = np.array(line_points)
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
if (valid_points := line_points[valid]).size:
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
@staticmethod
def fillConvexPoly(canvas_np, pts, color, **kwargs):
"""Fill polygon using vectorized scanline algorithm."""
if len(pts) < 3:
return
pts = np.array(pts, dtype=np.int32)
h, w = canvas_np.shape[:2]
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
if y_max <= y_min or x_max <= x_min:
return
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
for i in range(len(pts)):
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
y1, y2 = p1[1], p2[1]
if y1 == y2:
continue
if y1 > y2:
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
continue
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
"""Python implementation of cv2.ellipse2Poly."""
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
angle = angle % 360
if arc_start > arc_end:
arc_start, arc_end = arc_end, arc_start
while arc_start < 0:
arc_start, arc_end = arc_start + 360, arc_end + 360
while arc_end > 360:
arc_end, arc_start = arc_end - 360, arc_start - 360
if arc_end - arc_start > 360:
arc_start, arc_end = 0, 360
angle_rad = math.radians(angle)
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
pts = []
for i in range(arc_start, arc_end + delta, delta):
theta_rad = math.radians(min(i, arc_end))
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
unique_pts, prev_pt = [], (float('inf'), float('inf'))
for pt in pts:
if (pt_tuple := tuple(pt)) != prev_pt:
unique_pts.append(pt)
prev_pt = pt_tuple
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3):
"""
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
Expected keypoint format (after neck insertion and remapping):
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
- Foot: 18-23 (6 keypoints)
- Face: 24-91 (68 landmarks)
- Right hand: 92-112 (21 keypoints)
- Left hand: 113-133 (21 keypoints)
Args:
canvas: The canvas to draw on (numpy array)
keypoints: Array of keypoint coordinates
scores: Optional confidence scores for each keypoint
threshold: Minimum confidence threshold for drawing keypoints
Returns:
canvas: The canvas with keypoints drawn
"""
H, W, C = canvas.shape
# Draw body limbs
if draw_body and len(keypoints) >= 18:
for i, limb in enumerate(self.body_limbSeq):
# Convert from 1-indexed to 0-indexed
idx1, idx2 = limb[0] - 1, limb[1] - 1
if idx1 >= 18 or idx2 >= 18:
continue
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
Y = [keypoints[idx1][0], keypoints[idx2][0]]
X = [keypoints[idx1][1], keypoints[idx2][1]]
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
if length < 1:
continue
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)])
# Draw body keypoints
if draw_body and len(keypoints) >= 18:
for i in range(18):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw foot keypoints (18-23, 6 keypoints)
if draw_feet and len(keypoints) >= 24:
for i in range(18, 24):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw right hand (92-112)
if draw_hands and len(keypoints) >= 113:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 92 + edge[0], 92 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw right hand keypoints
for i in range(92, 113):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw left hand (113-133)
if draw_hands and len(keypoints) >= 134:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 113 + edge[0], 113 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw left hand keypoints
for i in range(113, 134):
if scores is not None and i < len(scores) and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw face keypoints (24-91) - white dots only, no lines
if draw_face and len(keypoints) >= 92:
eps = 0.01
for i in range(24, 92):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
return canvas
class SDPoseDrawKeypoints(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SDPoseDrawKeypoints",
category="image/preprocessors",
search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"],
inputs=[
io.Custom("POSE_KEYPOINT").Input("keypoints"),
io.Boolean.Input("draw_body", default=True),
io.Boolean.Input("draw_hands", default=True),
io.Boolean.Input("draw_face", default=True),
io.Boolean.Input("draw_feet", default=False),
io.Int.Input("stick_width", default=4, min=1, max=10, step=1),
io.Int.Input("face_point_size", default=3, min=1, max=10, step=1),
io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput:
if not keypoints:
return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32))
height = keypoints[0]["canvas_height"]
width = keypoints[0]["canvas_width"]
def _parse(flat, n):
arr = np.array(flat, dtype=np.float32).reshape(n, 3)
return arr[:, :2], arr[:, 2]
def _zeros(n):
return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32)
pose_outputs = []
drawer = KeypointDraw()
for frame in tqdm(keypoints, desc="Drawing keypoints on frames"):
canvas = np.zeros((height, width, 3), dtype=np.uint8)
for person in frame["people"]:
body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18)
foot_raw = person.get("foot_keypoints_2d")
foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6)
face_kp, face_sc = _parse(person["face_keypoints_2d"], 70)
face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them
rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21)
lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21)
kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0)
sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0)
canvas = drawer.draw_wholebody_keypoints(
canvas, kp, sc,
threshold=score_threshold,
draw_body=draw_body, draw_feet=draw_feet,
draw_face=draw_face, draw_hands=draw_hands,
stick_width=stick_width, face_point_size=face_point_size,
)
pose_outputs.append(canvas)
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
return io.NodeOutput(final_pose_output)
class SDPoseKeypointExtractor(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SDPoseKeypointExtractor",
category="image/preprocessors",
search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"],
description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Int.Input("batch_size", default=16, min=1, max=10000, step=1),
io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."),
],
outputs=[
io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"),
],
)
@classmethod
def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput:
height, width = image.shape[-3], image.shape[-2]
context = LotusConditioning().execute().result[0]
# Use output_block_patch to capture the last 640-channel feature
def output_patch(h, hsp, transformer_options):
nonlocal captured_feat
if h.shape[1] == 640: # Capture the features for wholebody
captured_feat = h.clone()
return h, hsp
model_clone = model.clone()
model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}}
if not hasattr(model.model.diffusion_model, 'heatmap_head'):
raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.")
head = model.model.diffusion_model.heatmap_head
total_images = image.shape[0]
captured_feat = None
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
def _run_on_latent(latent_batch):
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
nonlocal captured_feat
captured_feat = None
_ = comfy.sample.sample(
model_clone,
noise=torch.zeros_like(latent_batch),
steps=1, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=context, negative=context,
latent_image=latent_batch, disable_noise=True, disable_pbar=True,
)
return head(captured_feat) # keypoints_batch, scores_batch
# all_keypoints / all_scores are lists-of-lists:
# outer index = input image index
# inner index = detected person (one per bbox, or one for full-image)
all_keypoints = [] # shape: [n_images][n_persons]
all_scores = [] # shape: [n_images][n_persons]
pbar = comfy.utils.ProgressBar(total_images)
if bboxes is not None:
if not isinstance(bboxes, list):
bboxes = [[bboxes]]
elif len(bboxes) == 0:
bboxes = [None] * total_images
# --- bbox-crop mode: one forward pass per crop -------------------------
for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"):
img = image[img_idx:img_idx + 1] # (1, H, W, C)
# Broadcasting: if fewer bbox lists than images, repeat the last one.
img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None
img_keypoints = []
img_scores = []
if img_bboxes:
for bbox in img_bboxes:
x1 = max(0, int(bbox["x"]))
y1 = max(0, int(bbox["y"]))
x2 = min(width, int(bbox["x"] + bbox["width"]))
y2 = min(height, int(bbox["y"] + bbox["height"]))
if x2 <= x1 or y2 <= y1:
continue
crop_h_px, crop_w_px = y2 - y1, x2 - x1
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
scale = min(model_h / crop_h_px, model_w / crop_w_px)
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop)
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
# remove padding offset, undo scale, offset to full-image coordinates.
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
img_keypoints.append(kp)
img_scores.append(sc)
else:
# No bboxes for this image run on the full image
latent_img = vae.encode(img)
kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(kp_batch[0])
img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints)
all_scores.append(img_scores)
pbar.update(1)
else: # full-image mode, batched
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
for batch_start in range(0, total_images, batch_size):
batch_end = min(batch_start + batch_size, total_images)
latent_batch = vae.encode(image[batch_start:batch_end])
kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([kp])
all_scores.append([sc])
tqdm_pbar.update(1)
pbar.update(batch_end - batch_start)
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
return io.NodeOutput(openpose_frames)
def get_face_bboxes(kp2ds, scale, image_shape):
h, w = image_shape
kp2ds_face = kp2ds.copy()[1:] * (w, h)
min_x, min_y = np.min(kp2ds_face, axis=0)
max_x, max_y = np.max(kp2ds_face, axis=0)
initial_width = max_x - min_x
initial_height = max_y - min_y
if initial_width <= 0 or initial_height <= 0:
return [0, 0, 0, 0]
initial_area = initial_width * initial_height
expanded_area = initial_area * scale
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
delta_width = (new_width - initial_width) / 2
delta_height = (new_height - initial_height) / 4
expanded_min_x = max(min_x - delta_width, 0)
expanded_max_x = min(max_x + delta_width, w)
expanded_min_y = max(min_y - 3 * delta_height, 0)
expanded_max_y = min(max_y + delta_height, h)
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
class SDPoseFaceBBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SDPoseFaceBBoxes",
category="image/preprocessors",
search_aliases=["face bbox", "face bounding box", "pose", "keypoints"],
inputs=[
io.Custom("POSE_KEYPOINT").Input("keypoints"),
io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."),
io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."),
],
outputs=[
io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."),
],
)
@classmethod
def execute(cls, keypoints, scale, force_square) -> io.NodeOutput:
all_bboxes = []
for frame in keypoints:
h = frame["canvas_height"]
w = frame["canvas_width"]
frame_bboxes = []
for person in frame["people"]:
face_flat = person.get("face_keypoints_2d", [])
if not face_flat:
continue
# Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye)
face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3)
face_xy = face_arr[:, :2] # (70, 2) in absolute pixels
kp_norm = face_xy / np.array([w, h], dtype=np.float32)
kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2)
x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w))
if x2 > x1 and y2 > y1:
if force_square:
bw, bh = x2 - x1, y2 - y1
if bw != bh:
side = max(bw, bh)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
half = side // 2
x1 = max(0, cx - half)
y1 = max(0, cy - half)
x2 = min(w, x1 + side)
y2 = min(h, y1 + side)
# Re-anchor if clamped
x1 = max(0, x2 - side)
y1 = max(0, y2 - side)
frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1})
all_bboxes.append(frame_bboxes)
return io.NodeOutput(all_bboxes)
class CropByBBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CropByBBoxes",
category="image/preprocessors",
search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"],
description="Crop and resize regions from the input image batch based on provided bounding boxes.",
inputs=[
io.Image.Input("image"),
io.BoundingBox.Input("bboxes", force_input=True),
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
],
outputs=[
io.Image.Output(tooltip="All crops stacked into a single image batch."),
],
)
@classmethod
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
total_frames = image.shape[0]
img_h = image.shape[1]
img_w = image.shape[2]
num_ch = image.shape[3]
if not isinstance(bboxes, list):
bboxes = [[bboxes]]
elif len(bboxes) == 0:
return io.NodeOutput(image)
crops = []
for frame_idx in range(total_frames):
frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)]
if not frame_bboxes:
continue
frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W)
# Union all bboxes for this frame into a single crop region
x1 = min(int(b["x"]) for b in frame_bboxes)
y1 = min(int(b["y"]) for b in frame_bboxes)
x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes)
y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes)
if padding > 0:
x1 = max(0, x1 - padding)
y1 = max(0, y1 - padding)
x2 = min(img_w, x2 + padding)
y2 = min(img_h, y2 + padding)
x1, x2 = max(0, x1), min(img_w, x2)
y1, y2 = max(0, y1), min(img_h, y2)
# Fallback for empty/degenerate crops
if x2 <= x1 or y2 <= y1:
fallback_size = int(min(img_h, img_w) * 0.3)
fb_x1 = max(0, (img_w - fallback_size) // 2)
fb_y1 = max(0, int(img_h * 0.1))
fb_x2 = min(img_w, fb_x1 + fallback_size)
fb_y2 = min(img_h, fb_y1 + fallback_size)
if fb_x2 <= fb_x1 or fb_y2 <= fb_y1:
crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device))
continue
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
crops.append(resized)
if not crops:
return io.NodeOutput(image)
out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C)
return io.NodeOutput(out_images)
class SDPoseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SDPoseKeypointExtractor,
SDPoseDrawKeypoints,
SDPoseFaceBBoxes,
CropByBBoxes,
]
async def comfy_entrypoint() -> SDPoseExtension:
return SDPoseExtension()
+1 -1
View File
@@ -25,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
m = model.clone(disable_dynamic=True)
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m)
+1 -1
View File
@@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode):
search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components",
category="image/video",
essentials_category="Video Tools",
description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[
io.Video.Input("video", tooltip="The video to extract components from."),
@@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode):
"start time",
],
category="image/video",
essentials_category="Video Tools",
inputs=[
io.Video.Input("video"),
io.Float.Input(
+58
View File
@@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
WanSCAILToVideo,
]
async def comfy_entrypoint() -> WanExtension:
+1 -1
View File
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.15.0"
__version__ = "0.16.2"
+8 -6
View File
@@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
else:
try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
# Unwraps values wrapped in __value__ key or typed wrapper.
# This is used to pass list widget values to execution,
# as by default list value is reserved to represent the
# connection between nodes.
if isinstance(val, dict):
if "__value__" in val:
val = val["__value__"]
inputs[x] = val
if input_type == "INT":
val = int(val)
+6 -6
View File
@@ -16,7 +16,6 @@ from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
@@ -24,6 +23,11 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
@@ -173,10 +177,6 @@ import gc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
import comfy.utils
@@ -192,7 +192,7 @@ import hook_breaker_ac10a0
import comfy.memory_management
import comfy.model_patcher
if enables_dynamic_vram():
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
+31
View File
@@ -1,5 +1,6 @@
import hashlib
import torch
import logging
from comfy.cli_args import args
@@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
return c
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
"""
Apply values to conditioning only during [start_percent, end_percent], keeping the
original conditioning active outside that range. Respects existing per-entry ranges.
"""
if start_percent > end_percent:
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
return conditioning
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
c = []
for t in conditioning:
cond_start = t[1].get("start_percent", 0.0)
cond_end = t[1].get("end_percent", 1.0)
intersect_start = max(start_percent, cond_start)
intersect_end = min(end_percent, cond_end)
if intersect_start >= intersect_end: # no overlap: emit unchanged
c.append(t)
continue
if intersect_start > cond_start: # part before the requested range
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
if intersect_end < cond_end: # part after the requested range
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
return c
def pillow(fn, arg):
prev_value = None
try:
+3 -2
View File
@@ -976,7 +976,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -1925,7 +1925,6 @@ class ImageInvert:
class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
ESSENTIALS_CATEGORY = "Image Tools"
@classmethod
def INPUT_TYPES(s):
@@ -2436,6 +2435,7 @@ async def init_builtin_extra_nodes():
"nodes_audio_encoder.py",
"nodes_rope.py",
"nodes_logic.py",
"nodes_resolution.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
@@ -2448,6 +2448,7 @@ async def init_builtin_extra_nodes():
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
"nodes_sdpose.py",
]
import_failed = []
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.15.0"
version = "0.16.2"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
+3 -4
View File
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.16
comfyui-workflow-templates==0.9.3
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.10
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.0
comfy-aimdo>=0.2.7
requests
#non essential dependencies:
@@ -31,5 +31,4 @@ spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
PyOpenGL-accelerate
glfw
@@ -49,6 +49,12 @@ def mock_provider(mock_releases):
return provider
@pytest.fixture(autouse=True)
def clear_cache():
import utils.install_util
utils.install_util.PACKAGE_VERSIONS = {}
def test_get_release(mock_provider, mock_releases):
version = "1.0.0"
release = mock_provider.get_release(version)
@@ -0,0 +1,112 @@
import torch
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
import comfy.supported_models
def _make_longcat_comfyui_sd():
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
sd = {}
H = 32 # Reduce hidden state dimension to reduce memory usage
C_IN = 16
C_CTX = 3584
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
sd["img_in.bias"] = torch.empty(H)
sd["txt_in.weight"] = torch.empty(H, C_CTX)
sd["txt_in.bias"] = torch.empty(H)
sd["time_in.in_layer.weight"] = torch.empty(H, 256)
sd["time_in.in_layer.bias"] = torch.empty(H)
sd["time_in.out_layer.weight"] = torch.empty(H, H)
sd["time_in.out_layer.bias"] = torch.empty(H)
sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H)
sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H)
sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H)
sd["final_layer.linear.bias"] = torch.empty(C_IN * 4)
for i in range(19):
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H)
sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H)
for i in range(38):
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
return sd
def _make_flux_schnell_comfyui_sd():
"""Minimal ComfyUI-format state dict for standard Flux Schnell."""
sd = {}
H = 32 # Reduce hidden state dimension to reduce memory usage
C_IN = 16
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
sd["img_in.bias"] = torch.empty(H)
sd["txt_in.weight"] = torch.empty(H, 4096)
sd["txt_in.bias"] = torch.empty(H)
sd["double_blocks.0.img_attn.norm.key_norm.weight"] = torch.empty(128)
sd["double_blocks.0.img_attn.qkv.weight"] = torch.empty(3 * H, H)
sd["double_blocks.0.img_mod.lin.weight"] = torch.empty(H, H)
for i in range(19):
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
for i in range(38):
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
return sd
class TestModelDetection:
"""Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity."""
def test_longcat_before_schnell_in_models_list(self):
"""LongCatImage must appear before FluxSchnell in the models list."""
models = comfy.supported_models.models
longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage")
schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell")
assert longcat_idx < schnell_idx, (
f"LongCatImage (index {longcat_idx}) must come before "
f"FluxSchnell (index {schnell_idx}) in the models list"
)
def test_longcat_comfyui_detected_as_longcat(self):
sd = _make_longcat_comfyui_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "flux"
assert unet_config["context_in_dim"] == 3584
assert unet_config["vec_in_dim"] is None
assert unet_config["guidance_embed"] is False
assert unet_config["txt_ids_dims"] == [1, 2]
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "LongCatImage"
def test_longcat_comfyui_keys_pass_through_unchanged(self):
"""Pre-converted weights should not be transformed by process_unet_state_dict."""
sd = _make_longcat_comfyui_sd()
unet_config = detect_unet_config(sd, "")
model_config = model_config_from_unet_config(unet_config, sd)
processed = model_config.process_unet_state_dict(dict(sd))
assert "img_in.weight" in processed
assert "txt_in.weight" in processed
assert "time_in.in_layer.weight" in processed
assert "final_layer.linear.weight" in processed
def test_flux_schnell_comfyui_detected_as_flux_schnell(self):
sd = _make_flux_schnell_comfyui_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "flux"
assert unet_config["context_in_dim"] == 4096
assert unet_config["txt_ids_dims"] == []
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell"
+3 -3
View File
@@ -38,13 +38,13 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio, 3d media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d']:
"""Images, video, audio, 3d, text media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d', 'text']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
"""Other media types should not be previewable."""
for media_type in ['latents', 'text', 'metadata', 'files']:
for media_type in ['latents', 'metadata', 'files']:
assert is_previewable(media_type, {}) is False
def test_3d_extensions_previewable(self):
+33
View File
@@ -1,5 +1,7 @@
from pathlib import Path
import sys
import logging
import re
# The path to the requirements.txt file
requirements_path = Path(__file__).parents[1] / "requirements.txt"
@@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {requirements_path}
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
""".strip()
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
PACKAGE_VERSIONS = {}
def get_required_packages_versions():
if len(PACKAGE_VERSIONS) > 0:
return PACKAGE_VERSIONS.copy()
out = PACKAGE_VERSIONS
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip().replace(">=", "==")
s = line.split("==")
if len(s) == 2:
version_str = s[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
continue
out[s[0]] = version_str
return out.copy()
except FileNotFoundError:
logging.error("requirements.txt not found.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None