2023-02-25 20:57:40 +00:00
import os
import sys
import asyncio
2023-08-28 13:52:22 +09:00
import traceback
2023-02-25 20:57:40 +00:00
import nodes
2023-03-19 11:29:03 -04:00
import folder_paths
2023-02-27 19:43:55 -05:00
import execution
2023-02-25 20:57:40 +00:00
import uuid
2023-08-20 19:55:48 +01:00
import urllib
2023-02-25 20:57:40 +00:00
import json
2023-03-03 19:05:39 +00:00
import glob
2023-05-30 20:43:29 -05:00
import struct
2024-04-30 20:17:02 -04:00
import ssl
2024-09-11 01:00:31 -04:00
import socket
import ipaddress
2023-07-19 17:37:27 -04:00
from PIL import Image , ImageOps
2023-08-29 18:34:43 +10:00
from PIL . PngImagePlugin import PngInfo
2023-05-09 03:37:36 +09:00
from io import BytesIO
2024-03-11 12:30:11 -04:00
import aiohttp
from aiohttp import web
2024-03-11 13:54:56 -04:00
import logging
2023-02-25 20:57:40 +00:00
2023-03-06 14:09:23 -05:00
import mimetypes
2023-04-06 15:06:22 -04:00
from comfy . cli_args import args
2023-05-29 02:48:50 -04:00
import comfy . utils
2023-06-01 23:26:23 -05:00
import comfy . model_management
2024-07-16 18:27:09 -04:00
import node_helpers
2024-07-16 11:26:11 -04:00
from app . frontend_management import FrontendManager
2024-01-08 22:06:44 +00:00
from app . user_manager import UserManager
2024-08-13 12:48:52 -07:00
from model_filemanager import download_model , DownloadModelStatus
from typing import Optional
2024-08-20 22:25:06 -07:00
from api_server . routes . internal . internal_routes import InternalRoutes
2024-07-16 11:26:11 -04:00
2023-05-30 20:43:29 -05:00
class BinaryEventTypes :
PREVIEW_IMAGE = 1
2023-07-19 17:37:27 -04:00
UNENCODED_PREVIEW_IMAGE = 2
2023-05-30 20:43:29 -05:00
2023-06-15 11:01:06 -04:00
async def send_socket_catch_exception ( function , message ) :
try :
await function ( message )
except ( aiohttp . ClientError , aiohttp . ClientPayloadError , ConnectionResetError ) as err :
2024-03-11 13:54:56 -04:00
logging . warning ( " send error: {} " . format ( err ) )
2023-05-30 20:43:29 -05:00
2024-08-30 12:46:37 -04:00
def get_comfyui_version ( ) :
comfyui_version = " unknown "
repo_path = os . path . dirname ( os . path . realpath ( __file__ ) )
try :
import pygit2
repo = pygit2 . Repository ( repo_path )
comfyui_version = repo . describe ( describe_strategy = pygit2 . GIT_DESCRIBE_TAGS )
except Exception :
try :
import subprocess
2024-08-30 12:48:42 -04:00
comfyui_version = subprocess . check_output ( [ " git " , " describe " , " --tags " ] , cwd = repo_path ) . decode ( ' utf-8 ' )
2024-08-30 12:46:37 -04:00
except Exception as e :
logging . warning ( f " Failed to get ComfyUI version: { e } " )
return comfyui_version . strip ( )
2023-03-09 19:38:43 +00:00
@web.middleware
async def cache_control ( request : web . Request , handler ) :
response : web . Response = await handler ( request )
if request . path . endswith ( ' .js ' ) or request . path . endswith ( ' .css ' ) :
response . headers . setdefault ( ' Cache-Control ' , ' no-cache ' )
return response
2023-04-06 15:06:22 -04:00
def create_cors_middleware ( allowed_origin : str ) :
@web.middleware
async def cors_middleware ( request : web . Request , handler ) :
if request . method == " OPTIONS " :
# Pre-flight request. Reply successfully:
response = web . Response ( )
else :
response = await handler ( request )
response . headers [ ' Access-Control-Allow-Origin ' ] = allowed_origin
response . headers [ ' Access-Control-Allow-Methods ' ] = ' POST, GET, DELETE, PUT, OPTIONS '
response . headers [ ' Access-Control-Allow-Headers ' ] = ' Content-Type, Authorization '
response . headers [ ' Access-Control-Allow-Credentials ' ] = ' true '
return response
return cors_middleware
2023-04-05 13:08:08 -04:00
2024-09-11 01:00:31 -04:00
def is_loopback ( host ) :
if host is None :
return False
try :
if ipaddress . ip_address ( host ) . is_loopback :
return True
else :
return False
except :
pass
loopback = False
for family in ( socket . AF_INET , socket . AF_INET6 ) :
try :
r = socket . getaddrinfo ( host , None , family , socket . SOCK_STREAM )
for family , _ , _ , _ , sockaddr in r :
if not ipaddress . ip_address ( sockaddr [ 0 ] ) . is_loopback :
return loopback
else :
loopback = True
except socket . gaierror :
pass
return loopback
2024-09-08 18:08:28 -04:00
def create_origin_only_middleware ( ) :
@web.middleware
async def origin_only_middleware ( request : web . Request , handler ) :
2024-09-09 16:23:21 -04:00
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
#in that case the Host and Origin hostnames won't match
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
2024-09-08 18:08:28 -04:00
if ' Host ' in request . headers and ' Origin ' in request . headers :
host = request . headers [ ' Host ' ]
origin = request . headers [ ' Origin ' ]
host_domain = host . lower ( )
2024-09-09 03:18:17 -04:00
parsed = urllib . parse . urlparse ( origin )
origin_domain = parsed . netloc . lower ( )
2024-09-09 16:23:21 -04:00
host_domain_parsed = urllib . parse . urlsplit ( ' // ' + host_domain )
2024-09-11 01:00:31 -04:00
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback ( host_domain_parsed . hostname )
2024-09-09 16:23:21 -04:00
if parsed . port is None : #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed . hostname
if host_domain_parsed . port is None :
origin_domain = parsed . hostname
2024-09-09 03:18:17 -04:00
2024-09-11 01:00:31 -04:00
if loopback and host_domain is not None and origin_domain is not None and len ( host_domain ) > 0 and len ( origin_domain ) > 0 :
2024-09-09 01:04:03 -04:00
if host_domain != origin_domain :
logging . warning ( " WARNING: request with non matching host and origin {} != {} , returning 403 " . format ( host_domain , origin_domain ) )
return web . Response ( status = 403 )
2024-09-08 18:08:28 -04:00
2024-09-08 19:35:23 -04:00
if request . method == " OPTIONS " :
response = web . Response ( )
else :
response = await handler ( request )
2024-09-08 18:08:28 -04:00
return response
return origin_only_middleware
2023-02-25 20:57:40 +00:00
class PromptServer ( ) :
def __init__ ( self , loop ) :
2023-03-24 11:39:09 +00:00
PromptServer . instance = self
2023-03-26 15:16:52 -04:00
2023-06-24 16:45:41 +09:00
mimetypes . init ( )
2023-03-02 19:58:59 +00:00
mimetypes . types_map [ ' .js ' ] = ' application/javascript; charset=utf-8 '
2023-08-20 19:55:48 +01:00
2024-01-08 22:06:44 +00:00
self . user_manager = UserManager ( )
2024-08-20 22:25:06 -07:00
self . internal_routes = InternalRoutes ( )
2023-08-20 19:55:48 +01:00
self . supports = [ " custom_nodes_from_web " ]
2023-02-25 20:57:40 +00:00
self . prompt_queue = None
self . loop = loop
self . messages = asyncio . Queue ( )
2024-08-13 12:48:52 -07:00
self . client_session : Optional [ aiohttp . ClientSession ] = None
2023-02-25 20:57:40 +00:00
self . number = 0
2023-04-06 15:06:22 -04:00
middlewares = [ cache_control ]
2023-04-06 15:24:55 -04:00
if args . enable_cors_header :
middlewares . append ( create_cors_middleware ( args . enable_cors_header ) )
2024-09-08 18:08:28 -04:00
else :
middlewares . append ( create_origin_only_middleware ( ) )
2023-04-06 15:06:22 -04:00
2023-10-29 03:55:46 -04:00
max_upload_size = round ( args . max_upload_size * 1024 * 1024 )
self . app = web . Application ( client_max_size = max_upload_size , middlewares = middlewares )
2023-02-25 20:57:40 +00:00
self . sockets = dict ( )
2024-07-16 11:26:11 -04:00
self . web_root = (
FrontendManager . init_frontend ( args . front_end_version )
if args . front_end_root is None
else args . front_end_root
)
logging . info ( f " [Prompt Server] web root: { self . web_root } " )
2023-02-25 20:57:40 +00:00
routes = web . RouteTableDef ( )
2023-04-01 12:44:29 +01:00
self . routes = routes
2023-03-07 13:24:15 +00:00
self . last_node_id = None
self . client_id = None
2023-02-25 20:57:40 +00:00
2023-08-28 13:52:22 +09:00
self . on_prompt_handlers = [ ]
2023-02-25 20:57:40 +00:00
@routes.get ( ' /ws ' )
async def websocket_handler ( request ) :
ws = web . WebSocketResponse ( )
await ws . prepare ( request )
2023-03-07 13:24:15 +00:00
sid = request . rel_url . query . get ( ' clientId ' , ' ' )
if sid :
# Reusing existing session, remove old
self . sockets . pop ( sid , None )
else :
2023-05-10 16:41:43 -04:00
sid = uuid . uuid4 ( ) . hex
2023-03-07 13:24:15 +00:00
2023-02-25 20:57:40 +00:00
self . sockets [ sid ] = ws
2023-03-07 13:24:15 +00:00
2023-02-25 20:57:40 +00:00
try :
# Send initial state to the new client
await self . send ( " status " , { " status " : self . get_queue_info ( ) , ' sid ' : sid } , sid )
2023-03-07 13:24:15 +00:00
# On reconnect if we are the currently executing client send the current node
if self . client_id == sid and self . last_node_id is not None :
await self . send ( " executing " , { " node " : self . last_node_id } , sid )
2024-07-02 01:32:23 -04:00
2023-02-25 20:57:40 +00:00
async for msg in ws :
if msg . type == aiohttp . WSMsgType . ERROR :
2024-03-11 13:54:56 -04:00
logging . warning ( ' ws connection closed with exception %s ' % ws . exception ( ) )
2023-02-25 20:57:40 +00:00
finally :
2023-03-07 13:24:15 +00:00
self . sockets . pop ( sid , None )
2023-02-25 20:57:40 +00:00
return ws
@routes.get ( " / " )
async def get_root ( request ) :
2024-08-05 09:25:28 -07:00
response = web . FileResponse ( os . path . join ( self . web_root , " index.html " ) )
response . headers [ ' Cache-Control ' ] = ' no-cache '
response . headers [ " Pragma " ] = " no-cache "
response . headers [ " Expires " ] = " 0 "
return response
2023-02-25 18:36:29 -05:00
2023-03-12 21:36:42 +00:00
@routes.get ( " /embeddings " )
def get_embeddings ( self ) :
2023-03-19 11:29:03 -04:00
embeddings = folder_paths . get_filename_list ( " embeddings " )
2023-08-30 20:46:53 +02:00
return web . json_response ( list ( map ( lambda a : os . path . splitext ( a ) [ 0 ] , embeddings ) ) )
2024-09-17 09:22:05 +01:00
@routes.get ( " /models " )
def list_model_types ( request ) :
model_types = list ( folder_paths . folder_names_and_paths . keys ( ) )
return web . json_response ( model_types )
2023-03-12 21:36:42 +00:00
2024-08-20 23:04:42 -07:00
@routes.get ( " /models/ {folder} " )
async def get_models ( request ) :
folder = request . match_info . get ( " folder " , None )
if not folder in folder_paths . folder_names_and_paths :
return web . Response ( status = 404 )
files = folder_paths . get_filename_list ( folder )
return web . json_response ( files )
2023-03-03 19:05:39 +00:00
@routes.get ( " /extensions " )
async def get_extensions ( request ) :
2023-08-20 19:55:48 +01:00
files = glob . glob ( os . path . join (
2023-09-19 08:18:29 -04:00
glob . escape ( self . web_root ) , ' extensions/**/*.js ' ) , recursive = True )
2024-07-02 01:32:23 -04:00
2023-08-20 19:55:48 +01:00
extensions = list ( map ( lambda f : " / " + os . path . relpath ( f , self . web_root ) . replace ( " \\ " , " / " ) , files ) )
2024-07-02 01:32:23 -04:00
2023-08-20 19:55:48 +01:00
for name , dir in nodes . EXTENSION_WEB_DIRS . items ( ) :
2023-09-19 08:18:29 -04:00
files = glob . glob ( os . path . join ( glob . escape ( dir ) , ' **/*.js ' ) , recursive = True )
2023-08-20 19:55:48 +01:00
extensions . extend ( list ( map ( lambda f : " /extensions/ " + urllib . parse . quote (
name ) + " / " + os . path . relpath ( f , dir ) . replace ( " \\ " , " / " ) , files ) ) )
return web . json_response ( extensions )
2023-03-03 19:05:39 +00:00
2023-05-09 03:37:36 +09:00
def get_dir_by_type ( dir_type ) :
if dir_type is None :
2023-05-13 15:31:22 -04:00
dir_type = " input "
if dir_type == " input " :
2023-05-09 03:37:36 +09:00
type_dir = folder_paths . get_input_directory ( )
elif dir_type == " temp " :
type_dir = folder_paths . get_temp_directory ( )
elif dir_type == " output " :
type_dir = folder_paths . get_output_directory ( )
2023-05-13 15:31:22 -04:00
return type_dir , dir_type
2024-07-02 01:32:23 -04:00
2024-07-02 01:30:33 -04:00
def compare_image_hash ( filepath , image ) :
2024-07-16 18:27:09 -04:00
hasher = node_helpers . hasher ( )
2024-07-02 01:30:33 -04:00
# function to compare hashes of two images to see if it already exists, fix to #3465
if os . path . exists ( filepath ) :
2024-07-16 18:27:09 -04:00
a = hasher ( )
b = hasher ( )
2024-07-02 01:30:33 -04:00
with open ( filepath , " rb " ) as f :
a . update ( f . read ( ) )
b . update ( image . file . read ( ) )
image . file . seek ( 0 )
f . close ( )
return a . hexdigest ( ) == b . hexdigest ( )
return False
2024-07-02 01:32:23 -04:00
2023-05-08 14:13:06 -04:00
def image_upload ( post , image_save_function = None ) :
2023-04-24 04:58:55 +09:00
image = post . get ( " image " )
2023-05-11 14:15:13 -04:00
overwrite = post . get ( " overwrite " )
2024-07-02 01:30:33 -04:00
image_is_duplicate = False
2023-04-24 04:58:55 +09:00
2023-05-08 14:13:06 -04:00
image_upload_type = post . get ( " type " )
2023-05-13 15:31:22 -04:00
upload_dir , image_upload_type = get_dir_by_type ( image_upload_type )
2023-03-08 22:07:44 +00:00
if image and image . file :
filename = image . filename
if not filename :
return web . Response ( status = 400 )
2023-05-08 14:13:06 -04:00
subfolder = post . get ( " subfolder " , " " )
full_output_folder = os . path . join ( upload_dir , os . path . normpath ( subfolder ) )
2023-09-07 18:14:30 -04:00
filepath = os . path . abspath ( os . path . join ( full_output_folder , filename ) )
2023-05-08 14:13:06 -04:00
2023-09-07 18:14:30 -04:00
if os . path . commonpath ( ( upload_dir , filepath ) ) != upload_dir :
2023-05-08 14:13:06 -04:00
return web . Response ( status = 400 )
if not os . path . exists ( full_output_folder ) :
os . makedirs ( full_output_folder )
2023-03-09 17:57:59 +00:00
split = os . path . splitext ( filename )
2023-05-08 14:13:06 -04:00
2023-05-11 14:15:13 -04:00
if overwrite is not None and ( overwrite == " true " or overwrite == " 1 " ) :
pass
else :
i = 1
while os . path . exists ( filepath ) :
2024-07-02 01:30:33 -04:00
if compare_image_hash ( filepath , image ) : #compare hash to prevent saving of duplicates with same name, fix for #3465
image_is_duplicate = True
break
2023-05-11 14:15:13 -04:00
filename = f " { split [ 0 ] } ( { i } ) { split [ 1 ] } "
filepath = os . path . join ( full_output_folder , filename )
i + = 1
2023-03-09 17:57:59 +00:00
2024-07-02 01:32:23 -04:00
if not image_is_duplicate :
2024-07-02 01:30:33 -04:00
if image_save_function is not None :
image_save_function ( image , post , filepath )
else :
with open ( filepath , " wb " ) as f :
f . write ( image . file . read ( ) )
2023-03-09 17:57:59 +00:00
2023-05-08 14:13:06 -04:00
return web . json_response ( { " name " : filename , " subfolder " : subfolder , " type " : image_upload_type } )
2023-03-08 22:07:44 +00:00
else :
return web . Response ( status = 400 )
2023-05-08 14:13:06 -04:00
@routes.post ( " /upload/image " )
async def upload_image ( request ) :
post = await request . post ( )
return image_upload ( post )
2023-06-24 16:45:41 +09:00
2023-05-09 03:37:36 +09:00
@routes.post ( " /upload/mask " )
async def upload_mask ( request ) :
post = await request . post ( )
2023-05-08 14:13:06 -04:00
def image_save_function ( image , post , filepath ) :
2023-06-24 16:45:41 +09:00
original_ref = json . loads ( post . get ( " original_ref " ) )
filename , output_dir = folder_paths . annotated_filepath ( original_ref [ ' filename ' ] )
# validation for security: prevent accessing arbitrary path
if filename [ 0 ] == ' / ' or ' .. ' in filename :
return web . Response ( status = 400 )
if output_dir is None :
type = original_ref . get ( " type " , " output " )
output_dir = folder_paths . get_directory_by_type ( type )
if output_dir is None :
return web . Response ( status = 400 )
if original_ref . get ( " subfolder " , " " ) != " " :
full_output_dir = os . path . join ( output_dir , original_ref [ " subfolder " ] )
if os . path . commonpath ( ( os . path . abspath ( full_output_dir ) , output_dir ) ) != output_dir :
return web . Response ( status = 403 )
output_dir = full_output_dir
2023-05-09 03:37:36 +09:00
2023-06-24 16:45:41 +09:00
file = os . path . join ( output_dir , filename )
if os . path . isfile ( file ) :
with Image . open ( file ) as original_pil :
2023-08-29 18:34:43 +10:00
metadata = PngInfo ( )
2023-08-29 18:47:17 +10:00
if hasattr ( original_pil , ' text ' ) :
for key in original_pil . text :
metadata . add_text ( key , original_pil . text [ key ] )
2023-06-24 16:45:41 +09:00
original_pil = original_pil . convert ( ' RGBA ' )
mask_pil = Image . open ( image . file ) . convert ( ' RGBA ' )
# alpha copy
new_alpha = mask_pil . getchannel ( ' A ' )
original_pil . putalpha ( new_alpha )
2023-08-29 18:34:43 +10:00
original_pil . save ( filepath , compress_level = 4 , pnginfo = metadata )
2023-05-09 03:37:36 +09:00
2023-05-08 14:13:06 -04:00
return image_upload ( post , image_save_function )
2023-03-08 22:07:44 +00:00
2023-03-12 19:51:39 +01:00
@routes.get ( " /view " )
2023-02-25 20:57:40 +00:00
async def view_image ( request ) :
2023-03-19 12:54:29 +01:00
if " filename " in request . rel_url . query :
2023-05-09 03:37:36 +09:00
filename = request . rel_url . query [ " filename " ]
filename , output_dir = folder_paths . annotated_filepath ( filename )
# validation for security: prevent accessing arbitrary path
if filename [ 0 ] == ' / ' or ' .. ' in filename :
return web . Response ( status = 400 )
if output_dir is None :
type = request . rel_url . query . get ( " type " , " output " )
output_dir = folder_paths . get_directory_by_type ( type )
2023-04-05 14:01:01 -04:00
if output_dir is None :
2023-03-08 22:07:44 +00:00
return web . Response ( status = 400 )
2023-03-12 19:51:39 +01:00
if " subfolder " in request . rel_url . query :
2023-03-14 09:27:17 +01:00
full_output_dir = os . path . join ( output_dir , request . rel_url . query [ " subfolder " ] )
2023-03-23 21:25:21 -03:00
if os . path . commonpath ( ( os . path . abspath ( full_output_dir ) , output_dir ) ) != output_dir :
2023-03-14 09:27:17 +01:00
return web . Response ( status = 403 )
output_dir = full_output_dir
2023-03-12 19:51:39 +01:00
2023-03-22 17:32:01 +00:00
filename = os . path . basename ( filename )
file = os . path . join ( output_dir , filename )
2023-03-14 09:27:17 +01:00
2023-02-25 20:57:40 +00:00
if os . path . isfile ( file ) :
2023-06-05 14:49:43 +09:00
if ' preview ' in request . rel_url . query :
with Image . open ( file ) as img :
preview_info = request . rel_url . query [ ' preview ' ] . split ( ' ; ' )
2023-06-05 01:38:32 -04:00
image_format = preview_info [ 0 ]
2023-06-24 16:45:41 +09:00
if image_format not in [ ' webp ' , ' jpeg ' ] or ' a ' in request . rel_url . query . get ( ' channel ' , ' ' ) :
2023-06-05 01:38:32 -04:00
image_format = ' webp '
2023-06-05 14:49:43 +09:00
quality = 90
if preview_info [ - 1 ] . isdigit ( ) :
quality = int ( preview_info [ - 1 ] )
buffer = BytesIO ( )
2023-06-24 16:45:41 +09:00
if image_format in [ ' jpeg ' ] or request . rel_url . query . get ( ' channel ' , ' ' ) == ' rgb ' :
2023-06-05 01:38:32 -04:00
img = img . convert ( " RGB " )
img . save ( buffer , format = image_format , quality = quality )
2023-06-05 14:49:43 +09:00
buffer . seek ( 0 )
return web . Response ( body = buffer . read ( ) , content_type = f ' image/ { image_format } ' ,
headers = { " Content-Disposition " : f " filename= \" { filename } \" " } )
2023-05-09 03:37:36 +09:00
if ' channel ' not in request . rel_url . query :
channel = ' rgba '
else :
channel = request . rel_url . query [ " channel " ]
if channel == ' rgb ' :
with Image . open ( file ) as img :
if img . mode == " RGBA " :
r , g , b , a = img . split ( )
new_img = Image . merge ( ' RGB ' , ( r , g , b ) )
else :
new_img = img . convert ( " RGB " )
buffer = BytesIO ( )
new_img . save ( buffer , format = ' PNG ' )
buffer . seek ( 0 )
return web . Response ( body = buffer . read ( ) , content_type = ' image/png ' ,
headers = { " Content-Disposition " : f " filename= \" { filename } \" " } )
elif channel == ' a ' :
with Image . open ( file ) as img :
if img . mode == " RGBA " :
_ , _ , _ , a = img . split ( )
else :
a = Image . new ( ' L ' , img . size , 255 )
# alpha img
alpha_img = Image . new ( ' RGBA ' , img . size )
alpha_img . putalpha ( a )
alpha_buffer = BytesIO ( )
alpha_img . save ( alpha_buffer , format = ' PNG ' )
alpha_buffer . seek ( 0 )
return web . Response ( body = alpha_buffer . read ( ) , content_type = ' image/png ' ,
headers = { " Content-Disposition " : f " filename= \" { filename } \" " } )
else :
return web . FileResponse ( file , headers = { " Content-Disposition " : f " filename= \" { filename } \" " } )
2023-02-25 20:57:40 +00:00
return web . Response ( status = 404 )
2023-02-25 18:36:29 -05:00
2023-05-29 02:48:50 -04:00
@routes.get ( " /view_metadata/ {folder_name} " )
async def view_metadata ( request ) :
folder_name = request . match_info . get ( " folder_name " , None )
if folder_name is None :
return web . Response ( status = 404 )
if not " filename " in request . rel_url . query :
return web . Response ( status = 404 )
filename = request . rel_url . query [ " filename " ]
if not filename . endswith ( " .safetensors " ) :
return web . Response ( status = 404 )
safetensors_path = folder_paths . get_full_path ( folder_name , filename )
if safetensors_path is None :
return web . Response ( status = 404 )
out = comfy . utils . safetensors_header ( safetensors_path , max_size = 1024 * 1024 )
if out is None :
return web . Response ( status = 404 )
dt = json . loads ( out )
if not " __metadata__ " in dt :
return web . Response ( status = 404 )
return web . json_response ( dt [ " __metadata__ " ] )
2023-06-01 23:26:23 -05:00
@routes.get ( " /system_stats " )
2024-08-30 12:46:37 -04:00
async def system_stats ( request ) :
2023-06-02 15:05:25 -04:00
device = comfy . model_management . get_torch_device ( )
device_name = comfy . model_management . get_torch_device_name ( device )
2024-09-22 03:41:48 -04:00
cpu_device = comfy . model_management . torch . device ( " cpu " )
ram_total = comfy . model_management . get_total_memory ( cpu_device )
ram_free = comfy . model_management . get_free_memory ( cpu_device )
2023-06-01 23:26:23 -05:00
vram_total , torch_vram_total = comfy . model_management . get_total_memory ( device , torch_total_too = True )
vram_free , torch_vram_free = comfy . model_management . get_free_memory ( device , torch_free_too = True )
2024-08-30 12:46:37 -04:00
2023-06-01 23:26:23 -05:00
system_stats = {
2023-08-04 08:29:25 +01:00
" system " : {
" os " : os . name ,
2024-09-22 03:41:48 -04:00
" ram_total " : ram_total ,
" ram_free " : ram_free ,
2024-08-30 12:46:37 -04:00
" comfyui_version " : get_comfyui_version ( ) ,
2023-08-04 08:29:25 +01:00
" python_version " : sys . version ,
2024-08-30 12:48:42 -04:00
" pytorch_version " : comfy . model_management . torch_version ,
2024-08-30 12:46:37 -04:00
" embedded_python " : os . path . split ( os . path . split ( sys . executable ) [ 0 ] ) [ 1 ] == " python_embeded " ,
" argv " : sys . argv
2023-08-04 08:29:25 +01:00
} ,
2023-06-01 23:26:23 -05:00
" devices " : [
{
" name " : device_name ,
" type " : device . type ,
" index " : device . index ,
" vram_total " : vram_total ,
" vram_free " : vram_free ,
" torch_vram_total " : torch_vram_total ,
" torch_vram_free " : torch_vram_free ,
}
]
}
return web . json_response ( system_stats )
2023-02-25 20:57:40 +00:00
@routes.get ( " /prompt " )
async def get_prompt ( request ) :
return web . json_response ( self . get_queue_info ( ) )
2023-02-25 18:36:29 -05:00
2023-05-19 22:40:28 -04:00
def node_info ( node_class ) :
obj_class = nodes . NODE_CLASS_MAPPINGS [ node_class ]
info = { }
info [ ' input ' ] = obj_class . INPUT_TYPES ( )
2024-08-15 08:21:11 -07:00
info [ ' input_order ' ] = { key : list ( value . keys ( ) ) for ( key , value ) in obj_class . INPUT_TYPES ( ) . items ( ) }
2023-05-19 22:40:28 -04:00
info [ ' output ' ] = obj_class . RETURN_TYPES
info [ ' output_is_list ' ] = obj_class . OUTPUT_IS_LIST if hasattr ( obj_class , ' OUTPUT_IS_LIST ' ) else [ False ] * len ( obj_class . RETURN_TYPES )
info [ ' output_name ' ] = obj_class . RETURN_NAMES if hasattr ( obj_class , ' RETURN_NAMES ' ) else info [ ' output ' ]
info [ ' name ' ] = node_class
info [ ' display_name ' ] = nodes . NODE_DISPLAY_NAME_MAPPINGS [ node_class ] if node_class in nodes . NODE_DISPLAY_NAME_MAPPINGS . keys ( ) else node_class
2023-09-07 12:20:37 +10:00
info [ ' description ' ] = obj_class . DESCRIPTION if hasattr ( obj_class , ' DESCRIPTION ' ) else ' '
2024-07-09 17:07:15 -04:00
info [ ' python_module ' ] = getattr ( obj_class , " RELATIVE_PYTHON_MODULE " , " nodes " )
2023-05-19 22:40:28 -04:00
info [ ' category ' ] = ' sd '
2023-05-22 13:25:50 -04:00
if hasattr ( obj_class , ' OUTPUT_NODE ' ) and obj_class . OUTPUT_NODE == True :
info [ ' output_node ' ] = True
else :
info [ ' output_node ' ] = False
2023-05-19 22:40:28 -04:00
if hasattr ( obj_class , ' CATEGORY ' ) :
info [ ' category ' ] = obj_class . CATEGORY
2024-08-14 06:22:10 +01:00
if hasattr ( obj_class , ' OUTPUT_TOOLTIPS ' ) :
info [ ' output_tooltips ' ] = obj_class . OUTPUT_TOOLTIPS
2024-08-21 00:01:34 -04:00
if getattr ( obj_class , " DEPRECATED " , False ) :
info [ ' deprecated ' ] = True
if getattr ( obj_class , " EXPERIMENTAL " , False ) :
info [ ' experimental ' ] = True
2023-05-19 22:40:28 -04:00
return info
2023-02-25 20:57:40 +00:00
@routes.get ( " /object_info " )
async def get_object_info ( request ) :
2024-09-19 17:40:14 +09:00
with folder_paths . cache_helper :
out = { }
for x in nodes . NODE_CLASS_MAPPINGS :
try :
out [ x ] = node_info ( x )
except Exception as e :
logging . error ( f " [ERROR] An error occurred while retrieving information for the ' { x } ' node. " )
logging . error ( traceback . format_exc ( ) )
return web . json_response ( out )
2023-05-19 22:40:28 -04:00
@routes.get ( " /object_info/ {node_class} " )
async def get_object_info_node ( request ) :
node_class = request . match_info . get ( " node_class " , None )
out = { }
if ( node_class is not None ) and ( node_class in nodes . NODE_CLASS_MAPPINGS ) :
out [ node_class ] = node_info ( node_class )
2023-02-25 20:57:40 +00:00
return web . json_response ( out )
2023-02-25 18:36:29 -05:00
2023-02-25 20:57:40 +00:00
@routes.get ( " /history " )
async def get_history ( request ) :
2023-11-20 16:51:41 -05:00
max_items = request . rel_url . query . get ( " max_items " , None )
if max_items is not None :
max_items = int ( max_items )
return web . json_response ( self . prompt_queue . get_history ( max_items = max_items ) )
2023-02-25 18:36:29 -05:00
2023-06-12 14:34:30 -04:00
@routes.get ( " /history/ {prompt_id} " )
async def get_history ( request ) :
prompt_id = request . match_info . get ( " prompt_id " , None )
return web . json_response ( self . prompt_queue . get_history ( prompt_id = prompt_id ) )
2023-02-25 20:57:40 +00:00
@routes.get ( " /queue " )
async def get_queue ( request ) :
queue_info = { }
current_queue = self . prompt_queue . get_current_queue ( )
queue_info [ ' queue_running ' ] = current_queue [ 0 ]
queue_info [ ' queue_pending ' ] = current_queue [ 1 ]
return web . json_response ( queue_info )
2023-02-25 18:36:29 -05:00
2023-02-25 20:57:40 +00:00
@routes.post ( " /prompt " )
async def post_prompt ( request ) :
2024-03-11 13:54:56 -04:00
logging . info ( " got prompt " )
2023-02-25 20:57:40 +00:00
resp_code = 200
out_string = " "
json_data = await request . json ( )
2023-08-28 13:52:22 +09:00
json_data = self . trigger_on_prompt ( json_data )
2023-02-25 20:57:40 +00:00
if " number " in json_data :
number = float ( json_data [ ' number ' ] )
else :
number = self . number
if " front " in json_data :
if json_data [ ' front ' ] :
number = - number
self . number + = 1
if " prompt " in json_data :
prompt = json_data [ " prompt " ]
2023-02-27 19:43:55 -05:00
valid = execution . validate_prompt ( prompt )
2023-02-25 20:57:40 +00:00
extra_data = { }
if " extra_data " in json_data :
extra_data = json_data [ " extra_data " ]
if " client_id " in json_data :
extra_data [ " client_id " ] = json_data [ " client_id " ]
if valid [ 0 ] :
2023-05-10 16:41:43 -04:00
prompt_id = str ( uuid . uuid4 ( ) )
2023-05-20 23:06:33 -04:00
outputs_to_execute = valid [ 2 ]
self . prompt_queue . put ( ( number , prompt_id , prompt , extra_data , outputs_to_execute ) )
2023-07-13 02:25:38 -04:00
response = { " prompt_id " : prompt_id , " number " : number , " node_errors " : valid [ 3 ] }
return web . json_response ( response )
2023-02-25 20:57:40 +00:00
else :
2024-03-11 13:54:56 -04:00
logging . warning ( " invalid prompt: {} " . format ( valid [ 1 ] ) )
2023-05-22 13:22:38 -04:00
return web . json_response ( { " error " : valid [ 1 ] , " node_errors " : valid [ 3 ] } , status = 400 )
2023-05-14 01:30:58 -04:00
else :
2023-05-22 13:22:38 -04:00
return web . json_response ( { " error " : " no prompt " , " node_errors " : [ ] } , status = 400 )
2023-02-25 20:57:40 +00:00
@routes.post ( " /queue " )
async def post_queue ( request ) :
json_data = await request . json ( )
if " clear " in json_data :
if json_data [ " clear " ] :
self . prompt_queue . wipe_queue ( )
if " delete " in json_data :
to_delete = json_data [ ' delete ' ]
for id_to_delete in to_delete :
2023-05-13 02:07:49 -04:00
delete_func = lambda a : a [ 1 ] == id_to_delete
2023-02-25 20:57:40 +00:00
self . prompt_queue . delete_queue_item ( delete_func )
2023-05-13 02:07:49 -04:00
2023-02-25 20:57:40 +00:00
return web . Response ( status = 200 )
2023-03-03 15:20:49 +00:00
@routes.post ( " /interrupt " )
async def post_interrupt ( request ) :
nodes . interrupt_processing ( )
return web . Response ( status = 200 )
2024-01-04 14:28:11 -05:00
@routes.post ( " /free " )
2024-01-06 04:27:09 +02:00
async def post_free ( request ) :
2024-01-04 14:28:11 -05:00
json_data = await request . json ( )
unload_models = json_data . get ( " unload_models " , False )
free_memory = json_data . get ( " free_memory " , False )
if unload_models :
self . prompt_queue . set_flag ( " unload_models " , unload_models )
if free_memory :
self . prompt_queue . set_flag ( " free_memory " , free_memory )
return web . Response ( status = 200 )
2023-02-25 20:57:40 +00:00
@routes.post ( " /history " )
async def post_history ( request ) :
json_data = await request . json ( )
if " clear " in json_data :
if json_data [ " clear " ] :
2023-02-25 18:36:29 -05:00
self . prompt_queue . wipe_history ( )
2023-02-25 20:57:40 +00:00
if " delete " in json_data :
to_delete = json_data [ ' delete ' ]
for id_to_delete in to_delete :
2023-02-25 18:36:29 -05:00
self . prompt_queue . delete_history_item ( id_to_delete )
2023-02-25 20:57:40 +00:00
return web . Response ( status = 200 )
2024-08-13 12:48:52 -07:00
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
@routes.post ( " /internal/models/download " )
async def download_handler ( request ) :
async def report_progress ( filename : str , status : DownloadModelStatus ) :
2024-08-26 23:06:12 -07:00
payload = status . to_dict ( )
payload [ ' download_path ' ] = filename
await self . send_json ( " download_progress " , payload )
2024-08-13 12:48:52 -07:00
data = await request . json ( )
url = data . get ( ' url ' )
model_directory = data . get ( ' model_directory ' )
model_filename = data . get ( ' model_filename ' )
progress_interval = data . get ( ' progress_interval ' , 1.0 ) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename :
return web . json_response ( { " status " : " error " , " message " : " Missing URL or folder path or filename " } , status = 400 )
session = self . client_session
if session is None :
logging . error ( " Client session is not initialized " )
return web . Response ( status = 500 )
task = asyncio . create_task ( download_model ( lambda url : session . get ( url ) , model_filename , url , model_directory , report_progress , progress_interval ) )
await task
return web . json_response ( task . result ( ) . to_dict ( ) )
async def setup ( self ) :
timeout = aiohttp . ClientTimeout ( total = None ) # no timeout
self . client_session = aiohttp . ClientSession ( timeout = timeout )
2024-06-19 10:39:17 -04:00
2023-04-01 12:44:29 +01:00
def add_routes ( self ) :
2024-01-08 22:06:44 +00:00
self . user_manager . add_routes ( self . routes )
2024-08-20 22:25:06 -07:00
self . app . add_subapp ( ' /internal ' , self . internal_routes . get_app ( ) )
2024-06-19 10:39:17 -04:00
# Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward
# everything except serving of static files.
# Currently both the old endpoints without prefix and new endpoints with
# prefix are supported.
api_routes = web . RouteTableDef ( )
for route in self . routes :
2024-06-19 22:36:31 -04:00
# Custom nodes might add extra static routes. Only process non-static
# routes to add /api prefix.
if isinstance ( route , web . RouteDef ) :
api_routes . route ( route . method , " /api " + route . path ) ( route . handler , * * route . kwargs )
2024-06-19 10:39:17 -04:00
self . app . add_routes ( api_routes )
2023-04-01 12:44:29 +01:00
self . app . add_routes ( self . routes )
2023-08-20 19:55:48 +01:00
for name , dir in nodes . EXTENSION_WEB_DIRS . items ( ) :
self . app . add_routes ( [
2024-02-25 20:43:26 +08:00
web . static ( ' /extensions/ ' + urllib . parse . quote ( name ) , dir ) ,
2023-08-20 19:55:48 +01:00
] )
2023-02-25 20:57:40 +00:00
self . app . add_routes ( [
2024-02-25 20:43:26 +08:00
web . static ( ' / ' , self . web_root ) ,
2023-02-25 20:57:40 +00:00
] )
def get_queue_info ( self ) :
prompt_info = { }
exec_info = { }
exec_info [ ' queue_remaining ' ] = self . prompt_queue . get_tasks_remaining ( )
prompt_info [ ' exec_info ' ] = exec_info
return prompt_info
async def send ( self , event , data , sid = None ) :
2023-07-19 17:37:27 -04:00
if event == BinaryEventTypes . UNENCODED_PREVIEW_IMAGE :
await self . send_image ( data , sid = sid )
elif isinstance ( data , ( bytes , bytearray ) ) :
2023-05-30 20:43:29 -05:00
await self . send_bytes ( event , data , sid )
else :
await self . send_json ( event , data , sid )
def encode_bytes ( self , event , data ) :
if not isinstance ( event , int ) :
raise RuntimeError ( f " Binary event types must be integers, got { event } " )
packed = struct . pack ( " >I " , event )
message = bytearray ( packed )
message . extend ( data )
return message
2023-07-19 17:37:27 -04:00
async def send_image ( self , image_data , sid = None ) :
image_type = image_data [ 0 ]
image = image_data [ 1 ]
max_size = image_data [ 2 ]
if max_size is not None :
if hasattr ( Image , ' Resampling ' ) :
resampling = Image . Resampling . BILINEAR
else :
resampling = Image . ANTIALIAS
image = ImageOps . contain ( image , ( max_size , max_size ) , resampling )
type_num = 1
if image_type == " JPEG " :
type_num = 1
elif image_type == " PNG " :
type_num = 2
bytesIO = BytesIO ( )
header = struct . pack ( " >I " , type_num )
bytesIO . write ( header )
2023-11-28 11:01:05 -05:00
image . save ( bytesIO , format = image_type , quality = 95 , compress_level = 1 )
2023-07-19 17:37:27 -04:00
preview_bytes = bytesIO . getvalue ( )
await self . send_bytes ( BinaryEventTypes . PREVIEW_IMAGE , preview_bytes , sid = sid )
2023-05-30 20:43:29 -05:00
async def send_bytes ( self , event , data , sid = None ) :
message = self . encode_bytes ( event , data )
if sid is None :
2024-01-02 11:50:00 -05:00
sockets = list ( self . sockets . values ( ) )
for ws in sockets :
2023-06-15 11:01:06 -04:00
await send_socket_catch_exception ( ws . send_bytes , message )
2023-05-30 20:43:29 -05:00
elif sid in self . sockets :
2023-06-15 11:01:06 -04:00
await send_socket_catch_exception ( self . sockets [ sid ] . send_bytes , message )
2023-05-30 20:43:29 -05:00
async def send_json ( self , event , data , sid = None ) :
2023-02-25 20:57:40 +00:00
message = { " type " : event , " data " : data }
if sid is None :
2024-01-02 11:50:00 -05:00
sockets = list ( self . sockets . values ( ) )
for ws in sockets :
2023-06-15 11:01:06 -04:00
await send_socket_catch_exception ( ws . send_json , message )
2023-02-25 20:57:40 +00:00
elif sid in self . sockets :
2023-06-15 11:01:06 -04:00
await send_socket_catch_exception ( self . sockets [ sid ] . send_json , message )
2023-02-25 20:57:40 +00:00
def send_sync ( self , event , data , sid = None ) :
self . loop . call_soon_threadsafe (
self . messages . put_nowait , ( event , data , sid ) )
2023-02-25 18:36:29 -05:00
2023-02-25 20:57:40 +00:00
def queue_updated ( self ) :
self . send_sync ( " status " , { " status " : self . get_queue_info ( ) } )
async def publish_loop ( self ) :
while True :
msg = await self . messages . get ( )
await self . send ( * msg )
2023-03-12 15:44:16 -04:00
async def start ( self , address , port , verbose = True , call_on_start = None ) :
2023-09-08 21:11:53 -07:00
runner = web . AppRunner ( self . app , access_log = None )
2023-02-25 20:57:40 +00:00
await runner . setup ( )
2024-04-30 20:17:02 -04:00
ssl_ctx = None
scheme = " http "
if args . tls_keyfile and args . tls_certfile :
ssl_ctx = ssl . SSLContext ( protocol = ssl . PROTOCOL_TLS_SERVER , verify_mode = ssl . CERT_NONE )
ssl_ctx . load_cert_chain ( certfile = args . tls_certfile ,
keyfile = args . tls_keyfile )
scheme = " https "
site = web . TCPSite ( runner , address , port , ssl_context = ssl_ctx )
2023-02-25 20:57:40 +00:00
await site . start ( )
2023-02-25 18:36:29 -05:00
2024-08-15 08:21:11 -07:00
self . address = address
self . port = port
2023-02-25 22:49:22 -05:00
if verbose :
2024-03-11 13:54:56 -04:00
logging . info ( " Starting server \n " )
2024-04-30 20:17:02 -04:00
logging . info ( " To see the GUI go to: {} :// {} : {} " . format ( scheme , address , port ) )
2023-03-12 15:44:16 -04:00
if call_on_start is not None :
2024-04-30 20:17:02 -04:00
call_on_start ( scheme , address , port )
2023-03-12 15:44:16 -04:00
2023-08-28 13:52:22 +09:00
def add_on_prompt_handler ( self , handler ) :
self . on_prompt_handlers . append ( handler )
def trigger_on_prompt ( self , json_data ) :
for handler in self . on_prompt_handlers :
try :
json_data = handler ( json_data )
except Exception as e :
2024-03-11 13:54:56 -04:00
logging . warning ( f " [ERROR] An error occurred during the on_prompt_handler processing " )
2024-03-11 16:24:47 -04:00
logging . warning ( traceback . format_exc ( ) )
2023-08-28 13:52:22 +09:00
return json_data