Compare commits
	
		
			No commits in common. "main" and "simplify-js-imports" have entirely different histories.
		
	
	
		
			main
			...
			simplify-j
		
	
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -1,3 +1,2 @@
 | 
				
			|||||||
__pycache__
 | 
					__pycache__
 | 
				
			||||||
.DS_Store
 | 
					.DS_Store
 | 
				
			||||||
file-hash-cache.json
 | 
					 | 
				
			||||||
							
								
								
									
										504
									
								
								builder/modal-builder/src/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										504
									
								
								builder/modal-builder/src/main.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,504 @@
 | 
				
			|||||||
 | 
					from typing import Union, Optional, Dict, List
 | 
				
			||||||
 | 
					from pydantic import BaseModel, Field, field_validator
 | 
				
			||||||
 | 
					from fastapi import FastAPI, HTTPException, WebSocket, BackgroundTasks, WebSocketDisconnect
 | 
				
			||||||
 | 
					from fastapi.responses import JSONResponse
 | 
				
			||||||
 | 
					from fastapi.logger import logger as fastapi_logger
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from contextlib import asynccontextmanager
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import signal
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					from fastapi.logger import logger as fastapi_logger
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					from urllib.parse import parse_qs
 | 
				
			||||||
 | 
					from starlette.middleware.base import BaseHTTPMiddleware
 | 
				
			||||||
 | 
					from starlette.types import ASGIApp, Scope, Receive, Send
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from concurrent.futures import ThreadPoolExecutor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# executor = ThreadPoolExecutor(max_workers=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					gunicorn_error_logger = logging.getLogger("gunicorn.error")
 | 
				
			||||||
 | 
					gunicorn_logger = logging.getLogger("gunicorn")
 | 
				
			||||||
 | 
					uvicorn_access_logger = logging.getLogger("uvicorn.access")
 | 
				
			||||||
 | 
					uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fastapi_logger.handlers = gunicorn_error_logger.handlers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ != "__main__":
 | 
				
			||||||
 | 
					    fastapi_logger.setLevel(gunicorn_logger.level)
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    fastapi_logger.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.getLogger("uvicorn")
 | 
				
			||||||
 | 
					logger.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					last_activity_time = time.time()
 | 
				
			||||||
 | 
					global_timeout = 60 * 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					machine_id_websocket_dict = {}
 | 
				
			||||||
 | 
					machine_id_status = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlyReplayMiddleware(BaseHTTPMiddleware):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
 | 
				
			||||||
 | 
					    to repeat the request again on the right instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This only works if the right instance is provided as a query_string parameter.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, app: ASGIApp) -> None:
 | 
				
			||||||
 | 
					        self.app = app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 | 
				
			||||||
 | 
					        query_string = scope.get('query_string', b'').decode()
 | 
				
			||||||
 | 
					        query_params = parse_qs(query_string)
 | 
				
			||||||
 | 
					        target_instance = query_params.get(
 | 
				
			||||||
 | 
					            'fly_instance_id', [fly_instance_id])[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async def send_wrapper(message):
 | 
				
			||||||
 | 
					            if target_instance != fly_instance_id:
 | 
				
			||||||
 | 
					                if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
 | 
				
			||||||
 | 
					                    # fly.io only seems to look at the fly-replay header if websocket is accepted
 | 
				
			||||||
 | 
					                    message = {'type': 'websocket.accept'}
 | 
				
			||||||
 | 
					                if 'headers' not in message:
 | 
				
			||||||
 | 
					                    message['headers'] = []
 | 
				
			||||||
 | 
					                message['headers'].append(
 | 
				
			||||||
 | 
					                    [b'fly-replay', f'instance={target_instance}'.encode()])
 | 
				
			||||||
 | 
					            await send(message)
 | 
				
			||||||
 | 
					        await self.app(scope, receive, send_wrapper)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def check_inactivity():
 | 
				
			||||||
 | 
					    global last_activity_time
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        # logger.info("Checking inactivity...")
 | 
				
			||||||
 | 
					        if time.time() - last_activity_time > global_timeout:
 | 
				
			||||||
 | 
					            if len(machine_id_status) == 0:
 | 
				
			||||||
 | 
					                # The application has been inactive for more than 60 seconds.
 | 
				
			||||||
 | 
					                # Scale it down to zero here.
 | 
				
			||||||
 | 
					                logger.info(
 | 
				
			||||||
 | 
					                    f"No activity for {global_timeout} seconds, exiting...")
 | 
				
			||||||
 | 
					                # os._exit(0)
 | 
				
			||||||
 | 
					                os.kill(os.getpid(), signal.SIGINT)
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					                # logger.info(f"Timeout but still in progress")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await asyncio.sleep(1)  # Check every second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@asynccontextmanager
 | 
				
			||||||
 | 
					async def lifespan(app: FastAPI):
 | 
				
			||||||
 | 
					    thread = run_in_new_thread(check_inactivity())
 | 
				
			||||||
 | 
					    yield
 | 
				
			||||||
 | 
					    logger.info("Cancelling")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					app = FastAPI(lifespan=lifespan)
 | 
				
			||||||
 | 
					app.add_middleware(FlyReplayMiddleware)
 | 
				
			||||||
 | 
					# MODAL_ORG = os.environ.get("MODAL_ORG")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.get("/")
 | 
				
			||||||
 | 
					def read_root():
 | 
				
			||||||
 | 
					    global last_activity_time
 | 
				
			||||||
 | 
					    last_activity_time = time.time()
 | 
				
			||||||
 | 
					    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
				
			||||||
 | 
					    return {"Hello": "World"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# create a post route called /create takes in a json of example
 | 
				
			||||||
 | 
					# {
 | 
				
			||||||
 | 
					#     name: "my first image",
 | 
				
			||||||
 | 
					#     deps: {
 | 
				
			||||||
 | 
					#         "comfyui": "d0165d819afe76bd4e6bdd710eb5f3e571b6a804",
 | 
				
			||||||
 | 
					#         "git_custom_nodes": {
 | 
				
			||||||
 | 
					#             "https://github.com/cubiq/ComfyUI_IPAdapter_plus": {
 | 
				
			||||||
 | 
					#                 "hash": "2ca0c6dd0b2ad64b1c480828638914a564331dcd",
 | 
				
			||||||
 | 
					#                 "disabled": true
 | 
				
			||||||
 | 
					#             },
 | 
				
			||||||
 | 
					#             "https://github.com/ltdrdata/ComfyUI-Manager.git": {
 | 
				
			||||||
 | 
					#                 "hash": "9c86f62b912f4625fe2b929c7fc61deb9d16f6d3",
 | 
				
			||||||
 | 
					#                 "disabled": false
 | 
				
			||||||
 | 
					#             },
 | 
				
			||||||
 | 
					#         },
 | 
				
			||||||
 | 
					#         "file_custom_nodes": []
 | 
				
			||||||
 | 
					#     }
 | 
				
			||||||
 | 
					# }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GitCustomNodes(BaseModel):
 | 
				
			||||||
 | 
					    hash: str
 | 
				
			||||||
 | 
					    disabled: bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FileCustomNodes(BaseModel):
 | 
				
			||||||
 | 
					    filename: str
 | 
				
			||||||
 | 
					    disabled: bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Snapshot(BaseModel):
 | 
				
			||||||
 | 
					    comfyui: str
 | 
				
			||||||
 | 
					    git_custom_nodes: Dict[str, GitCustomNodes]
 | 
				
			||||||
 | 
					    file_custom_nodes: List[FileCustomNodes]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(BaseModel):
 | 
				
			||||||
 | 
					    name: str
 | 
				
			||||||
 | 
					    type: str
 | 
				
			||||||
 | 
					    base: str
 | 
				
			||||||
 | 
					    save_path: str
 | 
				
			||||||
 | 
					    description: str
 | 
				
			||||||
 | 
					    reference: str
 | 
				
			||||||
 | 
					    filename: str
 | 
				
			||||||
 | 
					    url: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GPUType(str, Enum):
 | 
				
			||||||
 | 
					    T4 = "T4"
 | 
				
			||||||
 | 
					    A10G = "A10G"
 | 
				
			||||||
 | 
					    A100 = "A100"
 | 
				
			||||||
 | 
					    L4 = "L4"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Item(BaseModel):
 | 
				
			||||||
 | 
					    machine_id: str
 | 
				
			||||||
 | 
					    name: str
 | 
				
			||||||
 | 
					    snapshot: Snapshot
 | 
				
			||||||
 | 
					    models: List[Model]
 | 
				
			||||||
 | 
					    callback_url: str
 | 
				
			||||||
 | 
					    gpu: GPUType = Field(default=GPUType.T4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @field_validator('gpu')
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def check_gpu(cls, value):
 | 
				
			||||||
 | 
					        if value not in GPUType.__members__:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
 | 
				
			||||||
 | 
					        return GPUType(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.websocket("/ws/{machine_id}")
 | 
				
			||||||
 | 
					async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
				
			||||||
 | 
					    await websocket.accept()
 | 
				
			||||||
 | 
					    machine_id_websocket_dict[machine_id] = websocket
 | 
				
			||||||
 | 
					    # Send existing logs
 | 
				
			||||||
 | 
					    if machine_id in machine_logs_cache:
 | 
				
			||||||
 | 
					        combined_logs = "\n".join(
 | 
				
			||||||
 | 
					            log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
 | 
				
			||||||
 | 
					        await websocket.send_text(json.dumps({"event": "LOGS", "data": {
 | 
				
			||||||
 | 
					            "machine_id": machine_id,
 | 
				
			||||||
 | 
					            "logs": combined_logs,
 | 
				
			||||||
 | 
					            "timestamp": time.time()
 | 
				
			||||||
 | 
					        }}))
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            data = await websocket.receive_text()
 | 
				
			||||||
 | 
					            global last_activity_time
 | 
				
			||||||
 | 
					            last_activity_time = time.time()
 | 
				
			||||||
 | 
					            logger.info(f"Extended inactivity time to {global_timeout}")
 | 
				
			||||||
 | 
					            # You can handle received messages here if needed
 | 
				
			||||||
 | 
					    except WebSocketDisconnect:
 | 
				
			||||||
 | 
					        if machine_id in machine_id_websocket_dict:
 | 
				
			||||||
 | 
					            machine_id_websocket_dict.pop(machine_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# @app.get("/test")
 | 
				
			||||||
 | 
					# async def test():
 | 
				
			||||||
 | 
					#     machine_id_status["123"] = True
 | 
				
			||||||
 | 
					#     global last_activity_time
 | 
				
			||||||
 | 
					#     last_activity_time = time.time()
 | 
				
			||||||
 | 
					#     logger.info(f"Extended inactivity time to {global_timeout}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     await asyncio.sleep(10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     machine_id_status["123"] = False
 | 
				
			||||||
 | 
					#     machine_id_status.pop("123")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     return {"Hello": "World"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/create")
 | 
				
			||||||
 | 
					async def create_machine(item: Item):
 | 
				
			||||||
 | 
					    global last_activity_time
 | 
				
			||||||
 | 
					    last_activity_time = time.time()
 | 
				
			||||||
 | 
					    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if item.machine_id in machine_id_status and machine_id_status[item.machine_id]:
 | 
				
			||||||
 | 
					        return JSONResponse(status_code=400, content={"error": "Build already in progress."})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Run the building logic in a separate thread
 | 
				
			||||||
 | 
					    # future = executor.submit(build_logic, item)
 | 
				
			||||||
 | 
					    task = asyncio.create_task(build_logic(item))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StopAppItem(BaseModel):
 | 
				
			||||||
 | 
					    machine_id: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def find_app_id(app_list, app_name):
 | 
				
			||||||
 | 
					    for app in app_list:
 | 
				
			||||||
 | 
					        if app['Name'] == app_name:
 | 
				
			||||||
 | 
					            return app['App ID']
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/stop-app")
 | 
				
			||||||
 | 
					async def stop_app(item: StopAppItem):
 | 
				
			||||||
 | 
					    # cmd = f"modal app list | grep {item.machine_id} | awk -F '│' '{{print $2}}'"
 | 
				
			||||||
 | 
					    cmd = f"modal app list --json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    env = os.environ.copy()
 | 
				
			||||||
 | 
					    env["COLUMNS"] = "10000"  # Set the width to a large value
 | 
				
			||||||
 | 
					    find_id_process = await asyncio.subprocess.create_subprocess_shell(cmd,
 | 
				
			||||||
 | 
					                                                                      stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					                                                                      stderr=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					                                                                      env=env)
 | 
				
			||||||
 | 
					    await find_id_process.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stdout, stderr = await find_id_process.communicate()
 | 
				
			||||||
 | 
					    if stdout:
 | 
				
			||||||
 | 
					        app_id = stdout.decode().strip()
 | 
				
			||||||
 | 
					        app_list = json.loads(app_id)
 | 
				
			||||||
 | 
					        app_id = find_app_id(app_list, item.machine_id)
 | 
				
			||||||
 | 
					        logger.info(f"cp_process stdout: {app_id}")
 | 
				
			||||||
 | 
					    if stderr:
 | 
				
			||||||
 | 
					        logger.info(f"cp_process stderr: {stderr.decode()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cp_process = await asyncio.subprocess.create_subprocess_exec("modal", "app", "stop", app_id,
 | 
				
			||||||
 | 
					                                                                 stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					                                                                 stderr=asyncio.subprocess.PIPE,)
 | 
				
			||||||
 | 
					    await cp_process.wait()
 | 
				
			||||||
 | 
					    logger.info(f"Stopping app {item.machine_id}")
 | 
				
			||||||
 | 
					    stdout, stderr = await cp_process.communicate()
 | 
				
			||||||
 | 
					    if stdout:
 | 
				
			||||||
 | 
					        logger.info(f"cp_process stdout: {stdout.decode()}")
 | 
				
			||||||
 | 
					    if stderr:
 | 
				
			||||||
 | 
					        logger.info(f"cp_process stderr: {stderr.decode()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cp_process.returncode == 0:
 | 
				
			||||||
 | 
					        return JSONResponse(status_code=200, content={"status": "success"})
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Initialize the logs cache
 | 
				
			||||||
 | 
					machine_logs_cache = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def build_logic(item: Item):
 | 
				
			||||||
 | 
					    # Deploy to modal
 | 
				
			||||||
 | 
					    folder_path = f"/app/builds/{item.machine_id}"
 | 
				
			||||||
 | 
					    machine_id_status[item.machine_id] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Ensure the os path is same as the current directory
 | 
				
			||||||
 | 
					    # os.chdir(os.path.dirname(os.path.realpath(__file__)))
 | 
				
			||||||
 | 
					    # print(
 | 
				
			||||||
 | 
					    #     f"builder - Current working directory: {os.getcwd()}"
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Copy the app template
 | 
				
			||||||
 | 
					    # os.system(f"cp -r template {folder_path}")
 | 
				
			||||||
 | 
					    cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/template", folder_path)
 | 
				
			||||||
 | 
					    await cp_process.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Write the config file
 | 
				
			||||||
 | 
					    config = {
 | 
				
			||||||
 | 
					        "name": item.name,
 | 
				
			||||||
 | 
					        "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
 | 
				
			||||||
 | 
					        "gpu": item.gpu,
 | 
				
			||||||
 | 
					        "civitai_token": os.environ.get("CIVITAI_TOKEN", "")
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    with open(f"{folder_path}/config.py", "w") as f:
 | 
				
			||||||
 | 
					        f.write("config = " + json.dumps(config))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(f"{folder_path}/data/snapshot.json", "w") as f:
 | 
				
			||||||
 | 
					        f.write(item.snapshot.json())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(f"{folder_path}/data/models.json", "w") as f:
 | 
				
			||||||
 | 
					        models_json_list = [model.dict() for model in item.models]
 | 
				
			||||||
 | 
					        models_json_string = json.dumps(models_json_list)
 | 
				
			||||||
 | 
					        f.write(models_json_string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # os.chdir(folder_path)
 | 
				
			||||||
 | 
					    # process = subprocess.Popen(f"modal deploy {folder_path}/app.py", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
 | 
				
			||||||
 | 
					    process = await asyncio.subprocess.create_subprocess_shell(
 | 
				
			||||||
 | 
					        f"modal deploy app.py",
 | 
				
			||||||
 | 
					        stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        stderr=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        cwd=folder_path,
 | 
				
			||||||
 | 
					        env={**os.environ, "COLUMNS": "10000"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    url = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if item.machine_id not in machine_logs_cache:
 | 
				
			||||||
 | 
					        machine_logs_cache[item.machine_id] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    machine_logs = machine_logs_cache[item.machine_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    url_queue = asyncio.Queue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            line = await stream.readline()
 | 
				
			||||||
 | 
					            if line:
 | 
				
			||||||
 | 
					                l = line.decode('utf-8').strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if l == "":
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if not isStderr:
 | 
				
			||||||
 | 
					                    logger.info(l)
 | 
				
			||||||
 | 
					                    machine_logs.append({
 | 
				
			||||||
 | 
					                        "logs": l,
 | 
				
			||||||
 | 
					                        "timestamp": time.time()
 | 
				
			||||||
 | 
					                    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if item.machine_id in machine_id_websocket_dict:
 | 
				
			||||||
 | 
					                        await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
 | 
				
			||||||
 | 
					                            "machine_id": item.machine_id,
 | 
				
			||||||
 | 
					                            "logs": l,
 | 
				
			||||||
 | 
					                            "timestamp": time.time()
 | 
				
			||||||
 | 
					                        }}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if "Created comfyui_api =>" in l or ((l.startswith("https://") or l.startswith("│")) and l.endswith(".modal.run")):
 | 
				
			||||||
 | 
					                        if "Created comfyui_api =>" in l:
 | 
				
			||||||
 | 
					                            url = l.split("=>")[1].strip()
 | 
				
			||||||
 | 
					                        # making sure it is a url
 | 
				
			||||||
 | 
					                        elif "comfyui-api" in l:
 | 
				
			||||||
 | 
					                            # Some case it only prints the url on a blank line
 | 
				
			||||||
 | 
					                            if l.startswith("│"):
 | 
				
			||||||
 | 
					                                url = l.split("│")[1].strip()
 | 
				
			||||||
 | 
					                            else:
 | 
				
			||||||
 | 
					                                url = l
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        if url:
 | 
				
			||||||
 | 
					                            machine_logs.append({
 | 
				
			||||||
 | 
					                                "logs": f"App image built, url: {url}",
 | 
				
			||||||
 | 
					                                "timestamp": time.time()
 | 
				
			||||||
 | 
					                            })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            await url_queue.put(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            if item.machine_id in machine_id_websocket_dict:
 | 
				
			||||||
 | 
					                                await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
 | 
				
			||||||
 | 
					                                    "machine_id": item.machine_id,
 | 
				
			||||||
 | 
					                                    "logs": f"App image built, url: {url}",
 | 
				
			||||||
 | 
					                                    "timestamp": time.time()
 | 
				
			||||||
 | 
					                                }}))
 | 
				
			||||||
 | 
					                                await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": {
 | 
				
			||||||
 | 
					                                    "status": "succuss",
 | 
				
			||||||
 | 
					                                }}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    # is error
 | 
				
			||||||
 | 
					                    logger.error(l)
 | 
				
			||||||
 | 
					                    machine_logs.append({
 | 
				
			||||||
 | 
					                        "logs": l,
 | 
				
			||||||
 | 
					                        "timestamp": time.time()
 | 
				
			||||||
 | 
					                    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if item.machine_id in machine_id_websocket_dict:
 | 
				
			||||||
 | 
					                        await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
 | 
				
			||||||
 | 
					                            "machine_id": item.machine_id,
 | 
				
			||||||
 | 
					                            "logs": l,
 | 
				
			||||||
 | 
					                            "timestamp": time.time()
 | 
				
			||||||
 | 
					                        }}))
 | 
				
			||||||
 | 
					                        await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": {
 | 
				
			||||||
 | 
					                            "status": "failed",
 | 
				
			||||||
 | 
					                        }}))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stdout_task = asyncio.create_task(
 | 
				
			||||||
 | 
					        read_stream(process.stdout, False, url_queue))
 | 
				
			||||||
 | 
					    stderr_task = asyncio.create_task(
 | 
				
			||||||
 | 
					        read_stream(process.stderr, True, url_queue))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await asyncio.wait([stdout_task, stderr_task])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Wait for the subprocess to finish
 | 
				
			||||||
 | 
					    await process.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not url_queue.empty():
 | 
				
			||||||
 | 
					        # The queue is not empty, you can get an item
 | 
				
			||||||
 | 
					        url = await url_queue.get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Close the ws connection and also pop the item
 | 
				
			||||||
 | 
					    if item.machine_id in machine_id_websocket_dict and machine_id_websocket_dict[item.machine_id] is not None:
 | 
				
			||||||
 | 
					        await machine_id_websocket_dict[item.machine_id].close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if item.machine_id in machine_id_websocket_dict:
 | 
				
			||||||
 | 
					        machine_id_websocket_dict.pop(item.machine_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if item.machine_id in machine_id_status:
 | 
				
			||||||
 | 
					        machine_id_status[item.machine_id] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Check for errors
 | 
				
			||||||
 | 
					    if process.returncode != 0:
 | 
				
			||||||
 | 
					        logger.info("An error occurred.")
 | 
				
			||||||
 | 
					        # Send a post request with the json body machine_id to the callback url
 | 
				
			||||||
 | 
					        machine_logs.append({
 | 
				
			||||||
 | 
					            "logs": "Unable to build the app image.",
 | 
				
			||||||
 | 
					            "timestamp": time.time()
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					        requests.post(item.callback_url, json={
 | 
				
			||||||
 | 
					                      "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if item.machine_id in machine_logs_cache:
 | 
				
			||||||
 | 
					            del machine_logs_cache[item.machine_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					        # return JSONResponse(status_code=400, content={"error": "Unable to build the app image."})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # app_suffix = "comfyui-app"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if url is None:
 | 
				
			||||||
 | 
					        machine_logs.append({
 | 
				
			||||||
 | 
					            "logs": "App image built, but url is None, unable to parse the url.",
 | 
				
			||||||
 | 
					            "timestamp": time.time()
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					        requests.post(item.callback_url, json={
 | 
				
			||||||
 | 
					                      "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if item.machine_id in machine_logs_cache:
 | 
				
			||||||
 | 
					            del machine_logs_cache[item.machine_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					        # return JSONResponse(status_code=400, content={"error": "App image built, but url is None, unable to parse the url."})
 | 
				
			||||||
 | 
					    # example https://bennykok--my-app-comfyui-app.modal.run/
 | 
				
			||||||
 | 
					    # my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    requests.post(item.callback_url, json={
 | 
				
			||||||
 | 
					                  "machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
 | 
				
			||||||
 | 
					    if item.machine_id in machine_logs_cache:
 | 
				
			||||||
 | 
					        del machine_logs_cache[item.machine_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger.info("done")
 | 
				
			||||||
 | 
					    logger.info(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def start_loop(loop):
 | 
				
			||||||
 | 
					    asyncio.set_event_loop(loop)
 | 
				
			||||||
 | 
					    loop.run_forever()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_in_new_thread(coroutine):
 | 
				
			||||||
 | 
					    new_loop = asyncio.new_event_loop()
 | 
				
			||||||
 | 
					    t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
 | 
				
			||||||
 | 
					    t.start()
 | 
				
			||||||
 | 
					    asyncio.run_coroutine_threadsafe(coroutine, new_loop)
 | 
				
			||||||
 | 
					    return t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    import uvicorn
 | 
				
			||||||
 | 
					    # , log_level="debug"
 | 
				
			||||||
 | 
					    uvicorn.run("main:app", host="0.0.0.0", port=8080, lifespan="on")
 | 
				
			||||||
@ -1,448 +0,0 @@
 | 
				
			|||||||
import modal
 | 
					 | 
				
			||||||
from typing import Union, Optional, Dict, List
 | 
					 | 
				
			||||||
from pydantic import BaseModel, Field, field_validator
 | 
					 | 
				
			||||||
from fastapi import FastAPI, HTTPException, WebSocket, BackgroundTasks, WebSocketDisconnect
 | 
					 | 
				
			||||||
from fastapi.responses import JSONResponse
 | 
					 | 
				
			||||||
from fastapi.logger import logger as fastapi_logger
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
from enum import Enum
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import subprocess
 | 
					 | 
				
			||||||
import time
 | 
					 | 
				
			||||||
from contextlib import asynccontextmanager
 | 
					 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
import signal
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
from fastapi.logger import logger as fastapi_logger
 | 
					 | 
				
			||||||
import requests
 | 
					 | 
				
			||||||
from urllib.parse import parse_qs
 | 
					 | 
				
			||||||
from starlette.middleware.base import BaseHTTPMiddleware
 | 
					 | 
				
			||||||
from starlette.types import ASGIApp, Scope, Receive, Send
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Modal应用实例
 | 
					 | 
				
			||||||
modal_app = modal.App(name="comfyui-deploy")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
gunicorn_error_logger = logging.getLogger("gunicorn.error")
 | 
					 | 
				
			||||||
gunicorn_logger = logging.getLogger("gunicorn")
 | 
					 | 
				
			||||||
uvicorn_access_logger = logging.getLogger("uvicorn.access")
 | 
					 | 
				
			||||||
uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
fastapi_logger.handlers = gunicorn_error_logger.handlers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ != "__main__":
 | 
					 | 
				
			||||||
    fastapi_logger.setLevel(gunicorn_logger.level)
 | 
					 | 
				
			||||||
else:
 | 
					 | 
				
			||||||
    fastapi_logger.setLevel(logging.DEBUG)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = logging.getLogger("uvicorn")
 | 
					 | 
				
			||||||
logger.setLevel(logging.INFO)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
last_activity_time = time.time()
 | 
					 | 
				
			||||||
global_timeout = 60 * 4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
machine_id_websocket_dict = {}
 | 
					 | 
				
			||||||
machine_id_status = {}
 | 
					 | 
				
			||||||
machine_logs_cache = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
 | 
					 | 
				
			||||||
    def __init__(self, app: ASGIApp) -> None:
 | 
					 | 
				
			||||||
        super().__init__(app)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 | 
					 | 
				
			||||||
        query_string = scope.get('query_string', b'').decode()
 | 
					 | 
				
			||||||
        query_params = parse_qs(query_string)
 | 
					 | 
				
			||||||
        target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async def send_wrapper(message):
 | 
					 | 
				
			||||||
            if target_instance != fly_instance_id:
 | 
					 | 
				
			||||||
                if message['type'] == 'websocket.close' and 'Invalid session' in message.get('reason', ''):
 | 
					 | 
				
			||||||
                    message = {'type': 'websocket.accept'}
 | 
					 | 
				
			||||||
                if 'headers' not in message:
 | 
					 | 
				
			||||||
                    message['headers'] = []
 | 
					 | 
				
			||||||
                message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
 | 
					 | 
				
			||||||
            await send(message)
 | 
					 | 
				
			||||||
        await self.app(scope, receive, send_wrapper)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def check_inactivity():
 | 
					 | 
				
			||||||
    global last_activity_time
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        if time.time() - last_activity_time > global_timeout:
 | 
					 | 
				
			||||||
            if len(machine_id_status) == 0:
 | 
					 | 
				
			||||||
                logger.info(f"No activity for {global_timeout} seconds, exiting...")
 | 
					 | 
				
			||||||
                os.kill(os.getpid(), signal.SIGINT)
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
        await asyncio.sleep(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@asynccontextmanager
 | 
					 | 
				
			||||||
async def lifespan(app: FastAPI):
 | 
					 | 
				
			||||||
    thread = run_in_new_thread(check_inactivity())
 | 
					 | 
				
			||||||
    yield
 | 
					 | 
				
			||||||
    logger.info("Cancelling")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# FastAPI实例
 | 
					 | 
				
			||||||
fastapi_app = FastAPI(lifespan=lifespan)
 | 
					 | 
				
			||||||
fastapi_app.add_middleware(FlyReplayMiddleware)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GitCustomNodes(BaseModel):
 | 
					 | 
				
			||||||
    hash: str
 | 
					 | 
				
			||||||
    disabled: bool
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FileCustomNodes(BaseModel):
 | 
					 | 
				
			||||||
    filename: str
 | 
					 | 
				
			||||||
    disabled: bool
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Snapshot(BaseModel):
 | 
					 | 
				
			||||||
    comfyui: str
 | 
					 | 
				
			||||||
    git_custom_nodes: Dict[str, GitCustomNodes]
 | 
					 | 
				
			||||||
    file_custom_nodes: List[FileCustomNodes]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Model(BaseModel):
 | 
					 | 
				
			||||||
    name: str
 | 
					 | 
				
			||||||
    type: str
 | 
					 | 
				
			||||||
    base: str
 | 
					 | 
				
			||||||
    save_path: str
 | 
					 | 
				
			||||||
    description: str
 | 
					 | 
				
			||||||
    reference: str
 | 
					 | 
				
			||||||
    filename: str
 | 
					 | 
				
			||||||
    url: str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GPUType(str, Enum):
 | 
					 | 
				
			||||||
    T4 = "T4"
 | 
					 | 
				
			||||||
    A10G = "A10G"
 | 
					 | 
				
			||||||
    A100 = "A100"
 | 
					 | 
				
			||||||
    L4 = "L4"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Item(BaseModel):
 | 
					 | 
				
			||||||
    machine_id: str
 | 
					 | 
				
			||||||
    name: str
 | 
					 | 
				
			||||||
    snapshot: Snapshot
 | 
					 | 
				
			||||||
    models: List[Model]
 | 
					 | 
				
			||||||
    callback_url: str
 | 
					 | 
				
			||||||
    gpu: GPUType = Field(default=GPUType.T4)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @field_validator('gpu')
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def check_gpu(cls, value):
 | 
					 | 
				
			||||||
        if value not in GPUType.__members__:
 | 
					 | 
				
			||||||
            raise ValueError(f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
 | 
					 | 
				
			||||||
        return GPUType(value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StopAppItem(BaseModel):
 | 
					 | 
				
			||||||
    machine_id: str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@fastapi_app.get("/")
 | 
					 | 
				
			||||||
def read_root():
 | 
					 | 
				
			||||||
    global last_activity_time
 | 
					 | 
				
			||||||
    last_activity_time = time.time()
 | 
					 | 
				
			||||||
    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
					 | 
				
			||||||
    return {"Hello": "World"}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@fastapi_app.websocket("/ws/{machine_id}")
 | 
					 | 
				
			||||||
async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
					 | 
				
			||||||
    await websocket.accept()
 | 
					 | 
				
			||||||
    machine_id_websocket_dict[machine_id] = websocket
 | 
					 | 
				
			||||||
    if machine_id in machine_logs_cache:
 | 
					 | 
				
			||||||
        combined_logs = "\n".join(log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
 | 
					 | 
				
			||||||
        await websocket.send_text(json.dumps({
 | 
					 | 
				
			||||||
            "event": "LOGS", 
 | 
					 | 
				
			||||||
            "data": {
 | 
					 | 
				
			||||||
                "machine_id": machine_id,
 | 
					 | 
				
			||||||
                "logs": combined_logs,
 | 
					 | 
				
			||||||
                "timestamp": time.time()
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }))
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            data = await websocket.receive_text()
 | 
					 | 
				
			||||||
            global last_activity_time
 | 
					 | 
				
			||||||
            last_activity_time = time.time()
 | 
					 | 
				
			||||||
            logger.info(f"Extended inactivity time to {global_timeout}")
 | 
					 | 
				
			||||||
    except WebSocketDisconnect:
 | 
					 | 
				
			||||||
        if machine_id in machine_id_websocket_dict:
 | 
					 | 
				
			||||||
            del machine_id_websocket_dict[machine_id]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@fastapi_app.post("/create")
 | 
					 | 
				
			||||||
async def create_machine(item: Item):
 | 
					 | 
				
			||||||
    global last_activity_time
 | 
					 | 
				
			||||||
    last_activity_time = time.time()
 | 
					 | 
				
			||||||
    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if item.machine_id in machine_id_status and machine_id_status[item.machine_id]:
 | 
					 | 
				
			||||||
        return JSONResponse(status_code=400, content={"error": "Build already in progress."})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    task = asyncio.create_task(build_logic(item))
 | 
					 | 
				
			||||||
    return JSONResponse(
 | 
					 | 
				
			||||||
        status_code=200, 
 | 
					 | 
				
			||||||
        content={
 | 
					 | 
				
			||||||
            "message": "Build Queued",
 | 
					 | 
				
			||||||
            "build_machine_instance_id": fly_instance_id
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def find_app_id(app_list, app_name):
 | 
					 | 
				
			||||||
    for app in app_list:
 | 
					 | 
				
			||||||
        if app['Name'] == app_name:
 | 
					 | 
				
			||||||
            return app['App ID']
 | 
					 | 
				
			||||||
    return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@fastapi_app.post("/stop-app")
 | 
					 | 
				
			||||||
async def stop_app(item: StopAppItem):
 | 
					 | 
				
			||||||
    cmd = f"modal app list --json"
 | 
					 | 
				
			||||||
    env = os.environ.copy()
 | 
					 | 
				
			||||||
    env["COLUMNS"] = "10000"
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    find_id_process = await asyncio.subprocess.create_subprocess_shell(
 | 
					 | 
				
			||||||
        cmd,
 | 
					 | 
				
			||||||
        stdout=asyncio.subprocess.PIPE,
 | 
					 | 
				
			||||||
        stderr=asyncio.subprocess.PIPE,
 | 
					 | 
				
			||||||
        env=env
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    stdout, stderr = await find_id_process.communicate()
 | 
					 | 
				
			||||||
    if stdout:
 | 
					 | 
				
			||||||
        app_list = json.loads(stdout.decode().strip())
 | 
					 | 
				
			||||||
        app_id = find_app_id(app_list, item.machine_id)
 | 
					 | 
				
			||||||
        logger.info(f"cp_process stdout: {app_id}")
 | 
					 | 
				
			||||||
    if stderr:
 | 
					 | 
				
			||||||
        logger.info(f"cp_process stderr: {stderr.decode()}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cp_process = await asyncio.subprocess.create_subprocess_exec(
 | 
					 | 
				
			||||||
        "modal", "app", "stop", app_id,
 | 
					 | 
				
			||||||
        stdout=asyncio.subprocess.PIPE,
 | 
					 | 
				
			||||||
        stderr=asyncio.subprocess.PIPE,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    await cp_process.wait()
 | 
					 | 
				
			||||||
    stdout, stderr = await cp_process.communicate()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    if stdout:
 | 
					 | 
				
			||||||
        logger.info(f"cp_process stdout: {stdout.decode()}")
 | 
					 | 
				
			||||||
    if stderr:
 | 
					 | 
				
			||||||
        logger.info(f"cp_process stderr: {stderr.decode()}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if cp_process.returncode == 0:
 | 
					 | 
				
			||||||
        return JSONResponse(status_code=200, content={"status": "success"})
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return JSONResponse(
 | 
					 | 
				
			||||||
            status_code=500, 
 | 
					 | 
				
			||||||
            content={"status": "error", "error": stderr.decode()}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def build_logic(item: Item):
 | 
					 | 
				
			||||||
    folder_path = f"/app/builds/{item.machine_id}"
 | 
					 | 
				
			||||||
    machine_id_status[item.machine_id] = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cp_process = await asyncio.subprocess.create_subprocess_exec(
 | 
					 | 
				
			||||||
        "cp", "-r", "/app/src/template", folder_path
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    await cp_process.wait()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    config = {
 | 
					 | 
				
			||||||
        "name": item.name,
 | 
					 | 
				
			||||||
        "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
 | 
					 | 
				
			||||||
        "gpu": item.gpu,
 | 
					 | 
				
			||||||
        "civitai_token": os.environ.get("CIVITAI_TOKEN", "833b4ded5c7757a06a803763500bab58")
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
					 | 
				
			||||||
        f.write("config = " + json.dumps(config))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    with open(f"{folder_path}/data/snapshot.json", "w") as f:
 | 
					 | 
				
			||||||
        f.write(item.snapshot.json())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    with open(f"{folder_path}/data/models.json", "w") as f:
 | 
					 | 
				
			||||||
        models_json_list = [model.dict() for model in item.models]
 | 
					 | 
				
			||||||
        f.write(json.dumps(models_json_list))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    process = await asyncio.subprocess.create_subprocess_shell(
 | 
					 | 
				
			||||||
        f"modal deploy app.py",
 | 
					 | 
				
			||||||
        stdout=asyncio.subprocess.PIPE,
 | 
					 | 
				
			||||||
        stderr=asyncio.subprocess.PIPE,
 | 
					 | 
				
			||||||
        cwd=folder_path,
 | 
					 | 
				
			||||||
        env={**os.environ, "COLUMNS": "10000"}
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if item.machine_id not in machine_logs_cache:
 | 
					 | 
				
			||||||
        machine_logs_cache[item.machine_id] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    machine_logs = machine_logs_cache[item.machine_id]
 | 
					 | 
				
			||||||
    url_queue = asyncio.Queue()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            line = await stream.readline()
 | 
					 | 
				
			||||||
            if not line:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            l = line.decode('utf-8').strip()
 | 
					 | 
				
			||||||
            if not l:
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not isStderr:
 | 
					 | 
				
			||||||
                logger.info(l)
 | 
					 | 
				
			||||||
                machine_logs.append({
 | 
					 | 
				
			||||||
                    "logs": l,
 | 
					 | 
				
			||||||
                    "timestamp": time.time()
 | 
					 | 
				
			||||||
                })
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if item.machine_id in machine_id_websocket_dict:
 | 
					 | 
				
			||||||
                    await machine_id_websocket_dict[item.machine_id].send_text(
 | 
					 | 
				
			||||||
                        json.dumps({
 | 
					 | 
				
			||||||
                            "event": "LOGS",
 | 
					 | 
				
			||||||
                            "data": {
 | 
					 | 
				
			||||||
                                "machine_id": item.machine_id,
 | 
					 | 
				
			||||||
                                "logs": l,
 | 
					 | 
				
			||||||
                                "timestamp": time.time()
 | 
					 | 
				
			||||||
                            }
 | 
					 | 
				
			||||||
                        })
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if "Created comfyui_api =>" in l or ((l.startswith("https://") or l.startswith("│")) and l.endswith(".modal.run")):
 | 
					 | 
				
			||||||
                    if "Created comfyui_api =>" in l:
 | 
					 | 
				
			||||||
                        url = l.split("=>")[1].strip()
 | 
					 | 
				
			||||||
                    elif "comfyui-api" in l:
 | 
					 | 
				
			||||||
                        url = l.split("│")[1].strip() if l.startswith("│") else l
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if url:
 | 
					 | 
				
			||||||
                        machine_logs.append({
 | 
					 | 
				
			||||||
                            "logs": f"App image built, url: {url}",
 | 
					 | 
				
			||||||
                            "timestamp": time.time()
 | 
					 | 
				
			||||||
                        })
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        await url_queue.put(url)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        if item.machine_id in machine_id_websocket_dict:
 | 
					 | 
				
			||||||
                            await machine_id_websocket_dict[item.machine_id].send_text(
 | 
					 | 
				
			||||||
                                json.dumps({
 | 
					 | 
				
			||||||
                                    "event": "LOGS",
 | 
					 | 
				
			||||||
                                    "data": {
 | 
					 | 
				
			||||||
                                        "machine_id": item.machine_id,
 | 
					 | 
				
			||||||
                                        "logs": f"App image built, url: {url}",
 | 
					 | 
				
			||||||
                                        "timestamp": time.time()
 | 
					 | 
				
			||||||
                                    }
 | 
					 | 
				
			||||||
                                })
 | 
					 | 
				
			||||||
                            )
 | 
					 | 
				
			||||||
                            await machine_id_websocket_dict[item.machine_id].send_text(
 | 
					 | 
				
			||||||
                                json.dumps({
 | 
					 | 
				
			||||||
                                    "event": "FINISHED",
 | 
					 | 
				
			||||||
                                    "data": {
 | 
					 | 
				
			||||||
                                        "status": "success",
 | 
					 | 
				
			||||||
                                    }
 | 
					 | 
				
			||||||
                                })
 | 
					 | 
				
			||||||
                            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                logger.error(l)
 | 
					 | 
				
			||||||
                machine_logs.append({
 | 
					 | 
				
			||||||
                    "logs": l,
 | 
					 | 
				
			||||||
                    "timestamp": time.time()
 | 
					 | 
				
			||||||
                })
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if item.machine_id in machine_id_websocket_dict:
 | 
					 | 
				
			||||||
                    await machine_id_websocket_dict[item.machine_id].send_text(
 | 
					 | 
				
			||||||
                        json.dumps({
 | 
					 | 
				
			||||||
                            "event": "LOGS",
 | 
					 | 
				
			||||||
                            "data": {
 | 
					 | 
				
			||||||
                                "machine_id": item.machine_id,
 | 
					 | 
				
			||||||
                                "logs": l,
 | 
					 | 
				
			||||||
                                "timestamp": time.time()
 | 
					 | 
				
			||||||
                            }
 | 
					 | 
				
			||||||
                        })
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    await machine_id_websocket_dict[item.machine_id].send_text(
 | 
					 | 
				
			||||||
                        json.dumps({
 | 
					 | 
				
			||||||
                            "event": "FINISHED",
 | 
					 | 
				
			||||||
                            "data": {
 | 
					 | 
				
			||||||
                                "status": "failed",
 | 
					 | 
				
			||||||
                            }
 | 
					 | 
				
			||||||
                        })
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    stdout_task = asyncio.create_task(read_stream(process.stdout, False, url_queue))
 | 
					 | 
				
			||||||
    stderr_task = asyncio.create_task(read_stream(process.stderr, True, url_queue))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    await asyncio.wait([stdout_task, stderr_task])
 | 
					 | 
				
			||||||
    await process.wait()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    url = await url_queue.get() if not url_queue.empty() else None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if item.machine_id in machine_id_websocket_dict and machine_id_websocket_dict[item.machine_id] is not None:
 | 
					 | 
				
			||||||
        await machine_id_websocket_dict[item.machine_id].close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if item.machine_id in machine_id_websocket_dict:
 | 
					 | 
				
			||||||
        del machine_id_websocket_dict[item.machine_id]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if item.machine_id in machine_id_status:
 | 
					 | 
				
			||||||
        machine_id_status[item.machine_id] = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if process.returncode != 0:
 | 
					 | 
				
			||||||
        logger.info("An error occurred.")
 | 
					 | 
				
			||||||
        machine_logs.append({
 | 
					 | 
				
			||||||
            "logs": "Unable to build the app image.",
 | 
					 | 
				
			||||||
            "timestamp": time.time()
 | 
					 | 
				
			||||||
        })
 | 
					 | 
				
			||||||
        requests.post(
 | 
					 | 
				
			||||||
            item.callback_url,
 | 
					 | 
				
			||||||
            json={
 | 
					 | 
				
			||||||
                "machine_id": item.machine_id,
 | 
					 | 
				
			||||||
                "build_log": json.dumps(machine_logs)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if item.machine_id in machine_logs_cache:
 | 
					 | 
				
			||||||
            del machine_logs_cache[item.machine_id]
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if url is None:
 | 
					 | 
				
			||||||
        machine_logs.append({
 | 
					 | 
				
			||||||
            "logs": "App image built, but url is None, unable to parse the url.",
 | 
					 | 
				
			||||||
            "timestamp": time.time()
 | 
					 | 
				
			||||||
        })
 | 
					 | 
				
			||||||
        requests.post(
 | 
					 | 
				
			||||||
            item.callback_url,
 | 
					 | 
				
			||||||
            json={
 | 
					 | 
				
			||||||
                "machine_id": item.machine_id,
 | 
					 | 
				
			||||||
                "build_log": json.dumps(machine_logs)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if item.machine_id in machine_logs_cache:
 | 
					 | 
				
			||||||
            del machine_logs_cache[item.machine_id]
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    requests.post(
 | 
					 | 
				
			||||||
        item.callback_url,
 | 
					 | 
				
			||||||
        json={
 | 
					 | 
				
			||||||
            "machine_id": item.machine_id,
 | 
					 | 
				
			||||||
            "endpoint": url,
 | 
					 | 
				
			||||||
            "build_log": json.dumps(machine_logs)
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    if item.machine_id in machine_logs_cache:
 | 
					 | 
				
			||||||
        del machine_logs_cache[item.machine_id]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger.info("done")
 | 
					 | 
				
			||||||
    logger.info(url)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def start_loop(loop):
 | 
					 | 
				
			||||||
    asyncio.set_event_loop(loop)
 | 
					 | 
				
			||||||
    loop.run_forever()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def run_in_new_thread(coroutine):
 | 
					 | 
				
			||||||
    new_loop = asyncio.new_event_loop()
 | 
					 | 
				
			||||||
    t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
 | 
					 | 
				
			||||||
    t.start()
 | 
					 | 
				
			||||||
    asyncio.run_coroutine_threadsafe(coroutine, new_loop)
 | 
					 | 
				
			||||||
    return t
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Modal endpoint
 | 
					 | 
				
			||||||
@modal_app.function()
 | 
					 | 
				
			||||||
@modal.asgi_app()
 | 
					 | 
				
			||||||
def app():
 | 
					 | 
				
			||||||
    return fastapi_app
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    import uvicorn
 | 
					 | 
				
			||||||
    uvicorn.run(fastapi_app, host="0.0.0.0", port=8080, lifespan="on")
 | 
					 | 
				
			||||||
@ -307,5 +307,4 @@ def comfyui_app():
 | 
				
			|||||||
        },
 | 
					        },
 | 
				
			||||||
    )()
 | 
					    )()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    proxy_app = make_simple_proxy_app(ProxyContext(config)) # Assign to variable
 | 
					    return make_simple_proxy_app(ProxyContext(config))
 | 
				
			||||||
    return proxy_app # Return the variable
 | 
					 | 
				
			||||||
@ -1,57 +0,0 @@
 | 
				
			|||||||
import os
 | 
					 | 
				
			||||||
import io
 | 
					 | 
				
			||||||
import torchaudio
 | 
					 | 
				
			||||||
from folder_paths import get_annotated_filepath
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ComfyUIDeployExternalAudio:
 | 
					 | 
				
			||||||
    RETURN_TYPES = ("AUDIO",)
 | 
					 | 
				
			||||||
    RETURN_NAMES = ("audio",)
 | 
					 | 
				
			||||||
    FUNCTION = "load_audio"
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def INPUT_TYPES(cls):
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "required": {
 | 
					 | 
				
			||||||
                "input_id": (
 | 
					 | 
				
			||||||
                    "STRING",
 | 
					 | 
				
			||||||
                    {"multiline": False, "default": "input_audio"},
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
                "audio_file": ("STRING", {"default": ""}),
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            "optional": {
 | 
					 | 
				
			||||||
                "default_value": ("AUDIO",),
 | 
					 | 
				
			||||||
                "display_name": (
 | 
					 | 
				
			||||||
                    "STRING",
 | 
					 | 
				
			||||||
                    {"multiline": False, "default": ""},
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
                "description": (
 | 
					 | 
				
			||||||
                    "STRING",
 | 
					 | 
				
			||||||
                    {"multiline": False, "default": ""},
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def VALIDATE_INPUTS(s, audio_file, **kwargs):
 | 
					 | 
				
			||||||
        return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def load_audio(self, input_id, audio_file, default_value=None, display_name=None, description=None):
 | 
					 | 
				
			||||||
        if audio_file and audio_file != "":
 | 
					 | 
				
			||||||
            if audio_file.startswith(('http://', 'https://')):
 | 
					 | 
				
			||||||
                # Handle URL input
 | 
					 | 
				
			||||||
                import requests
 | 
					 | 
				
			||||||
                response = requests.get(audio_file)
 | 
					 | 
				
			||||||
                audio_data = io.BytesIO(response.content)
 | 
					 | 
				
			||||||
                waveform, sample_rate = torchaudio.load(audio_data)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # Handle local file
 | 
					 | 
				
			||||||
                audio_path = get_annotated_filepath(audio_file)
 | 
					 | 
				
			||||||
                waveform, sample_rate = torchaudio.load(audio_path)
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
            audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
 | 
					 | 
				
			||||||
            return (audio,)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return (default_value,)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
NODE_CLASS_MAPPINGS = {"ComfyUIDeployExternalAudio": ComfyUIDeployExternalAudio}
 | 
					 | 
				
			||||||
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyUIDeployExternalAudio": "External Audio (ComfyUI Deploy)"}
 | 
					 | 
				
			||||||
@ -21,9 +21,8 @@ class ComfyUIDeployExternalImage:
 | 
				
			|||||||
                ),
 | 
					                ),
 | 
				
			||||||
                "description": (
 | 
					                "description": (
 | 
				
			||||||
                    "STRING",
 | 
					                    "STRING",
 | 
				
			||||||
                    {"multiline": False, "default": ""},
 | 
					                    {"multiline": True, "default": ""},
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                "default_value_url": ("STRING", {"image_preview": True, "default": ""}),
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -34,44 +33,32 @@ class ComfyUIDeployExternalImage:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    CATEGORY = "image"
 | 
					    CATEGORY = "image"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, input_id, default_value=None, display_name=None, description=None, default_value_url=None):
 | 
					    def run(self, input_id, default_value=None, display_name=None, description=None):
 | 
				
			||||||
        image = default_value
 | 
					        image = default_value
 | 
				
			||||||
        
 | 
					        try:
 | 
				
			||||||
        # Try both input_id and default_value_url
 | 
					            if input_id.startswith('http'):
 | 
				
			||||||
        urls_to_try = [url for url in [input_id, default_value_url] if url]
 | 
					                import requests
 | 
				
			||||||
        
 | 
					                from io import BytesIO
 | 
				
			||||||
        print(default_value_url)
 | 
					                print("Fetching image from url: ", input_id)
 | 
				
			||||||
        
 | 
					                response = requests.get(input_id)
 | 
				
			||||||
        for url in urls_to_try:
 | 
					                image = Image.open(BytesIO(response.content))
 | 
				
			||||||
            try:
 | 
					            elif input_id.startswith('data:image/png;base64,') or input_id.startswith('data:image/jpeg;base64,') or input_id.startswith('data:image/jpg;base64,'):
 | 
				
			||||||
                if url.startswith('http'):
 | 
					                import base64
 | 
				
			||||||
                    import requests
 | 
					                from io import BytesIO
 | 
				
			||||||
                    from io import BytesIO
 | 
					                print("Decoding base64 image")
 | 
				
			||||||
                    print(f"Fetching image from url: {url}")
 | 
					                base64_image = input_id[input_id.find(",")+1:]
 | 
				
			||||||
                    response = requests.get(url)
 | 
					                decoded_image = base64.b64decode(base64_image)
 | 
				
			||||||
                    image = Image.open(BytesIO(response.content))
 | 
					                image = Image.open(BytesIO(decoded_image))
 | 
				
			||||||
                    break
 | 
					            else:
 | 
				
			||||||
                elif url.startswith(('data:image/png;base64,', 'data:image/jpeg;base64,', 'data:image/jpg;base64,')):
 | 
					                raise ValueError("Invalid image url provided.")
 | 
				
			||||||
                    import base64
 | 
					
 | 
				
			||||||
                    from io import BytesIO
 | 
					            image = ImageOps.exif_transpose(image)
 | 
				
			||||||
                    print("Decoding base64 image")
 | 
					            image = image.convert("RGB")
 | 
				
			||||||
                    base64_image = url[url.find(",")+1:]
 | 
					            image = np.array(image).astype(np.float32) / 255.0
 | 
				
			||||||
                    decoded_image = base64.b64decode(base64_image)
 | 
					            image = torch.from_numpy(image)[None,]
 | 
				
			||||||
                    image = Image.open(BytesIO(decoded_image))
 | 
					            return [image]
 | 
				
			||||||
                    break
 | 
					        except:
 | 
				
			||||||
            except:
 | 
					            return [image]
 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        if image is not None:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                image = ImageOps.exif_transpose(image)
 | 
					 | 
				
			||||||
                image = image.convert("RGB")
 | 
					 | 
				
			||||||
                image = np.array(image).astype(np.float32) / 255.0
 | 
					 | 
				
			||||||
                image = torch.from_numpy(image)[None,]
 | 
					 | 
				
			||||||
            except:
 | 
					 | 
				
			||||||
                pass
 | 
					 | 
				
			||||||
                
 | 
					 | 
				
			||||||
        return [image]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
NODE_CLASS_MAPPINGS = {"ComfyUIDeployExternalImage": ComfyUIDeployExternalImage}
 | 
					NODE_CLASS_MAPPINGS = {"ComfyUIDeployExternalImage": ComfyUIDeployExternalImage}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,92 +0,0 @@
 | 
				
			|||||||
import os
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
from PIL import Image
 | 
					 | 
				
			||||||
from PIL.PngImagePlugin import PngInfo
 | 
					 | 
				
			||||||
import folder_paths
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ComfyDeployOutputImage:
 | 
					 | 
				
			||||||
    def __init__(self):
 | 
					 | 
				
			||||||
        self.output_dir = folder_paths.get_output_directory()
 | 
					 | 
				
			||||||
        self.type = "output"
 | 
					 | 
				
			||||||
        self.prefix_append = ""
 | 
					 | 
				
			||||||
        self.compress_level = 4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def INPUT_TYPES(s):
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "required": {
 | 
					 | 
				
			||||||
                "images": ("IMAGE", {"tooltip": "The images to save."}),
 | 
					 | 
				
			||||||
                "filename_prefix": (
 | 
					 | 
				
			||||||
                    "STRING",
 | 
					 | 
				
			||||||
                    {
 | 
					 | 
				
			||||||
                        "default": "ComfyUI",
 | 
					 | 
				
			||||||
                        "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
 | 
					 | 
				
			||||||
                    },
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
                "file_type": (["png", "jpg", "webp"], {"default": "webp"}),
 | 
					 | 
				
			||||||
                "quality": ("INT", {"default": 80, "min": 1, "max": 100, "step": 1}),
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    RETURN_TYPES = ()
 | 
					 | 
				
			||||||
    FUNCTION = "run"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    OUTPUT_NODE = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    CATEGORY = "output"
 | 
					 | 
				
			||||||
    DESCRIPTION = "Saves the input images to your ComfyUI output directory."
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def run(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        images,
 | 
					 | 
				
			||||||
        filename_prefix="ComfyUI",
 | 
					 | 
				
			||||||
        file_type="png",
 | 
					 | 
				
			||||||
        quality=80,
 | 
					 | 
				
			||||||
        prompt=None,
 | 
					 | 
				
			||||||
        extra_pnginfo=None,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        filename_prefix += self.prefix_append
 | 
					 | 
				
			||||||
        full_output_folder, filename, counter, subfolder, filename_prefix = (
 | 
					 | 
				
			||||||
            folder_paths.get_save_image_path(
 | 
					 | 
				
			||||||
                filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        results = list()
 | 
					 | 
				
			||||||
        for batch_number, image in enumerate(images):
 | 
					 | 
				
			||||||
            i = 255.0 * image.cpu().numpy()
 | 
					 | 
				
			||||||
            img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
 | 
					 | 
				
			||||||
            metadata = PngInfo()
 | 
					 | 
				
			||||||
            if prompt is not None:
 | 
					 | 
				
			||||||
                metadata.add_text("prompt", json.dumps(prompt))
 | 
					 | 
				
			||||||
            if extra_pnginfo is not None:
 | 
					 | 
				
			||||||
                for x in extra_pnginfo:
 | 
					 | 
				
			||||||
                    metadata.add_text(x, json.dumps(extra_pnginfo[x]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
 | 
					 | 
				
			||||||
            file = f"{filename_with_batch_num}_{counter:05}_.{file_type}"
 | 
					 | 
				
			||||||
            file_path = os.path.join(full_output_folder, file)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if file_type == "png":
 | 
					 | 
				
			||||||
                img.save(
 | 
					 | 
				
			||||||
                    file_path, pnginfo=metadata, compress_level=self.compress_level
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            elif file_type == "jpg":
 | 
					 | 
				
			||||||
                img.save(file_path, quality=quality, optimize=True)
 | 
					 | 
				
			||||||
            elif file_type == "webp":
 | 
					 | 
				
			||||||
                img.save(file_path, quality=quality)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            results.append(
 | 
					 | 
				
			||||||
                {"filename": file, "subfolder": subfolder, "type": self.type}
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            counter += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return {"ui": {"images": results}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
NODE_CLASS_MAPPINGS = {"ComfyDeployOutputImage": ComfyDeployOutputImage}
 | 
					 | 
				
			||||||
NODE_DISPLAY_NAME_MAPPINGS = {
 | 
					 | 
				
			||||||
    "ComfyDeployOutputImage": "Image Output (ComfyDeploy)"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -386,9 +386,6 @@ def apply_inputs_to_workflow(workflow_api: Any, inputs: Any, sid: str = None):
 | 
				
			|||||||
                if value["class_type"] == "ComfyUIDeployExternalFaceModel":
 | 
					                if value["class_type"] == "ComfyUIDeployExternalFaceModel":
 | 
				
			||||||
                    value["inputs"]["face_model_url"] = new_value
 | 
					                    value["inputs"]["face_model_url"] = new_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if value["class_type"] == "ComfyUIDeployExternalAudio":
 | 
					 | 
				
			||||||
                    value["inputs"]["audio_file"] = new_value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
					def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
				
			||||||
    # workflow_api = inputs.workflow_api
 | 
					    # workflow_api = inputs.workflow_api
 | 
				
			||||||
@ -1283,8 +1280,6 @@ async def send_json_override(self, event, data, sid=None):
 | 
				
			|||||||
        if prompt_id in prompt_metadata:
 | 
					        if prompt_id in prompt_metadata:
 | 
				
			||||||
            prompt_metadata[prompt_id].start_time = time.perf_counter()
 | 
					            prompt_metadata[prompt_id].start_time = time.perf_counter()
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        logger.info("Executing prompt: " + prompt_id)
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
        asyncio.create_task(update_run(prompt_id, Status.RUNNING))
 | 
					        asyncio.create_task(update_run(prompt_id, Status.RUNNING))
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
name = "comfyui-deploy"
 | 
					name = "comfyui-deploy"
 | 
				
			||||||
description = "Open source comfyui deployment platform, a vercel for generative workflow infra."
 | 
					description = "Open source comfyui deployment platform, a vercel for generative workflow infra."
 | 
				
			||||||
version = "1.1.0"
 | 
					version = "1.1.0"
 | 
				
			||||||
license = { file = "LICENSE" }
 | 
					license = "LICENSE"
 | 
				
			||||||
dependencies = ["aiofiles", "pydantic", "opencv-python", "imageio-ffmpeg"]
 | 
					dependencies = ["aiofiles", "pydantic", "opencv-python", "imageio-ffmpeg"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[project.urls]
 | 
					[project.urls]
 | 
				
			||||||
 | 
				
			|||||||
@ -50,14 +50,6 @@ function sendEventToCD(event, data) {
 | 
				
			|||||||
  window.parent.postMessage(JSON.stringify(message), "*");
 | 
					  window.parent.postMessage(JSON.stringify(message), "*");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function sendDirectEventToCD(event, data) {
 | 
					 | 
				
			||||||
  const message = {
 | 
					 | 
				
			||||||
    type: event,
 | 
					 | 
				
			||||||
    data: data,
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
  window.parent.postMessage(message, "*");
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
function dispatchAPIEventData(data) {
 | 
					function dispatchAPIEventData(data) {
 | 
				
			||||||
  const msg = JSON.parse(data);
 | 
					  const msg = JSON.parse(data);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -494,13 +486,6 @@ const ext = {
 | 
				
			|||||||
      return r;
 | 
					      return r;
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (
 | 
					 | 
				
			||||||
      nodeData?.input?.optional?.default_value_url?.[1]?.image_preview === true
 | 
					 | 
				
			||||||
    ) {
 | 
					 | 
				
			||||||
      nodeData.input.optional.default_value_url = ["IMAGEPREVIEW"];
 | 
					 | 
				
			||||||
      console.log(nodeData.input.optional.default_value_url);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // const origonNodeCreated = nodeType.prototype.onNodeCreated;
 | 
					    // const origonNodeCreated = nodeType.prototype.onNodeCreated;
 | 
				
			||||||
    // nodeType.prototype.onNodeCreated = function () {
 | 
					    // nodeType.prototype.onNodeCreated = function () {
 | 
				
			||||||
    //   const r = origonNodeCreated
 | 
					    //   const r = origonNodeCreated
 | 
				
			||||||
@ -627,78 +612,6 @@ const ext = {
 | 
				
			|||||||
    ComfyDeploy.category = "deploy";
 | 
					    ComfyDeploy.category = "deploy";
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  getCustomWidgets() {
 | 
					 | 
				
			||||||
    return {
 | 
					 | 
				
			||||||
      IMAGEPREVIEW(node, inputName, inputData) {
 | 
					 | 
				
			||||||
        // Find or create the URL input widget
 | 
					 | 
				
			||||||
        const urlWidget = node.addWidget(
 | 
					 | 
				
			||||||
          "string",
 | 
					 | 
				
			||||||
          inputName,
 | 
					 | 
				
			||||||
          /* value=*/ "",
 | 
					 | 
				
			||||||
          () => {},
 | 
					 | 
				
			||||||
          { serialize: true },
 | 
					 | 
				
			||||||
        );
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        const buttonWidget = node.addWidget(
 | 
					 | 
				
			||||||
          "button",
 | 
					 | 
				
			||||||
          "Open Assets Browser",
 | 
					 | 
				
			||||||
          /* value=*/ "",
 | 
					 | 
				
			||||||
          () => {
 | 
					 | 
				
			||||||
            sendEventToCD("assets", {
 | 
					 | 
				
			||||||
              node: node.id,
 | 
					 | 
				
			||||||
              inputName: inputName,
 | 
					 | 
				
			||||||
            });
 | 
					 | 
				
			||||||
            // console.log("load image");
 | 
					 | 
				
			||||||
          },
 | 
					 | 
				
			||||||
          { serialize: false },
 | 
					 | 
				
			||||||
        );
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        console.log(node.widgets);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        console.log("urlWidget", urlWidget);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Add image preview functionality
 | 
					 | 
				
			||||||
        function showImage(url) {
 | 
					 | 
				
			||||||
          const img = new Image();
 | 
					 | 
				
			||||||
          img.onload = () => {
 | 
					 | 
				
			||||||
            node.imgs = [img];
 | 
					 | 
				
			||||||
            app.graph.setDirtyCanvas(true);
 | 
					 | 
				
			||||||
            node.setSizeForImage?.();
 | 
					 | 
				
			||||||
          };
 | 
					 | 
				
			||||||
          img.onerror = () => {
 | 
					 | 
				
			||||||
            node.imgs = [];
 | 
					 | 
				
			||||||
            app.graph.setDirtyCanvas(true);
 | 
					 | 
				
			||||||
          };
 | 
					 | 
				
			||||||
          img.src = url;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Set up URL widget value handling
 | 
					 | 
				
			||||||
        let default_value = urlWidget.value;
 | 
					 | 
				
			||||||
        Object.defineProperty(urlWidget, "value", {
 | 
					 | 
				
			||||||
          set: function (value) {
 | 
					 | 
				
			||||||
            this._real_value = value;
 | 
					 | 
				
			||||||
            // Preview image when URL changes
 | 
					 | 
				
			||||||
            if (value) {
 | 
					 | 
				
			||||||
              showImage(value);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
          },
 | 
					 | 
				
			||||||
          get: function () {
 | 
					 | 
				
			||||||
            return this._real_value || default_value;
 | 
					 | 
				
			||||||
          },
 | 
					 | 
				
			||||||
        });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Show initial image if URL exists
 | 
					 | 
				
			||||||
        requestAnimationFrame(() => {
 | 
					 | 
				
			||||||
          if (urlWidget.value) {
 | 
					 | 
				
			||||||
            showImage(urlWidget.value);
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return { widget: urlWidget };
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
    };
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  async setup() {
 | 
					  async setup() {
 | 
				
			||||||
    // const graphCanvas = document.getElementById("graph-canvas");
 | 
					    // const graphCanvas = document.getElementById("graph-canvas");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -726,7 +639,6 @@ const ext = {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
            console.log("loadGraphData");
 | 
					            console.log("loadGraphData");
 | 
				
			||||||
            app.loadGraphData(comfyUIWorkflow);
 | 
					            app.loadGraphData(comfyUIWorkflow);
 | 
				
			||||||
            sendEventToCD("graph_loaded");
 | 
					 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        } else if (message.type === "deploy") {
 | 
					        } else if (message.type === "deploy") {
 | 
				
			||||||
          // deployWorkflow();
 | 
					          // deployWorkflow();
 | 
				
			||||||
@ -741,35 +653,11 @@ const ext = {
 | 
				
			|||||||
            console.warn("api.handlePromptGenerated is not a function");
 | 
					            console.warn("api.handlePromptGenerated is not a function");
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
          sendEventToCD("cd_plugin_onQueuePrompt", prompt);
 | 
					          sendEventToCD("cd_plugin_onQueuePrompt", prompt);
 | 
				
			||||||
        } else if (message.type === "configure_queue_buttons") {
 | 
					 | 
				
			||||||
          addQueueButtons(message.data);
 | 
					 | 
				
			||||||
        } else if (message.type === "configure_menu_right_buttons") {
 | 
					 | 
				
			||||||
          addMenuRightButtons(message.data);
 | 
					 | 
				
			||||||
        } else if (message.type === "configure_menu_buttons") {
 | 
					 | 
				
			||||||
          addMenuButtons(message.data);
 | 
					 | 
				
			||||||
        } else if (message.type === "get_prompt") {
 | 
					        } else if (message.type === "get_prompt") {
 | 
				
			||||||
          const prompt = await app.graphToPrompt();
 | 
					          const prompt = await app.graphToPrompt();
 | 
				
			||||||
          sendEventToCD("cd_plugin_onGetPrompt", prompt);
 | 
					          sendEventToCD("cd_plugin_onGetPrompt", prompt);
 | 
				
			||||||
        } else if (message.type === "event") {
 | 
					        } else if (message.type === "event") {
 | 
				
			||||||
          dispatchAPIEventData(message.data);
 | 
					          dispatchAPIEventData(message.data);
 | 
				
			||||||
        } else if (message.type === "update_widget") {
 | 
					 | 
				
			||||||
          // New handler for updating widget values
 | 
					 | 
				
			||||||
          const { nodeId, widgetName, value } = message.data;
 | 
					 | 
				
			||||||
          const node = app.graph.getNodeById(nodeId);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (!node) {
 | 
					 | 
				
			||||||
            console.warn(`Node with ID ${nodeId} not found`);
 | 
					 | 
				
			||||||
            return;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          const widget = node.widgets?.find((w) => w.name === widgetName);
 | 
					 | 
				
			||||||
          if (!widget) {
 | 
					 | 
				
			||||||
            console.warn(`Widget ${widgetName} not found in node ${nodeId}`);
 | 
					 | 
				
			||||||
            return;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          widget.value = value;
 | 
					 | 
				
			||||||
          app.graph.setDirtyCanvas(true);
 | 
					 | 
				
			||||||
        } else if (message.type === "add_node") {
 | 
					        } else if (message.type === "add_node") {
 | 
				
			||||||
          console.log("add node", message.data);
 | 
					          console.log("add node", message.data);
 | 
				
			||||||
          app.graph.beforeChange();
 | 
					          app.graph.beforeChange();
 | 
				
			||||||
@ -867,9 +755,9 @@ const ext = {
 | 
				
			|||||||
        );
 | 
					        );
 | 
				
			||||||
        await app.ui.settings.setSettingValueAsync(
 | 
					        await app.ui.settings.setSettingValueAsync(
 | 
				
			||||||
          "Comfy.Sidebar.Location",
 | 
					          "Comfy.Sidebar.Location",
 | 
				
			||||||
          "left",
 | 
					          "right",
 | 
				
			||||||
        );
 | 
					        );
 | 
				
			||||||
        // localStorage.setItem("Comfy.MenuPosition.Docked", "true");
 | 
					        localStorage.setItem("Comfy.MenuPosition.Docked", "true");
 | 
				
			||||||
        console.log("native mode manmanman");
 | 
					        console.log("native mode manmanman");
 | 
				
			||||||
      } catch (error) {
 | 
					      } catch (error) {
 | 
				
			||||||
        console.error("Error setting validation to false", error);
 | 
					        console.error("Error setting validation to false", error);
 | 
				
			||||||
@ -1872,7 +1760,7 @@ app.extensionManager.registerSidebarTab({
 | 
				
			|||||||
      <div style="padding: 20px;">
 | 
					      <div style="padding: 20px;">
 | 
				
			||||||
        <h3>Comfy Deploy</h3>
 | 
					        <h3>Comfy Deploy</h3>
 | 
				
			||||||
        <div id="deploy-container" style="margin-bottom: 20px;"></div>
 | 
					        <div id="deploy-container" style="margin-bottom: 20px;"></div>
 | 
				
			||||||
        <div id="workflows-container" style="display: none;">
 | 
					        <div id="workflows-container">
 | 
				
			||||||
          <h4>Your Workflows</h4>
 | 
					          <h4>Your Workflows</h4>
 | 
				
			||||||
          <div id="workflows-loading" style="display: flex; justify-content: center; align-items: center; height: 100px;">
 | 
					          <div id="workflows-loading" style="display: flex; justify-content: center; align-items: center; height: 100px;">
 | 
				
			||||||
            ${loadingIcon}
 | 
					            ${loadingIcon}
 | 
				
			||||||
@ -1972,16 +1860,10 @@ async function loadWorkflowApi(versionId) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const orginal_fetch_api = api.fetchApi;
 | 
					const orginal_fetch_api = api.fetchApi;
 | 
				
			||||||
api.fetchApi = async (route, options) => {
 | 
					api.fetchApi = async (route, options) => {
 | 
				
			||||||
  // console.log("Fetch API called with args:", route, options, ext.native_mode);
 | 
					  console.log("Fetch API called with args:", route, options, ext.native_mode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (route.startsWith("/prompt") && ext.native_mode) {
 | 
					  if (route.startsWith("/prompt") && ext.native_mode) {
 | 
				
			||||||
    const info = await getSelectedWorkflowInfo();
 | 
					    const info = await getSelectedWorkflowInfo();
 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (!info.workflow_id) {
 | 
					 | 
				
			||||||
      console.log("No workflow id found, fallback to original fetch");
 | 
					 | 
				
			||||||
      return await orginal_fetch_api.call(api, route, options);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    console.log("info", info);
 | 
					    console.log("info", info);
 | 
				
			||||||
    if (info) {
 | 
					    if (info) {
 | 
				
			||||||
      const body = JSON.parse(options.body);
 | 
					      const body = JSON.parse(options.body);
 | 
				
			||||||
@ -1995,7 +1877,6 @@ api.fetchApi = async (route, options) => {
 | 
				
			|||||||
        workflow_id: info.workflow_id,
 | 
					        workflow_id: info.workflow_id,
 | 
				
			||||||
        native_run_api_endpoint: info.native_run_api_endpoint,
 | 
					        native_run_api_endpoint: info.native_run_api_endpoint,
 | 
				
			||||||
        gpu_event_id: info.gpu_event_id,
 | 
					        gpu_event_id: info.gpu_event_id,
 | 
				
			||||||
        gpu: info.gpu,
 | 
					 | 
				
			||||||
      };
 | 
					      };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      return await fetch("/comfyui-deploy/run", {
 | 
					      return await fetch("/comfyui-deploy/run", {
 | 
				
			||||||
@ -2011,306 +1892,3 @@ api.fetchApi = async (route, options) => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  return await orginal_fetch_api.call(api, route, options);
 | 
					  return await orginal_fetch_api.call(api, route, options);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					 | 
				
			||||||
// Intercept window drag and drop events
 | 
					 | 
				
			||||||
const originalDropHandler = document.ondrop;
 | 
					 | 
				
			||||||
document.ondrop = async (e) => {
 | 
					 | 
				
			||||||
  console.log("Drop event intercepted:", e);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Prevent default browser behavior
 | 
					 | 
				
			||||||
  e.preventDefault();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Handle files if present
 | 
					 | 
				
			||||||
  if (e.dataTransfer?.files?.length > 0) {
 | 
					 | 
				
			||||||
    const files = Array.from(e.dataTransfer.files);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Send file data to parent directly as JSON
 | 
					 | 
				
			||||||
    sendDirectEventToCD("file_drop", {
 | 
					 | 
				
			||||||
      files: files,
 | 
					 | 
				
			||||||
      x: e.clientX,
 | 
					 | 
				
			||||||
      y: e.clientY,
 | 
					 | 
				
			||||||
      timestamp: Date.now(),
 | 
					 | 
				
			||||||
    });
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Call original handler if exists
 | 
					 | 
				
			||||||
  if (originalDropHandler) {
 | 
					 | 
				
			||||||
    originalDropHandler(e);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const originalDragEnterHandler = document.ondragenter;
 | 
					 | 
				
			||||||
document.ondragenter = (e) => {
 | 
					 | 
				
			||||||
  // Prevent default to allow drop
 | 
					 | 
				
			||||||
  e.preventDefault();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Send dragenter event to parent directly as JSON
 | 
					 | 
				
			||||||
  sendDirectEventToCD("file_dragenter", {
 | 
					 | 
				
			||||||
    x: e.clientX,
 | 
					 | 
				
			||||||
    y: e.clientY,
 | 
					 | 
				
			||||||
    timestamp: Date.now(),
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (originalDragEnterHandler) {
 | 
					 | 
				
			||||||
    originalDragEnterHandler(e);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const originalDragLeaveHandler = document.ondragleave;
 | 
					 | 
				
			||||||
document.ondragleave = (e) => {
 | 
					 | 
				
			||||||
  // Prevent default to allow drop
 | 
					 | 
				
			||||||
  e.preventDefault();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Send dragleave event to parent directly as JSON
 | 
					 | 
				
			||||||
  sendDirectEventToCD("file_dragleave", {
 | 
					 | 
				
			||||||
    x: e.clientX,
 | 
					 | 
				
			||||||
    y: e.clientY,
 | 
					 | 
				
			||||||
    timestamp: Date.now(),
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (originalDragLeaveHandler) {
 | 
					 | 
				
			||||||
    originalDragLeaveHandler(e);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const originalDragOverHandler = document.ondragover;
 | 
					 | 
				
			||||||
document.ondragover = (e) => {
 | 
					 | 
				
			||||||
  // Prevent default to allow drop
 | 
					 | 
				
			||||||
  e.preventDefault();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Send dragover event to parent directly as JSON
 | 
					 | 
				
			||||||
  sendDirectEventToCD("file_dragover", {
 | 
					 | 
				
			||||||
    x: e.clientX,
 | 
					 | 
				
			||||||
    y: e.clientY,
 | 
					 | 
				
			||||||
    timestamp: Date.now(),
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (originalDragOverHandler) {
 | 
					 | 
				
			||||||
    originalDragOverHandler(e);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Function to create a single button
 | 
					 | 
				
			||||||
function createQueueButton(config) {
 | 
					 | 
				
			||||||
  const button = document.createElement("button");
 | 
					 | 
				
			||||||
  button.id = `cd-button-${config.id}`;
 | 
					 | 
				
			||||||
  button.className =
 | 
					 | 
				
			||||||
    "p-button p-component p-button-icon-only p-button-secondary p-button-text";
 | 
					 | 
				
			||||||
  button.innerHTML = `
 | 
					 | 
				
			||||||
    <span class="p-button-icon pi ${config.icon}"></span>
 | 
					 | 
				
			||||||
    <span class="p-button-label"> </span>
 | 
					 | 
				
			||||||
  `;
 | 
					 | 
				
			||||||
  button.onclick = () => {
 | 
					 | 
				
			||||||
    const eventData =
 | 
					 | 
				
			||||||
      typeof config.eventData === "function"
 | 
					 | 
				
			||||||
        ? config.eventData()
 | 
					 | 
				
			||||||
        : config.eventData || {};
 | 
					 | 
				
			||||||
    sendEventToCD(config.event, eventData);
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
  button.setAttribute("data-pd-tooltip", config.tooltip);
 | 
					 | 
				
			||||||
  return button;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Function to add buttons to queue group
 | 
					 | 
				
			||||||
function addQueueButtons(buttonConfigs = DEFAULT_BUTTONS) {
 | 
					 | 
				
			||||||
  const queueButtonGroup = document.querySelector(".queue-button-group.flex");
 | 
					 | 
				
			||||||
  if (!queueButtonGroup) return;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Remove any existing CD buttons
 | 
					 | 
				
			||||||
  const existingButtons =
 | 
					 | 
				
			||||||
    queueButtonGroup.querySelectorAll('[id^="cd-button-"]');
 | 
					 | 
				
			||||||
  existingButtons.forEach((button) => button.remove());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Add new buttons
 | 
					 | 
				
			||||||
  buttonConfigs.forEach((config) => {
 | 
					 | 
				
			||||||
    const button = createQueueButton(config);
 | 
					 | 
				
			||||||
    queueButtonGroup.appendChild(button);
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// addMenuRightButtons([
 | 
					 | 
				
			||||||
//   {
 | 
					 | 
				
			||||||
//     id: "cd-button-save-image",
 | 
					 | 
				
			||||||
//     icon: "pi-save",
 | 
					 | 
				
			||||||
//     label: "Snapshot",
 | 
					 | 
				
			||||||
//     tooltip: "Save the current image to your output directory.",
 | 
					 | 
				
			||||||
//     event: "save_image",
 | 
					 | 
				
			||||||
//     eventData: () => ({}),
 | 
					 | 
				
			||||||
//   },
 | 
					 | 
				
			||||||
// ]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// addMenuLeftButtons([
 | 
					 | 
				
			||||||
//   {
 | 
					 | 
				
			||||||
//     id: "cd-button-back",
 | 
					 | 
				
			||||||
//     icon: `<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
 | 
					 | 
				
			||||||
//       <path d="M15 18L9 12L15 6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
 | 
					 | 
				
			||||||
//     </svg>`,
 | 
					 | 
				
			||||||
//     tooltip: "Go back to the previous page.",
 | 
					 | 
				
			||||||
//     event: "back",
 | 
					 | 
				
			||||||
//     eventData: () => ({}),
 | 
					 | 
				
			||||||
//   },
 | 
					 | 
				
			||||||
// ]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// addMenuButtons({
 | 
					 | 
				
			||||||
//   containerSelector: "body > div.comfyui-body-top > div",
 | 
					 | 
				
			||||||
//   buttonConfigs: [
 | 
					 | 
				
			||||||
//     {
 | 
					 | 
				
			||||||
//       id: "cd-button-workflow-1",
 | 
					 | 
				
			||||||
//       icon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24"><path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m16 3l4 4l-4 4m-6-4h10M8 13l-4 4l4 4m-4-4h9"/></svg>`,
 | 
					 | 
				
			||||||
//       label: "Workflow",
 | 
					 | 
				
			||||||
//       tooltip: "Go to Workflow 1",
 | 
					 | 
				
			||||||
//       event: "workflow_1",
 | 
					 | 
				
			||||||
//       // btnClasses: "",
 | 
					 | 
				
			||||||
//       eventData: () => ({}),
 | 
					 | 
				
			||||||
//     },
 | 
					 | 
				
			||||||
//     {
 | 
					 | 
				
			||||||
//       id: "cd-button-workflow-3",
 | 
					 | 
				
			||||||
//       // icon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24"><path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m16 3l4 4l-4 4m-6-4h10M8 13l-4 4l4 4m-4-4h9"/></svg>`,
 | 
					 | 
				
			||||||
//       label: "v1",
 | 
					 | 
				
			||||||
//       tooltip: "Go to Workflow 1",
 | 
					 | 
				
			||||||
//       event: "workflow_1",
 | 
					 | 
				
			||||||
//       // btnClasses: "",
 | 
					 | 
				
			||||||
//       eventData: () => ({}),
 | 
					 | 
				
			||||||
//     },
 | 
					 | 
				
			||||||
//     {
 | 
					 | 
				
			||||||
//       id: "cd-button-workflow-2",
 | 
					 | 
				
			||||||
//       icon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24"><g fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"><path d="M12 3v6"/><circle cx="12" cy="12" r="3"/><path d="M12 15v6"/></g></svg>`,
 | 
					 | 
				
			||||||
//       label: "Commit",
 | 
					 | 
				
			||||||
//       tooltip: "Commit the current workflow",
 | 
					 | 
				
			||||||
//       event: "commit",
 | 
					 | 
				
			||||||
//       style: {
 | 
					 | 
				
			||||||
//         backgroundColor: "oklch(.476 .114 61.907)",
 | 
					 | 
				
			||||||
//       },
 | 
					 | 
				
			||||||
//       eventData: () => ({}),
 | 
					 | 
				
			||||||
//     },
 | 
					 | 
				
			||||||
//   ],
 | 
					 | 
				
			||||||
//   buttonIdPrefix: "cd-button-workflow-",
 | 
					 | 
				
			||||||
//   insertBefore:
 | 
					 | 
				
			||||||
//     "body > div.comfyui-body-top > div > div.flex-grow.min-w-0.app-drag.h-full",
 | 
					 | 
				
			||||||
//   // containerStyle: { order: "3" }
 | 
					 | 
				
			||||||
// });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// addMenuButtons({
 | 
					 | 
				
			||||||
//   containerSelector:
 | 
					 | 
				
			||||||
//     "body > div.comfyui-body-top > div > div.flex-grow.min-w-0.app-drag.h-full",
 | 
					 | 
				
			||||||
//   clearContainer: true,
 | 
					 | 
				
			||||||
//   buttonConfigs: [],
 | 
					 | 
				
			||||||
//   buttonIdPrefix: "cd-button-p-",
 | 
					 | 
				
			||||||
//   containerStyle: { order: "-1" },
 | 
					 | 
				
			||||||
// });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Function to add buttons to a menu container
 | 
					 | 
				
			||||||
function addMenuButtons(options) {
 | 
					 | 
				
			||||||
  const {
 | 
					 | 
				
			||||||
    containerSelector,
 | 
					 | 
				
			||||||
    buttonConfigs,
 | 
					 | 
				
			||||||
    buttonIdPrefix = "cd-button-",
 | 
					 | 
				
			||||||
    containerClass = "comfyui-button-group",
 | 
					 | 
				
			||||||
    containerStyle = {},
 | 
					 | 
				
			||||||
    clearContainer = false,
 | 
					 | 
				
			||||||
    insertBefore = null, // New option to specify selector for insertion point
 | 
					 | 
				
			||||||
  } = options;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const menuContainer = document.querySelector(containerSelector);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (!menuContainer) return;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Remove any existing CD buttons
 | 
					 | 
				
			||||||
  const existingButtons = document.querySelectorAll(
 | 
					 | 
				
			||||||
    `[id^="${buttonIdPrefix}"]`,
 | 
					 | 
				
			||||||
  );
 | 
					 | 
				
			||||||
  existingButtons.forEach((button) => button.remove());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const container = document.createElement("div");
 | 
					 | 
				
			||||||
  container.className = containerClass;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Apply container styles
 | 
					 | 
				
			||||||
  Object.assign(container.style, containerStyle);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Clear existing content if specified
 | 
					 | 
				
			||||||
  if (clearContainer) {
 | 
					 | 
				
			||||||
    menuContainer.innerHTML = "";
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Create and add buttons
 | 
					 | 
				
			||||||
  buttonConfigs.forEach((config) => {
 | 
					 | 
				
			||||||
    const button = createMenuButton({
 | 
					 | 
				
			||||||
      ...config,
 | 
					 | 
				
			||||||
      idPrefix: buttonIdPrefix,
 | 
					 | 
				
			||||||
    });
 | 
					 | 
				
			||||||
    container.appendChild(button);
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Insert before specified element if provided, otherwise append
 | 
					 | 
				
			||||||
  if (insertBefore) {
 | 
					 | 
				
			||||||
    const targetElement = menuContainer.querySelector(insertBefore);
 | 
					 | 
				
			||||||
    if (targetElement) {
 | 
					 | 
				
			||||||
      menuContainer.insertBefore(container, targetElement);
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      menuContainer.appendChild(container);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    menuContainer.appendChild(container);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
function createMenuButton(config) {
 | 
					 | 
				
			||||||
  const {
 | 
					 | 
				
			||||||
    id,
 | 
					 | 
				
			||||||
    icon,
 | 
					 | 
				
			||||||
    label,
 | 
					 | 
				
			||||||
    btnClasses = "",
 | 
					 | 
				
			||||||
    tooltip,
 | 
					 | 
				
			||||||
    event,
 | 
					 | 
				
			||||||
    eventData,
 | 
					 | 
				
			||||||
    idPrefix,
 | 
					 | 
				
			||||||
    style = {},
 | 
					 | 
				
			||||||
  } = config;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const button = document.createElement("button");
 | 
					 | 
				
			||||||
  button.id = `${idPrefix}${id}`;
 | 
					 | 
				
			||||||
  button.className = `comfyui-button ${btnClasses}`;
 | 
					 | 
				
			||||||
  Object.assign(button.style, style);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Only add icon if provided
 | 
					 | 
				
			||||||
  const iconHtml = icon
 | 
					 | 
				
			||||||
    ? icon.startsWith("<svg")
 | 
					 | 
				
			||||||
      ? icon
 | 
					 | 
				
			||||||
      : `<span class="p-button-icon pi ${icon}"></span>`
 | 
					 | 
				
			||||||
    : "";
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  button.innerHTML = `
 | 
					 | 
				
			||||||
    ${iconHtml}
 | 
					 | 
				
			||||||
    ${label ? `<span class="p-button-label text-sm">${label}</span>` : ""}
 | 
					 | 
				
			||||||
  `;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  button.onclick = () => {
 | 
					 | 
				
			||||||
    const data =
 | 
					 | 
				
			||||||
      typeof eventData === "function" ? eventData() : eventData || {};
 | 
					 | 
				
			||||||
    sendEventToCD(event, data);
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (tooltip) {
 | 
					 | 
				
			||||||
    button.setAttribute("data-pd-tooltip", tooltip);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return button;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Refactored menu button functions
 | 
					 | 
				
			||||||
function addMenuLeftButtons(buttonConfigs) {
 | 
					 | 
				
			||||||
  addMenuButtons({
 | 
					 | 
				
			||||||
    containerSelector: "body > div.comfyui-body-top > div",
 | 
					 | 
				
			||||||
    buttonConfigs,
 | 
					 | 
				
			||||||
    buttonIdPrefix: "cd-button-left-",
 | 
					 | 
				
			||||||
    containerStyle: { order: "-1" },
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
function addMenuRightButtons(buttonConfigs) {
 | 
					 | 
				
			||||||
  addMenuButtons({
 | 
					 | 
				
			||||||
    containerSelector: ".comfyui-menu-right .flex",
 | 
					 | 
				
			||||||
    buttonConfigs,
 | 
					 | 
				
			||||||
    buttonIdPrefix: "cd-button-",
 | 
					 | 
				
			||||||
    containerStyle: {},
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user