Compare commits

...

90 Commits

Author SHA1 Message Date
comfyanonymous ee9547ba31 Improve temporal VAE Encode (Tiled) math.
Python Linting / Run Ruff (push) Failing after 41s
Build package / Build Test (3.10) (push) Failing after 34s
Build package / Build Test (3.11) (push) Failing after 31s
Build package / Build Test (3.8) (push) Failing after 31s
Build package / Build Test (3.9) (push) Failing after 38s
2024-12-26 07:18:49 -05:00
comfyanonymous 19a64d6291 Cleanup some mac related code. 2024-12-25 05:32:51 -05:00
comfyanonymous b486885e08 Disable bfloat16 on older mac. 2024-12-25 05:18:50 -05:00
comfyanonymous 0229228f3f Clean up the VAE dtypes code. 2024-12-25 04:50:34 -05:00
comfyanonymous 1ed75ab30e Update nightly pytorch instructions in readme for nvidia. 2024-12-25 03:29:03 -05:00
comfyanonymous 99a1fb6027 Make fast fp8 take a bit less peak memory. 2024-12-24 18:05:19 -05:00
comfyanonymous 73e04987f7 Prevent black images in VAE Decode (Tiled) node.
Overlap should be minimum 1 with tiling 2 for tiled temporal VAE decoding.
2024-12-24 07:36:30 -05:00
comfyanonymous 5388df784a Add temporal tiling to VAE Encode (Tiled) node. 2024-12-24 07:10:09 -05:00
Alexander Piskun 26e0ba8f8c Enable External Event Loop Integration for ComfyUI [refactor] (#6114)
* Refactor main.py to support external event loop integration

* added optional "asyncio_loop" argument to allow using existing event loop

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-24 06:38:52 -05:00
comfyanonymous bc6dac4327 Add temporal tiling to VAE Decode (Tiled) node.
You can now do tiled VAE decoding on the temporal direction for videos.
2024-12-23 20:03:37 -05:00
Chenlei Hu f18ebbd316 Use raw dir name to serve static web content (#6107) 2024-12-23 03:29:42 -05:00
comfyanonymous 15564688ed Add a try except block so if torch version is weird it won't crash. 2024-12-23 03:22:48 -05:00
Simon Lui c6b9c11ef6 Add oneAPI device selector for xpu and some other changes. (#6112)
* Add oneAPI device selector and some other minor changes.

* Fix device selector variable name.

* Flip minor version check sign.

* Undo changes to README.md.
2024-12-23 03:18:32 -05:00
comfyanonymous e44d0ac7f7 Make --novram completely offload weights.
This flag is mainly used for testing the weight offloading, it shouldn't
actually be used in practice.

Remove useless import.
2024-12-23 01:51:08 -05:00
comfyanonymous 56bc64f351 Comment out some useless code. 2024-12-22 23:51:14 -05:00
zhangp365 f7d83b72e0 fixed a bug in ldm/pixart/blocks.py (#6158) 2024-12-22 23:44:20 -05:00
comfyanonymous 80f07952d2 Fix lowvram issue with ltxv vae. 2024-12-22 23:20:17 -05:00
comfyanonymous 57f330caf9 Relax minimum ratio of weights loaded in memory on nvidia.
This should make it possible to do higher res images/longer videos by
further offloading weights to CPU memory.

Please report an issue if this slows down things on your system.
2024-12-22 03:06:37 -05:00
comfyanonymous 601ff9e3db Add that Hunyuan Video and Pixart are supported to readme.
Clean up the supported models part of the readme.
2024-12-21 11:31:39 -05:00
TechnoByte 341667c4d5 remove minimum step count for AYS (#6137)
The 10 step minimum for the AYS scheduler is pointless, it works well at lower steps, like 8 steps, or even 4 steps.

For example with LCM or DMD2.

Example here: https://i.ibb.co/56CSPMj/image.png
2024-12-21 10:05:09 -05:00
Qiacheng Li 1419dee915 Update README.md for Intel GPUs (#6069) 2024-12-20 18:04:03 -05:00
comfyanonymous da13b6b827 Get rid of meshgrid warning. 2024-12-20 18:02:12 -05:00
comfyanonymous c86cd58573 Remove useless code. 2024-12-20 17:50:03 -05:00
comfyanonymous b5fe39211a Remove some useless code. 2024-12-20 17:43:50 -05:00
comfyanonymous e946667216 Some fixes/cleanups to pixart code.
Commented out the masking related code because it is never used in this
implementation.
2024-12-20 17:10:52 -05:00
Chenlei Hu d7969cb070 Replace print with logging (#6138)
* Replace print with logging

* nit

* nit

* nit

* nit

* nit

* nit
2024-12-20 16:24:55 -05:00
City bddb02660c Add PixArt model support (#6055)
* PixArt initial version

* PixArt Diffusers convert logic

* pos_emb and interpolation logic

* Reduce  duplicate code

* Formatting

* Use optimized attention

* Edit empty token logic

* Basic PixArt LoRA support

* Fix aspect ratio logic

* PixArtAlpha text encode with conds

* Use same detection key logic for PixArt diffusers
2024-12-20 15:25:00 -05:00
comfyanonymous 418eb7062d Support new LTXV VAE.
Python Linting / Run Ruff (push) Failing after 37s
Build package / Build Test (3.10) (push) Failing after 26s
Build package / Build Test (3.11) (push) Failing after 34s
Build package / Build Test (3.8) (push) Failing after 33s
Build package / Build Test (3.9) (push) Failing after 42s
2024-12-20 04:38:29 -05:00
comfyanonymous cac68ca813 Fix some more video tiled encode issues.
The downscale_ratio formula for the temporal had issues with some frame
numbers.
2024-12-19 23:14:03 -05:00
comfyanonymous 52c1d933b2 Fix tiled hunyuan video VAE encode issue.
Some shapes like 1024x1024 with tile_size 256 and overlap 64 had issues.
2024-12-19 22:55:15 -05:00
catboxanon 3cacd3fca5 Support preview images embedded in safetensors metadata (#6119)
* Support preview images embedded in safetensors metadata

* Add unit test for safetensors embedded image previews
2024-12-19 14:01:56 -08:00
comfyanonymous 2dda7c11a3 More proper fix for the memory issue. 2024-12-19 16:21:56 -05:00
comfyanonymous 3ad3248ad7 Fix lowvram bug when using a model multiple times in a row.
The memory system would load an extra 64MB each time until either the
model was completely in memory or OOM.
2024-12-19 16:04:56 -05:00
comfyanonymous c441048a4f Make VAE Encode tiled node work with video VAE. 2024-12-19 05:31:39 -05:00
comfyanonymous 9f4b181ab3 Add fast previews for hunyuan video.
Python Linting / Run Ruff (push) Failing after 33s
Build package / Build Test (3.10) (push) Failing after 33s
Build package / Build Test (3.11) (push) Failing after 31s
Build package / Build Test (3.8) (push) Failing after 33s
Build package / Build Test (3.9) (push) Failing after 31s
2024-12-18 18:24:23 -05:00
comfyanonymous cbbf077593 Small optimizations. 2024-12-18 18:23:28 -05:00
Chenlei Hu 0c04a6ae78 Add .github folder to maintainer owner list (#6027) 2024-12-18 15:06:53 -05:00
Chenlei Hu 416ccc9e45 Update web content to release v1.5.19 (#6105) 2024-12-18 15:06:20 -05:00
comfyanonymous ff2ff02168 Support old diffusion-pipe hunyuan video loras. 2024-12-18 06:23:54 -05:00
comfyanonymous 4c5c4ddeda Fix regression in VAE code on old pytorch versions. 2024-12-18 03:08:28 -05:00
comfyanonymous 79badea452 Add ConditioningStableAudio.
This lets you control the seconds_start and seconds_total parameters for
the Stable Audio model.
2024-12-18 03:01:12 -05:00
comfyanonymous 37e5390f5f Add: --use-sage-attention to enable SageAttention.
You need to have the library installed first.
2024-12-18 01:56:10 -05:00
comfyanonymous a4f59bc65e Pick attention implementation based on device in llama code. 2024-12-18 01:30:20 -05:00
comfyanonymous ca457f7ba1 Properly tokenize the template for hunyuan video. 2024-12-17 16:22:02 -05:00
comfyanonymous cd6f615038 Fix tiled vae not working with some shapes. 2024-12-17 16:22:02 -05:00
Terry Jia 517669aaa3 add preview 3d node (#6070)
* add preview 3d node

* mark 3d nodes as EXPERIMENTAL
2024-12-17 10:42:24 -08:00
comfyanonymous e4e1bff605 Support diffusion-pipe hunyuan video lora format. 2024-12-17 07:14:21 -05:00
comfyanonymous d6656b0c0c Support llama hunyuan video text encoder in scaled fp8 format. 2024-12-17 04:19:22 -05:00
comfyanonymous f4cdedea62 Fix regression with ltxv VAE. 2024-12-17 02:17:31 -05:00
comfyanonymous 39b1fc4ccc Adjust used dtypes for hunyuan video VAE and diffusion model. 2024-12-16 23:31:10 -05:00
comfyanonymous 0b25f47bd9 Add some missing imports. 2024-12-16 19:42:01 -05:00
comfyanonymous bda1482a27 Basic Hunyuan Video model support. 2024-12-16 19:35:40 -05:00
comfyanonymous 19ee5d9d8b Don't expand mask when not necessary.
Expanding seems to slow down inference.
2024-12-16 18:22:50 -05:00
Raphael Walker 61b50720d0 Add support for attention masking in Flux (#5942)
* fix attention OOM in xformers

* allow passing attention mask in flux attention

* allow an attn_mask in flux

* attn masks can be done using replace patches instead of a separate dict

* fix return types

* fix return order

* enumerate

* patch the right keys

* arg names

* fix a silly bug

* fix xformers masks

* replace match with if, elif, else

* mask with image_ref_size

* remove unused import

* remove unused import 2

* fix pytorch/xformers attention

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.

* fix mask shapes
2024-12-16 18:21:17 -05:00
Alexander Dyadyun 0f954f34af Update README.md (#6071)
The last ROCM 6.2 build was November 22nd, after that date new builds use ROCM 6.2.4.

The builds from the new URL have been tested and work without problems.
2024-12-16 15:24:54 -05:00
Chenlei Hu 5262901c5c Update web content to release v1.5.18 (#6075) 2024-12-16 11:38:24 -08:00
Terry Jia cc550d5908 use String directly to set bg color for load 3d canvas (#6057) 2024-12-16 10:51:40 -08:00
comfyanonymous 6d1a3f7d00 Fix case of ExecutionBlocker not handled correctly with INPUT_IS_LIST. 2024-12-15 08:41:35 -05:00
Alexander Piskun 1b3a650f19 (fix): added "model_type" to photomaker node (#6047) 2024-12-15 00:18:02 -05:00
comfyanonymous e83063bf24 Support conv3d in PatchEmbed. 2024-12-14 05:46:04 -05:00
Dr.Lt.Data 558b7d8b22 fix: prestartup script is not applied due to extra_model_paths.yaml and ensure custom paths are used during startup (#5872)
* fix: The custom nodes installed in the paths specified in `extra_model_paths.yaml` encounter a bug where the prestartup script is not imported.

* Ensure custom paths are used during startup
https://github.com/comfyanonymous/ComfyUI/pull/5794
2024-12-13 18:21:32 -05:00
Alexander Piskun caf2074773 add_model_folder_path: ensure unique paths by removing duplicates (#5998)
* add_model_folder_path: ensure unique paths by removing duplicates

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* refactored "add_model_folder_path" and added tests

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2024-12-13 18:19:22 -05:00
Terry Jia bdf393792d add load 3d node support (#5564)
* add load 3d node support

* remove Preview3D from BE
2024-12-13 18:13:52 -05:00
comfyanonymous 4e14032c02 Make pad_to_patch_size function work on multi dim. 2024-12-13 07:22:05 -05:00
Chenlei Hu 59d58b1158 [Security] Fix potential XSS on /view (#6034) 2024-12-13 04:56:43 -05:00
Chenlei Hu 563291ee51 Enforce all pyflake lint rules (#6033)
* Enforce F821 undefined-name

* Enforce all pyflake lint rules
2024-12-12 19:29:37 -05:00
Chenlei Hu 6c0377f43e Enforce F821 undefined-name (#6032) 2024-12-12 19:24:41 -05:00
Chenlei Hu 2cddbf0821 Lint and fix undefined names (1/N) (#6028) 2024-12-12 18:55:26 -05:00
Chenlei Hu 60749f345d Lint and fix undefined names (3/N) (#6030) 2024-12-12 18:49:40 -05:00
Chenlei Hu d4426dce7c Lint and fix undefined names (2/N) (#6029) 2024-12-12 18:48:21 -05:00
Chenlei Hu d9d7f3c619 Lint all unused variables (#5989)
* Enable F841

* Autofix

* Remove all unused variable assignment
2024-12-12 17:59:16 -05:00
comfyanonymous fd5dfb812c Set initial load devices for te and model to mps device on mac. 2024-12-12 06:00:31 -05:00
Chenlei Hu 3dfdddcc91 Update README (Add new keybinding entries) (#6020) 2024-12-11 15:55:38 -08:00
Hayden 5747bc6457 Optimize model library (#5841)
* Move model manager routes

* Add experiment model manager api

* Fix cache causing returns to be empty

* Fix unable to compare sub-dir caches

* Skip non-existent folders

* Add model preview

* Revert 'Move model manager routes'

* move model_filemanager.py to app/

* Update model_manager.py

3.8 compatibility

---------
2024-12-11 18:12:04 -05:00
yoinked 5bea1d2ec9 Add MaHiRo (improved/alternate CFG) (#5975)
* Add MaHiRo (improved CFG)

long explanation of what it is is [here](https://huggingface.co/spaces/yoinked/blue-arxiv) (2024-1208.1) 


note: if the node name has encoding issues (utf 8/whatever), id suggest to replace the face at the end with `(>w<)`

* add it to nodes.py, add description, and make it a post_cfg function

* fix

* revert the sampler_cfg_function thing

* switch cfg to args["denoised"]
2024-12-11 16:51:51 -05:00
Yoland Yan 5def9fbc83 Update CI workflow to remove Windows testing configuration (#6007)
- Commented out Windows OS from the CI matrix in test-ci.yml.
- Removed the test-win-nightly job to streamline testing on macOS and Linux only.
- Adjusted the matrix strategy to focus on Python versions and CUDA compatibility without Windows support.
2024-12-11 16:48:41 -05:00
comfyanonymous 7a7efe8424 Support loading some checkpoint files with nested dicts. 2024-12-11 08:04:54 -05:00
comfyanonymous 44db978531 Fix a few things in text enc code for models with no eos token. 2024-12-10 23:07:26 -05:00
comfyanonymous 1c8d11e48a Support different types of tokenizers.
Support tokenizers without an eos token.

Pass full sentences to tokenizer for more efficient tokenizing.
2024-12-10 15:03:39 -05:00
Chenlei Hu a220d11e6b Replace pylint with ruff (#5987) 2024-12-09 22:04:23 -05:00
catboxanon 23827ca312 Add cond_scale to sampler_post_cfg_function (#5985) 2024-12-09 20:13:18 -05:00
Chenlei Hu 0fd4e6c778 Lint unused import (#5973)
* Lint unused import

* nit

* Remove unused imports

* revert fix_torch import

* nit
2024-12-09 15:24:39 -05:00
comfyanonymous e2fafe0686 Make CLIP set last layer node work with t5 models. 2024-12-09 03:57:14 -05:00
comfyanonymous 6579632201 Remove unused imports and variables. 2024-12-08 08:08:12 -05:00
comfyanonymous ac2f0523ca Set env vars to disable telemetry in libs used by some custom nodes. 2024-12-07 14:51:45 -05:00
Haoming fbf68c4e52 clamp input (#5928) 2024-12-07 14:00:31 -05:00
Chenlei Hu 93477f8efe Add code owners (#5873)
* Add code owners

* Update owners

* nit

* Inline owners

* Remove team links

* Add Kosinkadink
2024-12-06 22:00:54 -05:00
comfyanonymous 8af9a91e0c A few improvements to #5937. 2024-12-06 05:49:15 -05:00
Michael Kupchick 005d2d3a13 ltxv: add noise to guidance image to ensure generated motion. (#5937) 2024-12-06 05:46:08 -05:00
comfyanonymous 1e21f4c14e Make timestep ranges more usable on rectified flow models.
This breaks some old workflows but should make the nodes actually useful.
2024-12-05 16:40:58 -05:00
172 changed files with 480271 additions and 30119 deletions
+13 -13
View File
@@ -28,17 +28,17 @@ def pull(repo, remote_name='origin', branch='master'):
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path)
print('Conflicts found in:', conflict[0].path) # noqa: T201
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
@@ -49,18 +49,18 @@ repo_path = str(sys.argv[1])
repo = pygit2.Repository(repo_path)
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
print("stashing current changes") # noqa: T201
repo.stash(ident)
except KeyError:
print("nothing to stash")
print("nothing to stash") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:
repo.branches.local.create(backup_branch_name, repo.head.peel())
except:
pass
print("checking out master branch")
print("checking out master branch") # noqa: T201
branch = repo.lookup_branch('master')
if branch is None:
ref = repo.lookup_reference('refs/remotes/origin/master')
@@ -72,7 +72,7 @@ else:
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
print("pulling latest changes") # noqa: T201
pull(repo)
if "--stable" in sys.argv:
@@ -94,7 +94,7 @@ if "--stable" in sys.argv:
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!")
print("Done!") # noqa: T201
self_update = True
if len(sys.argv) > 2:
@@ -3,8 +3,8 @@ name: Python Linting
on: [push, pull_request]
jobs:
pylint:
name: Run Pylint
ruff:
name: Run Ruff
runs-on: ubuntu-latest
steps:
@@ -16,8 +16,8 @@ jobs:
with:
python-version: 3.x
- name: Install Pylint
run: pip install pylint
- name: Install Ruff
run: pip install ruff
- name: Run Pylint
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
- name: Run Ruff
run: ruff check .
+27 -26
View File
@@ -20,7 +20,8 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
# os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
@@ -31,9 +32,9 @@ jobs:
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
@@ -45,28 +46,28 @@ jobs:
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
# test-win-nightly:
# strategy:
# fail-fast: true
# matrix:
# os: [windows]
# python_version: ["3.9", "3.10", "3.11", "3.12"]
# cuda_version: ["12.1"]
# torch_version: ["nightly"]
# include:
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
# runs-on: ${{ matrix.runner_label }}
# steps:
# - name: Test Workflows
# uses: comfy-org/comfy-action@main
# with:
# os: ${{ matrix.os }}
# python_version: ${{ matrix.python_version }}
# torch_version: ${{ matrix.torch_version }}
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
# comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:
-3
View File
@@ -1,3 +0,0 @@
[MESSAGES CONTROL]
disable=all
enable=eval-used
+23 -1
View File
@@ -1 +1,23 @@
* @comfyanonymous
# Admins
* @comfyanonymous
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
+43 -21
View File
@@ -38,10 +38,21 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Image Models
- SD1.x, SD2.x,
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
- Pixart Alpha and Sigma
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -61,9 +72,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
@@ -101,6 +109,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| `F` | Show/Hide menu |
| `.` | Fit view to selection (Whole graph when nothing is selected) |
| Double-Click LMB | Open node quick search palette |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
@@ -145,7 +155,31 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch nightly, use the following command:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
2. Launch ComfyUI by running `python main.py`
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
```
conda install libuv
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
```
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
### NVIDIA
@@ -155,7 +189,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
#### Troubleshooting
@@ -175,17 +209,6 @@ After this you should have everything installed and can proceed to running Comfy
### Others:
#### Intel GPUs
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
#### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
@@ -306,4 +329,3 @@ This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy
### Which GPU should I buy for this?
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
@@ -40,7 +40,7 @@ class InternalRoutes:
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_logs(request):
async def get_raw_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
+184
View File
@@ -0,0 +1,184 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()
+2 -2
View File
@@ -38,8 +38,8 @@ class UserManager():
if not os.path.exists(user_directory):
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.get_users_file()):
-4
View File
@@ -2,11 +2,9 @@
#and modified
import torch
import torch as th
import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)
@@ -162,7 +160,6 @@ class ControlNet(nn.Module):
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
@@ -415,7 +412,6 @@ class ControlNet(nn.Module):
out_output = []
out_middle = []
hs = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
-2
View File
@@ -1,10 +1,8 @@
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
+1 -1
View File
@@ -1,5 +1,5 @@
import torch
from typing import Dict, Optional
from typing import Optional
import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
+3 -1
View File
@@ -84,7 +84,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -104,6 +105,7 @@ attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
+1 -3
View File
@@ -297,7 +297,6 @@ class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
@@ -382,7 +381,6 @@ class ControlLora(ControlNet):
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
@@ -823,7 +821,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
for i in range(4):
for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys()
+10 -3
View File
@@ -157,16 +157,23 @@ vae_conversion_map_attn = [
]
def reshape_weight_for_sd(w):
def reshape_weight_for_sd(w, conv3d=False):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
if conv3d:
return w.reshape(*w.shape, 1, 1, 1)
else:
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
conv3d = False
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
if v.endswith(".conv.weight"):
if not conv3d and vae_state_dict[k].ndim == 5:
conv3d = True
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
@@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
return new_state_dict
+3 -5
View File
@@ -1,10 +1,10 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified
import torch
import torch.nn.functional as F
import math
import logging
from tqdm.auto import trange, tqdm
from tqdm.auto import trange
class NoiseScheduleVP:
@@ -475,7 +475,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule
assert order <= len(model_prev_list)
@@ -519,7 +519,6 @@ class UniPC:
A_p = C_inv_p
if use_corrector:
print('using corrector')
C_inv = torch.linalg.inv(C)
A_c = C_inv
@@ -704,7 +703,6 @@ class UniPC:
):
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1
if method == 'multistep':
assert steps >= order
+1
View File
@@ -1,3 +1,4 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention
+3 -2
View File
@@ -5,6 +5,7 @@ import math
import torch
import numpy as np
import itertools
import logging
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
@@ -130,7 +131,7 @@ class WeightHook(Hook):
weights = self.weights
else:
weights = self.weights_clip
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
# TODO: add logs about any keys that were not applied
@@ -575,7 +576,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print(f"NOT LOADED {x}")
logging.warning(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
-1
View File
@@ -11,7 +11,6 @@ import numpy as np
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
+24
View File
@@ -352,3 +352,27 @@ class LTXV(LatentFormat):
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[ 0.0696, 0.0795, 0.0518],
[ 0.0135, -0.0945, -0.0282],
[ 0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[ 0.1166, 0.1627, 0.0962],
[ 0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[ 0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
+2 -2
View File
@@ -2,7 +2,7 @@
import torch
from torch import nn
from typing import Literal, Dict, Any
from typing import Literal
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -97,7 +97,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
act = Activation1d(act) # noqa: F821 Activation1d is not defined
return act
+6 -12
View File
@@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
def forward(self, t):
# device = self.inv_freq.device
device = t.device
dtype = t.dtype
# t = t.to(torch.float32)
@@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
@@ -229,9 +228,9 @@ class FeedForward(nn.Module):
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
@@ -246,9 +245,9 @@ class FeedForward(nn.Module):
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
@@ -346,18 +345,13 @@ class Attention(nn.Module):
# determine masking
masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask)
# Other masks will be added here later
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
n = q.shape[-2]
causal = self.causal if causal is None else causal
+2 -2
View File
@@ -2,8 +2,8 @@
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from torch import Tensor
from typing import List, Union
from einops import rearrange
import math
import comfy.ops
-2
View File
@@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
@@ -382,7 +381,6 @@ class MMDiT(nn.Module):
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim
print("PE extended to", target_dim)
def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size
-1
View File
@@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
from .common import LayerNorm2d_op
+6 -3
View File
@@ -4,9 +4,12 @@ import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
pad = ()
for i in range(img.ndim - 2):
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
return torch.nn.functional.pad(img, pad, mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
+1 -3
View File
@@ -6,9 +6,7 @@ import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .layers import (timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
+20 -9
View File
@@ -114,7 +114,7 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -141,8 +141,9 @@ class DoubleStreamBlock(nn.Module):
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -160,12 +161,22 @@ class DoubleStreamBlock(nn.Module):
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
@@ -217,7 +228,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -226,7 +237,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
+4 -2
View File
@@ -1,14 +1,15 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
@@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+31 -11
View File
@@ -4,6 +4,8 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from .layers import (
DoubleStreamBlock,
@@ -14,9 +16,6 @@ from .layers import (
timestep_embedding,
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
@@ -98,8 +97,9 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -124,14 +124,27 @@ class Flux(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -146,13 +159,20 @@ class Flux(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -181,5 +201,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
@@ -461,8 +461,6 @@ class AsymmDiTJoint(nn.Module):
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3
B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
+1 -1
View File
@@ -1,7 +1,7 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn as nn
+1 -1
View File
@@ -1,7 +1,7 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from functools import partial
import math
+330
View File
@@ -0,0 +1,330 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
from torch import Tensor, nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding
)
import comfy.ldm.common_dit
@dataclass
class HunyuanVideoParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: list
qkv_bias: bool
guidance_embed: bool
class SelfAttentionRef(nn.Module):
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
class TokenRefinerBlock(nn.Module):
def __init__(
self,
hidden_size,
heads,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.heads = heads
mlp_hidden_dim = hidden_size * 4
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
class IndividualTokenRefiner(nn.Module):
def __init__(
self,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.blocks = nn.ModuleList(
[
TokenRefinerBlock(
hidden_size=hidden_size,
heads=heads,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_blocks)
]
)
def forward(self, x, c, mask):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
return x
class TokenRefiner(nn.Module):
def __init__(
self,
text_dim,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
def forward(
self,
x,
timesteps,
mask,
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((img, txt), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out
-10
View File
@@ -1,24 +1,17 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
class HunYuanControlNet(nn.Module):
@@ -171,9 +164,6 @@ class HunYuanControlNet(nn.Module):
),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
-5
View File
@@ -1,8 +1,6 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
@@ -250,9 +248,6 @@ class HunYuanDiT(nn.Module):
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
-1
View File
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
+15 -2
View File
@@ -379,6 +379,7 @@ class LTXVModel(torch.nn.Module):
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
@@ -415,7 +416,7 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
@@ -431,10 +432,22 @@ class LTXVModel(torch.nn.Module):
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = 0.0
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
+1 -1
View File
@@ -53,7 +53,7 @@ class Patchifier(ABC):
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
@@ -3,10 +3,12 @@ from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Any, Mapping, Optional, Tuple, Union, List
from typing import Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
r"""
@@ -236,6 +238,7 @@ class Decoder(nn.Module):
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
):
super().__init__()
self.patch_size = patch_size
@@ -250,6 +253,8 @@ class Decoder(nn.Module):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@@ -276,6 +281,19 @@ class Decoder(nn.Module):
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
@@ -286,6 +304,8 @@ class Decoder(nn.Module):
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
@@ -296,11 +316,13 @@ class Decoder(nn.Module):
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
)
else:
raise ValueError(f"unknown layer: {block_name}")
@@ -323,27 +345,75 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(
torch.tensor(1000.0, dtype=torch.float32)
)
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
# assert target_shape is not None, "target_shape must be provided"
batch_size = sample.shape[0]
sample = self.conv_in(sample, causal=self.causal)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = sample.to(upscale_dtype)
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=sample.shape[0],
hidden_dtype=sample.dtype,
)
embedded_timestep = embedded_timestep.view(
batch_size, embedded_timestep.shape[-1], 1, 1, 1
)
ada_values = self.last_scale_shift_table[
None, ..., None, None, None
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
batch_size,
2,
-1,
embedded_timestep.shape[-3],
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
@@ -379,12 +449,21 @@ class UNetMidBlock3D(nn.Module):
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
in_channels * 4, 0, operations=ops,
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
@@ -395,25 +474,48 @@ class UNetMidBlock3D(nn.Module):
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
batch_size = hidden_states.shape[0]
timestep_embed = self.time_embedder(
timestep=timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
timestep_embed = timestep_embed.view(
batch_size, timestep_embed.shape[-1], 1, 1, 1
)
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states, causal=causal)
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
return hidden_states
class DepthToSpaceUpsample(nn.Module):
def __init__(self, dims, in_channels, stride, residual=False):
def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
):
super().__init__()
self.stride = stride
self.out_channels = math.prod(stride) * in_channels
self.out_channels = (
math.prod(stride) * in_channels // out_channels_reduction_factor
)
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
@@ -423,8 +525,9 @@ class DepthToSpaceUpsample(nn.Module):
causal=True,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
def forward(self, x, causal: bool = True):
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
@@ -434,7 +537,8 @@ class DepthToSpaceUpsample(nn.Module):
p2=self.stride[1],
p3=self.stride[2],
)
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
@@ -451,7 +555,6 @@ class DepthToSpaceUpsample(nn.Module):
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
@@ -486,11 +589,14 @@ class ResnetBlock3D(nn.Module):
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
@@ -513,6 +619,9 @@ class ResnetBlock3D(nn.Module):
causal=True,
)
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
@@ -534,6 +643,9 @@ class ResnetBlock3D(nn.Module):
causal=True,
)
if inject_noise:
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
@@ -548,29 +660,84 @@ class ResnetBlock3D(nn.Module):
else nn.Identity()
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
spatial_shape = hidden_states.shape[-2:]
device = hidden_states.device
dtype = hidden_states.dtype
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = input_tensor
batch_size = hidden_states.shape[0]
hidden_states = self.norm1(hidden_states)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[
None, ..., None, None, None
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
batch_size,
4,
-1,
timestep.shape[-3],
timestep.shape[-2],
timestep.shape[-1],
)
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
hidden_states = hidden_states * (1 + scale1) + shift1
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
)
hidden_states = self.norm2(hidden_states)
if self.timestep_conditioning:
hidden_states = hidden_states * (1 + scale2) + shift2
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
)
input_tensor = self.norm3(input_tensor)
batch_size = input_tensor.shape[0]
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
@@ -634,33 +801,71 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self):
def __init__(self, version=0):
super().__init__()
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 6, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 7, "inject_noise": True}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 8, "inject_noise": False}]
],
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_all", {}],
["res_x_y", 1],
["res_x", {"num_layers": 3}],
["compress_all", {}],
["res_x_y", 1],
["res_x", {"num_layers": 3}],
["compress_all", {}],
["res_x", {"num_layers": 3}],
["res_x", {"num_layers": 4}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
@@ -671,7 +876,7 @@ class VideoVAE(nn.Module):
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("blocks")),
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
@@ -681,18 +886,22 @@ class VideoVAE(nn.Module):
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("blocks")),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x):
return self.decoder(self.per_channel_statistics.un_normalize(x))
def decode(self, x, timestep=0.05, noise_scale=0.025):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
@@ -1,6 +1,5 @@
from typing import Tuple, Union
import torch
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
+18 -9
View File
@@ -1,10 +1,12 @@
import logging
import math
import torch
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma
import comfy.ops
@@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any:
raise NotImplementedError()
@@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f"{context}: Switched to EMA weights")
logging.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f"{context}: Restored training weights")
logging.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
self.regularization = instantiate_from_config(
regularizer_config
)
@@ -160,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
else:
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:
+92 -34
View File
@@ -15,6 +15,9 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
from sageattention import sageattn
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -157,8 +160,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
scale = dim_head ** -0.5
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@@ -177,9 +178,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
mem_free_total, _ = model_management.get_free_memory(query.device, True)
kv_chunk_size_min = None
kv_chunk_size = None
@@ -230,7 +230,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
scale = dim_head ** -0.5
h = heads
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
@@ -344,12 +343,9 @@ except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
disabled_xformers = False
if BROKEN_XFORMERS:
@@ -364,35 +360,44 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
# b h k d -> b k h d
q, k, v = map(
lambda t: t.permute(0, 2, 1, 3),
(q, k, v),
)
# actually do the reshaping
else:
dim_head //= heads
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if mask is not None:
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a singleton heads dimension
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# pad to a multiple of 8
pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
# in flux, this matrix ends up being over 1GB
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
# doesn't this remove the padding again??
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
@@ -414,32 +419,85 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v),
)
if SDP_BATCH_LIMIT >= q.shape[0]:
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
m = mask
if mask is not None:
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
optimized_attention = attention_basic
if model_management.xformers_enabled():
logging.info("Using xformers cross attention")
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch cross attention")
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for cross attention")
logging.info("Using split optimization for attention")
optimized_attention = attention_split
else:
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention
+14 -27
View File
@@ -1,5 +1,4 @@
import logging
import math
from functools import partial
from typing import Dict, Optional, List
import numpy as np
@@ -72,45 +71,33 @@ class PatchEmbed(nn.Module):
strict_img_size: bool = True,
dynamic_img_pad: bool = True,
padding_mode='circular',
conv3d=False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = (patch_size, patch_size)
try:
len(patch_size)
self.patch_size = patch_size
except:
if conv3d:
self.patch_size = (patch_size, patch_size, patch_size)
else:
self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
if conv3d:
self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
else:
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
# elif not self.dynamic_img_pad:
# _assert(
# H % self.patch_size[0] == 0,
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
# )
# _assert(
# W % self.patch_size[1] == 0,
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad:
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
x = self.proj(x)
+148 -64
View File
@@ -3,7 +3,6 @@ import math
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Any
import logging
from comfy import model_management
@@ -44,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.padding_mode = padding_mode
if padding != 0:
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
else:
kwargs["padding"] = padding
self.padding = padding
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
if self.padding != 0:
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
super().__init__()
self.with_conv = with_conv
self.scale_factor = scale_factor
if self.with_conv:
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=2,
stride=stride,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
if x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@@ -97,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
dropout, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -106,7 +154,7 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = ops.Conv2d(in_channels,
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
@@ -116,20 +164,20 @@ class ResnetBlock(nn.Module):
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = ops.Conv2d(out_channels,
self.conv2 = conv_op(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv2d(in_channels,
self.conv_shortcut = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = ops.Conv2d(in_channels,
self.nin_shortcut = conv_op(in_channels,
out_channels,
kernel_size=1,
stride=1,
@@ -163,7 +211,6 @@ def slice_attention(q, k, v):
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
@@ -196,21 +243,25 @@ def slice_attention(q, k, v):
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
orig_shape = q.shape
b = orig_shape[0]
c = orig_shape[1]
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
q = q.reshape(b, c, -1)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, -1) # b,c,hw
v = v.reshape(b, c, -1)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
h_ = r1.reshape(orig_shape)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
@@ -218,14 +269,16 @@ def xformers_attention(q, k, v):
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -233,35 +286,35 @@ def pytorch_attention(q, k, v):
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
class AttnBlock(nn.Module):
def __init__(self, in_channels):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels,
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = ops.Conv2d(in_channels,
self.k = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = ops.Conv2d(in_channels,
self.v = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = ops.Conv2d(in_channels,
self.proj_out = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
@@ -291,8 +344,8 @@ class AttnBlock(nn.Module):
return x+h_
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
return AttnBlock(in_channels)
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
return AttnBlock(in_channels, conv_op=conv_op)
class Model(nn.Module):
@@ -451,6 +504,7 @@ class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
@@ -461,8 +515,15 @@ class Encoder(nn.Module):
self.resolution = resolution
self.in_channels = in_channels
if conv3d:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# downsampling
self.conv_in = ops.Conv2d(in_channels,
self.conv_in = conv_op(in_channels,
self.ch,
kernel_size=3,
stride=1,
@@ -481,15 +542,20 @@ class Encoder(nn.Module):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
stride = 2
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
@@ -498,16 +564,18 @@ class Encoder(nn.Module):
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# end
self.norm_out = Normalize(block_in)
self.conv_out = ops.Conv2d(block_in,
self.conv_out = conv_op(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@@ -545,9 +613,10 @@ class Decoder(nn.Module):
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
conv3d=False,
time_compress=None,
**ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
@@ -557,8 +626,15 @@ class Decoder(nn.Module):
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
if conv3d:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# compute block_in and curr_res at lowest res
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
@@ -566,7 +642,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = ops.Conv2d(z_channels,
self.conv_in = conv_op(z_channels,
block_in,
kernel_size=3,
stride=1,
@@ -577,12 +653,14 @@ class Decoder(nn.Module):
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = attn_op(block_in)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# upsampling
self.up = nn.ModuleList()
@@ -594,15 +672,21 @@ class Decoder(nn.Module):
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(attn_op(block_in))
attn.append(attn_op(block_in, conv_op=conv_op))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
scale_factor = 2.0
if time_compress is not None:
if i_level > math.log2(time_compress):
scale_factor = (1.0, 2.0, 2.0)
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
@@ -9,7 +9,6 @@ import logging
from .util import (
checkpoint,
avg_pool_nd,
zero_module,
timestep_embedding,
AlphaBlender,
)
@@ -4,7 +4,6 @@ import numpy as np
from functools import partial
from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module):
+4 -4
View File
@@ -8,8 +8,8 @@
# thanks!
import os
import math
import logging
import torch
import torch.nn as nn
import numpy as np
@@ -131,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
@@ -143,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
logging.info(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev
@@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
+1 -2
View File
@@ -22,7 +22,6 @@ except ImportError:
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from torch import Tensor
from typing import List
from comfy import model_management
@@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
del attn_scores
except model_management.OOM_EXCEPTION:
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed
+2 -1
View File
@@ -1,5 +1,5 @@
import functools
from typing import Callable, Iterable, Union
from typing import Iterable, Union
import torch
from einops import rearrange, repeat
@@ -194,6 +194,7 @@ def make_time_attn(
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
conv_op=ops.Conv2d,
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
+380
View File
@@ -0,0 +1,380 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
# if model_management.xformers_enabled():
# import xformers.ops
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
# else:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
super(MultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
assert mask is None # TODO?
# # TODO: xformers needs separate mask logic here
# if model_management.xformers_enabled():
# attn_bias = None
# if mask is not None:
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
# else:
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
# attn_mask = None
# mask = torch.ones(())
# if mask is not None and len(mask) > 1:
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
# # But depth doesn't matter as tensors can expand in that dimension
# attn_mask_template = torch.ones(
# [q.shape[2] // B, mask[0]],
# dtype=torch.bool,
# device=q.device
# )
# attn_mask = torch.block_diag(attn_mask_template)
#
# # create a mask on the diagonal for each mask in the batch
# for _ in range(B - 1):
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionKVCompress(nn.Module):
"""Multi-head Attention block with KV token compression and qk norm."""
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
self.sr_ratio = sr_ratio
if sr_ratio > 1 and sampling == 'conv':
# Avg Conv Init.
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
# self.sr.weight.data.fill_(1/sr_ratio**2)
# self.sr.bias.data.zero_()
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
if qk_norm:
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
if sampling is None or scale_factor == 1:
return tensor
B, N, C = tensor.shape
if sampling == 'uniform_every':
return tensor[:, ::scale_factor], int(N // scale_factor)
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
new_N = new_H * new_W
if sampling == 'ave':
tensor = F.interpolate(
tensor, scale_factor=1 / scale_factor, mode='nearest'
).permute(0, 2, 3, 1)
elif sampling == 'uniform':
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
elif sampling == 'conv':
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
tensor = self.norm(tensor)
else:
raise ValueError
return tensor.reshape(B, new_N, C).contiguous(), new_N
def forward(self, x, mask=None, HW=None, block_id=None):
B, N, C = x.shape # 2 4096 1152
new_N = N
if HW is None:
H = W = int(N ** 0.5)
else:
H, W = HW
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2)
q = self.q_norm(q)
k = self.k_norm(k)
# KV compression
if self.sr_ratio > 1:
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
if mask is not None:
raise NotImplementedError("Attn mask logic not added for self attention")
# This is never called at the moment
# attn_bias = None
# if mask is not None:
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
# attention 2
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
x = x.view(B, N, C)
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.out_channels = out_channels
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MaskFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs//s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
super().__init__()
self.proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
dtype=dtype, device=device, operations=operations,
)
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
self.uncond_prob = uncond_prob
def token_drop(self, global_caption, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return global_caption, caption
def forward(self, caption, train, force_drop_ids=None):
assert caption.shape[2: ] == self.y_embedding.shape
global_caption = caption.mean(dim=2).squeeze()
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
y_embed = self.proj(global_caption)
return y_embed, caption
+256
View File
@@ -0,0 +1,256 @@
# Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch
import torch.nn as nn
from .blocks import (
t2i_modulate,
CaptionEmbedder,
AttentionKVCompress,
MultiHeadCrossAttention,
T2IFinalLayer,
SizeEmbedder,
)
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij'
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module):
"""
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
class PixArtMS(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=None,
pe_precision=None,
config=None,
model_max_length=120,
micro_condition=True,
qk_norm=False,
kv_compress_config=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
nn.Module.__init__(self)
self.dtype = dtype
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision
self.hidden_size = hidden_size
self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size,
bias=True,
dtype=dtype,
device=device,
operations=operations
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations,
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations,
)
self.micro_conditioning = micro_condition
if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None:
kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
"""
Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width
"""
B, C, H, W = x.shape
c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size,
h=(H // self.patch_size),
w=(W // self.patch_size),
pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device,
dtype=x.dtype,
).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
# Fallback for missing microconds
if self.micro_conditioning:
if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3:
context = context.unsqueeze(1)
## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS
if self.pred_sigma:
return out[:, :self.in_channels]
return out
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = h // self.patch_size
w = w // self.patch_size
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
+3 -3
View File
@@ -1,4 +1,5 @@
import importlib
import logging
import torch
from torch import optim
@@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
logging.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -65,7 +66,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
@@ -133,7 +134,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
+26 -1
View File
@@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.PixArt):
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
key_map[key_lora] = to
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
key_map[key_lora] = to
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
key_map[key_lora] = to
if isinstance(model, comfy.model_base.HunyuanDiT):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
@@ -374,6 +387,18 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.HunyuanVideo):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
# diffusion-pipe lora format
key_lora = k
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
return key_map
+54 -4
View File
@@ -26,11 +26,13 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.genmo.joint_model.asymm_models_joint
import comfy.ldm.aura.mmdit
import comfy.ldm.pixart.pixartms
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.model_management
import comfy.patcher_extension
@@ -427,7 +429,6 @@ class SVD_img2vid(BaseModel):
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
@@ -687,6 +688,7 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l]
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
@@ -711,14 +713,31 @@ class HunyuanDiT(BaseModel):
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
width = kwargs.get("width", None)
height = kwargs.get("height", None)
if width is not None and height is not None:
out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]]))
out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]]))
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
@@ -755,7 +774,6 @@ class Flux(BaseModel):
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
print(mask.shape)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
@@ -769,6 +787,16 @@ class Flux(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
@@ -804,5 +832,27 @@ class LTXV(BaseModel):
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
return out
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
return out
+54 -2
View File
@@ -133,6 +133,26 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["image_model"] = "hydit1"
return unet_config
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = 16
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
@@ -183,11 +203,42 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
# PixArt diffusers
return None
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}
dit_config["num_heads"] = 16
dit_config["patch_size"] = patch_size
dit_config["hidden_size"] = 1152
dit_config["in_channels"] = 4
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
dit_config["model_max_length"] = state_dict[y_key].shape[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
if ar_key in state_dict_keys:
dit_config["image_model"] = "pixart_alpha"
dit_config["micro_condition"] = True
else:
dit_config["image_model"] = "pixart_sigma"
dit_config["micro_condition"] = False
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -216,7 +267,6 @@ def detect_unet_config(state_dict, key_prefix):
num_res_blocks = []
channel_mult = []
attention_resolutions = []
transformer_depth = []
transformer_depth_output = []
context_dim = None
@@ -388,7 +438,6 @@ def convert_config(unet_config):
t_out += [d] * (res + 1)
s *= 2
transformer_depth = t_in
transformer_depth_output = t_out
new_config["transformer_depth"] = t_in
new_config["transformer_depth_output"] = t_out
new_config["transformer_depth_middle"] = transformer_depth_middle
@@ -555,6 +604,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
+54 -40
View File
@@ -75,7 +75,7 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex
_ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available()
xpu_available = xpu_available or torch.xpu.is_available()
except:
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
@@ -188,38 +188,44 @@ def is_nvidia():
return True
return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
if torch.version.hip:
return True
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.2
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPES = [torch.float32]
try:
if is_nvidia():
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if is_intel_xpu():
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if args.cpu_vae:
VAE_DTYPES = [torch.float32]
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True
@@ -314,6 +320,9 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_loaded_memory(self):
return self.model.loaded_size()
def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size()
@@ -504,15 +513,18 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
lowvram_model_memory = 0.1
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return
@@ -581,7 +593,7 @@ def unet_offload_device():
def unet_inital_load_device(parameters, dtype):
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM:
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev
cpu_dev = torch.device("cpu")
@@ -695,7 +707,7 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
return offload_device
if is_device_mps(load_device):
return offload_device
return load_device
mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device)
@@ -738,7 +750,6 @@ def vae_offload_device():
return torch.device("cpu")
def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPES
if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
@@ -747,12 +758,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return torch.float32
for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
return d
if d in VAE_DTYPES:
if d == torch.float16 and should_use_fp16(device):
return d
return VAE_DTYPES[0]
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
return torch.float32
def get_autocast_device(dev):
if hasattr(dev, 'type'):
@@ -837,6 +850,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
def sage_attention_enabled():
return args.use_sage_attention
def xformers_enabled():
global directml_enabled
@@ -871,14 +886,19 @@ def pytorch_attention_flash_attention():
return True
return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
upcast = True
except:
pass
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
upcast = True
if upcast:
return torch.float32
else:
@@ -949,17 +969,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP16:
return True
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32:
return False
if directml_enabled:
return False
if mps_mode():
if (device is not None and is_device_mps(device)) or mps_mode():
return True
if cpu_mode():
@@ -1008,17 +1024,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32:
return False
if directml_enabled:
return False
if mps_mode():
if (device is not None and is_device_mps(device)) or mps_mode():
if mac_version() < (14,):
return False
return True
if cpu_mode():
@@ -1077,7 +1091,7 @@ def unload_all_models():
def resolve_lowvram_weight(weight, model, key): #TODO: remove
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
logging.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight
#TODO: might be cleaner to put this somewhere else
+3 -3
View File
@@ -773,7 +773,7 @@ class ModelPatcher:
return self.model.device
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self):
@@ -1029,7 +1029,7 @@ class ModelPatcher:
if cached_weights is not None:
for key in cached_weights:
if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
@@ -1039,7 +1039,7 @@ class ModelPatcher:
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
+2 -2
View File
@@ -243,7 +243,7 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
return time_snr_shift(self.shift, 1.0 - percent)
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
@@ -336,4 +336,4 @@ class ModelSamplingFlux(torch.nn.Module):
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
+10 -8
View File
@@ -255,9 +255,10 @@ def fp8_linear(self, input):
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
scale_weight = self.scale_weight
@@ -269,23 +270,24 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None
+1 -1
View File
@@ -113,7 +113,7 @@ class WrapperExecutor:
def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
-2
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
import uuid
import torch
import comfy.model_management
import comfy.conds
import comfy.utils
@@ -104,7 +103,6 @@ def cleanup_additional_models(models):
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device
real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
+1 -8
View File
@@ -130,11 +130,6 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
@@ -346,7 +341,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
@@ -608,8 +603,6 @@ def pre_run_control(model, conds):
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
+134 -20
View File
@@ -12,6 +12,7 @@ from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import yaml
import math
import comfy.utils
@@ -26,11 +27,13 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.model_patcher
import comfy.lora
@@ -108,7 +111,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self):
n = CLIP(no_init=True)
@@ -256,6 +259,9 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_index_formula = None
self.upscale_index_formula = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -306,8 +312,8 @@ class VAE:
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@@ -335,15 +341,41 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.upscale_index_formula = (6, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
version = 0
if tensor_conv1.shape[0] == 512:
version = 0
elif tensor_conv1.shape[0] == 1024:
version = 1
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
ddconfig["time_compress"] = 4
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -370,13 +402,15 @@ class VAE:
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
x_offset = (dims[d] % self.downscale_ratio) // 2
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
return pixels
@@ -397,11 +431,11 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
@@ -420,6 +454,10 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
pixel_samples = None
try:
@@ -435,7 +473,7 @@ class VAE:
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION as e:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1:
@@ -450,7 +488,7 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
@@ -468,6 +506,13 @@ class VAE:
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
@@ -490,20 +535,58 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
if len(pixel_samples.shape) == 3:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
def get_sd(self):
@@ -515,6 +598,18 @@ class VAE:
except:
return self.upscale_ratio
def spacial_compression_encode(self):
try:
return self.downscale_ratio[-1]
except:
return self.downscale_ratio
def temporal_compression_decode(self):
try:
return round(self.upscale_ratio[0](8192) / 8192)
except:
return None
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@@ -544,6 +639,9 @@ class CLIPType(Enum):
FLUX = 6
MOCHI = 7
LTXV = 8
HUNYUAN_VIDEO = 9
PIXART = 10
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@@ -559,6 +657,7 @@ class TEModel(Enum):
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
LLAMA3_8 = 7
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -575,6 +674,8 @@ def detect_te_model(sd):
return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -587,6 +688,14 @@ def t5xxl_detect(clip_data):
return {}
def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
@@ -625,6 +734,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@@ -652,6 +764,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.FLUX:
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -691,7 +806,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
@@ -861,11 +975,11 @@ def load_diffusion_model(unet_path, model_options={}):
return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
+61 -21
View File
@@ -10,6 +10,7 @@ import comfy.clip_model
import json
import logging
import numbers
import re
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@@ -36,7 +37,10 @@ class ClipTokenWeightEncoder:
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
if hasattr(self, "gen_empty_tokens"):
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
else:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
o = self.encode(to_encode)
out, pooled = o[:2]
@@ -90,8 +94,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
with open(textmodel_json_config) as f:
config = json.load(f)
if isinstance(textmodel_json_config, dict):
config = textmodel_json_config
else:
with open(textmodel_json_config) as f:
config = json.load(f)
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
@@ -196,11 +203,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1)
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == end_token:
if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break
attention_mask_model = None
@@ -326,7 +340,6 @@ def expand_directory_list(directories):
return list(dirs)
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
i = 0
out_list = []
for k in embed:
if k.startswith(prefix) and k.endswith(suffix):
@@ -382,7 +395,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = safe_load_embed_zip(embed_path)
else:
embed = torch.load(embed_path, map_location="cpu")
except Exception as e:
except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None
@@ -411,22 +424,31 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
self.end_token = None
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
if end_token is not None:
self.end_token = end_token
else:
if has_end_token:
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token
@@ -451,13 +473,16 @@ class SDTokenizer:
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
split_embed = embedding_name.split()
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):])
return (embed, "")
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -471,13 +496,18 @@ class SDTokenizer:
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
#tokenize words
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = unescape_important(weighted_segment)
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
# if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
@@ -493,8 +523,11 @@ class SDTokenizer:
word = leftover
else:
continue
end = 999999999999
if self.tokenizer_adds_end_token:
end = -1
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
#reshape token array to CLIP input size
batched_tokens = []
@@ -505,18 +538,24 @@ class SDTokenizer:
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
if self.end_token is not None:
has_end_token = 1
else:
has_end_token = 0
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
if len(t_group) + len(batch) > self.max_length - has_end_token:
remaining_length = self.max_length - len(batch) - has_end_token
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch
@@ -529,7 +568,8 @@ class SDTokenizer:
t_group = []
#fill last batch
batch.append((self.end_token, 1.0, 0))
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
+82 -3
View File
@@ -8,10 +8,12 @@ import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
from . import supported_models_base
from . import latent_formats
@@ -224,7 +226,6 @@ class SDXL(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
for k in state_dict:
if k.startswith("clip_l"):
@@ -527,7 +528,6 @@ class SD3(supported_models_base.BASE):
clip_l = False
clip_g = False
t5 = False
dtype_t5 = None
pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True
@@ -593,6 +593,37 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
class PixArtAlpha(supported_models_base.BASE):
unet_config = {
"image_model": "pixart_alpha",
}
sampling_settings = {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
unet_extra_config = {}
latent_format = latent_formats.SD15
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.PixArt(self, device=device)
return out.eval()
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
class PixArtSigma(PixArtAlpha):
unet_config = {
"image_model": "pixart_sigma",
}
latent_format = latent_formats.SDXL
class HunyuanDiT(supported_models_base.BASE):
unet_config = {
"image_model": "hydit",
@@ -740,6 +771,54 @@ class LTXV(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
class HunyuanVideo(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan_video",
}
sampling_settings = {
"shift": 7.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo(self, device=device)
return out
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
models += [SVD_img2vid]
+112
View File
@@ -0,0 +1,112 @@
from comfy import sd1_clip
import comfy.model_management
import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast
import torch
import os
def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
llama_text = "{}{}".format(self.llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.llama.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_llama = token_weight_pairs["llama"]
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0
for i, v in enumerate(token_weight_pairs_llama[0]):
if v[0] == 128007: # <|end_header_id|>
template_end = i
if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
template_end += 2
llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_out, l_pooled, llama_extra_out
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_
+226
View File
@@ -0,0 +1,226 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
import comfy.model_management
@dataclass
class Llama2Config:
vocab_size: int = 128320
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 500000.0
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return (cos, sin)
def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1)
sin = freqs_cis[1].unsqueeze(1)
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
class Attention(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = residual + x
# MLP
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
self.layers = nn.ModuleList([
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embed_tokens(x, out_dtype=dtype)
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
x.shape[1],
self.config.rope_theta,
device=x.device)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
x = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)
return x, intermediate
class Llama2(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Llama2Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, embeddings):
self.model.embed_tokens = embeddings
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+42
View File
@@ -0,0 +1,42 @@
import os
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy()
special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_
-1
View File
@@ -1,4 +1,3 @@
import os
import torch
class SPieceTokenizer:
+5 -1
View File
@@ -172,7 +172,6 @@ class T5LayerSelfAttention(torch.nn.Module):
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
normed_hidden_states = self.layer_norm(x)
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
# x = x + self.dropout(attention_output)
x += output
@@ -209,6 +208,11 @@ class T5Stack(torch.nn.Module):
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
past_bias = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.block) + intermediate_output
for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias, optimized_attention)
if i == intermediate_output:
+164 -8
View File
@@ -26,6 +26,8 @@ import numpy as np
from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@@ -46,7 +48,13 @@ def load_torch_file(ckpt, safe_load=False, device=None):
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
if len(pl_sd) == 1:
key = list(pl_sd.keys())[0]
sd = pl_sd[key]
if not isinstance(sd, dict):
sd = pl_sd
else:
sd = pl_sd
return sd
def save_torch_file(sd, ckpt, metadata=None):
@@ -378,6 +386,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map
PIXART_MAP_BASIC = {
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
("y_embedder.y_embedding", "caption_projection.y_embedding"),
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
("t_block.1.weight", "adaln_single.linear.weight"),
("t_block.1.bias", "adaln_single.linear.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.scale_shift_table", "scale_shift_table"),
}
PIXART_MAP_BLOCK = {
("scale_shift_table", "scale_shift_table"),
("attn.proj.weight", "attn1.to_out.0.weight"),
("attn.proj.bias", "attn1.to_out.0.bias"),
("mlp.fc1.weight", "ff.net.0.proj.weight"),
("mlp.fc1.bias", "ff.net.0.proj.bias"),
("mlp.fc2.weight", "ff.net.2.weight"),
("mlp.fc2.bias", "ff.net.2.bias"),
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
}
def pixart_to_diffusers(mmdit_config, output_prefix=""):
key_map = {}
depth = mmdit_config.get("depth", 0)
offset = mmdit_config.get("hidden_size", 1152)
for i in range(depth):
block_from = "transformer_blocks.{}".format(i)
block_to = "{}blocks.{}".format(output_prefix, i)
for end in ("weight", "bias"):
s = "{}.attn1.".format(block_from)
qkv = "{}.attn.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
s = "{}.attn2.".format(block_from)
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
key_map["{}to_q.{}".format(s, end)] = q
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
for k in PIXART_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
for k in PIXART_MAP_BASIC:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)
@@ -743,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return rows * cols
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))):
@@ -752,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
@@ -759,10 +844,38 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
else:
return up * val
def get_downscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a):
out = []
for i in range(len(a)):
out.append(round(get_upscale(i, a[i])))
out.append(round(get_scale(i, a[i])))
return out
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
@@ -787,16 +900,18 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_upscale(d, pos)))
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
for d in range(2, dims + 2):
feather = round(get_upscale(d - 2, overlap[d - 2]))
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
@@ -818,7 +933,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
@@ -867,5 +982,46 @@ def reshape_mask(input_mask, output_shape):
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
mask = repeat_to_batch_size(mask, output_shape[0])
return mask
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in
ho, wo = img_size_out
# if it's already the correct size, no need to do anything
if (hi, wi) == (ho, wo):
return mask
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim != 3:
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
txt_tokens = mask.shape[1] - (hi * wi)
# quadrants of the mask
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
img_to_img = mask[:, txt_tokens:, txt_tokens:]
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
# convert to 1d x 2d, interpolate, then back to 1d x 1d
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
# this one is hard because we have to do it twice
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
# convert to 2d x 1d, interpolate, then back to 1d x 1d
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks
out = torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1
)
return out
@@ -1,5 +1,6 @@
import logging
from spandrel import ModelLoader
def load_state_dict(state_dict):
print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
return ModelLoader().load_from_state_dict(state_dict).eval()
+1 -2
View File
@@ -2,8 +2,7 @@ import comfy.samplers
import comfy.utils
import torch
import numpy as np
from tqdm.auto import trange, tqdm
import math
from tqdm.auto import trange
@torch.no_grad()
+1 -1
View File
@@ -24,7 +24,7 @@ class AlignYourStepsScheduler:
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
+23
View File
@@ -8,6 +8,7 @@ import json
import struct
import random
import hashlib
import node_helpers
from comfy.cli_args import args
class EmptyLatentAudio:
@@ -29,6 +30,27 @@ class EmptyLatentAudio:
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
class ConditioningStableAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, positive, negative, seconds_start, seconds_total):
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
return (positive, negative)
class VAEEncodeAudio:
@classmethod
def INPUT_TYPES(s):
@@ -225,4 +247,5 @@ NODE_CLASS_MAPPINGS = {
"SaveAudio": SaveAudio,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
}
+3 -3
View File
@@ -1,4 +1,3 @@
import torch
from nodes import MAX_RESOLUTION
class CLIPTextEncodeSDXLRefiner:
@@ -23,14 +22,15 @@ class CLIPTextEncodeSDXL:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
"text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
"text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
-1
View File
@@ -1,4 +1,3 @@
import numpy as np
import torch
import comfy.utils
from enum import Enum
+3 -3
View File
@@ -1,10 +1,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
import logging
import torch
from collections.abc import Iterable
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.sd import CLIP
import comfy.hooks
@@ -540,7 +540,7 @@ class CreateHookKeyframesInterpolated:
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
@@ -589,7 +589,7 @@ class CreateHookKeyframesFromFloats:
is_first = False
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
if print_keyframes:
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
return (prev_hook_kf,)
#------------------------------------------
###########################################
+21
View File
@@ -1,3 +1,8 @@
import nodes
import torch
import comfy.model_management
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
@@ -17,7 +22,23 @@ class CLIPTextEncodeHunyuanDiT:
return (clip.encode_from_tokens_scheduled(tokens), )
class EmptyHunyuanLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
}
-2
View File
@@ -35,8 +35,6 @@ class HyperTile:
CATEGORY = "model_patches/unet"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
latent_tile_size = max(32, tile_size) // 8
self.temp = None
+124
View File
@@ -0,0 +1,124 @@
import nodes
import folder_paths
import os
def normalize_path(path):
return path.replace('\\', '/')
class Load3D():
@classmethod
def INPUT_TYPES(s):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"show_grid": ([True, False],),
"camera_type": (["perspective", "orthographic"],),
"view": (["front", "right", "top", "isometric"],),
"material": (["original", "normal", "wireframe", "depth"],),
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("image", "mask", "mesh_path")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
imagepath = folder_paths.get_annotated_filepath(image)
load_image_node = nodes.LoadImage()
output_image, output_mask = load_image_node.load_image(image=imagepath)
return output_image, output_mask, model_file,
class Load3DAnimation():
@classmethod
def INPUT_TYPES(s):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"show_grid": ([True, False],),
"camera_type": (["perspective", "orthographic"],),
"view": (["front", "right", "top", "isometric"],),
"material": (["original", "normal", "wireframe", "depth"],),
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("image", "mask", "mesh_path")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
imagepath = folder_paths.get_annotated_filepath(image)
load_image_node = nodes.LoadImage()
output_image, output_mask = load_image_node.load_image(image=imagepath)
return output_image, output_mask, model_file,
class Preview3D():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
"show_grid": ([True, False],),
"camera_type": (["perspective", "orthographic"],),
"view": (["front", "right", "top", "isometric"],),
"material": (["original", "normal", "wireframe", "depth"],),
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
return {"ui": {"model_file": [model_file]}, "result": ()}
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,
"Load3DAnimation": Load3DAnimation,
"Preview3D": Preview3D
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D",
"Load3DAnimation": "Load 3D - Animation",
"Preview3D": "Preview 3D"
}
+7 -4
View File
@@ -32,7 +32,9 @@ class LTXVImgToVideo:
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@@ -40,12 +42,12 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
@@ -109,6 +111,7 @@ class ModelSamplingLTXV:
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
+41
View File
@@ -0,0 +1,41 @@
import torch
import torch.nn.functional as F
class Mahiro:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
def patch(self, model):
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
cond_p: torch.Tensor = args['cond_denoised']
uncond_p: torch.Tensor = args['uncond_denoised']
#naive leap
leap = cond_p * scale
#sim with uncond leap
u_leap = uncond_p * scale
cfg = args["denoised"]
merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, )
NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
}
-2
View File
@@ -1,4 +1,3 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
@@ -241,7 +240,6 @@ class ModelSamplingContinuousV:
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
latent_format = None
sigma_data = 1.0
if sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
-1
View File
@@ -1,4 +1,3 @@
import torch
import comfy.utils
class PatchModelAddDownscale:
+1
View File
@@ -16,6 +16,7 @@ VISION_CONFIG_DICT = {
"patch_size": 14,
"projection_dim": 768,
"hidden_act": "quick_gelu",
"model_type": "clip_vision_model",
}
class MLP(nn.Module):
+24
View File
@@ -0,0 +1,24 @@
from nodes import MAX_RESOLUTION
class CLIPTextEncodePixArtAlpha:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
}
-1
View File
@@ -1,4 +1,3 @@
import os
import logging
from spandrel import ModelLoader, ImageModelDescriptor
from comfy import model_management
+1 -3
View File
@@ -1,7 +1,5 @@
from PIL import Image, ImageOps
from io import BytesIO
from PIL import Image
import numpy as np
import struct
import comfy.utils
import time
+8 -4
View File
@@ -17,7 +17,6 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
from comfy.cli_args import args
class ExecutionResult(Enum):
SUCCESS = 0
@@ -145,11 +144,16 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None):
def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
for k, v in inputs.items():
if input_is_list:
for e in v:
if isinstance(e, ExecutionBlocker):
v = e
break
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb else v
break
@@ -161,7 +165,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
results.append(execution_block)
if input_is_list:
process_inputs(input_data_all, 0)
process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0:
process_inputs({})
else:
@@ -761,7 +765,7 @@ def validate_prompt(prompt):
if 'class_type' not in prompt[x]:
error = {
"type": "invalid_prompt",
"message": f"Cannot execute because a node is missing the class_type property.",
"message": "Cannot execute because a node is missing the class_type property.",
"details": f"Node ID '#{x}'",
"extra_info": {}
}
+20 -16
View File
@@ -5,20 +5,24 @@ import ctypes
import logging
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, 'rb') as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
def fix_pytorch_libomp():
"""
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
"""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
try:
mydll = ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError as e:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
with open(test_file, "rb") as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
break
try:
ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
+12 -5
View File
@@ -4,7 +4,7 @@ import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from typing import Literal
from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@@ -133,7 +133,7 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
@@ -200,10 +200,17 @@ def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: b
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths:
if is_default:
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
paths, _exts = folder_names_and_paths[folder_name]
if full_folder_path in paths:
if is_default and paths[0] != full_folder_path:
# If the path to the folder is not the first in the list, move it to the beginning.
paths.remove(full_folder_path)
paths.insert(0, full_folder_path)
else:
folder_names_and_paths[folder_name][0].append(full_folder_path)
if is_default:
paths.insert(0, full_folder_path)
else:
paths.append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
-2
View File
@@ -1,7 +1,5 @@
import torch
from PIL import Image
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.model_management
+93 -59
View File
@@ -7,10 +7,52 @@ import folder_paths
import time
from comfy.cli_args import args
from app.logger import setup_logger
import itertools
import utils.extra_config
import logging
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
setup_logger(log_level=args.verbose)
def apply_custom_paths():
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
utils.extra_config.load_extra_path_config(config_path)
# --output-directory, --input-directory, --user-directory
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models",
os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
def execute_prestartup_script():
def execute_script(script_path):
@@ -21,7 +63,7 @@ def execute_prestartup_script():
spec.loader.exec_module(module)
return True
except Exception as e:
print(f"Failed to execute startup-script: {script_path} / {e}")
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
return False
if args.disable_all_custom_nodes:
@@ -43,27 +85,25 @@ def execute_prestartup_script():
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_prestartup_times) > 0:
print("\nPrestartup times for custom nodes:")
logging.info("\nPrestartup times for custom nodes:")
for n in sorted(node_prestartup_times):
if n[2]:
import_message = ""
else:
import_message = " (PRESTARTUP FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
apply_custom_paths()
execute_prestartup_script()
# Main code
import asyncio
import itertools
import shutil
import threading
import gc
import logging
import utils.extra_config
if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -74,6 +114,10 @@ if __name__ == "__main__":
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None:
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
@@ -82,7 +126,8 @@ if __name__ == "__main__":
if args.windows_standalone_build:
try:
import fix_torch
from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except:
pass
@@ -105,8 +150,10 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@@ -121,7 +168,7 @@ def prompt_worker(q, server):
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
server.last_prompt_id = prompt_id
server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
@@ -131,8 +178,8 @@ def prompt_worker(q, server):
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
@@ -159,21 +206,23 @@ def prompt_worker(q, server):
last_gc_collect = current_time
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
def hijack_progress(server):
def hijack_progress(server_instance):
def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
server.send_sync("progress", progress, server.client_id)
server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
@@ -183,7 +232,11 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
Returns the event loop, server instance, and a function to start the server asynchronously.
"""
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logging.info(f"Setting temp directory to: {temp_dir}")
@@ -197,49 +250,20 @@ if __name__ == "__main__":
except:
pass
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = execution.PromptQueue(server)
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
utils.extra_config.load_extra_path_config(config_path)
if not asyncio_loop:
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
cuda_malloc_warning()
server.add_routes()
hijack_progress(server)
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
if args.quick_test_for_ci:
exit(0)
@@ -256,9 +280,19 @@ if __name__ == "__main__":
webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server
async def start_all():
await prompt_server.setup()
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
# Returning these so that other code can integrate with the ComfyUI loop and server
return asyncio_loop, prompt_server, start_all
if __name__ == "__main__":
# Running directly, just start ComfyUI.
event_loop, _, start_all_func = start_comfyui()
try:
loop.run_until_complete(server.setup())
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
event_loop.run_until_complete(start_all_func())
except KeyboardInterrupt:
logging.info("\nStopped server")
+1 -1
View File
@@ -32,4 +32,4 @@ def update_windows_updater():
except:
pass
shutil.copy(bat_path, dest_bat_path)
print("Updated the windows standalone package updater.")
print("Updated the windows standalone package updater.") # noqa: T201
+73 -18
View File
@@ -11,7 +11,7 @@ import time
import random
import logging
from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
@@ -291,19 +291,31 @@ class VAEDecodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size, overlap=64):
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
overlap = tile_size // 4
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_compression = vae.temporal_compression_decode()
if temporal_compression is not None:
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
else:
temporal_size = None
temporal_overlap = None
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
@@ -325,16 +337,19 @@ class VAEEncodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, )
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, )
class VAEEncodeForInpaint:
@classmethod
@@ -644,9 +659,7 @@ class LoraLoader:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
@@ -899,7 +912,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@@ -919,6 +932,8 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
@@ -931,7 +946,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux"], ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@@ -949,6 +964,8 @@ class DualCLIPLoader:
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
@@ -1010,23 +1027,58 @@ class StyleModelApply:
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ),
"strength_type": (["multiply", "attn_bias"], ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply":
cond *= strength
c = []
n = cond.shape[1]
c_out = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
c.append(n)
return (c, )
(txt, keys) = t
keys = keys.copy()
if strength_type == "attn_bias" and strength != 1.0:
# math.log raises an error if the argument is zero
# torch.log returns -inf, which is what we want
attn_bias = torch.log(torch.Tensor([strength]))
# get the size of the mask image
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
n_ref = mask_ref_size[0] * mask_ref_size[1]
n_txt = txt.shape[1]
# grab the existing mask
mask = keys.get("attention_mask", None)
# create a default mask if it doesn't exist
if mask is None:
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
# convert the mask dtype, because it might be boolean
# we want it to be interpreted as a bias
if mask.dtype == torch.bool:
# log(True) = log(1) = 0
# log(False) = log(0) = -inf
mask = torch.log(mask.to(dtype=torch.float16))
# now we make the mask bigger to add space for our new tokens
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
# copy over the old mask, in quandrants
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
# now fill in the attention bias to our redux tokens
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
keys["attention_mask"] = new_mask.to(txt.device)
keys["attention_mask_img_shape"] = mask_ref_size
c_out.append([torch.cat((txt, cond), dim=1), keys])
return (c_out,)
class unCLIPConditioning:
@classmethod
@@ -2128,6 +2180,7 @@ def init_builtin_extra_nodes():
"nodes_stable3d.py",
"nodes_sdupscale.py",
"nodes_photomaker.py",
"nodes_pixart.py",
"nodes_cond.py",
"nodes_morphology.py",
"nodes_stable_cascade.py",
@@ -2149,8 +2202,10 @@ def init_builtin_extra_nodes():
"nodes_torch_compile.py",
"nodes_mochi.py",
"nodes_slg.py",
"nodes_mahiro.py",
"nodes_lt.py",
"nodes_hooks.py",
"nodes_load_3d.py",
]
import_failed = []

Some files were not shown because too many files have changed in this diff Show More