Compare commits

..

8 Commits

Author SHA1 Message Date
karrix
02f499a4dc add: name 2024-10-23 19:07:42 +08:00
karrix
67d27bd4aa add: any output 2024-10-23 18:47:58 +08:00
karrix
8ee2f88e72 tweak 2024-10-22 17:51:04 +08:00
karrix
c17a029173 tweak: back to support all 2024-10-22 17:34:33 +08:00
karrix
c82f0c2bf0 tweak 2024-10-22 14:56:06 +08:00
karrix
fdbf24207f tweak 2024-10-22 14:04:37 +08:00
karrix
0a9d0d3e3e tweak 2024-10-22 13:57:19 +08:00
karrix
d41c4de352 init 2024-10-22 12:54:33 +08:00
17 changed files with 666 additions and 1577 deletions

3
.gitignore vendored
View File

@ -1,3 +1,2 @@
__pycache__ __pycache__
.DS_Store .DS_Store
file-hash-cache.json

View File

@ -0,0 +1,504 @@
from typing import Union, Optional, Dict, List
from pydantic import BaseModel, Field, field_validator
from fastapi import FastAPI, HTTPException, WebSocket, BackgroundTasks, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.logger import logger as fastapi_logger
import os
from enum import Enum
import json
import subprocess
import time
from contextlib import asynccontextmanager
import asyncio
import threading
import signal
import logging
from fastapi.logger import logger as fastapi_logger
import requests
from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send
from concurrent.futures import ThreadPoolExecutor
# executor = ThreadPoolExecutor(max_workers=5)
gunicorn_error_logger = logging.getLogger("gunicorn.error")
gunicorn_logger = logging.getLogger("gunicorn")
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
fastapi_logger.handlers = gunicorn_error_logger.handlers
if __name__ != "__main__":
fastapi_logger.setLevel(gunicorn_logger.level)
else:
fastapi_logger.setLevel(logging.DEBUG)
logger = logging.getLogger("uvicorn")
logger.setLevel(logging.INFO)
last_activity_time = time.time()
global_timeout = 60 * 4
machine_id_websocket_dict = {}
machine_id_status = {}
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
class FlyReplayMiddleware(BaseHTTPMiddleware):
"""
If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
to repeat the request again on the right instance.
This only works if the right instance is provided as a query_string parameter.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
query_string = scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
target_instance = query_params.get(
'fly_instance_id', [fly_instance_id])[0]
async def send_wrapper(message):
if target_instance != fly_instance_id:
if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
# fly.io only seems to look at the fly-replay header if websocket is accepted
message = {'type': 'websocket.accept'}
if 'headers' not in message:
message['headers'] = []
message['headers'].append(
[b'fly-replay', f'instance={target_instance}'.encode()])
await send(message)
await self.app(scope, receive, send_wrapper)
async def check_inactivity():
global last_activity_time
while True:
# logger.info("Checking inactivity...")
if time.time() - last_activity_time > global_timeout:
if len(machine_id_status) == 0:
# The application has been inactive for more than 60 seconds.
# Scale it down to zero here.
logger.info(
f"No activity for {global_timeout} seconds, exiting...")
# os._exit(0)
os.kill(os.getpid(), signal.SIGINT)
break
else:
pass
# logger.info(f"Timeout but still in progress")
await asyncio.sleep(1) # Check every second
@asynccontextmanager
async def lifespan(app: FastAPI):
thread = run_in_new_thread(check_inactivity())
yield
logger.info("Cancelling")
#
app = FastAPI(lifespan=lifespan)
app.add_middleware(FlyReplayMiddleware)
# MODAL_ORG = os.environ.get("MODAL_ORG")
@app.get("/")
def read_root():
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
return {"Hello": "World"}
# create a post route called /create takes in a json of example
# {
# name: "my first image",
# deps: {
# "comfyui": "d0165d819afe76bd4e6bdd710eb5f3e571b6a804",
# "git_custom_nodes": {
# "https://github.com/cubiq/ComfyUI_IPAdapter_plus": {
# "hash": "2ca0c6dd0b2ad64b1c480828638914a564331dcd",
# "disabled": true
# },
# "https://github.com/ltdrdata/ComfyUI-Manager.git": {
# "hash": "9c86f62b912f4625fe2b929c7fc61deb9d16f6d3",
# "disabled": false
# },
# },
# "file_custom_nodes": []
# }
# }
class GitCustomNodes(BaseModel):
hash: str
disabled: bool
class FileCustomNodes(BaseModel):
filename: str
disabled: bool
class Snapshot(BaseModel):
comfyui: str
git_custom_nodes: Dict[str, GitCustomNodes]
file_custom_nodes: List[FileCustomNodes]
class Model(BaseModel):
name: str
type: str
base: str
save_path: str
description: str
reference: str
filename: str
url: str
class GPUType(str, Enum):
T4 = "T4"
A10G = "A10G"
A100 = "A100"
L4 = "L4"
class Item(BaseModel):
machine_id: str
name: str
snapshot: Snapshot
models: List[Model]
callback_url: str
gpu: GPUType = Field(default=GPUType.T4)
@field_validator('gpu')
@classmethod
def check_gpu(cls, value):
if value not in GPUType.__members__:
raise ValueError(
f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
return GPUType(value)
@app.websocket("/ws/{machine_id}")
async def websocket_endpoint(websocket: WebSocket, machine_id: str):
await websocket.accept()
machine_id_websocket_dict[machine_id] = websocket
# Send existing logs
if machine_id in machine_logs_cache:
combined_logs = "\n".join(
log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
await websocket.send_text(json.dumps({"event": "LOGS", "data": {
"machine_id": machine_id,
"logs": combined_logs,
"timestamp": time.time()
}}))
try:
while True:
data = await websocket.receive_text()
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
# You can handle received messages here if needed
except WebSocketDisconnect:
if machine_id in machine_id_websocket_dict:
machine_id_websocket_dict.pop(machine_id)
# @app.get("/test")
# async def test():
# machine_id_status["123"] = True
# global last_activity_time
# last_activity_time = time.time()
# logger.info(f"Extended inactivity time to {global_timeout}")
# await asyncio.sleep(10)
# machine_id_status["123"] = False
# machine_id_status.pop("123")
# return {"Hello": "World"}
@app.post("/create")
async def create_machine(item: Item):
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
if item.machine_id in machine_id_status and machine_id_status[item.machine_id]:
return JSONResponse(status_code=400, content={"error": "Build already in progress."})
# Run the building logic in a separate thread
# future = executor.submit(build_logic, item)
task = asyncio.create_task(build_logic(item))
return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id})
class StopAppItem(BaseModel):
machine_id: str
def find_app_id(app_list, app_name):
for app in app_list:
if app['Name'] == app_name:
return app['App ID']
return None
@app.post("/stop-app")
async def stop_app(item: StopAppItem):
# cmd = f"modal app list | grep {item.machine_id} | awk -F '│' '{{print $2}}'"
cmd = f"modal app list --json"
env = os.environ.copy()
env["COLUMNS"] = "10000" # Set the width to a large value
find_id_process = await asyncio.subprocess.create_subprocess_shell(cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env)
await find_id_process.wait()
stdout, stderr = await find_id_process.communicate()
if stdout:
app_id = stdout.decode().strip()
app_list = json.loads(app_id)
app_id = find_app_id(app_list, item.machine_id)
logger.info(f"cp_process stdout: {app_id}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
cp_process = await asyncio.subprocess.create_subprocess_exec("modal", "app", "stop", app_id,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,)
await cp_process.wait()
logger.info(f"Stopping app {item.machine_id}")
stdout, stderr = await cp_process.communicate()
if stdout:
logger.info(f"cp_process stdout: {stdout.decode()}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
if cp_process.returncode == 0:
return JSONResponse(status_code=200, content={"status": "success"})
else:
return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()})
# Initialize the logs cache
machine_logs_cache = {}
async def build_logic(item: Item):
# Deploy to modal
folder_path = f"/app/builds/{item.machine_id}"
machine_id_status[item.machine_id] = True
# Ensure the os path is same as the current directory
# os.chdir(os.path.dirname(os.path.realpath(__file__)))
# print(
# f"builder - Current working directory: {os.getcwd()}"
# )
# Copy the app template
# os.system(f"cp -r template {folder_path}")
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/template", folder_path)
await cp_process.wait()
# Write the config file
config = {
"name": item.name,
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
"gpu": item.gpu,
"civitai_token": os.environ.get("CIVITAI_TOKEN", "")
}
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))
with open(f"{folder_path}/data/snapshot.json", "w") as f:
f.write(item.snapshot.json())
with open(f"{folder_path}/data/models.json", "w") as f:
models_json_list = [model.dict() for model in item.models]
models_json_string = json.dumps(models_json_list)
f.write(models_json_string)
# os.chdir(folder_path)
# process = subprocess.Popen(f"modal deploy {folder_path}/app.py", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
process = await asyncio.subprocess.create_subprocess_shell(
f"modal deploy app.py",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"}
)
url = None
if item.machine_id not in machine_logs_cache:
machine_logs_cache[item.machine_id] = []
machine_logs = machine_logs_cache[item.machine_id]
url_queue = asyncio.Queue()
async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
while True:
line = await stream.readline()
if line:
l = line.decode('utf-8').strip()
if l == "":
continue
if not isStderr:
logger.info(l)
machine_logs.append({
"logs": l,
"timestamp": time.time()
})
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
"machine_id": item.machine_id,
"logs": l,
"timestamp": time.time()
}}))
if "Created comfyui_api =>" in l or ((l.startswith("https://") or l.startswith("")) and l.endswith(".modal.run")):
if "Created comfyui_api =>" in l:
url = l.split("=>")[1].strip()
# making sure it is a url
elif "comfyui-api" in l:
# Some case it only prints the url on a blank line
if l.startswith(""):
url = l.split("")[1].strip()
else:
url = l
if url:
machine_logs.append({
"logs": f"App image built, url: {url}",
"timestamp": time.time()
})
await url_queue.put(url)
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
"machine_id": item.machine_id,
"logs": f"App image built, url: {url}",
"timestamp": time.time()
}}))
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": {
"status": "succuss",
}}))
else:
# is error
logger.error(l)
machine_logs.append({
"logs": l,
"timestamp": time.time()
})
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
"machine_id": item.machine_id,
"logs": l,
"timestamp": time.time()
}}))
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": {
"status": "failed",
}}))
else:
break
stdout_task = asyncio.create_task(
read_stream(process.stdout, False, url_queue))
stderr_task = asyncio.create_task(
read_stream(process.stderr, True, url_queue))
await asyncio.wait([stdout_task, stderr_task])
# Wait for the subprocess to finish
await process.wait()
if not url_queue.empty():
# The queue is not empty, you can get an item
url = await url_queue.get()
# Close the ws connection and also pop the item
if item.machine_id in machine_id_websocket_dict and machine_id_websocket_dict[item.machine_id] is not None:
await machine_id_websocket_dict[item.machine_id].close()
if item.machine_id in machine_id_websocket_dict:
machine_id_websocket_dict.pop(item.machine_id)
if item.machine_id in machine_id_status:
machine_id_status[item.machine_id] = False
# Check for errors
if process.returncode != 0:
logger.info("An error occurred.")
# Send a post request with the json body machine_id to the callback url
machine_logs.append({
"logs": "Unable to build the app image.",
"timestamp": time.time()
})
requests.post(item.callback_url, json={
"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
return
# return JSONResponse(status_code=400, content={"error": "Unable to build the app image."})
# app_suffix = "comfyui-app"
if url is None:
machine_logs.append({
"logs": "App image built, but url is None, unable to parse the url.",
"timestamp": time.time()
})
requests.post(item.callback_url, json={
"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
return
# return JSONResponse(status_code=400, content={"error": "App image built, but url is None, unable to parse the url."})
# example https://bennykok--my-app-comfyui-app.modal.run/
# my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run"
requests.post(item.callback_url, json={
"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
logger.info("done")
logger.info(url)
def start_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
def run_in_new_thread(coroutine):
new_loop = asyncio.new_event_loop()
t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
t.start()
asyncio.run_coroutine_threadsafe(coroutine, new_loop)
return t
if __name__ == "__main__":
import uvicorn
# , log_level="debug"
uvicorn.run("main:app", host="0.0.0.0", port=8080, lifespan="on")

View File

@ -1,448 +0,0 @@
import modal
from typing import Union, Optional, Dict, List
from pydantic import BaseModel, Field, field_validator
from fastapi import FastAPI, HTTPException, WebSocket, BackgroundTasks, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.logger import logger as fastapi_logger
import os
from enum import Enum
import json
import subprocess
import time
from contextlib import asynccontextmanager
import asyncio
import threading
import signal
import logging
from fastapi.logger import logger as fastapi_logger
import requests
from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send
# Modal应用实例
modal_app = modal.App(name="comfyui-deploy")
gunicorn_error_logger = logging.getLogger("gunicorn.error")
gunicorn_logger = logging.getLogger("gunicorn")
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
fastapi_logger.handlers = gunicorn_error_logger.handlers
if __name__ != "__main__":
fastapi_logger.setLevel(gunicorn_logger.level)
else:
fastapi_logger.setLevel(logging.DEBUG)
logger = logging.getLogger("uvicorn")
logger.setLevel(logging.INFO)
last_activity_time = time.time()
global_timeout = 60 * 4
machine_id_websocket_dict = {}
machine_id_status = {}
machine_logs_cache = {}
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
class FlyReplayMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
query_string = scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
async def send_wrapper(message):
if target_instance != fly_instance_id:
if message['type'] == 'websocket.close' and 'Invalid session' in message.get('reason', ''):
message = {'type': 'websocket.accept'}
if 'headers' not in message:
message['headers'] = []
message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
await send(message)
await self.app(scope, receive, send_wrapper)
async def check_inactivity():
global last_activity_time
while True:
if time.time() - last_activity_time > global_timeout:
if len(machine_id_status) == 0:
logger.info(f"No activity for {global_timeout} seconds, exiting...")
os.kill(os.getpid(), signal.SIGINT)
break
await asyncio.sleep(1)
@asynccontextmanager
async def lifespan(app: FastAPI):
thread = run_in_new_thread(check_inactivity())
yield
logger.info("Cancelling")
# FastAPI实例
fastapi_app = FastAPI(lifespan=lifespan)
fastapi_app.add_middleware(FlyReplayMiddleware)
class GitCustomNodes(BaseModel):
hash: str
disabled: bool
class FileCustomNodes(BaseModel):
filename: str
disabled: bool
class Snapshot(BaseModel):
comfyui: str
git_custom_nodes: Dict[str, GitCustomNodes]
file_custom_nodes: List[FileCustomNodes]
class Model(BaseModel):
name: str
type: str
base: str
save_path: str
description: str
reference: str
filename: str
url: str
class GPUType(str, Enum):
T4 = "T4"
A10G = "A10G"
A100 = "A100"
L4 = "L4"
class Item(BaseModel):
machine_id: str
name: str
snapshot: Snapshot
models: List[Model]
callback_url: str
gpu: GPUType = Field(default=GPUType.T4)
@field_validator('gpu')
@classmethod
def check_gpu(cls, value):
if value not in GPUType.__members__:
raise ValueError(f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
return GPUType(value)
class StopAppItem(BaseModel):
machine_id: str
@fastapi_app.get("/")
def read_root():
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
return {"Hello": "World"}
@fastapi_app.websocket("/ws/{machine_id}")
async def websocket_endpoint(websocket: WebSocket, machine_id: str):
await websocket.accept()
machine_id_websocket_dict[machine_id] = websocket
if machine_id in machine_logs_cache:
combined_logs = "\n".join(log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
await websocket.send_text(json.dumps({
"event": "LOGS",
"data": {
"machine_id": machine_id,
"logs": combined_logs,
"timestamp": time.time()
}
}))
try:
while True:
data = await websocket.receive_text()
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
except WebSocketDisconnect:
if machine_id in machine_id_websocket_dict:
del machine_id_websocket_dict[machine_id]
@fastapi_app.post("/create")
async def create_machine(item: Item):
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
if item.machine_id in machine_id_status and machine_id_status[item.machine_id]:
return JSONResponse(status_code=400, content={"error": "Build already in progress."})
task = asyncio.create_task(build_logic(item))
return JSONResponse(
status_code=200,
content={
"message": "Build Queued",
"build_machine_instance_id": fly_instance_id
}
)
def find_app_id(app_list, app_name):
for app in app_list:
if app['Name'] == app_name:
return app['App ID']
return None
@fastapi_app.post("/stop-app")
async def stop_app(item: StopAppItem):
cmd = f"modal app list --json"
env = os.environ.copy()
env["COLUMNS"] = "10000"
find_id_process = await asyncio.subprocess.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env
)
stdout, stderr = await find_id_process.communicate()
if stdout:
app_list = json.loads(stdout.decode().strip())
app_id = find_app_id(app_list, item.machine_id)
logger.info(f"cp_process stdout: {app_id}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
cp_process = await asyncio.subprocess.create_subprocess_exec(
"modal", "app", "stop", app_id,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await cp_process.wait()
stdout, stderr = await cp_process.communicate()
if stdout:
logger.info(f"cp_process stdout: {stdout.decode()}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
if cp_process.returncode == 0:
return JSONResponse(status_code=200, content={"status": "success"})
else:
return JSONResponse(
status_code=500,
content={"status": "error", "error": stderr.decode()}
)
async def build_logic(item: Item):
folder_path = f"/app/builds/{item.machine_id}"
machine_id_status[item.machine_id] = True
cp_process = await asyncio.subprocess.create_subprocess_exec(
"cp", "-r", "/app/src/template", folder_path
)
await cp_process.wait()
config = {
"name": item.name,
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
"gpu": item.gpu,
"civitai_token": os.environ.get("CIVITAI_TOKEN", "833b4ded5c7757a06a803763500bab58")
}
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))
with open(f"{folder_path}/data/snapshot.json", "w") as f:
f.write(item.snapshot.json())
with open(f"{folder_path}/data/models.json", "w") as f:
models_json_list = [model.dict() for model in item.models]
f.write(json.dumps(models_json_list))
process = await asyncio.subprocess.create_subprocess_shell(
f"modal deploy app.py",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"}
)
if item.machine_id not in machine_logs_cache:
machine_logs_cache[item.machine_id] = []
machine_logs = machine_logs_cache[item.machine_id]
url_queue = asyncio.Queue()
async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
while True:
line = await stream.readline()
if not line:
break
l = line.decode('utf-8').strip()
if not l:
continue
if not isStderr:
logger.info(l)
machine_logs.append({
"logs": l,
"timestamp": time.time()
})
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(
json.dumps({
"event": "LOGS",
"data": {
"machine_id": item.machine_id,
"logs": l,
"timestamp": time.time()
}
})
)
if "Created comfyui_api =>" in l or ((l.startswith("https://") or l.startswith("")) and l.endswith(".modal.run")):
if "Created comfyui_api =>" in l:
url = l.split("=>")[1].strip()
elif "comfyui-api" in l:
url = l.split("")[1].strip() if l.startswith("") else l
if url:
machine_logs.append({
"logs": f"App image built, url: {url}",
"timestamp": time.time()
})
await url_queue.put(url)
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(
json.dumps({
"event": "LOGS",
"data": {
"machine_id": item.machine_id,
"logs": f"App image built, url: {url}",
"timestamp": time.time()
}
})
)
await machine_id_websocket_dict[item.machine_id].send_text(
json.dumps({
"event": "FINISHED",
"data": {
"status": "success",
}
})
)
else:
logger.error(l)
machine_logs.append({
"logs": l,
"timestamp": time.time()
})
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(
json.dumps({
"event": "LOGS",
"data": {
"machine_id": item.machine_id,
"logs": l,
"timestamp": time.time()
}
})
)
await machine_id_websocket_dict[item.machine_id].send_text(
json.dumps({
"event": "FINISHED",
"data": {
"status": "failed",
}
})
)
stdout_task = asyncio.create_task(read_stream(process.stdout, False, url_queue))
stderr_task = asyncio.create_task(read_stream(process.stderr, True, url_queue))
await asyncio.wait([stdout_task, stderr_task])
await process.wait()
url = await url_queue.get() if not url_queue.empty() else None
if item.machine_id in machine_id_websocket_dict and machine_id_websocket_dict[item.machine_id] is not None:
await machine_id_websocket_dict[item.machine_id].close()
if item.machine_id in machine_id_websocket_dict:
del machine_id_websocket_dict[item.machine_id]
if item.machine_id in machine_id_status:
machine_id_status[item.machine_id] = False
if process.returncode != 0:
logger.info("An error occurred.")
machine_logs.append({
"logs": "Unable to build the app image.",
"timestamp": time.time()
})
requests.post(
item.callback_url,
json={
"machine_id": item.machine_id,
"build_log": json.dumps(machine_logs)
}
)
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
return
if url is None:
machine_logs.append({
"logs": "App image built, but url is None, unable to parse the url.",
"timestamp": time.time()
})
requests.post(
item.callback_url,
json={
"machine_id": item.machine_id,
"build_log": json.dumps(machine_logs)
}
)
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
return
requests.post(
item.callback_url,
json={
"machine_id": item.machine_id,
"endpoint": url,
"build_log": json.dumps(machine_logs)
}
)
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
logger.info("done")
logger.info(url)
def start_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
def run_in_new_thread(coroutine):
new_loop = asyncio.new_event_loop()
t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
t.start()
asyncio.run_coroutine_threadsafe(coroutine, new_loop)
return t
# Modal endpoint
@modal_app.function()
@modal.asgi_app()
def app():
return fastapi_app
if __name__ == "__main__":
import uvicorn
uvicorn.run(fastapi_app, host="0.0.0.0", port=8080, lifespan="on")

View File

@ -307,5 +307,4 @@ def comfyui_app():
}, },
)() )()
proxy_app = make_simple_proxy_app(ProxyContext(config)) # Assign to variable return make_simple_proxy_app(ProxyContext(config))
return proxy_app # Return the variable

View File

@ -0,0 +1,46 @@
import json
class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
any = AnyType("*")
class ComfyDeployStdOutputAny:
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {
"name": ("STRING", {"default": "ComfyUI"}),
"source": (any, {}), # Use "*" to accept any input type
},
}
CATEGORY = "output"
RETURN_TYPES = ()
FUNCTION = "run"
OUTPUT_NODE = True
def run(self, name, source=None):
value = "None"
if source is not None:
try:
value = json.dumps(source)
except Exception:
try:
value = str(source)
except Exception:
value = "source exists, but could not be serialized."
return {"ui": {name: (value,)}}
NODE_CLASS_MAPPINGS = {"ComfyDeployStdOutputAny": ComfyDeployStdOutputAny}
NODE_DISPLAY_NAME_MAPPINGS = {
"ComfyDeployStdOutputAny": "Standard Any Output (ComfyDeploy)"
}

View File

@ -6,7 +6,7 @@ from PIL.PngImagePlugin import PngInfo
import folder_paths import folder_paths
class ComfyDeployOutputImage: class ComfyDeployStdOutputImage:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
@ -86,7 +86,7 @@ class ComfyDeployOutputImage:
return {"ui": {"images": results}} return {"ui": {"images": results}}
NODE_CLASS_MAPPINGS = {"ComfyDeployOutputImage": ComfyDeployOutputImage} NODE_CLASS_MAPPINGS = {"ComfyDeployStdOutputImage": ComfyDeployStdOutputImage}
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"ComfyDeployOutputImage": "Image Output (ComfyDeploy)" "ComfyDeployStdOutputImage": "Standard Image Output (ComfyDeploy)"
} }

View File

@ -1,57 +0,0 @@
import os
import io
import torchaudio
from folder_paths import get_annotated_filepath
class ComfyUIDeployExternalAudio:
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "load_audio"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input_id": (
"STRING",
{"multiline": False, "default": "input_audio"},
),
"audio_file": ("STRING", {"default": ""}),
},
"optional": {
"default_value": ("AUDIO",),
"display_name": (
"STRING",
{"multiline": False, "default": ""},
),
"description": (
"STRING",
{"multiline": False, "default": ""},
),
}
}
@classmethod
def VALIDATE_INPUTS(s, audio_file, **kwargs):
return True
def load_audio(self, input_id, audio_file, default_value=None, display_name=None, description=None):
if audio_file and audio_file != "":
if audio_file.startswith(('http://', 'https://')):
# Handle URL input
import requests
response = requests.get(audio_file)
audio_data = io.BytesIO(response.content)
waveform, sample_rate = torchaudio.load(audio_data)
else:
# Handle local file
audio_path = get_annotated_filepath(audio_file)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio,)
else:
return (default_value,)
NODE_CLASS_MAPPINGS = {"ComfyUIDeployExternalAudio": ComfyUIDeployExternalAudio}
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyUIDeployExternalAudio": "External Audio (ComfyUI Deploy)"}

View File

@ -21,9 +21,8 @@ class ComfyUIDeployExternalImage:
), ),
"description": ( "description": (
"STRING", "STRING",
{"multiline": False, "default": ""}, {"multiline": True, "default": ""},
), ),
"default_value_url": ("STRING", {"image_preview": True, "default": ""}),
} }
} }
@ -34,44 +33,32 @@ class ComfyUIDeployExternalImage:
CATEGORY = "image" CATEGORY = "image"
def run(self, input_id, default_value=None, display_name=None, description=None, default_value_url=None): def run(self, input_id, default_value=None, display_name=None, description=None):
image = default_value image = default_value
try:
# Try both input_id and default_value_url if input_id.startswith('http'):
urls_to_try = [url for url in [input_id, default_value_url] if url] import requests
from io import BytesIO
print(default_value_url) print("Fetching image from url: ", input_id)
response = requests.get(input_id)
for url in urls_to_try: image = Image.open(BytesIO(response.content))
try: elif input_id.startswith('data:image/png;base64,') or input_id.startswith('data:image/jpeg;base64,') or input_id.startswith('data:image/jpg;base64,'):
if url.startswith('http'): import base64
import requests from io import BytesIO
from io import BytesIO print("Decoding base64 image")
print(f"Fetching image from url: {url}") base64_image = input_id[input_id.find(",")+1:]
response = requests.get(url) decoded_image = base64.b64decode(base64_image)
image = Image.open(BytesIO(response.content)) image = Image.open(BytesIO(decoded_image))
break else:
elif url.startswith(('data:image/png;base64,', 'data:image/jpeg;base64,', 'data:image/jpg;base64,')): raise ValueError("Invalid image url provided.")
import base64
from io import BytesIO image = ImageOps.exif_transpose(image)
print("Decoding base64 image") image = image.convert("RGB")
base64_image = url[url.find(",")+1:] image = np.array(image).astype(np.float32) / 255.0
decoded_image = base64.b64decode(base64_image) image = torch.from_numpy(image)[None,]
image = Image.open(BytesIO(decoded_image)) return [image]
break except:
except: return [image]
continue
if image is not None:
try:
image = ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
except:
pass
return [image]
NODE_CLASS_MAPPINGS = {"ComfyUIDeployExternalImage": ComfyUIDeployExternalImage} NODE_CLASS_MAPPINGS = {"ComfyUIDeployExternalImage": ComfyUIDeployExternalImage}

View File

@ -64,42 +64,32 @@ class ComfyUIDeployExternalLora:
import os import os
import uuid import uuid
if lora_url: if lora_url and lora_url.startswith("http"):
if lora_url.startswith("http"): if lora_save_name:
if lora_save_name: existing_loras = folder_paths.get_filename_list("loras")
existing_loras = folder_paths.get_filename_list("loras") # Check if lora_save_name exists in the list
# Check if lora_save_name exists in the list if lora_save_name in existing_loras:
if lora_save_name in existing_loras: print(f"using lora: {lora_save_name}")
print(f"using lora: {lora_save_name}") return (lora_save_name,)
return (lora_save_name,)
else:
lora_save_name = str(uuid.uuid4()) + ".safetensors"
print(lora_save_name)
print(folder_paths.folder_names_and_paths["loras"][0][0])
destination_path = os.path.join(
folder_paths.folder_names_and_paths["loras"][0][0], lora_save_name
)
print(destination_path)
print(
"Downloading external lora - "
+ lora_url
+ " to "
+ destination_path
)
response = requests.get(
lora_url,
headers={"User-Agent": "Mozilla/5.0"},
allow_redirects=True,
)
with open(destination_path, "wb") as out_file:
out_file.write(response.content)
print(f"Ext Lora loading: {lora_url} to {lora_save_name}")
return (lora_save_name,)
else: else:
print(f"Ext Lora loading: {lora_url}") lora_save_name = str(uuid.uuid4()) + ".safetensors"
return (lora_url,) print(lora_save_name)
print(folder_paths.folder_names_and_paths["loras"][0][0])
destination_path = os.path.join(
folder_paths.folder_names_and_paths["loras"][0][0], lora_save_name
)
print(destination_path)
print("Downloading external lora - " + lora_url + " to " + destination_path)
response = requests.get(
lora_url,
headers={"User-Agent": "Mozilla/5.0"},
allow_redirects=True,
)
with open(destination_path, "wb") as out_file:
out_file.write(response.content)
return (lora_save_name,)
else: else:
print(f"Ext Lora loading: {default_lora_name}") print(f"using lora: {default_lora_name}")
return (default_lora_name,) return (default_lora_name,)

View File

@ -1,53 +0,0 @@
import re
class StringFunction:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"action": (["append", "replace"], {}),
"tidy_tags": (["yes", "no"], {}),
},
"optional": {
"text_a": ("STRING", {"multiline": True, "dynamicPrompts": False}),
"text_b": ("STRING", {"multiline": True, "dynamicPrompts": False}),
"text_c": ("STRING", {"multiline": True, "dynamicPrompts": False}),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "exec"
CATEGORY = "utils"
OUTPUT_NODE = True
def exec(self, action, tidy_tags, text_a="", text_b="", text_c=""):
tidy_tags = tidy_tags == "yes"
out = ""
if action == "append":
out = (", " if tidy_tags else "").join(
filter(None, [text_a, text_b, text_c])
)
else:
if text_c is None:
text_c = ""
if text_b.startswith("/") and text_b.endswith("/"):
regex = text_b[1:-1]
out = re.sub(regex, text_c, text_a)
else:
out = text_a.replace(text_b, text_c)
if tidy_tags:
out = re.sub(r"\s{2,}", " ", out)
out = out.replace(" ,", ",")
out = re.sub(r",{2,}", ",", out)
out = out.strip()
return {"ui": {"text": (out,)}, "result": (out,)}
NODE_CLASS_MAPPINGS = {
"ComfyUIDeployStringCombine": StringFunction,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ComfyUIDeployStringCombine": "String Combine (ComfyUI Deploy)",
}

View File

@ -26,10 +26,7 @@ import copy
import struct import struct
from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError
import atexit import atexit
from model_management import get_torch_device
import torch
import psutil
from collections import OrderedDict
# Global session # Global session
client_session = None client_session = None
@ -386,9 +383,6 @@ def apply_inputs_to_workflow(workflow_api: Any, inputs: Any, sid: str = None):
if value["class_type"] == "ComfyUIDeployExternalFaceModel": if value["class_type"] == "ComfyUIDeployExternalFaceModel":
value["inputs"]["face_model_url"] = new_value value["inputs"]["face_model_url"] = new_value
if value["class_type"] == "ComfyUIDeployExternalAudio":
value["inputs"]["audio_file"] = new_value
def send_prompt(sid: str, inputs: StreamingPrompt): def send_prompt(sid: str, inputs: StreamingPrompt):
# workflow_api = inputs.workflow_api # workflow_api = inputs.workflow_api
@ -1126,138 +1120,9 @@ async def proxy_to_comfydeploy(request):
prompt_server = server.PromptServer.instance prompt_server = server.PromptServer.instance
NODE_EXECUTION_TIMES = {} # New dictionary to store node execution times
CURRENT_START_EXECUTION_DATA = None
def get_peak_memory():
device = get_torch_device()
if device.type == 'cuda':
return torch.cuda.max_memory_allocated(device)
elif device.type == 'mps':
# Return system memory usage for MPS devices
return psutil.Process().memory_info().rss
return 0
def reset_peak_memory_record():
device = get_torch_device()
if device.type == 'cuda':
torch.cuda.reset_max_memory_allocated(device)
# MPS doesn't need reset as we're not tracking its memory
def handle_execute(class_type, last_node_id, prompt_id, server, unique_id):
if not CURRENT_START_EXECUTION_DATA:
return
start_time = CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"].get(unique_id)
start_vram = CURRENT_START_EXECUTION_DATA["nodes_start_vram"].get(unique_id)
if start_time:
end_time = time.perf_counter()
execution_time = end_time - start_time
end_vram = get_peak_memory()
vram_used = end_vram - start_vram
global NODE_EXECUTION_TIMES
# print(f"end_vram - start_vram: {end_vram} - {start_vram} = {vram_used}")
NODE_EXECUTION_TIMES[unique_id] = {
"time": execution_time,
"class_type": class_type,
"vram_used": vram_used
}
# print(f"#{unique_id} [{class_type}]: {execution_time:.2f}s - vram {vram_used}b")
try:
origin_execute = execution.execute
def swizzle_execute(
server,
dynprompt,
caches,
current_item,
extra_data,
executed,
prompt_id,
execution_list,
pending_subgraph_results,
):
unique_id = current_item
class_type = dynprompt.get_node(unique_id)["class_type"]
last_node_id = server.last_node_id
result = origin_execute(
server,
dynprompt,
caches,
current_item,
extra_data,
executed,
prompt_id,
execution_list,
pending_subgraph_results,
)
handle_execute(class_type, last_node_id, prompt_id, server, unique_id)
return result
execution.execute = swizzle_execute
except Exception as e:
pass
def format_table(headers, data):
# Calculate column widths
widths = [len(h) for h in headers]
for row in data:
for i, cell in enumerate(row):
widths[i] = max(widths[i], len(str(cell)))
# Create separator line
separator = '+' + '+'.join('-' * (w + 2) for w in widths) + '+'
# Format header
result = [separator]
header_row = '|' + '|'.join(f' {h:<{w}} ' for w, h in zip(widths, headers)) + '|'
result.append(header_row)
result.append(separator)
# Format data rows
for row in data:
data_row = '|' + '|'.join(f' {str(cell):<{w}} ' for w, cell in zip(widths, row)) + '|'
result.append(data_row)
result.append(separator)
return '\n'.join(result)
origin_func = server.PromptServer.send_sync
def swizzle_send_sync(self, event, data, sid=None):
# print(f"swizzle_send_sync, event: {event}, data: {data}")
global CURRENT_START_EXECUTION_DATA
if event == "execution_start":
global NODE_EXECUTION_TIMES
NODE_EXECUTION_TIMES = {} # Reset execution times at start
CURRENT_START_EXECUTION_DATA = dict(
start_perf_time=time.perf_counter(),
nodes_start_perf_time={},
nodes_start_vram={},
)
origin_func(self, event=event, data=data, sid=sid)
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
if data.get("node") is not None:
node_id = data.get("node")
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
time.perf_counter()
)
reset_peak_memory_record()
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
get_peak_memory()
)
server.PromptServer.send_sync = swizzle_send_sync
send_json = prompt_server.send_json send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None): async def send_json_override(self, event, data, sid=None):
# logger.info("INTERNAL:", event, data, sid) # logger.info("INTERNAL:", event, data, sid)
prompt_id = data.get("prompt_id") prompt_id = data.get("prompt_id")
@ -1280,62 +1145,10 @@ async def send_json_override(self, event, data, sid=None):
asyncio.create_task(update_run_ws_event(prompt_id, event, data)) asyncio.create_task(update_run_ws_event(prompt_id, event, data))
if event == "execution_start": if event == "execution_start":
await update_run(prompt_id, Status.RUNNING)
if prompt_id in prompt_metadata: if prompt_id in prompt_metadata:
prompt_metadata[prompt_id].start_time = time.perf_counter() prompt_metadata[prompt_id].start_time = time.perf_counter()
logger.info("Executing prompt: " + prompt_id)
asyncio.create_task(update_run(prompt_id, Status.RUNNING))
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
if data.get("node") is None:
start_perf_time = CURRENT_START_EXECUTION_DATA.get("start_perf_time")
new_data = data.copy()
if start_perf_time is not None:
execution_time = time.perf_counter() - start_perf_time
new_data["execution_time"] = int(execution_time * 1000)
# Replace the print statements with tabulate
headers = ["Node ID", "Type", "Time (s)", "VRAM (GB)"]
table_data = []
node_execution_array = [] # New array to store execution data
for node_id, node_data in NODE_EXECUTION_TIMES.items():
vram_gb = node_data['vram_used'] / (1024**3) # Convert bytes to GB
table_data.append([
f"#{node_id}",
node_data['class_type'],
f"{node_data['time']:.2f}",
f"{vram_gb:.2f}"
])
# Add to our new array format
node_execution_array.append({
"id": node_id,
**node_data,
})
# Add total execution time as the last row
table_data.append([
"TOTAL",
"-",
f"{execution_time:.2f}",
"-"
])
prompt_id = data.get("prompt_id")
asyncio.create_task(update_run_with_output(
prompt_id,
node_execution_array, # Send the array instead of the OrderedDict
))
print(node_execution_array)
# print("\n=== Node Execution Times ===")
logger.info("Printing Node Execution Times")
logger.info(format_table(headers, table_data))
# print("========================\n")
# the last executing event is none, then the workflow is finished # the last executing event is none, then the workflow is finished
if event == "executing" and data.get("node") is None: if event == "executing" and data.get("node") is None:
@ -1347,11 +1160,11 @@ async def send_json_override(self, event, data, sid=None):
if prompt_metadata[prompt_id].start_time is not None: if prompt_metadata[prompt_id].start_time is not None:
elapsed_time = current_time - prompt_metadata[prompt_id].start_time elapsed_time = current_time - prompt_metadata[prompt_id].start_time
logger.info(f"Elapsed time: {elapsed_time} seconds") logger.info(f"Elapsed time: {elapsed_time} seconds")
asyncio.create_task(send( await send(
"elapsed_time", "elapsed_time",
{"prompt_id": prompt_id, "elapsed_time": elapsed_time}, {"prompt_id": prompt_id, "elapsed_time": elapsed_time},
sid=sid, sid=sid,
)) )
if event == "executing" and data.get("node") is not None: if event == "executing" and data.get("node") is not None:
node = data.get("node") node = data.get("node")
@ -1375,7 +1188,7 @@ async def send_json_override(self, event, data, sid=None):
prompt_metadata[prompt_id].last_updated_node = node prompt_metadata[prompt_id].last_updated_node = node
class_type = prompt_metadata[prompt_id].workflow_api[node]["class_type"] class_type = prompt_metadata[prompt_id].workflow_api[node]["class_type"]
logger.info(f"At: {round(calculated_progress * 100)}% - {class_type}") logger.info(f"At: {round(calculated_progress * 100)}% - {class_type}")
asyncio.create_task(send( await send(
"live_status", "live_status",
{ {
"prompt_id": prompt_id, "prompt_id": prompt_id,
@ -1383,10 +1196,10 @@ async def send_json_override(self, event, data, sid=None):
"progress": calculated_progress, "progress": calculated_progress,
}, },
sid=sid, sid=sid,
)) )
asyncio.create_task(update_run_live_status( await update_run_live_status(
prompt_id, "Executing " + class_type, calculated_progress prompt_id, "Executing " + class_type, calculated_progress
)) )
if event == "execution_cached" and data.get("nodes") is not None: if event == "execution_cached" and data.get("nodes") is not None:
if prompt_id in prompt_metadata: if prompt_id in prompt_metadata:
@ -1414,8 +1227,7 @@ async def send_json_override(self, event, data, sid=None):
"node_class": class_type, "node_class": class_type,
} }
if class_type == "PreviewImage": if class_type == "PreviewImage":
pass logger.info("Skipping preview image")
# logger.info("Skipping preview image")
else: else:
await update_run_with_output( await update_run_with_output(
prompt_id, prompt_id,
@ -1427,10 +1239,9 @@ async def send_json_override(self, event, data, sid=None):
comfy_message_queues[prompt_id].put_nowait( comfy_message_queues[prompt_id].put_nowait(
{"event": "output_ready", "data": data} {"event": "output_ready", "data": data}
) )
# logger.info(f"Executed {class_type} {data}") logger.info(f"Executed {class_type} {data}")
else: else:
pass logger.info(f"Executed {data}")
# logger.info(f"Executed {data}")
# Global variable to keep track of the last read line number # Global variable to keep track of the last read line number
@ -1902,23 +1713,17 @@ async def upload_in_background(
# await handle_upload(prompt_id, data, 'files', "content_type", "image/png") # await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
# await handle_upload(prompt_id, data, 'gifs', "format", "image/gif") # await handle_upload(prompt_id, data, 'gifs', "format", "image/gif")
# await handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream") # await handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream")
upload_tasks = [
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint handle_upload(prompt_id, data, "images", "content_type", "image/png"),
handle_upload(prompt_id, data, "files", "content_type", "image/png"),
if file_upload_endpoint is not None and file_upload_endpoint != "": handle_upload(prompt_id, data, "gifs", "format", "image/gif"),
upload_tasks = [ handle_upload(
handle_upload(prompt_id, data, "images", "content_type", "image/png"), prompt_id, data, "mesh", "format", "application/octet-stream"
handle_upload(prompt_id, data, "files", "content_type", "image/png"), ),
handle_upload(prompt_id, data, "gifs", "format", "image/gif"), ]
handle_upload(
prompt_id, data, "mesh", "format", "application/octet-stream" await asyncio.gather(*upload_tasks)
),
]
await asyncio.gather(*upload_tasks)
else:
print("No file upload endpoint, skipping file upload")
status_endpoint = prompt_metadata[prompt_id].status_endpoint status_endpoint = prompt_metadata[prompt_id].status_endpoint
token = prompt_metadata[prompt_id].token token = prompt_metadata[prompt_id].token
gpu_event_id = prompt_metadata[prompt_id].gpu_event_id or None gpu_event_id = prompt_metadata[prompt_id].gpu_event_id or None

View File

@ -2,7 +2,7 @@
name = "comfyui-deploy" name = "comfyui-deploy"
description = "Open source comfyui deployment platform, a vercel for generative workflow infra." description = "Open source comfyui deployment platform, a vercel for generative workflow infra."
version = "1.1.0" version = "1.1.0"
license = { file = "LICENSE" } license = "LICENSE"
dependencies = ["aiofiles", "pydantic", "opencv-python", "imageio-ffmpeg"] dependencies = ["aiofiles", "pydantic", "opencv-python", "imageio-ffmpeg"]
[project.urls] [project.urls]

View File

@ -3,5 +3,4 @@ pydantic
opencv-python opencv-python
imageio-ffmpeg imageio-ffmpeg
brotli brotli
tabulate
# logfire # logfire

4
web-plugin/api.js Normal file
View File

@ -0,0 +1,4 @@
/** @typedef {import('../../../web/scripts/api.js').api} API*/
import { api as _api } from '../../scripts/api.js';
/** @type {API} */
export const api = _api;

4
web-plugin/app.js Normal file
View File

@ -0,0 +1,4 @@
/** @typedef {import('../../../web/scripts/app.js').ComfyApp} ComfyApp*/
import { app as _app } from '../../scripts/app.js';
/** @type {ComfyApp} */
export const app = _app;

View File

@ -1,44 +1,8 @@
import { app } from "../../scripts/app.js"; import { app } from "./app.js";
import { api } from "../../scripts/api.js"; import { api } from "./api.js";
// import { LGraphNode } from "../../scripts/widgets.js"; import { ComfyWidgets, LGraphNode } from "./widgets.js";
LGraphNode = LiteGraph.LGraphNode;
import { ComfyDialog, $el } from "../../scripts/ui.js";
import { generateDependencyGraph } from "https://esm.sh/comfyui-json@0.1.25"; import { generateDependencyGraph } from "https://esm.sh/comfyui-json@0.1.25";
import { ComfyDeploy } from "https://esm.sh/comfydeploy@2.0.0-beta.69"; import { ComfyDeploy } from "https://esm.sh/comfydeploy@0.0.19-beta.30";
const styles = `
.comfydeploy-menu-item {
background: linear-gradient(to right, rgba(74, 144, 226, 0.9), rgba(103, 178, 111, 0.9)) !important;
color: white !important;
position: relative !important;
padding-left: 20px !important;
}
.comfydeploy-menu-item:hover {
filter: brightness(1.1) !important;
cursor: pointer !important;
}
.comfydeploy-menu-item::before {
content: '';
position: absolute;
left: 4px;
top: 50%;
transform: translateY(-50%);
width: 12px;
height: 12px;
background-image: url('https://www.comfydeploy.com/icon.svg');
background-size: contain;
background-repeat: no-repeat;
background-position: center;
}
`;
// Add stylesheet to document
const styleSheet = document.createElement("style");
styleSheet.textContent = styles;
document.head.appendChild(styleSheet);
const loadingIcon = `<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 24 24"><g fill="none" stroke="#888888" stroke-linecap="round" stroke-width="2"><path stroke-dasharray="60" stroke-dashoffset="60" stroke-opacity=".3" d="M12 3C16.9706 3 21 7.02944 21 12C21 16.9706 16.9706 21 12 21C7.02944 21 3 16.9706 3 12C3 7.02944 7.02944 3 12 3Z"><animate fill="freeze" attributeName="stroke-dashoffset" dur="1.3s" values="60;0"/></path><path stroke-dasharray="15" stroke-dashoffset="15" d="M12 3C16.9706 3 21 7.02944 21 12"><animate fill="freeze" attributeName="stroke-dashoffset" dur="0.3s" values="15;0"/><animateTransform attributeName="transform" dur="1.5s" repeatCount="indefinite" type="rotate" values="0 12 12;360 12 12"/></path></g></svg>`; const loadingIcon = `<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 24 24"><g fill="none" stroke="#888888" stroke-linecap="round" stroke-width="2"><path stroke-dasharray="60" stroke-dashoffset="60" stroke-opacity=".3" d="M12 3C16.9706 3 21 7.02944 21 12C21 16.9706 16.9706 21 12 21C7.02944 21 3 16.9706 3 12C3 7.02944 7.02944 3 12 3Z"><animate fill="freeze" attributeName="stroke-dashoffset" dur="1.3s" values="60;0"/></path><path stroke-dasharray="15" stroke-dashoffset="15" d="M12 3C16.9706 3 21 7.02944 21 12"><animate fill="freeze" attributeName="stroke-dashoffset" dur="0.3s" values="15;0"/><animateTransform attributeName="transform" dur="1.5s" repeatCount="indefinite" type="rotate" values="0 12 12;360 12 12"/></path></g></svg>`;
@ -50,14 +14,6 @@ function sendEventToCD(event, data) {
window.parent.postMessage(JSON.stringify(message), "*"); window.parent.postMessage(JSON.stringify(message), "*");
} }
function sendDirectEventToCD(event, data) {
const message = {
type: event,
data: data,
};
window.parent.postMessage(message, "*");
}
function dispatchAPIEventData(data) { function dispatchAPIEventData(data) {
const msg = JSON.parse(data); const msg = JSON.parse(data);
@ -166,146 +122,6 @@ function setSelectedWorkflowInfo(info) {
context.selectedWorkflowInfo = info; context.selectedWorkflowInfo = info;
} }
const VALID_TYPES = [
"STRING",
"combo",
"number",
"toggle",
"BOOLEAN",
"text",
"string",
];
function hideWidget(node, widget, suffix = "") {
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
widget.computeSize = () => [0, -4];
widget.type = CONVERTED_TYPE + suffix;
widget.serializeValue = () => {
if (!node.inputs) {
return void 0;
}
let node_input = node.inputs.find((i) => i.widget?.name === widget.name);
if (!node_input || !node_input.link) {
return void 0;
}
return widget.origSerializeValue
? widget.origSerializeValue()
: widget.value;
};
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
hideWidget(node, w, ":" + widget.name);
}
}
}
function getWidgetType(config) {
let type = config[0];
if (type instanceof Array) {
type = "COMBO";
}
return { type };
}
const GET_CONFIG = Symbol();
function convertToInput(node, widget, config) {
console.log(node);
if (node.type == "LoadImage") {
var inputNode = LiteGraph.createNode("ComfyUIDeployExternalImage");
console.log(widget);
const currentOutputsLinks = node.outputs[0].links;
// const index = node.inputs.findIndex((x) => x.name == widget.name);
// console.log(node.widgets_values, index);
// inputNode.configure({
// widgets_values: ["input_text", widget.value],
// });
inputNode.pos = node.pos;
inputNode.id = ++app.graph.last_node_id;
// inputNode.pos[0] += node.size[0] + 40;
node.pos[0] -= inputNode.size[0] + 20;
console.log(inputNode);
console.log(app.graph);
app.graph.add(inputNode);
const links = app.graph.links;
console.log(currentOutputsLinks);
for (let i = 0; i < currentOutputsLinks.length; i++) {
const link = currentOutputsLinks[i];
const llink = links[link];
console.log(links[link]);
setTimeout(
() => inputNode.connect(0, llink.target_id, llink.target_slot),
100,
);
}
node.connect(0, inputNode, 0);
return null;
}
hideWidget(node, widget);
const { type } = getWidgetType(config);
const sz = node.size;
const inputIsOptional = !!widget.options?.inputIsOptional;
const input = node.addInput(widget.name, type, {
widget: { name: widget.name, [GET_CONFIG]: () => config },
...(inputIsOptional ? { shape: LiteGraph.SlotShape.HollowCircle } : {}),
});
for (const widget2 of node.widgets) {
widget2.last_y += LiteGraph.NODE_SLOT_HEIGHT;
}
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
if (type == "STRING") {
var inputNode = LiteGraph.createNode("ComfyUIDeployExternalText");
console.log(widget);
const index = node.inputs.findIndex((x) => x.name == widget.name);
console.log(node.widgets_values, index);
inputNode.configure({
widgets_values: ["input_text", widget.value],
});
inputNode.id = ++app.graph.last_node_id;
inputNode.pos = node.pos;
inputNode.pos[0] -= node.size[0] + 40;
console.log(inputNode);
console.log(app.graph);
app.graph.add(inputNode);
inputNode.connect(0, node, index);
}
return input;
}
const CONVERTED_TYPE = "converted-widget";
function getConfig(widgetName) {
const { nodeData } = this.constructor;
return (
nodeData?.input?.required?.[widgetName] ??
nodeData?.input?.optional?.[widgetName]
);
}
function isConvertibleWidget(widget, config) {
return (
(VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) &&
!widget.options?.forceInput
);
}
var __defProp = Object.defineProperty;
var __name = (target, value) =>
__defProp(target, "name", { value, configurable: true });
/** @typedef {import('../../../web/types/comfy.js').ComfyExtension} ComfyExtension*/ /** @typedef {import('../../../web/types/comfy.js').ComfyExtension} ComfyExtension*/
/** @type {ComfyExtension} */ /** @type {ComfyExtension} */
const ext = { const ext = {
@ -416,125 +232,6 @@ const ext = {
} }
}, },
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
nodeType.prototype.getExtraMenuOptions = function (_, options) {
const r = origGetExtraMenuOptions
? origGetExtraMenuOptions.apply(this, arguments)
: void 0;
if (this.widgets) {
let toInput = [];
let toWidget = [];
for (const w of this.widgets) {
if (w.options?.forceInput) {
continue;
}
if (w.type === CONVERTED_TYPE) {
toWidget.push({
content: `Convert ${w.name} to widget`,
callback: /* @__PURE__ */ __name(
() => convertToWidget(this, w),
"callback",
),
});
} else {
const config = getConfig.call(this, w.name) ?? [
w.type,
w.options || {},
];
if (isConvertibleWidget(w, config)) {
toInput.push({
content: `Convert ${w.name} to external input`,
callback: /* @__PURE__ */ __name(
() => convertToInput(this, w, config),
"callback",
),
className: "comfydeploy-menu-item",
});
}
}
}
if (toInput.length) {
if (true) {
options.push();
let optionIndex = options.findIndex((o) => o.content === "Outputs");
if (optionIndex === -1) optionIndex = options.length;
else optionIndex++;
options.splice(
0,
0,
// {
// content: "[ComfyDeploy] Convert to External Input",
// submenu: {
// options: toInput,
// },
// className: "comfydeploy-menu-item"
// },
...toInput,
null,
);
} else {
options.push(...toInput, null);
}
}
// if (toWidget.length) {
// if (useConversionSubmenusSetting.value) {
// options.push({
// content: "Convert Input to Widget",
// submenu: {
// options: toWidget,
// },
// });
// } else {
// options.push(...toWidget, null);
// }
// }
}
return r;
};
if (
nodeData?.input?.optional?.default_value_url?.[1]?.image_preview === true
) {
nodeData.input.optional.default_value_url = ["IMAGEPREVIEW"];
console.log(nodeData.input.optional.default_value_url);
}
// const origonNodeCreated = nodeType.prototype.onNodeCreated;
// nodeType.prototype.onNodeCreated = function () {
// const r = origonNodeCreated
// ? origonNodeCreated.apply(this, arguments)
// : void 0;
// if (!this.widgets) {
// return;
// }
// console.log(this.widgets);
// this.widgets.forEach(element => {
// if (element.type != "customtext") return
// console.log(element.element);
// const parent = element.element.parentElement
// console.log(element.element.parentElement)
// const btn = document.createElement("button");
// // const div = document.createElement("div");
// // parent.removeChild(element.element)
// // div.appendChild(element.element)
// // parent.appendChild(div)
// // element.element = div
// // console.log(element.element);
// // btn.style = element.element.style
// });
// return r
// };
},
registerCustomNodes() { registerCustomNodes() {
/** @type {LGraphNode}*/ /** @type {LGraphNode}*/
class ComfyDeploy extends LGraphNode { class ComfyDeploy extends LGraphNode {
@ -627,78 +324,6 @@ const ext = {
ComfyDeploy.category = "deploy"; ComfyDeploy.category = "deploy";
}, },
getCustomWidgets() {
return {
IMAGEPREVIEW(node, inputName, inputData) {
// Find or create the URL input widget
const urlWidget = node.addWidget(
"string",
inputName,
/* value=*/ "",
() => {},
{ serialize: true },
);
const buttonWidget = node.addWidget(
"button",
"Open Assets Browser",
/* value=*/ "",
() => {
sendEventToCD("assets", {
node: node.id,
inputName: inputName,
});
// console.log("load image");
},
{ serialize: false },
);
console.log(node.widgets);
console.log("urlWidget", urlWidget);
// Add image preview functionality
function showImage(url) {
const img = new Image();
img.onload = () => {
node.imgs = [img];
app.graph.setDirtyCanvas(true);
node.setSizeForImage?.();
};
img.onerror = () => {
node.imgs = [];
app.graph.setDirtyCanvas(true);
};
img.src = url;
}
// Set up URL widget value handling
let default_value = urlWidget.value;
Object.defineProperty(urlWidget, "value", {
set: function (value) {
this._real_value = value;
// Preview image when URL changes
if (value) {
showImage(value);
}
},
get: function () {
return this._real_value || default_value;
},
});
// Show initial image if URL exists
requestAnimationFrame(() => {
if (urlWidget.value) {
showImage(urlWidget.value);
}
});
return { widget: urlWidget };
},
};
},
async setup() { async setup() {
// const graphCanvas = document.getElementById("graph-canvas"); // const graphCanvas = document.getElementById("graph-canvas");
@ -726,7 +351,6 @@ const ext = {
} }
console.log("loadGraphData"); console.log("loadGraphData");
app.loadGraphData(comfyUIWorkflow); app.loadGraphData(comfyUIWorkflow);
sendEventToCD("graph_loaded");
} }
} else if (message.type === "deploy") { } else if (message.type === "deploy") {
// deployWorkflow(); // deployWorkflow();
@ -741,35 +365,11 @@ const ext = {
console.warn("api.handlePromptGenerated is not a function"); console.warn("api.handlePromptGenerated is not a function");
} }
sendEventToCD("cd_plugin_onQueuePrompt", prompt); sendEventToCD("cd_plugin_onQueuePrompt", prompt);
} else if (message.type === "configure_queue_buttons") {
addQueueButtons(message.data);
} else if (message.type === "configure_menu_right_buttons") {
addMenuRightButtons(message.data);
} else if (message.type === "configure_menu_buttons") {
addMenuButtons(message.data);
} else if (message.type === "get_prompt") { } else if (message.type === "get_prompt") {
const prompt = await app.graphToPrompt(); const prompt = await app.graphToPrompt();
sendEventToCD("cd_plugin_onGetPrompt", prompt); sendEventToCD("cd_plugin_onGetPrompt", prompt);
} else if (message.type === "event") { } else if (message.type === "event") {
dispatchAPIEventData(message.data); dispatchAPIEventData(message.data);
} else if (message.type === "update_widget") {
// New handler for updating widget values
const { nodeId, widgetName, value } = message.data;
const node = app.graph.getNodeById(nodeId);
if (!node) {
console.warn(`Node with ID ${nodeId} not found`);
return;
}
const widget = node.widgets?.find((w) => w.name === widgetName);
if (!widget) {
console.warn(`Widget ${widgetName} not found in node ${nodeId}`);
return;
}
widget.value = value;
app.graph.setDirtyCanvas(true);
} else if (message.type === "add_node") { } else if (message.type === "add_node") {
console.log("add node", message.data); console.log("add node", message.data);
app.graph.beforeChange(); app.graph.beforeChange();
@ -863,13 +463,13 @@ const ext = {
await app.ui.settings.setSettingValueAsync("Comfy.UseNewMenu", "Top"); await app.ui.settings.setSettingValueAsync("Comfy.UseNewMenu", "Top");
await app.ui.settings.setSettingValueAsync( await app.ui.settings.setSettingValueAsync(
"Comfy.Sidebar.Size", "Comfy.Sidebar.Size",
"small", "small"
); );
await app.ui.settings.setSettingValueAsync( await app.ui.settings.setSettingValueAsync(
"Comfy.Sidebar.Location", "Comfy.Sidebar.Location",
"left", "right"
); );
// localStorage.setItem("Comfy.MenuPosition.Docked", "true"); localStorage.setItem("Comfy.MenuPosition.Docked", "true");
console.log("native mode manmanman"); console.log("native mode manmanman");
} catch (error) { } catch (error) {
console.error("Error setting validation to false", error); console.error("Error setting validation to false", error);
@ -1402,6 +1002,8 @@ function addButton() {
app.registerExtension(ext); app.registerExtension(ext);
import { ComfyDialog, $el } from "../../scripts/ui.js";
export class InfoDialog extends ComfyDialog { export class InfoDialog extends ComfyDialog {
constructor() { constructor() {
super(); super();
@ -1872,7 +1474,7 @@ app.extensionManager.registerSidebarTab({
<div style="padding: 20px;"> <div style="padding: 20px;">
<h3>Comfy Deploy</h3> <h3>Comfy Deploy</h3>
<div id="deploy-container" style="margin-bottom: 20px;"></div> <div id="deploy-container" style="margin-bottom: 20px;"></div>
<div id="workflows-container" style="display: none;"> <div id="workflows-container">
<h4>Your Workflows</h4> <h4>Your Workflows</h4>
<div id="workflows-loading" style="display: flex; justify-content: center; align-items: center; height: 100px;"> <div id="workflows-loading" style="display: flex; justify-content: center; align-items: center; height: 100px;">
${loadingIcon} ${loadingIcon}
@ -1972,16 +1574,10 @@ async function loadWorkflowApi(versionId) {
const orginal_fetch_api = api.fetchApi; const orginal_fetch_api = api.fetchApi;
api.fetchApi = async (route, options) => { api.fetchApi = async (route, options) => {
// console.log("Fetch API called with args:", route, options, ext.native_mode); console.log("Fetch API called with args:", route, options, ext.native_mode);
if (route.startsWith("/prompt") && ext.native_mode) { if (route.startsWith("/prompt") && ext.native_mode) {
const info = await getSelectedWorkflowInfo(); const info = await getSelectedWorkflowInfo();
if (!info.workflow_id) {
console.log("No workflow id found, fallback to original fetch");
return await orginal_fetch_api.call(api, route, options);
}
console.log("info", info); console.log("info", info);
if (info) { if (info) {
const body = JSON.parse(options.body); const body = JSON.parse(options.body);
@ -1995,7 +1591,6 @@ api.fetchApi = async (route, options) => {
workflow_id: info.workflow_id, workflow_id: info.workflow_id,
native_run_api_endpoint: info.native_run_api_endpoint, native_run_api_endpoint: info.native_run_api_endpoint,
gpu_event_id: info.gpu_event_id, gpu_event_id: info.gpu_event_id,
gpu: info.gpu,
}; };
return await fetch("/comfyui-deploy/run", { return await fetch("/comfyui-deploy/run", {
@ -2011,306 +1606,3 @@ api.fetchApi = async (route, options) => {
return await orginal_fetch_api.call(api, route, options); return await orginal_fetch_api.call(api, route, options);
}; };
// Intercept window drag and drop events
const originalDropHandler = document.ondrop;
document.ondrop = async (e) => {
console.log("Drop event intercepted:", e);
// Prevent default browser behavior
e.preventDefault();
// Handle files if present
if (e.dataTransfer?.files?.length > 0) {
const files = Array.from(e.dataTransfer.files);
// Send file data to parent directly as JSON
sendDirectEventToCD("file_drop", {
files: files,
x: e.clientX,
y: e.clientY,
timestamp: Date.now(),
});
}
// Call original handler if exists
if (originalDropHandler) {
originalDropHandler(e);
}
};
const originalDragEnterHandler = document.ondragenter;
document.ondragenter = (e) => {
// Prevent default to allow drop
e.preventDefault();
// Send dragenter event to parent directly as JSON
sendDirectEventToCD("file_dragenter", {
x: e.clientX,
y: e.clientY,
timestamp: Date.now(),
});
if (originalDragEnterHandler) {
originalDragEnterHandler(e);
}
};
const originalDragLeaveHandler = document.ondragleave;
document.ondragleave = (e) => {
// Prevent default to allow drop
e.preventDefault();
// Send dragleave event to parent directly as JSON
sendDirectEventToCD("file_dragleave", {
x: e.clientX,
y: e.clientY,
timestamp: Date.now(),
});
if (originalDragLeaveHandler) {
originalDragLeaveHandler(e);
}
};
const originalDragOverHandler = document.ondragover;
document.ondragover = (e) => {
// Prevent default to allow drop
e.preventDefault();
// Send dragover event to parent directly as JSON
sendDirectEventToCD("file_dragover", {
x: e.clientX,
y: e.clientY,
timestamp: Date.now(),
});
if (originalDragOverHandler) {
originalDragOverHandler(e);
}
};
// Function to create a single button
function createQueueButton(config) {
const button = document.createElement("button");
button.id = `cd-button-${config.id}`;
button.className =
"p-button p-component p-button-icon-only p-button-secondary p-button-text";
button.innerHTML = `
<span class="p-button-icon pi ${config.icon}"></span>
<span class="p-button-label">&nbsp;</span>
`;
button.onclick = () => {
const eventData =
typeof config.eventData === "function"
? config.eventData()
: config.eventData || {};
sendEventToCD(config.event, eventData);
};
button.setAttribute("data-pd-tooltip", config.tooltip);
return button;
}
// Function to add buttons to queue group
function addQueueButtons(buttonConfigs = DEFAULT_BUTTONS) {
const queueButtonGroup = document.querySelector(".queue-button-group.flex");
if (!queueButtonGroup) return;
// Remove any existing CD buttons
const existingButtons =
queueButtonGroup.querySelectorAll('[id^="cd-button-"]');
existingButtons.forEach((button) => button.remove());
// Add new buttons
buttonConfigs.forEach((config) => {
const button = createQueueButton(config);
queueButtonGroup.appendChild(button);
});
}
// addMenuRightButtons([
// {
// id: "cd-button-save-image",
// icon: "pi-save",
// label: "Snapshot",
// tooltip: "Save the current image to your output directory.",
// event: "save_image",
// eventData: () => ({}),
// },
// ]);
// addMenuLeftButtons([
// {
// id: "cd-button-back",
// icon: `<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
// <path d="M15 18L9 12L15 6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
// </svg>`,
// tooltip: "Go back to the previous page.",
// event: "back",
// eventData: () => ({}),
// },
// ]);
// addMenuButtons({
// containerSelector: "body > div.comfyui-body-top > div",
// buttonConfigs: [
// {
// id: "cd-button-workflow-1",
// icon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24"><path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m16 3l4 4l-4 4m-6-4h10M8 13l-4 4l4 4m-4-4h9"/></svg>`,
// label: "Workflow",
// tooltip: "Go to Workflow 1",
// event: "workflow_1",
// // btnClasses: "",
// eventData: () => ({}),
// },
// {
// id: "cd-button-workflow-3",
// // icon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24"><path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m16 3l4 4l-4 4m-6-4h10M8 13l-4 4l4 4m-4-4h9"/></svg>`,
// label: "v1",
// tooltip: "Go to Workflow 1",
// event: "workflow_1",
// // btnClasses: "",
// eventData: () => ({}),
// },
// {
// id: "cd-button-workflow-2",
// icon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24"><g fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"><path d="M12 3v6"/><circle cx="12" cy="12" r="3"/><path d="M12 15v6"/></g></svg>`,
// label: "Commit",
// tooltip: "Commit the current workflow",
// event: "commit",
// style: {
// backgroundColor: "oklch(.476 .114 61.907)",
// },
// eventData: () => ({}),
// },
// ],
// buttonIdPrefix: "cd-button-workflow-",
// insertBefore:
// "body > div.comfyui-body-top > div > div.flex-grow.min-w-0.app-drag.h-full",
// // containerStyle: { order: "3" }
// });
// addMenuButtons({
// containerSelector:
// "body > div.comfyui-body-top > div > div.flex-grow.min-w-0.app-drag.h-full",
// clearContainer: true,
// buttonConfigs: [],
// buttonIdPrefix: "cd-button-p-",
// containerStyle: { order: "-1" },
// });
// Function to add buttons to a menu container
function addMenuButtons(options) {
const {
containerSelector,
buttonConfigs,
buttonIdPrefix = "cd-button-",
containerClass = "comfyui-button-group",
containerStyle = {},
clearContainer = false,
insertBefore = null, // New option to specify selector for insertion point
} = options;
const menuContainer = document.querySelector(containerSelector);
if (!menuContainer) return;
// Remove any existing CD buttons
const existingButtons = document.querySelectorAll(
`[id^="${buttonIdPrefix}"]`,
);
existingButtons.forEach((button) => button.remove());
const container = document.createElement("div");
container.className = containerClass;
// Apply container styles
Object.assign(container.style, containerStyle);
// Clear existing content if specified
if (clearContainer) {
menuContainer.innerHTML = "";
}
// Create and add buttons
buttonConfigs.forEach((config) => {
const button = createMenuButton({
...config,
idPrefix: buttonIdPrefix,
});
container.appendChild(button);
});
// Insert before specified element if provided, otherwise append
if (insertBefore) {
const targetElement = menuContainer.querySelector(insertBefore);
if (targetElement) {
menuContainer.insertBefore(container, targetElement);
} else {
menuContainer.appendChild(container);
}
} else {
menuContainer.appendChild(container);
}
}
function createMenuButton(config) {
const {
id,
icon,
label,
btnClasses = "",
tooltip,
event,
eventData,
idPrefix,
style = {},
} = config;
const button = document.createElement("button");
button.id = `${idPrefix}${id}`;
button.className = `comfyui-button ${btnClasses}`;
Object.assign(button.style, style);
// Only add icon if provided
const iconHtml = icon
? icon.startsWith("<svg")
? icon
: `<span class="p-button-icon pi ${icon}"></span>`
: "";
button.innerHTML = `
${iconHtml}
${label ? `<span class="p-button-label text-sm">${label}</span>` : ""}
`;
button.onclick = () => {
const data =
typeof eventData === "function" ? eventData() : eventData || {};
sendEventToCD(event, data);
};
if (tooltip) {
button.setAttribute("data-pd-tooltip", tooltip);
}
return button;
}
// Refactored menu button functions
function addMenuLeftButtons(buttonConfigs) {
addMenuButtons({
containerSelector: "body > div.comfyui-body-top > div",
buttonConfigs,
buttonIdPrefix: "cd-button-left-",
containerStyle: { order: "-1" },
});
}
function addMenuRightButtons(buttonConfigs) {
addMenuButtons({
containerSelector: ".comfyui-menu-right .flex",
buttonConfigs,
buttonIdPrefix: "cd-button-",
containerStyle: {},
});
}

18
web-plugin/widgets.js Normal file
View File

@ -0,0 +1,18 @@
// /** @typedef {import('../../../web/scripts/api.js').api} API*/
// import { api as _api } from "../../scripts/api.js";
// /** @type {API} */
// export const api = _api;
/** @typedef {typeof import('../../../web/scripts/widgets.js').ComfyWidgets} Widgets*/
import { ComfyWidgets as _ComfyWidgets } from "../../scripts/widgets.js";
/**
* @type {Widgets}
*/
export const ComfyWidgets = _ComfyWidgets;
// import { LGraphNode as _LGraphNode } from "../../types/litegraph.js";
/** @typedef {typeof import('../../../web/types/litegraph.js').LGraphNode} LGraphNode*/
/** @type {LGraphNode}*/
export const LGraphNode = LiteGraph.LGraphNode;