Compare commits
160 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72212fef66 | |||
| df34f1549a | |||
| 9b0553809c | |||
| 8d7c930246 | |||
| de44b95db6 | |||
| 543888d3d8 | |||
| 70fc0425b3 | |||
| 85e34643f8 | |||
| 5c33872e2f | |||
| 206595f854 | |||
| b288fb0db8 | |||
| f73b176abd | |||
| 103a12cb66 | |||
| 97652d26b8 | |||
| bd1d9bcd5f | |||
| fb763d4333 | |||
| bcbd7884e3 | |||
| 27a0fcccc3 | |||
| ea6cdd2631 | |||
| 2ee7879a0b | |||
| 3493b9cb1f | |||
| c9ebe70072 | |||
| 261421e218 | |||
| a9f1bb10a5 | |||
| b0338e930b | |||
| b71f9bcb71 | |||
| 72855db715 | |||
| f48d05a2d1 | |||
| 4368d8f87f | |||
| 22da0a83e9 | |||
| 50333f1715 | |||
| 26d5b86da8 | |||
| 4f5812b937 | |||
| 1bcb469089 | |||
| 464ba1d614 | |||
| e3018c2a5a | |||
| 3412d53b1d | |||
| e2d1e5dad9 | |||
| 27e067ce50 | |||
| 9b15155972 | |||
| 32a627bf1f | |||
| fe442fac2e | |||
| d2c502e629 | |||
| fea9ea8268 | |||
| f949094b3c | |||
| 4449e14769 | |||
| 885015eecf | |||
| a86aaa4301 | |||
| 2efb2cbc38 | |||
| 15aa9222c4 | |||
| c7bb3e2bce | |||
| e80a14ad50 | |||
| d28b39d93d | |||
| 1c184c29eb | |||
| edde0b5043 | |||
| 0063610177 | |||
| ce0052c087 | |||
| 0eb821a7b6 | |||
| 4aa79dbf2c | |||
| 38f697d953 | |||
| 3aad339b63 | |||
| 491755325c | |||
| 496888fd68 | |||
| b5ac6ed7ce | |||
| b20ba1f27c | |||
| 31a37686d0 | |||
| 88aee596a3 | |||
| 6a193ac557 | |||
| 47f4db3e84 | |||
| 5352abc6d3 | |||
| 39aa06bd5d | |||
| 914c2a2973 | |||
| e633a47ad1 | |||
| f6b93d41a0 | |||
| 95ac7794b7 | |||
| 71ed4a399e | |||
| 3e316c6338 | |||
| 8be0d22ab7 | |||
| 59eddda900 | |||
| 41048c69b4 | |||
| fc247150fe | |||
| fe31ad0276 | |||
| ca4e96a8ae | |||
| 050c67323c | |||
| 497d41fb50 | |||
| ff57793659 | |||
| f7bd5e58dd | |||
| 7ed73d12d1 | |||
| eb39019daa | |||
| bab08f40d1 | |||
| bc49106837 | |||
| 1b2de2642d | |||
| 9fa1036f60 | |||
| 0737b7e0d2 | |||
| 0963493a9c | |||
| e73a9dbe30 | |||
| fe01885acf | |||
| 7139d6d93f | |||
| 2f52e8f05f | |||
| 8d38ea3bbf | |||
| 5a8f502db5 | |||
| 7cd2c4bd6a | |||
| dfa791eb4b | |||
| bddd69618b | |||
| 54d8fdbed0 | |||
| d844d8b13b | |||
| 07a927517c | |||
| f16a70ba67 | |||
| 36b5127fd3 | |||
| 4977f203fa | |||
| bd2ab73976 | |||
| da2efeaec6 | |||
| 7f3b9b16c6 | |||
| d4e353a94e | |||
| ed43784b0d | |||
| 0f2b8525bc | |||
| 20a84166d0 | |||
| ed2e33c69a | |||
| 1702e6df16 | |||
| c308a8840a | |||
| 027c63f63a | |||
| e08ecfbd8a | |||
| 4e5c230f6a | |||
| f0d5d0111f | |||
| ad19a069f6 | |||
| 5d65d6753b | |||
| deebee4ff6 | |||
| fa570cbf59 | |||
| 644b23ac0b | |||
| 72fd4d22b6 | |||
| e4f7ea105f | |||
| c991a5da65 | |||
| 9df8792d4b | |||
| 3da5a07510 | |||
| afa0a45206 | |||
| 615eb52049 | |||
| d5c1954d5c | |||
| e400f26c8f | |||
| 5ca8e2fac3 | |||
| 3294782d19 | |||
| 898d88e10e | |||
| 560d38f34c | |||
| e1d4f36d8d | |||
| 1e3ae1eed8 | |||
| f4231a80b1 | |||
| 2208aa616d | |||
| 629b173837 | |||
| fa340add55 | |||
| 966f3a5206 | |||
| 0552de7c7d | |||
| 5828607ccf | |||
| 735bb4bdb1 | |||
| bf2a1b5b1e | |||
| 42974a448c | |||
| 05df2df489 | |||
| 37d620a6b8 | |||
| 32691b16f4 | |||
| 4c3e57b0ae | |||
| 9126c0cfe4 | |||
| d8c51ba15a |
@@ -22,7 +22,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
||||
@@ -18,7 +18,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
||||
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -66,8 +66,13 @@ jobs:
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
@@ -85,7 +90,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
@@ -17,19 +17,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -64,6 +64,10 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@@ -82,7 +86,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
+14
-13
@@ -5,20 +5,21 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
|
||||
@@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
@@ -65,17 +65,17 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- Audio Models
|
||||
@@ -190,7 +190,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -202,7 +202,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
@@ -210,33 +210,25 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
||||
|
||||
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
@@ -351,7 +343,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
|
||||
|
||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
+10
-3
@@ -363,10 +363,17 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import logging
|
||||
import torchaudio
|
||||
|
||||
|
||||
class AudioEncoderModel():
|
||||
def __init__(self, config):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_audio(self, audio, sample_rate):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||
out, all_layers = self.model(audio.to(self.load_device))
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
return outputs
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
audio_encoder = AudioEncoderModel(None)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
|
||||
return audio_encoder
|
||||
@@ -0,0 +1,207 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class LayerNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class FeatureProjection(nn.Module):
|
||||
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer_norm(x)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
)
|
||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = self.conv(x)[:, :, :-1]
|
||||
x = self.activation(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
assert (mask is None) # TODO?
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.intermediate_dense(x)
|
||||
x = torch.nn.functional.gelu(x)
|
||||
x = self.output_dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
|
||||
x = x + self.feed_forward(self.final_layer_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
"""Complete Wav2Vec 2.0 model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=1024,
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
|
||||
return x, all_x
|
||||
+4
-1
@@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
@@ -141,8 +143,9 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
+12
-2
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
@@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
|
||||
+20
-4
@@ -50,7 +50,13 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -68,12 +74,18 @@ class ClipVisionModel():
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
outputs["all_hidden_states"] = all_hs
|
||||
else:
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
|
||||
# Dinov2
|
||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -0,0 +1,540 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get torch.Tensor applicable to current window.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
class ContextHandlerABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
@dataclass
|
||||
class ContextFuseMethod:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
self.context_overlap = context_overlap
|
||||
self.context_stride = context_stride
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
if control.previous_controlnet is not None:
|
||||
self.prepare_control_objects(control.previous_controlnet, device)
|
||||
return control
|
||||
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||
if cond_in is None:
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||
for key in actual_cond:
|
||||
try:
|
||||
cond_item = actual_cond[key]
|
||||
if isinstance(cond_item, torch.Tensor):
|
||||
# check that tensor is the expected length - x.size(0)
|
||||
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||
actual_cond_item = window.get_tensor(cond_item)
|
||||
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item.to(device)
|
||||
# look for control
|
||||
elif key == "control":
|
||||
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||
elif isinstance(cond_item, dict):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
resized_actual_cond[key] = new_cond_item
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item
|
||||
finally:
|
||||
del cond_item # just in case to prevent VRAM issues
|
||||
resized_cond.append(resized_actual_cond)
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
for pos, idx in enumerate(window.index_list):
|
||||
# bias is the influence of a specific index in relation to the whole context window
|
||||
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||
bias = max(1e-2, bias)
|
||||
# take weighted average relative to total bias of current idx
|
||||
for i in range(len(sub_conds_out)):
|
||||
bias_total = biases_final[i][idx]
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
window.add_window(counts_final[i], weights_tensor)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||
"ContextWindows_prepare_sampling",
|
||||
_prepare_sampling_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
for _ in range(dim):
|
||||
weights_tensor = weights_tensor.unsqueeze(0)
|
||||
for _ in range(total_dims - dim - 1):
|
||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||
return weights_tensor
|
||||
|
||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
total_dims = len(x_in.shape)
|
||||
shape = []
|
||||
for _ in range(dim):
|
||||
shape.append(1)
|
||||
shape.append(x_in.shape[dim])
|
||||
for _ in range(total_dims - dim - 1):
|
||||
shape.append(1)
|
||||
return shape
|
||||
|
||||
class ContextSchedules:
|
||||
UNIFORM_LOOPED = "looped_uniform"
|
||||
UNIFORM_STANDARD = "standard_uniform"
|
||||
STATIC_STANDARD = "standard_static"
|
||||
BATCHED = "batched"
|
||||
|
||||
|
||||
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames < handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
return windows
|
||||
|
||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||
# instead, they get shifted to the corresponding end of the frames.
|
||||
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# first, obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||
delete_idxs = []
|
||||
win_i = 0
|
||||
while win_i < len(windows):
|
||||
# if window is rolls over itself, need to shift it
|
||||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||
if is_roll:
|
||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||
# check if next window (cyclical) is missing roll_val
|
||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||
# need to insert new window here - just insert window starting at roll_val
|
||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
# delete window if it's not unique
|
||||
for pre_i in range(0, win_i):
|
||||
if windows[win_i] == windows[pre_i]:
|
||||
delete_idxs.append(win_i)
|
||||
break
|
||||
win_i += 1
|
||||
|
||||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||
delete_idxs.reverse()
|
||||
for i in delete_idxs:
|
||||
windows.pop(i)
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows
|
||||
delta = handler.context_length - handler.context_overlap
|
||||
for start_idx in range(0, num_frames, delta):
|
||||
# if past the end of frames, move start_idx back to allow same context_length
|
||||
ending = start_idx + handler.context_length
|
||||
if ending >= num_frames:
|
||||
final_delta = ending - num_frames
|
||||
final_start_idx = start_idx - final_delta
|
||||
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||
break
|
||||
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows;
|
||||
# no overlap, just cut up based on context_length;
|
||||
# last window size will be different if num_frames % opts.context_length != 0
|
||||
for start_idx in range(0, num_frames, handler.context_length):
|
||||
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||
return [list(range(num_frames))]
|
||||
|
||||
|
||||
CONTEXT_MAPPING = {
|
||||
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||
ContextSchedules.BATCHED: create_windows_batched,
|
||||
}
|
||||
|
||||
|
||||
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
# weight is the same for all
|
||||
return [1.0] * length
|
||||
|
||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
# weight is based on the distance away from the edge of the context window;
|
||||
# based on weighted average concept in FreeNoise paper
|
||||
if length % 2 == 0:
|
||||
max_weight = length // 2
|
||||
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||
else:
|
||||
max_weight = (length + 1) // 2
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
FLAT = "flat"
|
||||
PYRAMID = "pyramid"
|
||||
RELATIVE = "relative"
|
||||
OVERLAP_LINEAR = "overlap-linear"
|
||||
|
||||
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||
|
||||
|
||||
FUSE_MAPPING = {
|
||||
ContextFuseMethods.FLAT: create_weights_flat,
|
||||
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||
}
|
||||
|
||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||
func = FUSE_MAPPING.get(fuse_method, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||
return ContextFuseMethod(fuse_method, func)
|
||||
|
||||
# Returns fraction that has denominator that is a power of 2
|
||||
def ordered_halving(val):
|
||||
# get binary value, padded with 0s for 64 bits
|
||||
bin_str = f"{val:064b}"
|
||||
# flip binary value, padding included
|
||||
bin_flip = bin_str[::-1]
|
||||
# convert binary to int
|
||||
as_int = int(bin_flip, 2)
|
||||
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||
return as_int / (1 << 64)
|
||||
|
||||
|
||||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||
all_indexes = list(range(num_frames))
|
||||
for w in windows:
|
||||
for val in w:
|
||||
try:
|
||||
all_indexes.remove(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return all_indexes
|
||||
|
||||
|
||||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||
prev_val = -1
|
||||
for i, val in enumerate(window):
|
||||
val = val % num_frames
|
||||
if val < prev_val:
|
||||
return True, i
|
||||
prev_val = val
|
||||
return False, -1
|
||||
|
||||
|
||||
def shift_window_to_start(window: list[int], num_frames: int):
|
||||
start_val = window[0]
|
||||
for i in range(len(window)):
|
||||
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||
# 2) add num_frames and take modulus to get adjusted vals
|
||||
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||
|
||||
|
||||
def shift_window_to_end(window: list[int], num_frames: int):
|
||||
# 1) shift window to start
|
||||
shift_window_to_start(window, num_frames)
|
||||
end_val = window[-1]
|
||||
end_delta = num_frames - end_val - 1
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
+26
-3
@@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = None
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
compression_ratio *= self.vae.spacial_compression_encode()
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@@ -252,7 +253,10 @@ class ControlNet(ControlBase):
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
||||
if c.ndim < self.cond_hint.ndim:
|
||||
c = c.unsqueeze(2)
|
||||
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
@@ -582,6 +586,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
||||
|
||||
extra_condition_channels = 0
|
||||
concat_mask = False
|
||||
if control_latent_channels == 68: #inpaint controlnet
|
||||
extra_condition_channels = control_latent_channels - 64
|
||||
concat_mask = True
|
||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -655,8 +675,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
|
||||
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
class Dinov2MLP(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
mlp_ratio = 4
|
||||
hidden_features = int(hidden_size * mlp_ratio)
|
||||
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||
hidden_state = self.fc2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
else:
|
||||
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
for i, layer in enumerate(self.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"hidden_size": 1024,
|
||||
"use_mask_token": true,
|
||||
"patch_size": 14,
|
||||
"image_size": 518,
|
||||
"num_channels": 3,
|
||||
"num_attention_heads": 16,
|
||||
"initializer_range": 0.02,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_hidden_layers": 24,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": false,
|
||||
"layerscale_value": 1.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
return sigmas
|
||||
|
||||
|
||||
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||
return torch.expm1(h)
|
||||
|
||||
|
||||
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||
return (torch.expm1(h) - h) / h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@@ -853,6 +863,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@@ -925,6 +940,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -1535,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@@ -1549,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r < 1
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@@ -1609,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r_1 * h
|
||||
lambda_s_2 = lambda_s + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r_1 < r_2 < 1
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
# Step 2
|
||||
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||
if inject_noise:
|
||||
segment_factor = (r_1 - r_2) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
# Step 3
|
||||
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||
b1 = ei_h_phi_1(-h_eta) - b3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||
if inject_noise:
|
||||
segment_factor = (r_2 - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -533,11 +533,21 @@ class Wan22(Wan21):
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
scale_factor = 0.75289
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2_1(LatentFormat):
|
||||
scale_factor = 1.0039506158752403
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
+23
-1
@@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
@@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||
controlnet_scale, lyrics_strength, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
|
||||
@@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module):
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
@@ -253,6 +254,13 @@ class Chroma(nn.Module):
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask,
|
||||
fps,
|
||||
image_size,
|
||||
padding_mask,
|
||||
scalar_feature,
|
||||
data_type,
|
||||
latent_condition,
|
||||
latent_condition_sigma,
|
||||
condition_video_augment_sigma,
|
||||
**kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -11,6 +11,7 @@ import math
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
@@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
|
||||
+42
-11
@@ -6,6 +6,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -105,6 +106,7 @@ class Flux(nn.Module):
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -116,9 +118,17 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@@ -157,7 +167,7 @@ class Flux(nn.Module):
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
img[:, :add.shape[1]] += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -188,7 +198,7 @@ class Flux(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
@@ -214,6 +224,13 @@ class Flux(nn.Module):
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
@@ -224,19 +241,33 @@ class Flux(nn.Module):
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
for ref in ref_latents:
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
if ref_latents_method == "index":
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
index = 0
|
||||
h_offset = h_len * patch_size + h
|
||||
w_offset = w_len * patch_size + w
|
||||
h += ref.shape[-2]
|
||||
w += ref.shape[-1]
|
||||
else:
|
||||
h_offset = h
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
|
||||
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
|
||||
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
|
||||
+486
-85
@@ -4,81 +4,458 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
return xyz, grid_size, length
|
||||
batch_size = int(batch.max()) + 1
|
||||
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||
|
||||
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||
|
||||
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||
|
||||
#return fps_sampling(src, ptr_vec, ratio)
|
||||
sampled_indicies = []
|
||||
|
||||
for b in range(batch_size):
|
||||
# start and the end of each batch
|
||||
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||
# points from the point cloud
|
||||
points = src[start:end]
|
||||
|
||||
num_points = points.size(0)
|
||||
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||
|
||||
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||
|
||||
# select a random start point
|
||||
if start_random:
|
||||
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||
else:
|
||||
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||
|
||||
for i in range(num_samples):
|
||||
selected[i] = farthest
|
||||
centroid = points[farthest].squeeze(0)
|
||||
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||
distances = torch.minimum(distances, dist)
|
||||
farthest = torch.argmax(distances)
|
||||
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim = 0)
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
downsample_ratio: float,
|
||||
pc_size: int,
|
||||
pc_sharpedge_size: int,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
fourier_embedder,
|
||||
normal_pe: bool = False,
|
||||
qkv_bias: bool = False,
|
||||
use_ln_post: bool = True,
|
||||
qk_norm: bool = True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.normal_pe = normal_pe
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.num_latents = num_latents
|
||||
self.point_feats = point_feats
|
||||
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm,
|
||||
layers = layers
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
"""
|
||||
Subsample points randomly from the point cloud (input_pc)
|
||||
Further sample the subsampled points to get query_pc
|
||||
take the fourier embeddings for both input and query pc
|
||||
|
||||
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||
More computationally efficient.
|
||||
|
||||
Features are additional information for each point in the cloud
|
||||
"""
|
||||
|
||||
B, _, D = point_cloud.shape
|
||||
|
||||
num_latents = int(self.num_latents)
|
||||
|
||||
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||
num_sharpedge_query = num_latents - num_random_query
|
||||
|
||||
# Split random and sharpedge surface points
|
||||
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
|
||||
# assert statements
|
||||
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||
|
||||
else:
|
||||
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
# concat the random and sharpedges
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
if self.point_feats > 0:
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||
|
||||
input_random_surface_features, query_random_features = \
|
||||
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||
|
||||
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||
else:
|
||||
|
||||
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||
|
||||
if self.normal_pe:
|
||||
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||
# replace the first 3 dims with the new PE ones
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||
|
||||
# concat at the channels dim
|
||||
query = torch.cat([query, query_features], dim = -1)
|
||||
data = torch.cat([data, input_features], dim = -1)
|
||||
|
||||
# don't return pc_info to avoid unnecessary memory usuage
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||
|
||||
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||
|
||||
# apply projections
|
||||
query = self.input_proj(query)
|
||||
data = self.input_proj(data)
|
||||
|
||||
# apply cross attention between query and data
|
||||
latents = self.cross_attn(query, data)
|
||||
|
||||
if self.self_attn is not None:
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
num_query: number of points to keep after FPS
|
||||
input_pc_size: number of points to select before FPS
|
||||
"""
|
||||
|
||||
B, _, D = pc.shape
|
||||
query_ratio = num_query / input_pc_size
|
||||
|
||||
# random subsampling of points inside the point cloud
|
||||
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||
input_pc = pc[:, idx_pc, :]
|
||||
|
||||
# flatten to allow applying fps across the whole batch
|
||||
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||
|
||||
# construct a batch_down tensor to tell fps
|
||||
# which points belong to which batch
|
||||
N_down = int(flattent_input_pc.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||
|
||||
return query_pc, input_pc, idx_pc, idx_query
|
||||
|
||||
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||
|
||||
B = batch_size
|
||||
|
||||
input_surface_features = features[:, idx_pc, :]
|
||||
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||
flattent_input_features.shape[-1])
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
def normalize_mesh(mesh, scale = 0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
bbox = mesh.bounds
|
||||
center = (bbox[1] + bbox[0]) / 2
|
||||
|
||||
max_extent = (bbox[1] - bbox[0]).max()
|
||||
mesh.apply_translation(-center)
|
||||
mesh.apply_scale((2 * scale) / max_extent)
|
||||
|
||||
return mesh
|
||||
|
||||
def sample_pointcloud(mesh, num = 200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
points, face_idx = mesh.sample(num, return_index = True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
V, F = mesh.vertices, mesh.faces
|
||||
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||
|
||||
sharp_mask = np.ones(V.shape[0])
|
||||
for i in range(3):
|
||||
indices = F[:, i]
|
||||
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||
|
||||
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||
|
||||
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
""" Sample points preferentially from sharp edges in the mesh. """
|
||||
|
||||
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||
V, VN = mesh.vertices, mesh.vertex_normals
|
||||
|
||||
va, vb = V[edge_a], V[edge_b]
|
||||
na, nb = VN[edge_a], VN[edge_b]
|
||||
|
||||
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||
weights = edge_lengths / edge_lengths.sum()
|
||||
|
||||
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||
t = np.random.rand(num, 1)
|
||||
|
||||
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
except Exception:
|
||||
mesh_full = trimesh.util.concatenate(mesh)
|
||||
|
||||
mesh_full = normalize_mesh(mesh_full)
|
||||
|
||||
faces = mesh_full.faces
|
||||
vertices = mesh_full.vertices
|
||||
origin_face_count = faces.shape[0]
|
||||
|
||||
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||
|
||||
area_surface = mesh_surface.area
|
||||
area_fill = mesh_fill.area
|
||||
total_area = area_surface + area_fill
|
||||
|
||||
sample_num = 499712 // 2
|
||||
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||
|
||||
num_fill = int(sample_num * fill_ratio)
|
||||
num_surface = sample_num - num_fill
|
||||
|
||||
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||
|
||||
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||
|
||||
def assemble_tensor(points, normals, label=None):
|
||||
|
||||
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||
|
||||
if label is not None:
|
||||
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||
data = torch.cat([data, label_tensor], dim=1)
|
||||
|
||||
return data
|
||||
|
||||
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||
label = 0 if sharpedge_flag else None)
|
||||
|
||||
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||
label = 1 if sharpedge_flag else None)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||
|
||||
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||
|
||||
return full
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.total_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_input, device = "cuda"):
|
||||
mesh = self._load_mesh(mesh_input)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||
|
||||
@staticmethod
|
||||
def _load_mesh(mesh_input):
|
||||
import trimesh
|
||||
|
||||
if isinstance(mesh_input, str):
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||
else:
|
||||
mesh = mesh_input
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
combined = None
|
||||
for obj in mesh.geometry.values():
|
||||
combined = obj if combined is None else combined + obj
|
||||
return combined
|
||||
|
||||
return mesh
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
|
||||
class VanillaVolumeDecoder():
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@@ -232,38 +607,41 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
|
||||
return out
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
if not self.enable_ln_post:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
self,
|
||||
*,
|
||||
num_latents: int = 4096,
|
||||
embed_dim: int = 64,
|
||||
width: int = 1024,
|
||||
heads: int = 16,
|
||||
num_decoder_layers: int = 16,
|
||||
num_encoder_layers: int = 8,
|
||||
pc_size: int = 81920,
|
||||
pc_sharpedge_size: int = 0,
|
||||
point_feats: int = 4,
|
||||
downsample_ratio: int = 20,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
drop_path_rate: float = 0.0,
|
||||
include_pi: bool = False,
|
||||
scale_factor: float = 1.0039506158752403,
|
||||
label_type: str = "binary",
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||
num_latents = num_latents,
|
||||
downsample_ratio = downsample_ratio,
|
||||
heads = heads,
|
||||
pc_size = pc_size,
|
||||
width = width,
|
||||
point_feats = point_feats,
|
||||
fourier_embedder = self.fourier_embedder,
|
||||
pc_sharpedge_size = pc_sharpedge_size)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||
|
||||
latents = posterior.sample()
|
||||
|
||||
return latents
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if gate.device.type == "mps":
|
||||
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||
|
||||
return F.gelu(gate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(act_fn)
|
||||
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class AddAuxLoss(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
# do nothing in forward (no computation)
|
||||
ctx.requires_aux_loss = loss.requires_grad
|
||||
ctx.dtype = loss.dtype
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# add the aux loss gradients
|
||||
grad_loss = None
|
||||
# put the aux grad the same as the main grad loss
|
||||
# aux grad contributes equally
|
||||
if ctx.requires_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||
|
||||
return grad_output, grad_loss
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
self.top_k = num_experts_per_tok
|
||||
self.n_routed_experts = num_experts
|
||||
|
||||
self.alpha = aux_loss_alpha
|
||||
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# flatten hidden states
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
|
||||
# get logits and pass it to softmax
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||
scores = logits.softmax(dim = -1)
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
|
||||
# used bincount instead of one hot encoding
|
||||
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||
ce = counts / topk_idx.numel() # normalized expert usage
|
||||
|
||||
# mean expert score
|
||||
Pi = scores_for_aux.mean(0)
|
||||
|
||||
# expert balance loss
|
||||
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
class MoEBlock(nn.Module):
|
||||
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.moe_top_k = moe_top_k
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
|
||||
if self.training:
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||
|
||||
for i, expert in enumerate(self.experts):
|
||||
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||
y = y.view(*orig_shape)
|
||||
|
||||
y = AddAuxLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
|
||||
y = y + self.shared_experts(identity)
|
||||
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
# no need for .numpy().cpu() here
|
||||
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||
token_idxs = idxs // self.moe_top_k
|
||||
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||
# + avoid dtype conversion
|
||||
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||
scale: float = 1.0, max_period: int = 10000):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels = num_channels
|
||||
half_dim = num_channels // 2
|
||||
|
||||
# precompute the “inv_freq” vector once
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
half_dim, dtype=torch.float32
|
||||
) / (half_dim - downscale_freq_shift)
|
||||
|
||||
inv_freq = torch.exp(exponent)
|
||||
|
||||
# pad
|
||||
if num_channels % 2 == 1:
|
||||
# we’ll pad a zero at the end of the cos-half
|
||||
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||
|
||||
# register to buffer so it moves with the device
|
||||
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps: torch.Tensor):
|
||||
|
||||
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||
|
||||
|
||||
# fused CUDA kernels for sin and cos
|
||||
sin_emb = x.sin()
|
||||
cos_emb = x.cos()
|
||||
|
||||
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||
|
||||
# scale factor
|
||||
if self.scale != 1.0:
|
||||
emb = emb * self.scale
|
||||
|
||||
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||
if emb.shape[1] > self.num_channels:
|
||||
emb = emb[:, :self.num_channels]
|
||||
|
||||
return emb
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||
|
||||
self.time_embed = Timesteps(hidden_size)
|
||||
|
||||
def forward(self, timesteps, condition):
|
||||
|
||||
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||
|
||||
if condition is not None:
|
||||
cond_embed = self.cond_proj(condition)
|
||||
timestep_embed = timestep_embed + cond_embed
|
||||
|
||||
time_conditioned = self.mlp(timestep_embed)
|
||||
|
||||
# for broadcasting with image tokens
|
||||
return time_conditioned.unsqueeze(1)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.gelu(self.fc1(x)))
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.qdim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, y):
|
||||
|
||||
b, s1, _ = x.shape
|
||||
_, s2, _ = y.shape
|
||||
|
||||
y = y.to(next(self.to_k.parameters()).dtype)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(y)
|
||||
v = self.to_v(y)
|
||||
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
x = optimized_attention(
|
||||
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
|
||||
return out
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias = True,
|
||||
qk_norm = False,
|
||||
norm_layer = nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None,
|
||||
dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
query = self.q_norm(query)
|
||||
key = self.k_norm(key)
|
||||
|
||||
x = optimized_attention(
|
||||
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm_layer=True,
|
||||
qkv_bias=True,
|
||||
skip_connection=True,
|
||||
timested_modulate=False,
|
||||
use_moe: bool = False,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None, dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.timested_modulate = timested_modulate
|
||||
if self.timested_modulate:
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
if skip_connection:
|
||||
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.use_moe = use_moe
|
||||
|
||||
if self.use_moe:
|
||||
self.moe = MoEBlock(
|
||||
hidden_size,
|
||||
num_experts = num_experts,
|
||||
moe_top_k = moe_top_k,
|
||||
dropout = 0.0,
|
||||
ff_inner_dim = int(hidden_size * 4.0),
|
||||
device = device, dtype = dtype,
|
||||
operations = operations
|
||||
)
|
||||
else:
|
||||
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||
|
||||
if self.skip_linear is not None:
|
||||
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||
hidden_states = self.skip_linear(combined)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
|
||||
# self attention
|
||||
if self.timested_modulate:
|
||||
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||
hidden_states = hidden_states + modulation_shift
|
||||
|
||||
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = hidden_states + self_attn_out
|
||||
|
||||
# cross attention
|
||||
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||
|
||||
# MLP Layer
|
||||
mlp_input = self.norm3(hidden_states)
|
||||
|
||||
if self.use_moe:
|
||||
hidden_states = hidden_states + self.moe(mlp_input)
|
||||
else:
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
|
||||
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = x[:, 1:]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
hidden_size: int = 2048,
|
||||
context_dim: int = 1024,
|
||||
depth: int = 21,
|
||||
num_heads: int = 16,
|
||||
qk_norm: bool = True,
|
||||
qkv_bias: bool = False,
|
||||
num_moe_layers: int = 6,
|
||||
guidance_cond_proj_dim = 2048,
|
||||
norm_type = 'layer',
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||
qk_norm = operations.RMSNorm
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||
|
||||
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
text_states_dim=context_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer = norm,
|
||||
qk_norm_layer = qk_norm,
|
||||
skip_connection=layer > depth // 2,
|
||||
qkv_bias=qkv_bias,
|
||||
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k,
|
||||
use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
|
||||
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||
|
||||
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||
x_embedded = self.x_embedder(x)
|
||||
|
||||
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||
|
||||
def block_wrap(args):
|
||||
return block(
|
||||
args["x"],
|
||||
args["t"],
|
||||
args["cond"],
|
||||
skip_tensor=args.get("skip"),)
|
||||
|
||||
skip_stack = []
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for idx, block in enumerate(self.blocks):
|
||||
if idx <= self.depth // 2:
|
||||
skip_input = None
|
||||
else:
|
||||
skip_input = skip_stack.pop()
|
||||
|
||||
if ("block", idx) in blocks_replace:
|
||||
|
||||
combined = blocks_replace[("block", idx)](
|
||||
{
|
||||
"x": combined,
|
||||
"t": time_embedded,
|
||||
"cond": main_condition,
|
||||
"skip": skip_input,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
else:
|
||||
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||
|
||||
if idx < self.depth // 2:
|
||||
skip_stack.append(combined)
|
||||
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
@@ -1,6 +1,7 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@@ -39,6 +40,7 @@ class HunyuanVideoParams:
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@@ -160,6 +162,30 @@ class TokenRefiner(nn.Module):
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
|
||||
|
||||
class ByT5Mapper(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
|
||||
self.use_res = use_res
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res:
|
||||
res = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x2 = self.act_fn(x)
|
||||
x2 = self.fc3(x2)
|
||||
if self.use_res:
|
||||
x2 = x2 + res
|
||||
return x2
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@@ -184,9 +210,13 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
@@ -214,6 +244,18 @@ class HunyuanVideo(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
if params.byt5:
|
||||
self.byt5_in = ByT5Mapper(
|
||||
in_dim=1472,
|
||||
out_dim=2048,
|
||||
hidden_dim=2048,
|
||||
out_dim1=self.hidden_size,
|
||||
use_res=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.byt5_in = None
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@@ -225,7 +267,8 @@ class HunyuanVideo(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@@ -249,13 +292,17 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
if self.vector_in is not None:
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
else:
|
||||
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
@@ -268,6 +315,12 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@@ -327,12 +380,16 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
shape = initial_shape[-len(self.patch_size):]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
if img.ndim == 8:
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
else:
|
||||
img = img.permute(0, 3, 1, 4, 2, 5)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
|
||||
return img
|
||||
|
||||
def img_ids(self, x):
|
||||
@@ -347,9 +404,30 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
def img_ids_2d(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
|
||||
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
|
||||
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
||||
self.ratio = (in_dim << 2) // out_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h >> 1, w >> 1
|
||||
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
||||
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
||||
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
||||
|
||||
|
||||
class PixelUnshuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
||||
self.scale = (out_dim << 2) // in_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h << 1, w << 1
|
||||
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
return y + r
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, z):
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
@@ -590,8 +591,15 @@ class NextDiT(nn.Module):
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
|
||||
@@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
|
||||
@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
|
||||
|
||||
|
||||
#################################################################################
|
||||
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
|
||||
return x
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
if use_checkpoint:
|
||||
|
||||
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
def forward(self, x, temb=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
def __init__(
|
||||
self,
|
||||
extra_condition_channels=0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
self.main_model_double = 60
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
||||
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
hint, _, _ = self.process_img(hint)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
||||
|
||||
controlnet_block_samples = ()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
||||
|
||||
return {"input": controlnet_block_samples[:self.main_model_double]}
|
||||
+103
-34
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@@ -214,9 +215,9 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -248,11 +249,11 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
@@ -275,7 +276,7 @@ class LastLayer(nn.Module):
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||
return x
|
||||
|
||||
|
||||
@@ -293,6 +294,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -300,6 +302,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
@@ -329,46 +332,86 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def pos_embeds(self, x, context):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
txt_start = round(max(h_len, w_len))
|
||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
image_rotary_emb = self.pos_embeds(x, context)
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
@@ -383,18 +426,44 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
||||
+533
-21
@@ -4,13 +4,14 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
@@ -148,11 +149,14 @@ WAN_CROSSATTENTION_CLASSES = {
|
||||
|
||||
def repeat_e(e, x):
|
||||
repeats = 1
|
||||
if e.shape[1] > 1:
|
||||
repeats = x.shape[1] // e.shape[1]
|
||||
if e.size(1) > 1:
|
||||
repeats = x.size(1) // e.size(1)
|
||||
if repeats == 1:
|
||||
return e
|
||||
return torch.repeat_interleave(e, repeats, dim=1)
|
||||
if repeats * e.size(1) == x.size(1):
|
||||
return torch.repeat_interleave(e, repeats, dim=1)
|
||||
else:
|
||||
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
@@ -219,15 +223,15 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs)
|
||||
|
||||
x = x + y * repeat_e(e[2], x)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||
x = x + y * repeat_e(e[5], x)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
@@ -342,7 +346,7 @@ class Head(nn.Module):
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||
|
||||
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
|
||||
return x
|
||||
|
||||
|
||||
@@ -391,6 +395,7 @@ class WanModel(torch.nn.Module):
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
in_dim_ref_conv=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -484,6 +489,11 @@ class WanModel(torch.nn.Module):
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
if in_dim_ref_conv is not None:
|
||||
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.ref_conv = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
@@ -526,6 +536,13 @@ class WanModel(torch.nn.Module):
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
full_ref = None
|
||||
if self.ref_conv is not None:
|
||||
full_ref = kwargs.get("reference_latent", None)
|
||||
if full_ref is not None:
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@@ -552,31 +569,56 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
t_len = x.shape[2]
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@@ -749,7 +791,12 @@ class CameraWanModel(WanModel):
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
if model_type == 'camera':
|
||||
model_type = 'i2v'
|
||||
else:
|
||||
model_type = 't2v'
|
||||
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
@@ -807,3 +854,468 @@ class CameraWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_mode='replicate',
|
||||
operations=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = operations.Conv1d(
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MotionEncoder_tc(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim: int,
|
||||
hidden_dim: int,
|
||||
num_heads=int,
|
||||
need_global=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.need_global = need_global
|
||||
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
|
||||
if need_global:
|
||||
self.conv1_global = CausalConv1d(
|
||||
in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
|
||||
self.norm1 = operations.LayerNorm(
|
||||
hidden_dim // 4,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
**factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
|
||||
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
|
||||
|
||||
if need_global:
|
||||
self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||
|
||||
self.norm1 = operations.LayerNorm(
|
||||
hidden_dim // 4,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
**factory_kwargs)
|
||||
|
||||
self.norm2 = operations.LayerNorm(
|
||||
hidden_dim // 2,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
**factory_kwargs)
|
||||
|
||||
self.norm3 = operations.LayerNorm(
|
||||
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x_ori = x.clone()
|
||||
b, c, t = x.shape
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
if not self.need_global:
|
||||
return x_local
|
||||
|
||||
x = self.conv1_global(x_ori)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.final_linear(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
|
||||
return x, x_local
|
||||
|
||||
|
||||
class CausalAudioEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=5120,
|
||||
num_layers=25,
|
||||
out_dim=2048,
|
||||
video_rate=8,
|
||||
num_token=4,
|
||||
need_global=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.encoder = MotionEncoder_tc(
|
||||
in_dim=dim,
|
||||
hidden_dim=out_dim,
|
||||
num_heads=num_token,
|
||||
need_global=need_global, dtype=dtype, device=device, operations=operations)
|
||||
weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
|
||||
|
||||
self.weights = torch.nn.Parameter(weight)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, features):
|
||||
# features B * num_layers * dim * video_length
|
||||
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
|
||||
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||
weighted_feat = ((features * weights) / weights_sum).sum(
|
||||
dim=1) # b dim f
|
||||
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||
res = self.encoder(weighted_feat) # b f n dim
|
||||
return res # b f n dim
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(self.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AudioInjector_WAN(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=2048,
|
||||
num_heads=32,
|
||||
inject_layer=[0, 27],
|
||||
root_net=None,
|
||||
enable_adain=False,
|
||||
adain_dim=2048,
|
||||
adain_mode=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.enable_adain = enable_adain
|
||||
self.adain_mode = adain_mode
|
||||
self.injected_block_id = {}
|
||||
audio_injector_id = 0
|
||||
for inject_id in inject_layer:
|
||||
self.injected_block_id[inject_id] = audio_injector_id
|
||||
audio_injector_id += 1
|
||||
|
||||
self.injector = nn.ModuleList([
|
||||
WanT2VCrossAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
|
||||
) for _ in range(audio_injector_id)
|
||||
])
|
||||
self.injector_pre_norm_feat = nn.ModuleList([
|
||||
operations.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6, dtype=dtype, device=device
|
||||
) for _ in range(audio_injector_id)
|
||||
])
|
||||
self.injector_pre_norm_vec = nn.ModuleList([
|
||||
operations.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6, dtype=dtype, device=device
|
||||
) for _ in range(audio_injector_id)
|
||||
])
|
||||
if enable_adain:
|
||||
self.injector_adain_layers = nn.ModuleList([
|
||||
AdaLayerNorm(
|
||||
output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(audio_injector_id)
|
||||
])
|
||||
if adain_mode != "attn_norm":
|
||||
self.injector_adain_output_layers = nn.ModuleList(
|
||||
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
|
||||
|
||||
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
|
||||
audio_attn_id = self.injected_block_id.get(block_id, None)
|
||||
if audio_attn_id is None:
|
||||
return x
|
||||
|
||||
num_frames = audio_emb.shape[1]
|
||||
input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
|
||||
if self.enable_adain and self.adain_mode == "attn_norm":
|
||||
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||
adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||
attn_hidden_states = adain_hidden_states
|
||||
else:
|
||||
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
|
||||
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
attn_audio_emb = audio_emb
|
||||
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
|
||||
residual_out = rearrange(
|
||||
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
x[:, :seq_len] = x[:, :seq_len] + residual_out
|
||||
return x
|
||||
|
||||
|
||||
class FramePackMotioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inner_dim=1024,
|
||||
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
|
||||
zip_frame_buckets=[
|
||||
1, 2, 16
|
||||
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
|
||||
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
|
||||
self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
|
||||
self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
|
||||
self.zip_frame_buckets = zip_frame_buckets
|
||||
|
||||
self.inner_dim = inner_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.drop_mode = drop_mode
|
||||
|
||||
def forward(self, motion_latents, rope_embedder, add_last_motion=2):
|
||||
lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
|
||||
padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
|
||||
if overlap_frame > 0:
|
||||
padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||
zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
|
||||
padd_lat[:, :, -zero_end_frame:] = 0
|
||||
|
||||
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
|
||||
|
||||
# patchfy
|
||||
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||
clean_latents_2x = self.proj_2x(clean_latents_2x)
|
||||
l_2x_shape = clean_latents_2x.shape
|
||||
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
|
||||
clean_latents_4x = self.proj_4x(clean_latents_4x)
|
||||
l_4x_shape = clean_latents_4x.shape
|
||||
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||
clean_latents_post = clean_latents_post[:, :
|
||||
0] if add_last_motion < 2 else clean_latents_post
|
||||
clean_latents_2x = clean_latents_2x[:, :
|
||||
0] if add_last_motion < 1 else clean_latents_2x
|
||||
|
||||
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||
|
||||
rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||
|
||||
rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
|
||||
return motion_lat, rope
|
||||
|
||||
|
||||
class WanModel_S2V(WanModel):
|
||||
def __init__(self,
|
||||
model_type='s2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
audio_dim=1024,
|
||||
num_audio_token=4,
|
||||
enable_adain=True,
|
||||
cond_dim=16,
|
||||
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||
adain_mode="attn_norm",
|
||||
framepack_drop_mode="padd",
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
||||
|
||||
self.casual_audio_encoder = CausalAudioEncoder(
|
||||
dim=audio_dim,
|
||||
out_dim=self.dim,
|
||||
num_token=num_audio_token,
|
||||
need_global=enable_adain, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if cond_dim > 0:
|
||||
self.cond_encoder = operations.Conv3d(
|
||||
cond_dim,
|
||||
self.dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size, device=device, dtype=dtype)
|
||||
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
dim=self.dim,
|
||||
num_heads=self.num_heads,
|
||||
inject_layer=audio_inject_layers,
|
||||
root_net=self,
|
||||
enable_adain=enable_adain,
|
||||
adain_dim=self.dim,
|
||||
adain_mode=adain_mode,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.frame_packer = FramePackMotioner(
|
||||
inner_dim=self.dim,
|
||||
num_heads=self.num_heads,
|
||||
zip_frame_buckets=[1, 2, 16],
|
||||
drop_mode=framepack_drop_mode,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
audio_embed=None,
|
||||
reference_latent=None,
|
||||
control_video=None,
|
||||
reference_motion=None,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
if audio_embed is not None:
|
||||
num_embeds = x.shape[-3] * 4
|
||||
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
|
||||
else:
|
||||
audio_emb = None
|
||||
|
||||
# embeddings
|
||||
bs, _, time, height, width = x.shape
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if control_video is not None:
|
||||
x = x + self.cond_encoder(control_video)
|
||||
|
||||
if t.ndim == 1:
|
||||
t = t.unsqueeze(1).repeat(1, x.shape[2])
|
||||
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
seq_len = x.size(1)
|
||||
|
||||
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
|
||||
x = x + cond_mask_weight[0]
|
||||
|
||||
if reference_latent is not None:
|
||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||
ref = ref.flatten(2).transpose(1, 2)
|
||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||
ref = ref + cond_mask_weight[1]
|
||||
x = torch.cat([x, ref], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||
del ref, freqs_ref
|
||||
|
||||
if reference_motion is not None:
|
||||
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||
motion_encoded = motion_encoded + cond_mask_weight[2]
|
||||
x = torch.cat([x, motion_encoded], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_motion], dim=1)
|
||||
|
||||
t = torch.repeat_interleave(t, 2, dim=1)
|
||||
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||
del motion_encoded, freqs_motion
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
if audio_emb is not None:
|
||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
@@ -293,6 +297,16 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.QwenImage):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# Direct mapping for transformer_blocks format (QwenImage LoRA format)
|
||||
key_map["{}".format(key_lora)] = k
|
||||
# Support transformer prefix format
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||
|
||||
def convert_uso_lora(sd):
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
tensor = sd[k]
|
||||
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
|
||||
.replace(".up.weight", ".lora_up.weight")
|
||||
.replace(".qkv_lora2.", ".txt_attn.qkv.")
|
||||
.replace(".qkv_lora1.", ".img_attn.qkv.")
|
||||
.replace(".proj_lora1.", ".img_attn.proj.")
|
||||
.replace(".proj_lora2.", ".txt_attn.proj.")
|
||||
.replace(".qkv_lora.", ".linear1_qkv.")
|
||||
.replace(".proj_lora.", ".linear2.")
|
||||
.replace(".processor.", ".")
|
||||
)
|
||||
sd_out[k_to] = tensor
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||
return convert_lora_wan_fun(sd)
|
||||
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
||||
return convert_uso_lora(sd)
|
||||
return sd
|
||||
|
||||
+124
-12
@@ -16,6 +16,8 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@@ -150,6 +152,7 @@ class BaseModel(torch.nn.Module):
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
self.memory_usage_factor_conds = ()
|
||||
self.memory_usage_shape_process = {}
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -350,8 +353,15 @@ class BaseModel(torch.nn.Module):
|
||||
input_shapes = [input_shape]
|
||||
for c in self.memory_usage_factor_conds:
|
||||
shape = cond_shapes.get(c, None)
|
||||
if shape is not None and len(shape) > 0:
|
||||
input_shapes += shape
|
||||
if shape is not None:
|
||||
if c in self.memory_usage_shape_process:
|
||||
out = []
|
||||
for s in shape:
|
||||
out.append(self.memory_usage_shape_process[c](s))
|
||||
shape = out
|
||||
|
||||
if len(shape) > 0:
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
@@ -890,6 +900,10 @@ class Flux(BaseModel):
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
@@ -1098,9 +1112,10 @@ class WAN21(BaseModel):
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], 16):
|
||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if extra_channels != image.shape[1] + 4:
|
||||
@@ -1124,7 +1139,11 @@ class WAN21(BaseModel):
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((mask, image), dim=1)
|
||||
concat_mask_index = kwargs.get("concat_mask_index", 0)
|
||||
if concat_mask_index != 0:
|
||||
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
|
||||
else:
|
||||
return torch.cat((mask, image), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1140,6 +1159,10 @@ class WAN21(BaseModel):
|
||||
if time_dim_concat is not None:
|
||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1189,18 +1212,50 @@ class WAN21_Camera(WAN21):
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class WAN22(BaseModel):
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
|
||||
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
audio_embed = kwargs.get("audio_embed", None)
|
||||
if audio_embed is not None:
|
||||
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||
|
||||
reference_motion = kwargs.get("reference_motion", None)
|
||||
if reference_motion is not None:
|
||||
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
|
||||
|
||||
control_video = kwargs.get("control_video", None)
|
||||
if control_video is not None:
|
||||
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
reference_motion = kwargs.get("reference_motion", None)
|
||||
if reference_motion is not None:
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
denoise_mask = kwargs.get("denoise_mask", None)
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
return out
|
||||
@@ -1229,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2_1(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guidance = kwargs.get("guidance", 5.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
@@ -1313,10 +1383,52 @@ class Omnigen2(BaseModel):
|
||||
class QwenImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class HunyuanImage21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||
if conditioning_byt5small is not None:
|
||||
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
return out
|
||||
|
||||
@@ -136,20 +136,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
|
||||
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = [1, 2, 2]
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = list(in_w.shape[2:])
|
||||
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
||||
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
dit_config["axes_dim"] = [64, 64]
|
||||
|
||||
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_w.shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["num_heads"] = in_w.shape[0] // 128
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 256
|
||||
dit_config["qkv_bias"] = True
|
||||
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
|
||||
dit_config["byt5"] = True
|
||||
else:
|
||||
dit_config["byt5"] = False
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
@@ -364,7 +376,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
dit_config["model_type"] = "camera_2.2"
|
||||
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "s2v"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -373,6 +390,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
|
||||
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
@@ -390,6 +412,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
dit_config["context_dim"] = 1024
|
||||
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hidream"
|
||||
@@ -484,6 +520,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
|
||||
+50
-17
@@ -22,6 +22,7 @@ from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
@@ -78,7 +79,6 @@ try:
|
||||
torch_version = torch.version.__version__
|
||||
temp = torch_version.split(".")
|
||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -102,10 +102,14 @@ if args.directml is not None:
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = False
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -286,6 +290,24 @@ def is_amd():
|
||||
return True
|
||||
return False
|
||||
|
||||
def amd_min_version(device=None, min_rdna_version=0):
|
||||
if not is_amd():
|
||||
return False
|
||||
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if arch.startswith('gfx') and len(arch) == 7:
|
||||
try:
|
||||
cmp_rdna_version = int(arch[4]) + 2
|
||||
except:
|
||||
cmp_rdna_version = 0
|
||||
if cmp_rdna_version >= min_rdna_version:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||
@@ -318,12 +340,13 @@ try:
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 8):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
@@ -340,7 +363,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
|
||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||
try:
|
||||
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||
logging.info("Enabled fp16 accumulation.")
|
||||
@@ -590,7 +613,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models = set(models)
|
||||
models_temp = set()
|
||||
for m in models:
|
||||
models_temp.add(m)
|
||||
for mm in m.model_patches_models():
|
||||
models_temp.add(mm)
|
||||
|
||||
models = models_temp
|
||||
|
||||
models_to_load = []
|
||||
|
||||
@@ -896,7 +925,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||
# also a problem on RDNA4 except fp32 is also slow there.
|
||||
# This is due to large bf16 convolutions being extremely slow.
|
||||
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
return torch.float32
|
||||
@@ -946,10 +977,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
return dtype
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if args.force_non_blocking:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
|
||||
return False
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
@@ -1282,10 +1315,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 6):
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
@@ -430,6 +430,12 @@ class ModelPatcher:
|
||||
def set_model_forward_timestep_embed_patch(self, patch):
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@@ -486,6 +492,30 @@ class ModelPatcher:
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_patches_models(self):
|
||||
to = self.model_options["transformer_options"]
|
||||
models = []
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "models"):
|
||||
models += patch_list[i].models()
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "models"):
|
||||
models += patch_list[k].models()
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "models"):
|
||||
models += wrap_func.models()
|
||||
|
||||
return models
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
@@ -24,8 +24,37 @@ import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.MATH,
|
||||
]
|
||||
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
PREPARE_SAMPLING = "prepare_sampling"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
PREDICT_NOISE = "predict_noise"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
@@ -9,6 +10,7 @@ try:
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
|
||||
@@ -149,7 +149,7 @@ def cleanup_models(conds, models):
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
@@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
|
||||
# begin registering hooks
|
||||
registered = comfy.hooks.HookGroup()
|
||||
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||
|
||||
Regular → Executable
+25
-7
@@ -16,6 +16,8 @@ import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -60,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert (mask.shape[1:] == x_in.shape[2:])
|
||||
# assert (mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
@@ -68,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||
|
||||
mask = mask * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
@@ -198,14 +200,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
|
||||
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
|
||||
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
|
||||
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
|
||||
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -546,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
if len(mask.shape) == len(dims):
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1:] != dims:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||
if mask.ndim < 4:
|
||||
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
|
||||
else:
|
||||
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
|
||||
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
@@ -718,7 +729,7 @@ class Sampler:
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||
|
||||
@@ -946,7 +957,14 @@ class CFGGuider:
|
||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict_noise(*args, **kwargs)
|
||||
return self.outer_predict_noise(*args, **kwargs)
|
||||
|
||||
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.predict_noise,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
|
||||
).execute(x, timestep, model_options, seed)
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
+49
-10
@@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@@ -48,6 +49,7 @@ import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -328,6 +330,19 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.downscale_ratio = 32
|
||||
self.upscale_ratio = 32
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@@ -446,17 +461,29 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
# Hunyuan 3d v2 2.0 & 2.1
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
|
||||
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
|
||||
batch, num_tokens, hidden_dim = shape
|
||||
dtype_size = model_management.dtype_size(dtype)
|
||||
|
||||
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
|
||||
return total_mem
|
||||
|
||||
# better memory estimations
|
||||
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
|
||||
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||
|
||||
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
|
||||
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
@@ -773,6 +800,7 @@ class CLIPType(Enum):
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -794,6 +822,7 @@ class TEModel(Enum):
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
QWEN25_7B = 11
|
||||
BYT5_SMALL_GLYPH = 12
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -811,6 +840,9 @@ def detect_te_model(sd):
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
|
||||
if weight.shape[0] == 384:
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
@@ -925,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
elif te_model == TEModel.QWEN25_7B:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@@ -970,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
+11
-5
@@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||
index = 0
|
||||
pad_extra = 0
|
||||
embeds_info = []
|
||||
for o in other_embeds:
|
||||
emb = o[1]
|
||||
if torch.is_tensor(emb):
|
||||
emb = {"type": "embedding", "data": emb}
|
||||
|
||||
extra = None
|
||||
emb_type = emb.get("type", None)
|
||||
if emb_type == "embedding":
|
||||
emb = emb.get("data", None)
|
||||
else:
|
||||
if hasattr(self.transformer, "preprocess_embed"):
|
||||
emb = self.transformer.preprocess_embed(emb, device=device)
|
||||
emb, extra = self.transformer.preprocess_embed(emb, device=device)
|
||||
else:
|
||||
emb = None
|
||||
|
||||
@@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||
index += emb_shape - 1
|
||||
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
|
||||
else:
|
||||
index += -1
|
||||
pad_extra += emb_shape
|
||||
@@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
attention_masks.append(attention_mask)
|
||||
num_tokens.append(sum(attention_mask))
|
||||
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||
|
||||
def forward(self, tokens):
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
@@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0].float()
|
||||
@@ -531,7 +534,10 @@ class SDTokenizer:
|
||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
if kwargs.get("disable_weights", False):
|
||||
parsed_weights = [(text, 1.0)]
|
||||
else:
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
# tokenize words
|
||||
tokens = []
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -700,7 +701,7 @@ class Flux(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 2.8
|
||||
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
@@ -1046,6 +1047,18 @@ class WAN21_Camera(WAN21_T2V):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera_2.2",
|
||||
"in_dim": 36,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1060,6 +1073,19 @@ class WAN21_Vace(WAN21_T2V):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "s2v",
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN22_S2V(self, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1103,6 +1129,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class Hunyuan3Dv2_1(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2_1",
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2_1
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Hunyuan3Dv2_1(self, device = device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@@ -1259,7 +1296,31 @@ class QwenImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
class HunyuanImage21(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vec_in_dim": None,
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.HunyuanImage21
|
||||
|
||||
memory_usage_factor = 7.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanImage21(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
|
||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 3584,
|
||||
"d_kv": 64,
|
||||
"d_model": 1472,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 4,
|
||||
"num_heads": 6,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 1510
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
{
|
||||
"<extra_id_0>": 259,
|
||||
"<extra_id_100>": 359,
|
||||
"<extra_id_101>": 360,
|
||||
"<extra_id_102>": 361,
|
||||
"<extra_id_103>": 362,
|
||||
"<extra_id_104>": 363,
|
||||
"<extra_id_105>": 364,
|
||||
"<extra_id_106>": 365,
|
||||
"<extra_id_107>": 366,
|
||||
"<extra_id_108>": 367,
|
||||
"<extra_id_109>": 368,
|
||||
"<extra_id_10>": 269,
|
||||
"<extra_id_110>": 369,
|
||||
"<extra_id_111>": 370,
|
||||
"<extra_id_112>": 371,
|
||||
"<extra_id_113>": 372,
|
||||
"<extra_id_114>": 373,
|
||||
"<extra_id_115>": 374,
|
||||
"<extra_id_116>": 375,
|
||||
"<extra_id_117>": 376,
|
||||
"<extra_id_118>": 377,
|
||||
"<extra_id_119>": 378,
|
||||
"<extra_id_11>": 270,
|
||||
"<extra_id_120>": 379,
|
||||
"<extra_id_121>": 380,
|
||||
"<extra_id_122>": 381,
|
||||
"<extra_id_123>": 382,
|
||||
"<extra_id_124>": 383,
|
||||
"<extra_id_12>": 271,
|
||||
"<extra_id_13>": 272,
|
||||
"<extra_id_14>": 273,
|
||||
"<extra_id_15>": 274,
|
||||
"<extra_id_16>": 275,
|
||||
"<extra_id_17>": 276,
|
||||
"<extra_id_18>": 277,
|
||||
"<extra_id_19>": 278,
|
||||
"<extra_id_1>": 260,
|
||||
"<extra_id_20>": 279,
|
||||
"<extra_id_21>": 280,
|
||||
"<extra_id_22>": 281,
|
||||
"<extra_id_23>": 282,
|
||||
"<extra_id_24>": 283,
|
||||
"<extra_id_25>": 284,
|
||||
"<extra_id_26>": 285,
|
||||
"<extra_id_27>": 286,
|
||||
"<extra_id_28>": 287,
|
||||
"<extra_id_29>": 288,
|
||||
"<extra_id_2>": 261,
|
||||
"<extra_id_30>": 289,
|
||||
"<extra_id_31>": 290,
|
||||
"<extra_id_32>": 291,
|
||||
"<extra_id_33>": 292,
|
||||
"<extra_id_34>": 293,
|
||||
"<extra_id_35>": 294,
|
||||
"<extra_id_36>": 295,
|
||||
"<extra_id_37>": 296,
|
||||
"<extra_id_38>": 297,
|
||||
"<extra_id_39>": 298,
|
||||
"<extra_id_3>": 262,
|
||||
"<extra_id_40>": 299,
|
||||
"<extra_id_41>": 300,
|
||||
"<extra_id_42>": 301,
|
||||
"<extra_id_43>": 302,
|
||||
"<extra_id_44>": 303,
|
||||
"<extra_id_45>": 304,
|
||||
"<extra_id_46>": 305,
|
||||
"<extra_id_47>": 306,
|
||||
"<extra_id_48>": 307,
|
||||
"<extra_id_49>": 308,
|
||||
"<extra_id_4>": 263,
|
||||
"<extra_id_50>": 309,
|
||||
"<extra_id_51>": 310,
|
||||
"<extra_id_52>": 311,
|
||||
"<extra_id_53>": 312,
|
||||
"<extra_id_54>": 313,
|
||||
"<extra_id_55>": 314,
|
||||
"<extra_id_56>": 315,
|
||||
"<extra_id_57>": 316,
|
||||
"<extra_id_58>": 317,
|
||||
"<extra_id_59>": 318,
|
||||
"<extra_id_5>": 264,
|
||||
"<extra_id_60>": 319,
|
||||
"<extra_id_61>": 320,
|
||||
"<extra_id_62>": 321,
|
||||
"<extra_id_63>": 322,
|
||||
"<extra_id_64>": 323,
|
||||
"<extra_id_65>": 324,
|
||||
"<extra_id_66>": 325,
|
||||
"<extra_id_67>": 326,
|
||||
"<extra_id_68>": 327,
|
||||
"<extra_id_69>": 328,
|
||||
"<extra_id_6>": 265,
|
||||
"<extra_id_70>": 329,
|
||||
"<extra_id_71>": 330,
|
||||
"<extra_id_72>": 331,
|
||||
"<extra_id_73>": 332,
|
||||
"<extra_id_74>": 333,
|
||||
"<extra_id_75>": 334,
|
||||
"<extra_id_76>": 335,
|
||||
"<extra_id_77>": 336,
|
||||
"<extra_id_78>": 337,
|
||||
"<extra_id_79>": 338,
|
||||
"<extra_id_7>": 266,
|
||||
"<extra_id_80>": 339,
|
||||
"<extra_id_81>": 340,
|
||||
"<extra_id_82>": 341,
|
||||
"<extra_id_83>": 342,
|
||||
"<extra_id_84>": 343,
|
||||
"<extra_id_85>": 344,
|
||||
"<extra_id_86>": 345,
|
||||
"<extra_id_87>": 346,
|
||||
"<extra_id_88>": 347,
|
||||
"<extra_id_89>": 348,
|
||||
"<extra_id_8>": 267,
|
||||
"<extra_id_90>": 349,
|
||||
"<extra_id_91>": 350,
|
||||
"<extra_id_92>": 351,
|
||||
"<extra_id_93>": 352,
|
||||
"<extra_id_94>": 353,
|
||||
"<extra_id_95>": 354,
|
||||
"<extra_id_96>": 355,
|
||||
"<extra_id_97>": 356,
|
||||
"<extra_id_98>": 357,
|
||||
"<extra_id_99>": 358,
|
||||
"<extra_id_9>": 268
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>",
|
||||
"<extra_id_100>",
|
||||
"<extra_id_101>",
|
||||
"<extra_id_102>",
|
||||
"<extra_id_103>",
|
||||
"<extra_id_104>",
|
||||
"<extra_id_105>",
|
||||
"<extra_id_106>",
|
||||
"<extra_id_107>",
|
||||
"<extra_id_108>",
|
||||
"<extra_id_109>",
|
||||
"<extra_id_110>",
|
||||
"<extra_id_111>",
|
||||
"<extra_id_112>",
|
||||
"<extra_id_113>",
|
||||
"<extra_id_114>",
|
||||
"<extra_id_115>",
|
||||
"<extra_id_116>",
|
||||
"<extra_id_117>",
|
||||
"<extra_id_118>",
|
||||
"<extra_id_119>",
|
||||
"<extra_id_120>",
|
||||
"<extra_id_121>",
|
||||
"<extra_id_122>",
|
||||
"<extra_id_123>",
|
||||
"<extra_id_124>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,100 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from transformers import ByT5Tokenizer
|
||||
import os
|
||||
import re
|
||||
|
||||
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
# self.llama_template_images = "{}"
|
||||
self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
# ByT5 processing for HunyuanImage
|
||||
text_prompt_texts = []
|
||||
pattern_quote_single = r'\'(.*?)\''
|
||||
pattern_quote_double = r'\"(.*?)\"'
|
||||
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||
pattern_quote_chinese_double = r'“(.*?)”'
|
||||
|
||||
matches_quote_single = re.findall(pattern_quote_single, text)
|
||||
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_single)
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
if len(text_prompt_texts) > 0:
|
||||
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class ByT5SmallModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class HunyuanImageTEModel(QwenImageTEModel):
|
||||
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
|
||||
if byt5:
|
||||
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
|
||||
else:
|
||||
self.byt5_small = None
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs)
|
||||
if self.byt5_small is not None and "byt5" in token_weight_pairs:
|
||||
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
||||
extra["conditioning_byt5small"] = out[0]
|
||||
return cond, p, extra
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
if self.byt5_small is not None:
|
||||
self.byt5_small.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
super().reset_clip_options()
|
||||
if self.byt5_small is not None:
|
||||
self.byt5_small.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd:
|
||||
return self.byt5_small.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||
return QwenImageTEModel_
|
||||
@@ -2,12 +2,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
import comfy.model_management
|
||||
from . import qwen_vl
|
||||
|
||||
@dataclass
|
||||
class Llama2Config:
|
||||
@@ -25,6 +27,7 @@ class Llama2Config:
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@@ -42,6 +45,7 @@ class Qwen25_3BConfig:
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
rope_dims = None
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@@ -59,6 +63,7 @@ class Qwen25_7BVLI_Config:
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
rope_dims = [16, 24, 24]
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -76,6 +81,7 @@ class Gemma2_2B_Config:
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@@ -100,27 +106,34 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
|
||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||
|
||||
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||
mrope_section = rope_dims * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
cos = freqs_cis[0].unsqueeze(1)
|
||||
sin = freqs_cis[1].unsqueeze(1)
|
||||
org_dtype = xq.dtype
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
return q_embed, k_embed
|
||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -277,7 +290,7 @@ class Llama2_(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@@ -286,9 +299,13 @@ class Llama2_(nn.Module):
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
x.shape[1],
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
|
||||
mask = None
|
||||
@@ -372,8 +389,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
grid = None
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
start = e.get("index")
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
|
||||
position_ids[0, start:end] = start
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 151655:
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
|
||||
@@ -0,0 +1,428 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
def process_qwen2vl_images(
|
||||
images: torch.Tensor,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 12845056,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
image_mean: list = None,
|
||||
image_std: list = None,
|
||||
):
|
||||
if image_mean is None:
|
||||
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
if image_std is None:
|
||||
image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
batch_size, height, width, channels = images.shape
|
||||
device = images.device
|
||||
# dtype = images.dtype
|
||||
|
||||
images = images.permute(0, 3, 1, 2)
|
||||
|
||||
grid_thw_list = []
|
||||
img = images[0]
|
||||
|
||||
factor = patch_size * merge_size
|
||||
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
img_resized = F.interpolate(
|
||||
img.unsqueeze(0),
|
||||
size=(h_bar, w_bar),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
normalized = img_resized.clone()
|
||||
for c in range(3):
|
||||
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||
|
||||
grid_h = h_bar // patch_size
|
||||
grid_w = w_bar // patch_size
|
||||
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
|
||||
|
||||
pixel_values = normalized
|
||||
grid_thw_list.append(grid_thw)
|
||||
image_grid_thw = torch.stack(grid_thw_list)
|
||||
|
||||
grid_t = 1
|
||||
channel = pixel_values.shape[0]
|
||||
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
|
||||
|
||||
patches = pixel_values.reshape(
|
||||
grid_t,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, image_grid_thw
|
||||
|
||||
|
||||
class VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_channels: int = 3,
|
||||
embed_dim: int = 3584,
|
||||
device=None,
|
||||
dtype=None,
|
||||
ops=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||
self.proj = ops.Conv3d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
||||
)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states.view(-1, self.embed_dim)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
||||
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, seqlen: int, device) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
|
||||
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
freqs = torch.outer(seq, inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class PatchMerger(nn.Module):
|
||||
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
||||
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = nn.Sequential(
|
||||
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
||||
x = self.mlp(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cu_seqlens=None,
|
||||
optimized_attention=None,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.dim() == 2:
|
||||
seq_length, _ = hidden_states.shape
|
||||
batch_size = 1
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
else:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
|
||||
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
class VisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cu_seqlens=None,
|
||||
optimized_attention=None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 3584,
|
||||
output_hidden_size: int = 3584,
|
||||
intermediate_size: int = 3420,
|
||||
num_heads: int = 16,
|
||||
num_layers: int = 32,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
spatial_merge_size: int = 2,
|
||||
window_size: int = 112,
|
||||
device=None,
|
||||
dtype=None,
|
||||
ops=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = [7, 15, 23, 31]
|
||||
|
||||
self.patch_embed = VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=3,
|
||||
embed_dim=hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
ops=ops,
|
||||
)
|
||||
|
||||
head_dim = hidden_size // num_heads
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.merger = PatchMerger(
|
||||
dim=output_hidden_size,
|
||||
context_dim=hidden_size,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
ops=ops,
|
||||
)
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index = []
|
||||
cu_window_seqlens = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
||||
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||
index_padded = index_padded.reshape(
|
||||
grid_t,
|
||||
num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t,
|
||||
num_windows_h * num_windows_w,
|
||||
vit_merger_window_size,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
|
||||
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def get_position_embeddings(self, grid_thw, device):
|
||||
pos_ids = []
|
||||
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
|
||||
|
||||
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
|
||||
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
|
||||
return rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
||||
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
|
||||
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
|
||||
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
|
||||
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
|
||||
|
||||
seq_len, _ = hidden_states.size()
|
||||
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||
hidden_states = hidden_states[window_index, :, :]
|
||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||
|
||||
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||
position_embeddings = position_embeddings[window_index, :, :]
|
||||
position_embeddings = position_embeddings.reshape(seq_len, -1)
|
||||
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
|
||||
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
if i in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
return hidden_states
|
||||
@@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module):
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
|
||||
@@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LokrDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
||||
@@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
qwen_default_lora = "{}.lora_B.default.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
@@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
A_name = transformers_lora
|
||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif qwen_default_lora in lora.keys():
|
||||
A_name = qwen_default_lora
|
||||
B_name = "{}.lora_A.default.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
|
||||
@@ -8,6 +8,7 @@ import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
@@ -282,8 +283,6 @@ class VideoFromComponents(VideoInput):
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
@@ -298,27 +297,12 @@ class VideoFromComponents(VideoInput):
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
+21
-3
@@ -726,6 +726,18 @@ class SEGS(ComfyTypeIO):
|
||||
class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER")
|
||||
class AudioEncoder(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
|
||||
class AudioEncoderOutput(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@@ -1178,13 +1190,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, **kwargs) -> bool:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||
def validate_inputs(cls, **kwargs) -> bool | str:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.
|
||||
|
||||
If the function returns a string, it will be used as the validation error message for the node.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, **kwargs) -> Any:
|
||||
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
|
||||
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.
|
||||
|
||||
If this function returns the same value as last run, the node will not be executed."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -1580,6 +1597,7 @@ class _IO:
|
||||
Model = Model
|
||||
ClipVision = ClipVision
|
||||
ClipVisionOutput = ClipVisionOutput
|
||||
AudioEncoderOutput = AudioEncoderOutput
|
||||
StyleModel = StyleModel
|
||||
Gligen = Gligen
|
||||
UpscaleModel = UpscaleModel
|
||||
|
||||
@@ -9,7 +9,11 @@ from typing import Type
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
try:
|
||||
import torchaudio
|
||||
TORCH_AUDIO_AVAILABLE = True
|
||||
except:
|
||||
TORCH_AUDIO_AVAILABLE = False
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
@@ -302,6 +306,8 @@ class AudioSaveHelper:
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
if not TORCH_AUDIO_AVAILABLE:
|
||||
raise Exception("torchaudio is not available; cannot resample audio.")
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
|
||||
+115
-102
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -21,7 +22,6 @@ from server import PromptServer
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
@@ -30,7 +30,7 @@ from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
@@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = download_url_to_bytesio(video_url, timeout)
|
||||
video_io = await download_url_to_bytesio(video_url, timeout)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
@@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
@@ -86,35 +86,24 @@ def validate_and_cast_response(
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise ValueError("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {image_url}", node_id
|
||||
)
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
@@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
@@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(response.content)
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
@@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
@@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
async def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
upload_mime_type: Optional[str],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -354,7 +345,10 @@ def upload_file_to_comfyapi(
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
@@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
||||
)
|
||||
upload_response.raise_for_status()
|
||||
|
||||
response: UploadResponse = await operation.execute()
|
||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||
return response.download_url
|
||||
|
||||
|
||||
@@ -399,7 +389,7 @@ def video_to_base64_string(
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
@@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
@@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
@@ -527,7 +515,72 @@ def upload_audio_to_comfyapi(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
@@ -544,7 +597,7 @@ def audio_to_base64_string(
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@@ -561,55 +614,15 @@ def upload_images_to_comfyapi(
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
idx_image = 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_length = 1
|
||||
if is_batch:
|
||||
batch_length = image.shape[0]
|
||||
while True:
|
||||
curr_image = image
|
||||
if len(image.shape) > 3:
|
||||
curr_image = image[idx_image]
|
||||
# get BytesIO version of image
|
||||
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
||||
# first, request upload/download urls from comfy API
|
||||
if not mime_type:
|
||||
request_object = UploadRequest(file_name=img_binary.name)
|
||||
else:
|
||||
request_object = UploadRequest(
|
||||
file_name=img_binary.name, content_type=mime_type
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, img_binary, content_type=mime_type
|
||||
)
|
||||
# verify success
|
||||
try:
|
||||
upload_response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(f"Could not upload one or more images: {e}") from e
|
||||
# add download_url to list
|
||||
download_urls.append(response.download_url)
|
||||
|
||||
idx_image += 1
|
||||
# stop uploading additional files if done
|
||||
if is_batch and max_images > 0:
|
||||
if idx_image >= max_images:
|
||||
break
|
||||
if idx_image >= batch_length:
|
||||
break
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
|
||||
Generated
+35
-3
@@ -951,7 +951,11 @@ class MagicPrompt2(str, Enum):
|
||||
|
||||
|
||||
class StyleType1(str, Enum):
|
||||
AUTO = 'AUTO'
|
||||
GENERAL = 'GENERAL'
|
||||
REALISTIC = 'REALISTIC'
|
||||
DESIGN = 'DESIGN'
|
||||
FICTION = 'FICTION'
|
||||
|
||||
|
||||
class ImagenImageGenerationInstance(BaseModel):
|
||||
@@ -1315,6 +1319,7 @@ class KlingTaskStatus(str, Enum):
|
||||
class KlingTextToVideoModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
|
||||
|
||||
class KlingVideoGenAspectRatio(str, Enum):
|
||||
@@ -1347,6 +1352,8 @@ class KlingVideoGenModelName(str, Enum):
|
||||
kling_v1_5 = 'kling-v1-5'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_master = 'kling-v2-master'
|
||||
kling_v2_1 = 'kling-v2-1'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
|
||||
|
||||
class KlingVideoResult(BaseModel):
|
||||
@@ -1620,13 +1627,14 @@ class MinimaxTaskResultResponse(BaseModel):
|
||||
task_id: str = Field(..., description='The task ID being queried.')
|
||||
|
||||
|
||||
class Model(str, Enum):
|
||||
class MiniMaxModel(str, Enum):
|
||||
T2V_01_Director = 'T2V-01-Director'
|
||||
I2V_01_Director = 'I2V-01-Director'
|
||||
S2V_01 = 'S2V-01'
|
||||
I2V_01 = 'I2V-01'
|
||||
I2V_01_live = 'I2V-01-live'
|
||||
T2V_01 = 'T2V-01'
|
||||
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||
|
||||
|
||||
class SubjectReferenceItem(BaseModel):
|
||||
@@ -1648,7 +1656,7 @@ class MinimaxVideoGenerationRequest(BaseModel):
|
||||
None,
|
||||
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||
)
|
||||
model: Model = Field(
|
||||
model: MiniMaxModel = Field(
|
||||
...,
|
||||
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||
)
|
||||
@@ -1665,6 +1673,14 @@ class MinimaxVideoGenerationRequest(BaseModel):
|
||||
None,
|
||||
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||
)
|
||||
duration: Optional[int] = Field(
|
||||
None,
|
||||
description="The length of the output video in seconds."
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationResponse(BaseModel):
|
||||
@@ -2664,7 +2680,7 @@ class ReleaseNote(BaseModel):
|
||||
|
||||
|
||||
class RenderingSpeed(str, Enum):
|
||||
BALANCED = 'BALANCED'
|
||||
DEFAULT = 'DEFAULT'
|
||||
TURBO = 'TURBO'
|
||||
QUALITY = 'QUALITY'
|
||||
|
||||
@@ -4906,6 +4922,14 @@ class IdeogramV3EditRequest(BaseModel):
|
||||
None,
|
||||
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
|
||||
)
|
||||
character_reference_images: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
character_reference_images_mask: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
|
||||
|
||||
class IdeogramV3Request(BaseModel):
|
||||
@@ -4939,6 +4963,14 @@ class IdeogramV3Request(BaseModel):
|
||||
style_type: Optional[StyleType1] = Field(
|
||||
None, description='The type of style to apply'
|
||||
)
|
||||
character_reference_images: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
character_reference_images_mask: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
|
||||
|
||||
class ImagenGenerateImageResponse(BaseModel):
|
||||
|
||||
+366
-535
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GeminiImageGenerateContentRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
||||
@@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityTextToAudioRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: int = Field(190, ge=1, le=190)
|
||||
seed: int = Field(0, ge=0, le=4294967294)
|
||||
steps: int = Field(8, ge=4, le=8)
|
||||
output_format: str = Field("wav")
|
||||
|
||||
|
||||
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||
|
||||
|
||||
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||
mask_start: int = Field(30, ge=0, le=190)
|
||||
mask_end: int = Field(190, ge=0, le=190)
|
||||
|
||||
|
||||
class StabilityAudioResponse(BaseModel):
|
||||
audio: Optional[str] = Field(None)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union, Optional
|
||||
@@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
@@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor):
|
||||
return mask
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
async def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api: BFLFluxProGenerateResponse = await operation.execute()
|
||||
return await _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
|
||||
def _poll_until_generated(
|
||||
async def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
@@ -66,55 +67,56 @@ def _poll_until_generated(
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
time.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status_code == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
time.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status_code == 202:
|
||||
time.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
async with session.get(polling_url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
async with session.get(img_url) as img_resp:
|
||||
return process_image_response(await img_resp.content.read())
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
await asyncio.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
await asyncio.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status == 202:
|
||||
await asyncio.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
@@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_upsampling,
|
||||
@@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+323
-98
@@ -4,8 +4,12 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
@@ -22,6 +26,7 @@ from comfy_api_nodes.apis import (
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
)
|
||||
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
@@ -32,6 +37,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
)
|
||||
|
||||
|
||||
@@ -46,6 +52,16 @@ class GeminiModel(str, Enum):
|
||||
|
||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||
gemini_2_5_pro = "gemini-2.5-pro"
|
||||
gemini_2_5_flash = "gemini-2.5-flash"
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
"""
|
||||
Gemini Image Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
@@ -70,6 +86,135 @@ def get_gemini_endpoint(
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_image_endpoint(
|
||||
model: GeminiImageModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiImageModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiImageGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
|
||||
def create_text_part(text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
|
||||
def get_parts_from_response(
|
||||
response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
|
||||
def get_parts_by_type(
|
||||
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
|
||||
def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1,1024,1024,4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@@ -97,7 +242,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiModel],
|
||||
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
|
||||
"default": GeminiModel.gemini_2_5_pro.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
@@ -154,59 +299,6 @@ class GeminiNode(ComfyNodeABC):
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def get_parts_from_response(
|
||||
self, response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
def get_parts_by_type(
|
||||
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in self.get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = self.get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
@@ -266,44 +358,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
def create_text_part(self, text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiModel,
|
||||
@@ -318,11 +373,11 @@ class GeminiNode(ComfyNodeABC):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [self.create_text_part(prompt)]
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = self.create_image_parts(images)
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
@@ -332,7 +387,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
@@ -346,9 +401,29 @@ class GeminiNode(ComfyNodeABC):
|
||||
).execute()
|
||||
|
||||
# Get result output
|
||||
output_text = self.get_text_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
|
||||
@@ -437,12 +512,162 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
return (files,)
|
||||
|
||||
|
||||
class GeminiImage(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text and image responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, files) to generate coherent
|
||||
text and image responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiImageModel],
|
||||
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
# TODO: later we can add this parameter later
|
||||
# "n": (
|
||||
# IO.INT,
|
||||
# {
|
||||
# "default": 1,
|
||||
# "min": 1,
|
||||
# "max": 8,
|
||||
# "step": 1,
|
||||
# "display": "number",
|
||||
# "tooltip": "How many images to generate",
|
||||
# },
|
||||
# ),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE, IO.STRING)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Gemini"
|
||||
DESCRIPTION = "Edit images synchronously via Google API."
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiImageModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
n=1,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_image_endpoint(model),
|
||||
request=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT","IMAGE"]
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return (output_image, output_text,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiImageNode": GeminiImage,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiImageNode": "Google Gemini Image",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
|
||||
+330
-289
@@ -1,8 +1,8 @@
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
import torch
|
||||
from comfy_api_nodes.apis import (
|
||||
IdeogramGenerateRequest,
|
||||
@@ -212,7 +212,7 @@ V3_RESOLUTIONS= [
|
||||
"1536x640"
|
||||
]
|
||||
|
||||
def download_and_process_images(image_urls):
|
||||
async def download_and_process_images(image_urls):
|
||||
"""Helper function to download and process multiple images from URLs"""
|
||||
|
||||
# Initialize list to store image tensors
|
||||
@@ -220,7 +220,7 @@ def download_and_process_images(image_urls):
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
@@ -246,90 +246,82 @@ def display_image_urls_on_node(image_urls, node_id):
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV1(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
@@ -337,13 +329,15 @@ class IdeogramV1(ComfyNodeABC):
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
@@ -364,10 +358,10 @@ class IdeogramV1(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@@ -377,93 +371,86 @@ class IdeogramV1(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV2(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V1_RES_MAP.keys()),
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"style_type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
"default": "NONE",
|
||||
"tooltip": "Style type for generation (V2 only)",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
@@ -473,22 +460,20 @@ class IdeogramV2(ComfyNodeABC):
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
@@ -499,8 +484,6 @@ class IdeogramV2(ComfyNodeABC):
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
@@ -517,6 +500,10 @@ class IdeogramV2(ComfyNodeABC):
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
@@ -540,10 +527,10 @@ class IdeogramV2(ComfyNodeABC):
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@@ -553,108 +540,110 @@ class IdeogramV2(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV3(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation or editing",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or editing",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image for image editing.",
|
||||
optional=True,
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
comfy_io.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V3_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V3_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
optional=True,
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": V3_RESOLUTIONS,
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=V3_RESOLUTIONS,
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"rendering_speed": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["BALANCED", "TURBO", "QUALITY"],
|
||||
"default": "BALANCED",
|
||||
"tooltip": "Controls the trade-off between generation speed and quality",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"rendering_speed",
|
||||
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||
default="DEFAULT",
|
||||
tooltip="Controls the trade-off between generation speed and quality",
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
comfy_io.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Image to use as character reference.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
"character_mask",
|
||||
tooltip="Optional mask for character reference image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
image=None,
|
||||
mask=None,
|
||||
@@ -663,10 +652,46 @@ class IdeogramV3(ComfyNodeABC):
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
rendering_speed="DEFAULT",
|
||||
character_image=None,
|
||||
character_mask=None,
|
||||
):
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
if rendering_speed == "BALANCED": # for backward compatibility
|
||||
rendering_speed = "DEFAULT"
|
||||
|
||||
character_img_binary = None
|
||||
character_mask_binary = None
|
||||
|
||||
if character_image is not None:
|
||||
input_tensor = character_image.squeeze().cpu()
|
||||
if character_mask is not None:
|
||||
character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False)
|
||||
character_mask = 1.0 - character_mask
|
||||
if character_mask.shape[1:] != character_image.shape[1:-1]:
|
||||
raise Exception("Character mask and image must be the same size")
|
||||
|
||||
mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
character_mask_binary = mask_byte_arr
|
||||
character_mask_binary.name = "mask.png"
|
||||
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
character_img_binary = img_byte_arr
|
||||
character_img_binary.name = "image.png"
|
||||
elif character_mask is not None:
|
||||
raise Exception("Character mask requires character image to be present")
|
||||
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
@@ -686,7 +711,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
# Process image
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
@@ -695,7 +720,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
# Process mask - white areas will be replaced
|
||||
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = io.BytesIO()
|
||||
mask_byte_arr = BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
mask_binary = mask_byte_arr
|
||||
@@ -715,6 +740,15 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if num_images > 1:
|
||||
edit_request.num_images = num_images
|
||||
|
||||
files = {
|
||||
"image": img_binary,
|
||||
"mask": mask_binary,
|
||||
}
|
||||
if character_img_binary:
|
||||
files["character_reference_images"] = character_img_binary
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -724,12 +758,9 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
files={
|
||||
"image": img_binary,
|
||||
"mask": mask_binary,
|
||||
},
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
@@ -761,6 +792,14 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if num_images > 1:
|
||||
gen_request.num_images = num_images
|
||||
|
||||
files = {}
|
||||
if character_img_binary:
|
||||
files["character_reference_images"] = character_img_binary
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
if files:
|
||||
gen_request.style_type = "AUTO"
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -770,11 +809,13 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_kwargs=kwargs,
|
||||
files=files if files else None,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@@ -784,18 +825,18 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramV1": IdeogramV1,
|
||||
"IdeogramV2": IdeogramV2,
|
||||
"IdeogramV3": IdeogramV3,
|
||||
}
|
||||
class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV1": "Ideogram V1",
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
async def comfy_entrypoint() -> IdeogramExtension:
|
||||
return IdeogramExtension()
|
||||
|
||||
@@ -109,7 +109,7 @@ class KlingApiError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
@@ -117,7 +117,7 @@ def poll_until_finished(
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
KlingTaskStatus.succeed.value,
|
||||
@@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def video_result_to_node_output(
|
||||
async def video_result_to_node_output(
|
||||
video: KlingVideoResult,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
|
||||
return (
|
||||
download_url_to_video_output(video.url),
|
||||
await download_url_to_video_output(str(video.url)),
|
||||
str(video.id),
|
||||
str(video.duration),
|
||||
)
|
||||
|
||||
|
||||
def image_result_to_node_output(
|
||||
async def image_result_to_node_output(
|
||||
images: list[KlingImageResult],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -297,9 +297,9 @@ def image_result_to_node_output(
|
||||
If multiple images are returned, they will be stacked along the batch dimension.
|
||||
"""
|
||||
if len(images) == 1:
|
||||
return download_url_to_image_tensor(images[0].url)
|
||||
return await download_url_to_image_tensor(str(images[0].url))
|
||||
else:
|
||||
return torch.cat([download_url_to_image_tensor(image.url) for image in images])
|
||||
return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images])
|
||||
|
||||
|
||||
class KlingNodeBase(ComfyNodeABC):
|
||||
@@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
"pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
|
||||
"standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
|
||||
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
|
||||
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
|
||||
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -467,10 +469,10 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Text to Video Node"
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingText2VideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
@@ -483,7 +485,7 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -519,17 +521,17 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
@@ -581,7 +583,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
|
||||
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -591,7 +593,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
cfg_scale=cfg_scale,
|
||||
mode=KlingVideoGenMode.std,
|
||||
@@ -670,10 +672,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Image to Video Node"
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingImage2VideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
@@ -686,7 +688,7 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -733,17 +735,17 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
@@ -798,7 +800,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
|
||||
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -809,7 +811,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
start_frame=start_frame,
|
||||
cfg_scale=cfg_scale,
|
||||
@@ -897,7 +899,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
|
||||
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
@@ -912,7 +914,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
||||
mode
|
||||
]
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model_name=model_name,
|
||||
@@ -964,10 +966,10 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoExtendResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
@@ -980,7 +982,7 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -1006,17 +1008,17 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingVideoEffectsBase(KlingNodeBase):
|
||||
@@ -1025,10 +1027,10 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoEffectsResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
@@ -1041,7 +1043,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
dual_character: bool,
|
||||
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
||||
@@ -1084,17 +1086,17 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
@@ -1142,7 +1144,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "duration")
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image_left: torch.Tensor,
|
||||
image_right: torch.Tensor,
|
||||
@@ -1153,7 +1155,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
video, _, duration = super().api_call(
|
||||
video, _, duration = await super().api_call(
|
||||
dual_character=True,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@@ -1208,7 +1210,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
effect_scene: KlingSingleImageEffectsScene,
|
||||
@@ -1217,7 +1219,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
dual_character=False,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@@ -1253,11 +1255,11 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
||||
)
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingLipSyncResponse:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
@@ -1270,7 +1272,7 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
@@ -1287,12 +1289,12 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
self.validate_lip_sync_video(video)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
@@ -1319,17 +1321,17 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
@@ -1357,7 +1359,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
@@ -1365,7 +1367,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
@@ -1469,7 +1471,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
text: str,
|
||||
@@ -1479,7 +1481,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
**kwargs,
|
||||
):
|
||||
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
video=video,
|
||||
text=text,
|
||||
voice_language=voice_language,
|
||||
@@ -1533,10 +1535,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVirtualTryOnResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
@@ -1549,7 +1551,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
human_image: torch.Tensor,
|
||||
cloth_image: torch.Tensor,
|
||||
@@ -1572,17 +1574,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
return (image_result_to_node_output(images),)
|
||||
return (await image_result_to_node_output(images),)
|
||||
|
||||
|
||||
class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
@@ -1655,13 +1657,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
|
||||
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
node_id: Optional[str] = None,
|
||||
) -> KlingImageGenerationsResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
@@ -1674,7 +1676,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
model_name: KlingImageGenModelName,
|
||||
prompt: str,
|
||||
@@ -1690,7 +1692,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
):
|
||||
self.validate_prompt(prompt, negative_prompt)
|
||||
|
||||
if image is not None:
|
||||
if image is None:
|
||||
image_type = None
|
||||
elif model_name == KlingImageGenModelName.kling_v1:
|
||||
raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
|
||||
else:
|
||||
image = tensor_to_base64_string(image)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -1714,17 +1720,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
return (image_result_to_node_output(images),)
|
||||
return (await image_result_to_node_output(images),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
@@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
api_image_ref = await self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
api_style_ref = await self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
@@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
async def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
@@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
break
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
async def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
@@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
@@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return (img,)
|
||||
|
||||
|
||||
@@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
@@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
@@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
@@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
@@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
def _convert_to_keyframes(
|
||||
async def _convert_to_keyframes(
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
@@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
@@ -10,7 +11,7 @@ from comfy_api_nodes.apis import (
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
Model
|
||||
MiniMaxModel
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -84,9 +85,8 @@ class MinimaxTextToVideoNode:
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
@@ -104,12 +104,12 @@ class MinimaxTextToVideoNode:
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ class MinimaxTextToVideoNode:
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
model=Model(model),
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
first_frame_image=image_url,
|
||||
@@ -130,7 +130,7 @@ class MinimaxTextToVideoNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
@@ -151,7 +151,7 @@ class MinimaxTextToVideoNode:
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
@@ -167,7 +167,7 @@ class MinimaxTextToVideoNode:
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
@@ -182,7 +182,7 @@ class MinimaxTextToVideoNode:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
@@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
@@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
class MinimaxHailuoVideoNode:
|
||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"first_frame_image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Optional image to use as the first frame to generate a video."
|
||||
},
|
||||
),
|
||||
"prompt_optimizer": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"tooltip": "Optimize prompt to improve generation quality when needed.",
|
||||
"default": True,
|
||||
},
|
||||
),
|
||||
"duration": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The length of the output video in seconds.",
|
||||
"default": 6,
|
||||
"options": [6, 10],
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The dimensions of the video display. "
|
||||
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
|
||||
"default": "768P",
|
||||
"options": ["768P", "1080P"],
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
first_frame_image: torch.Tensor=None, # used for ImageToVideo
|
||||
prompt_optimizer=True,
|
||||
duration=6,
|
||||
resolution="768P",
|
||||
model="MiniMax-Hailuo-02",
|
||||
unique_id: Union[str, None]=None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_frame_image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
|
||||
if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6:
|
||||
raise Exception(
|
||||
"When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds."
|
||||
)
|
||||
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if first_frame_image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
first_frame_image=image_url,
|
||||
prompt_optimizer=prompt_optimizer,
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
average_duration = 120 if resolution == "768P" else 240
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
if unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
@@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
||||
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
||||
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
||||
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
@@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
||||
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
||||
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
||||
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
import random
|
||||
import torch
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
get_image_dimensions,
|
||||
@@ -95,14 +94,14 @@ def get_video_url_from_response(response) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
@@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||
supported_resolutions = {
|
||||
(1920, 1080), (1080, 1920), (1152, 1152),
|
||||
(1536, 1152), (1152, 1536)
|
||||
(1920, 1080),
|
||||
(1080, 1920),
|
||||
(1152, 1152),
|
||||
(1536, 1152),
|
||||
(1152, 1536),
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_container_format(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(
|
||||
f"Only MP4 container format supported. Got: {container_format}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
@@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
@@ -394,10 +403,10 @@ class BaseMoonvalleyVideoNode:
|
||||
else:
|
||||
return control_map["Motion Transfer"]
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
@@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode:
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
@@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode:
|
||||
"tooltip": "Resolution of the output video",
|
||||
},
|
||||
),
|
||||
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=7.0,
|
||||
default=10.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
@@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode:
|
||||
IO.INT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"seed",
|
||||
default=random.randint(0, 2**32 - 1),
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
"steps": model_field_to_node_input(
|
||||
IO.INT,
|
||||
@@ -507,7 +514,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
RETURN_NAMES = ("video",)
|
||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
image = kwargs.get("image", None)
|
||||
@@ -532,8 +539,10 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
@@ -549,14 +558,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
@@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
|
||||
multiline=True
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=False,
|
||||
),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=10.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
),
|
||||
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
@@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
"optional": {
|
||||
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"default": "",
|
||||
"multiline": False,
|
||||
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
},
|
||||
),
|
||||
"control_type": (
|
||||
["Motion Transfer", "Pose Transfer"],
|
||||
{"default": "Motion Transfer"},
|
||||
@@ -602,17 +640,24 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
"max": 100,
|
||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||
},
|
||||
)
|
||||
}
|
||||
),
|
||||
"image": model_field_to_node_input(
|
||||
IO.IMAGE,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
"image_url",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
video = kwargs.get("video")
|
||||
image = kwargs.get("image", None)
|
||||
|
||||
if not video:
|
||||
raise MoonvalleyApiError("video is required")
|
||||
@@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
video_url = ""
|
||||
if video:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(
|
||||
validated_video, auth_kwargs=kwargs
|
||||
)
|
||||
mime_type = "image/png"
|
||||
|
||||
if not image is None:
|
||||
validate_input_image(image, with_frame_conditioning=True)
|
||||
image_url = await upload_images_to_comfyapi(
|
||||
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
|
||||
)
|
||||
control_type = kwargs.get("control_type")
|
||||
motion_intensity = kwargs.get("motion_intensity")
|
||||
|
||||
@@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||
control_params['motion_intensity'] = motion_intensity
|
||||
control_params["motion_intensity"] = motion_intensity
|
||||
|
||||
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
||||
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
seed=kwargs.get("seed"),
|
||||
control_params=control_params
|
||||
control_params=control_params,
|
||||
)
|
||||
|
||||
control = self.parseControlParameter(control_type)
|
||||
@@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
request.image_url = image_url if not image is None else None
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -658,15 +712,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
|
||||
return (video,)
|
||||
|
||||
@@ -688,21 +742,21 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
del input_types["optional"][param]
|
||||
return input_types
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
|
||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=128,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
)
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=128,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
@@ -717,15 +771,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
|
||||
@@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum):
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
gpt_5 = "gpt-5"
|
||||
gpt_5_mini = "gpt-5-mini"
|
||||
gpt_5_nano = "gpt-5-nano"
|
||||
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
@@ -163,7 +166,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@@ -233,9 +236,9 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -311,7 +314,7 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@@ -343,9 +346,9 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -446,7 +449,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@@ -464,8 +467,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
@@ -484,14 +485,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
@@ -511,9 +509,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
@@ -537,9 +533,9 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -623,7 +619,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
|
||||
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||
|
||||
def get_result_response(
|
||||
async def get_result_response(
|
||||
self,
|
||||
response_id: str,
|
||||
include: Optional[list[Includable]] = None,
|
||||
@@ -639,7 +635,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
creation above for more information.
|
||||
|
||||
"""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||
method=HttpMethod.GET,
|
||||
@@ -784,7 +780,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
|
||||
self.history[session_id] = new_history
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
persist_context: bool,
|
||||
@@ -815,7 +811,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
previous_response_id = None
|
||||
|
||||
# Create response
|
||||
create_response = SynchronousOperation(
|
||||
create_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=RESPONSES_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
@@ -848,7 +844,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
response_id = create_response.id
|
||||
|
||||
# Get result output
|
||||
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
result_response = await self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
output_text = self.parse_output_text_from_response(result_response)
|
||||
|
||||
# Update history
|
||||
@@ -1002,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
"OpenAIChatNode": "OpenAI Chat",
|
||||
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
||||
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
||||
"OpenAIChatNode": "OpenAI ChatGPT",
|
||||
"OpenAIInputFiles": "OpenAI ChatGPT Input Files",
|
||||
"OpenAIChatConfig": "OpenAI ChatGPT Advanced Options",
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
FUNCTION = "api_call"
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
async def poll_for_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return polling_operation.execute()
|
||||
return await polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
async def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
"""
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
if not is_valid_initial_response(initial_response):
|
||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
final_response = await self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
@@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
video_url = str(final_response.url)
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
|
||||
return (download_url_to_video_output(video_url),)
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
@@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
@@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
@@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
@@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
@@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
@@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
@@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
@@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
]
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
@@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
@@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
@@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
||||
]
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
@@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
@@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
@@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
@@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
@@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
(
|
||||
"keyFrames",
|
||||
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
||||
),
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
|
||||
@@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
import torch
|
||||
import requests
|
||||
import aiohttp
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def get_video_url_from_response(
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
@@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(
|
||||
@@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
@@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
@@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
@@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -37,7 +37,7 @@ from io import BytesIO
|
||||
from PIL import UnidentifiedImageError
|
||||
|
||||
|
||||
def handle_recraft_file_request(
|
||||
async def handle_recraft_file_request(
|
||||
image: torch.Tensor,
|
||||
path: str,
|
||||
mask: torch.Tensor=None,
|
||||
@@ -71,13 +71,13 @@ def handle_recraft_file_request(
|
||||
auth_kwargs=auth_kwargs,
|
||||
multipart_parser=recraft_multipart_parser,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
all_bytesio = []
|
||||
if response.image is not None:
|
||||
all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||
else:
|
||||
for data in response.data:
|
||||
all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout))
|
||||
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
|
||||
|
||||
return all_bytesio
|
||||
|
||||
@@ -395,7 +395,7 @@ class RecraftTextToImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
size: str,
|
||||
@@ -439,7 +439,7 @@ class RecraftTextToImageNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
images = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
@@ -451,7 +451,7 @@ class RecraftTextToImageNode:
|
||||
f"Result URL: {urls_string}", unique_id
|
||||
)
|
||||
image = bytesio_to_image_tensor(
|
||||
download_url_to_bytesio(data.url, timeout=1024)
|
||||
await download_url_to_bytesio(data.url, timeout=1024)
|
||||
)
|
||||
if len(image.shape) < 4:
|
||||
image = image.unsqueeze(0)
|
||||
@@ -538,7 +538,7 @@ class RecraftImageToImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -578,7 +578,7 @@ class RecraftImageToImageNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/imageToImage",
|
||||
request=request,
|
||||
@@ -654,7 +654,7 @@ class RecraftImageInpaintingNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@@ -690,7 +690,7 @@ class RecraftImageInpaintingNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
mask=mask[i:i+1],
|
||||
path="/proxy/recraft/images/inpaint",
|
||||
@@ -779,7 +779,7 @@ class RecraftTextToVectorNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
substyle: str,
|
||||
@@ -821,7 +821,7 @@ class RecraftTextToVectorNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
svg_data = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
@@ -831,7 +831,7 @@ class RecraftTextToVectorNode:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {' '.join(urls)}", unique_id
|
||||
)
|
||||
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
|
||||
svg_data.append(await download_url_to_bytesio(data.url, timeout=1024))
|
||||
|
||||
return (SVG(svg_data),)
|
||||
|
||||
@@ -861,7 +861,7 @@ class RecraftVectorizeImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@@ -870,7 +870,7 @@ class RecraftVectorizeImageNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/vectorize",
|
||||
auth_kwargs=kwargs,
|
||||
@@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/replaceBackground",
|
||||
request=request,
|
||||
@@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/removeBackground",
|
||||
auth_kwargs=kwargs,
|
||||
@@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path=self.RECRAFT_PATH,
|
||||
auth_kwargs=kwargs,
|
||||
|
||||
@@ -9,11 +9,10 @@ from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
import folder_paths as comfy_paths
|
||||
import requests
|
||||
import aiohttp
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import time
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
@@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse):
|
||||
return hasattr(response, "error")
|
||||
|
||||
|
||||
|
||||
class Rodin3DAPI:
|
||||
"""
|
||||
Generate 3D Assets using Rodin API
|
||||
@@ -123,8 +121,8 @@ class Rodin3DAPI:
|
||||
else:
|
||||
return "Generating"
|
||||
|
||||
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images == None:
|
||||
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images is None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) >= 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
@@ -155,7 +153,7 @@ class Rodin3DAPI:
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if create_task_error(response):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
@@ -168,7 +166,7 @@ class Rodin3DAPI:
|
||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||
return task_uuid, subscription_key
|
||||
|
||||
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
|
||||
path = "/proxy/rodin/api/v2/status"
|
||||
|
||||
@@ -191,11 +189,9 @@ class Rodin3DAPI:
|
||||
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
|
||||
return poll_operation.execute()
|
||||
return await poll_operation.execute()
|
||||
|
||||
|
||||
|
||||
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
|
||||
path = "/proxy/rodin/api/v2/download"
|
||||
@@ -212,53 +208,59 @@ class Rodin3DAPI:
|
||||
auth_kwargs=kwargs
|
||||
)
|
||||
|
||||
return operation.execute()
|
||||
return await operation.execute()
|
||||
|
||||
def GetQualityAndMode(self, PolyCount):
|
||||
if PolyCount == "200K-Triangle":
|
||||
def get_quality_mode(self, poly_count):
|
||||
if poly_count == "200K-Triangle":
|
||||
mesh_mode = "Raw"
|
||||
quality = "medium"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
if PolyCount == "4K-Quad":
|
||||
if poly_count == "4K-Quad":
|
||||
quality = "extra-low"
|
||||
elif PolyCount == "8K-Quad":
|
||||
elif poly_count == "8K-Quad":
|
||||
quality = "low"
|
||||
elif PolyCount == "18K-Quad":
|
||||
elif poly_count == "18K-Quad":
|
||||
quality = "medium"
|
||||
elif PolyCount == "50K-Quad":
|
||||
elif poly_count == "50K-Quad":
|
||||
quality = "high"
|
||||
else:
|
||||
quality = "medium"
|
||||
|
||||
return mesh_mode, quality
|
||||
|
||||
def DownLoadFiles(self, Url_List):
|
||||
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(Save_path, exist_ok=True)
|
||||
async def download_files(self, url_list):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
for Item in Url_List.list:
|
||||
url = Item.url
|
||||
file_name = Item.name
|
||||
file_path = os.path.join(Save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
time.sleep(2)
|
||||
else:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
url = i.url
|
||||
file_name = i.name
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
f.write(chunk)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logging.info(
|
||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
||||
file_path,
|
||||
max_retries,
|
||||
)
|
||||
|
||||
return model_file_path
|
||||
|
||||
@@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Detail(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
material_type = "PBR"
|
||||
quality = "medium"
|
||||
mesh_mode = "Quad"
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
task_uuid, subscription_key = await self.create_generate_task(
|
||||
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
|
||||
)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
+364
-394
@@ -12,6 +12,7 @@ User Guides:
|
||||
"""
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
@@ -46,9 +47,9 @@ from comfy_api_nodes.apinode_utils import (
|
||||
validate_string,
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@@ -85,28 +86,19 @@ class RunwayGen3aAspectRatio(str, Enum):
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
# TODO: replace with updated image validation utils (upstream)
|
||||
def validate_input_image(image: torch.Tensor) -> bool:
|
||||
"""
|
||||
Validate the input image is within the size limits for the Runway API.
|
||||
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
|
||||
"""
|
||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
@@ -115,7 +107,7 @@ def poll_until_finished(
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: (response.status.value),
|
||||
status_extractor=lambda response: response.status.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
@@ -134,458 +126,438 @@ def extract_progress_from_task_status(
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
class RunwayVideoGenNode(ComfyNodeABC):
|
||||
"""Runway Video Node Base."""
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/video/Runway"
|
||||
API_NODE = True
|
||||
|
||||
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
async def generate_video(
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no video data found in response."
|
||||
)
|
||||
return True
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen3a",
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[model.value for model in Duration],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
options=[model.value for model in RunwayGen3aAspectRatio],
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
||||
self.validate_response(final_response)
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return (download_url_to_video_output(video_url),)
|
||||
return comfy_io.NodeOutput(
|
||||
await generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen3a Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
|
||||
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen4",
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
comfy_io.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[model.value for model in Duration],
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
options=[model.value for model in RunwayGen4TurboAspectRatio],
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
return comfy_io.NodeOutput(
|
||||
await generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen4_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen4 Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
|
||||
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayFirstLastFrameNode",
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
comfy_io.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
comfy_io.Image.Input(
|
||||
"end_frame",
|
||||
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[model.value for model in Duration],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
enum_type=RunwayGen4TurboAspectRatio,
|
||||
options=[model.value for model in RunwayGen3aAspectRatio],
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen4_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
"""Runway First-Last Frame Node."""
|
||||
|
||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"end_frame": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
|
||||
},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
validate_input_image(end_frame)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Upload images
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
return comfy_io.NodeOutput(
|
||||
await generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageNode(ComfyNodeABC):
|
||||
"""Runway Text to Image Node."""
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Runway"
|
||||
API_NODE = True
|
||||
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
|
||||
class RunwayTextToImageNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="RunwayTextToImageNode",
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayTextToImageRequest,
|
||||
comfy_io.Combo.Input(
|
||||
"ratio",
|
||||
enum_type=RunwayTextToImageAspectRatioEnum,
|
||||
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"reference_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Optional reference image to guide the generation"},
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
|
||||
def validate_response(self, response: TaskStatusResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no image data found in response."
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
node_id=node_id,
|
||||
comfy_io.Image.Input(
|
||||
"reference_image",
|
||||
tooltip="Optional reference image to guide the generation",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# Validate inputs
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_input_image(reference_image)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload reference image to comfy api.")
|
||||
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
# Create request
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
@@ -593,7 +565,6 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
# Execute initial request
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
@@ -602,34 +573,33 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
response_model=RunwayTextToImageResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
# Poll for completion
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
final_response = await get_response(
|
||||
initial_response.id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
self.validate_response(final_response)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
# Download and return image
|
||||
image_url = get_image_url_from_task_status(final_response)
|
||||
return (download_url_to_image_tensor(image_url),)
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
|
||||
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
|
||||
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
|
||||
"RunwayTextToImageNode": RunwayTextToImageNode,
|
||||
}
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
RunwayFirstLastFrameNode,
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
|
||||
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
|
||||
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
|
||||
"RunwayTextToImageNode": "Runway Text to Image",
|
||||
}
|
||||
async def comfy_entrypoint() -> RunwayExtension:
|
||||
return RunwayExtension()
|
||||
|
||||
+681
-320
File diff suppressed because it is too large
Load Diff
@@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
)
|
||||
|
||||
|
||||
def upload_image_to_tripo(image, **kwargs):
|
||||
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
async def upload_image_to_tripo(image, **kwargs):
|
||||
urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
@@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
kwargs: dict[str, str],
|
||||
response: TripoTaskResponse,
|
||||
) -> tuple[str, str]:
|
||||
@@ -57,7 +57,7 @@ def poll_until_finished(
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||
task_id = response.data.task_id
|
||||
response_poll = PollingOperation(
|
||||
response_poll = await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
@@ -80,7 +80,7 @@ def poll_until_finished(
|
||||
).execute()
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = download_url_to_bytesio(url)
|
||||
bytesio = await download_url_to_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
@@ -88,6 +88,7 @@ def poll_until_finished(
|
||||
return model_file, task_id
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
|
||||
class TripoTextToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||
@@ -126,11 +127,11 @@ class TripoTextToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
raise RuntimeError("Prompt is required")
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -155,7 +156,8 @@ class TripoTextToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoImageToModelNode:
|
||||
"""
|
||||
@@ -195,12 +197,12 @@ class TripoImageToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
raise RuntimeError("Image is required")
|
||||
tripo_file = upload_image_to_tripo(image, **kwargs)
|
||||
response = SynchronousOperation(
|
||||
tripo_file = await upload_image_to_tripo(image, **kwargs)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -225,7 +227,8 @@ class TripoImageToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoMultiviewToModelNode:
|
||||
"""
|
||||
@@ -267,7 +270,7 @@ class TripoMultiviewToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
images = []
|
||||
@@ -282,11 +285,11 @@ class TripoMultiviewToModelNode:
|
||||
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||
image_ = image_dict[image_name]
|
||||
if image_ is not None:
|
||||
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
||||
tripo_file = await upload_image_to_tripo(image_, **kwargs)
|
||||
images.append(tripo_file)
|
||||
else:
|
||||
images.append(TripoFileEmptyReference())
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -309,7 +312,8 @@ class TripoMultiviewToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoTextureNode:
|
||||
@classmethod
|
||||
@@ -340,8 +344,8 @@ class TripoTextureNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 80
|
||||
|
||||
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -358,7 +362,7 @@ class TripoTextureNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRefineNode:
|
||||
@@ -387,8 +391,8 @@ class TripoRefineNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 240
|
||||
|
||||
def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -400,7 +404,7 @@ class TripoRefineNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRigNode:
|
||||
@@ -425,8 +429,8 @@ class TripoRigNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 180
|
||||
|
||||
def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -440,7 +444,8 @@ class TripoRigNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRetargetNode:
|
||||
@classmethod
|
||||
@@ -475,8 +480,8 @@ class TripoRetargetNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -491,7 +496,8 @@ class TripoRetargetNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoConversionNode:
|
||||
@classmethod
|
||||
@@ -529,10 +535,10 @@ class TripoConversionNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
if not original_model_task_id:
|
||||
raise RuntimeError("original_model_task_id is required")
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -549,7 +555,8 @@ class TripoConversionNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripoTextToModelNode": TripoTextToModelNode,
|
||||
|
||||
+196
-149
@@ -1,17 +1,18 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse
|
||||
VeoGenVidPollResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import (
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
tensor_to_base64_string
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
@@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
|
||||
@@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text description of the video",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["16:9", "9:16"],
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of the output video",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
"duration_seconds": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 5,
|
||||
"min": 5,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"duration_seconds",
|
||||
default=5,
|
||||
min=5,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
"enhance_prompt": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Whether to enhance the prompt with AI assistance",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
"person_generation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["ALLOW", "BLOCK"],
|
||||
"default": "ALLOW",
|
||||
"tooltip": "Whether to allow generating people in the video",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFF,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation (0 for random)",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-2.0-generate-001"],
|
||||
"default": "veo-2.0-generate-001",
|
||||
"tooltip": "Veo 2 model to use for video generation",
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["veo-2.0-generate-001"],
|
||||
default="veo-2.0-generate-001",
|
||||
tooltip="Veo 2 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
||||
API_NODE = True
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
negative_prompt="",
|
||||
@@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
image=None,
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
@@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -214,10 +209,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||
@@ -248,15 +243,15 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=unique_id,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = poll_operation.execute()
|
||||
poll_response = await poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
@@ -281,7 +276,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
video_data = None
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
@@ -291,9 +285,9 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
video_url = video.gcsUri
|
||||
video_response = requests.get(video_url)
|
||||
video_data = video_response.content
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(video.gcsUri) as video_response:
|
||||
video_data = await video_response.content.read()
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
@@ -305,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = io.BytesIO(video_data)
|
||||
video_io = BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return (VideoFromFile(video_io),)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
@@ -324,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
parent_input = super().INPUT_TYPES()
|
||||
|
||||
# Update model options for Veo 3
|
||||
parent_input["optional"]["model"] = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
"default": "veo-3.0-generate-001",
|
||||
"tooltip": "Veo 3 model to use for video generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration_seconds",
|
||||
default=8,
|
||||
min=8,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
# Add generateAudio parameter
|
||||
parent_input["optional"]["generate_audio"] = (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
||||
}
|
||||
)
|
||||
|
||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
||||
parent_input["optional"]["duration_seconds"] = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 8,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
},
|
||||
)
|
||||
class VeoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
VeoVideoGenerationNode,
|
||||
Veo3VideoGenerationNode,
|
||||
]
|
||||
|
||||
return parent_input
|
||||
|
||||
|
||||
# Register the nodes
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
||||
}
|
||||
async def comfy_entrypoint() -> VeoExtension:
|
||||
return VeoExtension()
|
||||
|
||||
@@ -0,0 +1,622 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Literal, TypeVar
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_aspect_ratio_closeness,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
||||
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||
VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
|
||||
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
|
||||
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = 'viduq1'
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
r_16_9 = "16:9"
|
||||
r_9_16 = "9:16"
|
||||
r_1_1 = "1:1"
|
||||
|
||||
|
||||
class Resolution(str, Enum):
|
||||
r_1080p = "1080p"
|
||||
|
||||
|
||||
class MovementAmplitude(str, Enum):
|
||||
auto = "auto"
|
||||
small = "small"
|
||||
medium = "medium"
|
||||
large = "large"
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: VideoModelName = VideoModelName.vidu_q1
|
||||
prompt: Optional[str] = Field(None, max_length=1500)
|
||||
duration: Optional[Literal[5]] = 5
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
|
||||
resolution: Optional[Resolution] = Resolution.r_1080p
|
||||
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
created = "created"
|
||||
queueing = "queueing"
|
||||
processing = "processing"
|
||||
success = "success"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: TaskStatus = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
id: str = Field(..., description="Creation id")
|
||||
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: TaskStatus = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[TaskStatus.success.value],
|
||||
failed_statuses=[TaskStatus.failed.value],
|
||||
status_extractor=lambda response: response.state.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
return None
|
||||
|
||||
|
||||
def get_video_from_response(response) -> TaskResult:
|
||||
if not response.creations:
|
||||
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||
logging.info(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
|
||||
return response.creations[0]
|
||||
|
||||
|
||||
async def execute_task(
|
||||
vidu_endpoint: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
node_id: str,
|
||||
) -> R:
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=vidu_endpoint,
|
||||
method=HttpMethod.POST,
|
||||
request_model=TaskCreationRequest,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
if response.state == TaskStatus.failed:
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from text prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[model.value for model in AspectRatio],
|
||||
default=AspectRatio.r_16_9.value,
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduImageToVideoNode",
|
||||
display_name="Vidu Image To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="A textual description for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from multiple images and prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"images",
|
||||
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[model.value for model in AspectRatio],
|
||||
default=AspectRatio.r_16_9.value,
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
images: torch.Tensor,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
a = get_number_of_images(images)
|
||||
if a > 7:
|
||||
raise ValueError("Too many images, maximum allowed is 7.")
|
||||
for image in images:
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
images,
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduStartEndToVideoNode",
|
||||
display_name="Vidu Start End To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="Start frame",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"end_frame",
|
||||
tooltip="End frame",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
first_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = [
|
||||
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
ViduTextToVideoNode,
|
||||
ViduImageToVideoNode,
|
||||
ViduReferenceVideoNode,
|
||||
ViduStartEndToVideoNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> ViduExtension:
|
||||
return ViduExtension()
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
@@ -53,8 +53,55 @@ def validate_image_aspect_ratio(
|
||||
)
|
||||
|
||||
|
||||
def validate_image_aspect_ratio_range(
|
||||
image: torch.Tensor,
|
||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
a1, b1 = min_ratio
|
||||
a2, b2 = max_ratio
|
||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
||||
lo, hi = (a1 / b1), (a2 / b2)
|
||||
if lo > hi:
|
||||
lo, hi = hi, lo
|
||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
||||
w, h = get_image_dimensions(image)
|
||||
if w <= 0 or h <= 0:
|
||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||
ar = w / h
|
||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
||||
if not ok:
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
||||
return ar
|
||||
|
||||
|
||||
def validate_aspect_ratio_closeness(
|
||||
start_img,
|
||||
end_img,
|
||||
min_rel: float,
|
||||
max_rel: float,
|
||||
*,
|
||||
strict: bool = False, # True => exclusive, False => inclusive
|
||||
) -> None:
|
||||
w1, h1 = get_image_dimensions(start_img)
|
||||
w2, h2 = get_image_dimensions(end_img)
|
||||
if min(w1, h1, w2, h2) <= 0:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
ar1 = w1 / h1
|
||||
ar2 = w2 / h2
|
||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
||||
if (closeness >= limit) if strict else (closeness > limit):
|
||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
@@ -79,7 +126,7 @@ def validate_video_dimensions(
|
||||
|
||||
|
||||
def validate_video_duration(
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
@@ -98,3 +145,23 @@ def validate_video_duration(
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
if isinstance(images, torch.Tensor):
|
||||
return images.shape[0] if images.ndim >= 4 else 1
|
||||
return len(images)
|
||||
|
||||
|
||||
def validate_audio_duration(
|
||||
audio: Input.Audio,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> None:
|
||||
sr = int(audio["sample_rate"])
|
||||
dur = int(audio["waveform"].shape[-1]) / sr
|
||||
eps = 1.0 / sr
|
||||
if min_duration is not None and dur + eps < min_duration:
|
||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||
if max_duration is not None and dur - eps > max_duration:
|
||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||
|
||||
@@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
# Include client_id to ensure message is only sent to the initiating client
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
+47
-33
@@ -1,49 +1,63 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class TextEncodeAceStepAudio:
|
||||
|
||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return (conditioning, )
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||
return ({"samples": latent, "type": "audio"}, )
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||
}
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
return AceExtension()
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import trange
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm.auto import trange
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
|
||||
return x
|
||||
|
||||
|
||||
class SamplerLCMUpscale:
|
||||
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
class SamplerLCMUpscale(io.ComfyNode):
|
||||
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
|
||||
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCMUpscale",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
|
||||
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
|
||||
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
|
||||
@classmethod
|
||||
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
||||
if scale_steps < 0:
|
||||
scale_steps = None
|
||||
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||
return (sampler, )
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
import comfy.model_patcher
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
@@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
return x
|
||||
|
||||
|
||||
class SamplerEulerCFGpp:
|
||||
class SamplerEulerCFGpp(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"version": (["regular", "alternative"],),}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
# CATEGORY = "sampling/custom_sampling/samplers"
|
||||
CATEGORY = "_for_testing"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerCFGpp",
|
||||
display_name="SamplerEulerCFG++",
|
||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||
inputs=[
|
||||
io.Combo.Input("version", options=["regular", "alternative"]),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, version):
|
||||
@classmethod
|
||||
def execute(cls, version) -> io.NodeOutput:
|
||||
if version == "alternative":
|
||||
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
||||
else:
|
||||
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
||||
return (sampler, )
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SamplerLCMUpscale": SamplerLCMUpscale,
|
||||
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SamplerEulerCFGpp": "SamplerEulerCFG++",
|
||||
}
|
||||
class AdvancedSamplersExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SamplerLCMUpscale,
|
||||
SamplerEulerCFGpp,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AdvancedSamplersExtension:
|
||||
return AdvancedSamplersExtension()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""
|
||||
@@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152
|
||||
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
|
||||
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
|
||||
|
||||
class AlignYourStepsScheduler:
|
||||
class AlignYourStepsScheduler(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model_type": (["SD1", "SDXL", "SVD"], ),
|
||||
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
io.Int.Input("steps", default=10, min=1, max=10000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Sigmas.Output()],
|
||||
)
|
||||
|
||||
def get_sigmas(self, model_type, steps, denoise):
|
||||
# Deprecated: use the V3 schema's `execute` method instead of this.
|
||||
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return (torch.FloatTensor([]),)
|
||||
return io.NodeOutput(torch.FloatTensor([]))
|
||||
total_steps = round(steps * denoise)
|
||||
|
||||
sigmas = NOISE_LEVELS[model_type][:]
|
||||
@@ -46,8 +55,15 @@ class AlignYourStepsScheduler:
|
||||
|
||||
sigmas = sigmas[-(total_steps + 1):]
|
||||
sigmas[-1] = 0
|
||||
return (torch.FloatTensor(sigmas), )
|
||||
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AlignYourStepsScheduler": AlignYourStepsScheduler,
|
||||
}
|
||||
|
||||
class AlignYourStepsExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
AlignYourStepsScheduler,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AlignYourStepsExtension:
|
||||
return AlignYourStepsExtension()
|
||||
|
||||
+51
-21
@@ -1,4 +1,8 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
@@ -6,22 +10,45 @@ def project(v0, v1):
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
class APG:
|
||||
class APG(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="APG",
|
||||
display_name="Adaptive Projected Guidance",
|
||||
category="sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
"eta",
|
||||
default=1.0,
|
||||
min=-10.0,
|
||||
max=10.0,
|
||||
step=0.01,
|
||||
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"norm_threshold",
|
||||
default=5.0,
|
||||
min=0.0,
|
||||
max=50.0,
|
||||
step=0.1,
|
||||
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"momentum",
|
||||
default=0.0,
|
||||
min=-5.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
def patch(self, model, eta, norm_threshold, momentum):
|
||||
@classmethod
|
||||
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
@@ -65,12 +92,15 @@ class APG:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"APG": APG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"APG": "Adaptive Projected Guidance",
|
||||
}
|
||||
class ApgExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
APG,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> ApgExtension:
|
||||
return ApgExtension()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def attention_multiply(attn, model, q, k, v, out):
|
||||
m = model.clone()
|
||||
@@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out):
|
||||
return m
|
||||
|
||||
|
||||
class UNetSelfAttentionMultiply:
|
||||
class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetSelfAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
m = attention_multiply("attn1", model, q, k, v, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class UNetCrossAttentionMultiply:
|
||||
|
||||
class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetCrossAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
m = attention_multiply("attn2", model, q, k, v, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class CLIPAttentionMultiply:
|
||||
|
||||
class CLIPAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip": ("CLIP",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, clip, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
||||
m = clip.clone()
|
||||
sd = m.patcher.model_state_dict()
|
||||
|
||||
@@ -79,23 +97,28 @@ class CLIPAttentionMultiply:
|
||||
m.add_patches({key: (None,)}, 0.0, v)
|
||||
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class UNetTemporalAttentionMultiply:
|
||||
|
||||
class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetTemporalAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
|
||||
@classmethod
|
||||
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
sd = model.model_state_dict()
|
||||
|
||||
@@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||
else:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
|
||||
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
|
||||
"CLIPAttentionMultiply": CLIPAttentionMultiply,
|
||||
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
|
||||
}
|
||||
|
||||
class AttentionMultiplyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
UNetSelfAttentionMultiply,
|
||||
UNetCrossAttentionMultiply,
|
||||
CLIPAttentionMultiply,
|
||||
UNetTemporalAttentionMultiply,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AttentionMultiplyExtension:
|
||||
return AttentionMultiplyExtension()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user