diff --git a/custom_routes.py b/custom_routes.py index 30586e6..72c5382 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -18,13 +18,47 @@ import threading import hashlib import aiohttp import aiofiles -from typing import List, Union, Any, Optional +from typing import Dict, List, Union, Any, Optional from PIL import Image import copy import struct +from logging import basicConfig, getLogger +import logfire +# if os.environ.get('LOGFIRE_TOKEN', None) is not None: +logfire.configure( + send_to_logfire="if-token-present" +) +# basicConfig(handlers=[logfire.LogfireLoggingHandler()]) +logfire_handler = logfire.LogfireLoggingHandler() +logger = getLogger("comfy-deploy") +logger.addHandler(logfire_handler) + from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata +class EventEmitter: + def __init__(self): + self.listeners = {} + + def on(self, event, listener): + if event not in self.listeners: + self.listeners[event] = [] + self.listeners[event].append(listener) + + def off(self, event, listener): + if event in self.listeners: + self.listeners[event].remove(listener) + if not self.listeners[event]: + del self.listeners[event] + + def emit(self, event, *args, **kwargs): + if event in self.listeners: + for listener in self.listeners[event]: + listener(*args, **kwargs) + +# Create a global event emitter instance +event_emitter = EventEmitter() + api = None api_task = None @@ -32,19 +66,19 @@ cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' bypass_upload = os.environ.get('CD_BYPASS_UPLOAD', 'false').lower() == 'true' -print("CD_BYPASS_UPLOAD", bypass_upload) +logger.info(f"CD_BYPASS_UPLOAD {bypass_upload}") def clear_current_prompt(sid): prompt_server = server.PromptServer.instance to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list - - print("clearning out prompt: ", to_delete) + + logger.info("clearning out prompt: ", to_delete) for id_to_delete in to_delete: delete_func = lambda a: a[1] == id_to_delete prompt_server.prompt_queue.delete_queue_item(delete_func) - print("deleted prompt: ", id_to_delete, prompt_server.prompt_queue.get_tasks_remaining()) - + logger.info("deleted prompt: ", id_to_delete, prompt_server.prompt_queue.get_tasks_remaining()) + streaming_prompt_metadata[sid].running_prompt_ids.clear() def post_prompt(json_data): @@ -84,7 +118,7 @@ def post_prompt(json_data): } return response else: - print("invalid prompt:", valid[1]) + logger.info("invalid prompt:", valid[1]) return {"error": valid[1], "node_errors": valid[3]} else: return {"error": "no prompt", "node_errors": []} @@ -109,69 +143,69 @@ def apply_random_seed_to_workflow(workflow_api): workflow_api[key]['inputs']['seed'] = randomSeed(8); continue workflow_api[key]['inputs']['seed'] = randomSeed(); - + def apply_inputs_to_workflow(workflow_api: Any, inputs: Any, sid: str = None): # Loop through each of the inputs and replace them for key, value in workflow_api.items(): if 'inputs' in value: - + # Support websocket if sid is not None: if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): value['inputs']["client_id"] = sid if (value["class_type"] == "ComfyDeployWebscoketImageInput"): value['inputs']["client_id"] = sid - + if "input_id" in value['inputs'] and inputs is not None and value['inputs']['input_id'] in inputs: new_value = inputs[value['inputs']['input_id']] - + # Lets skip it if its an image if isinstance(new_value, Image.Image): continue - + # Backward compactibility value['inputs']["input_id"] = new_value - + # Fix for external text default value if (value["class_type"] == "ComfyUIDeployExternalText"): value['inputs']["default_value"] = new_value - + if (value["class_type"] == "ComfyUIDeployExternalCheckpoint"): value['inputs']["default_value"] = new_value - + if (value["class_type"] == "ComfyUIDeployExternalImageBatch"): value['inputs']["images"] = new_value - + if value["class_type"] == "ComfyUIDeployExternalLora": value["inputs"]["default_lora_name"] = new_value - + if value["class_type"] == "ComfyUIDeployExternalSlider": value["inputs"]["default_value"] = new_value - + if value["class_type"] == "ComfyUIDeployExternalBoolean": value["inputs"]["default_value"] = new_value def send_prompt(sid: str, inputs: StreamingPrompt): # workflow_api = inputs.workflow_api workflow_api = copy.deepcopy(inputs.workflow_api) - + # Random seed apply_random_seed_to_workflow(workflow_api) - - print("getting inputs" , inputs.inputs) - + + logger.info("getting inputs" , inputs.inputs) + apply_inputs_to_workflow(workflow_api, inputs.inputs, sid=sid) - - print(workflow_api) - + + logger.info(workflow_api) + prompt_id = str(uuid.uuid4()) - + prompt = { "prompt": workflow_api, "client_id": sid, #"comfy_deploy_instance", #api.client_id "prompt_id": prompt_id } - + try: res = post_prompt(prompt) inputs.running_prompt_ids.add(prompt_id) @@ -185,12 +219,11 @@ def send_prompt(sid: str, inputs: StreamingPrompt): error_type = type(e).__name__ stack_trace_short = traceback.format_exc().strip().split('\n')[-2] stack_trace = traceback.format_exc().strip() - print(f"error: {error_type}, {e}") - print(f"stack trace: {stack_trace_short}") - + logger.info(f"error: {error_type}, {e}") + logger.info(f"stack trace: {stack_trace_short}") + @server.PromptServer.instance.routes.post("/comfyui-deploy/run") async def comfy_deploy_run(request): - prompt_server = server.PromptServer.instance data = await request.json() # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky @@ -221,8 +254,8 @@ async def comfy_deploy_run(request): error_type = type(e).__name__ stack_trace_short = traceback.format_exc().strip().split('\n')[-2] stack_trace = traceback.format_exc().strip() - print(f"error: {error_type}, {e}") - print(f"stack trace: {stack_trace_short}") + logger.info(f"error: {error_type}, {e}") + logger.info(f"stack trace: {stack_trace_short}") await update_run_with_output(prompt_id, { "error": { "error_type": error_type, @@ -234,14 +267,7 @@ async def comfy_deploy_run(request): return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}") status = 200 - # if "error" in res: - # status = 400 - # await update_run_with_output(prompt_id, { - # "error": { - # **res - # } - # }) - + if "node_errors" in res and res["node_errors"]: # Even tho there are node_errors it can still be run status = 400 @@ -257,24 +283,134 @@ async def comfy_deploy_run(request): return web.json_response(res, status=status) +async def stream_prompt(data): + # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky + workflow_api = data.get("workflow_api_raw") + # The prompt id generated from comfy deploy, can be None + prompt_id = data.get("prompt_id") + inputs = data.get("inputs") + + # Now it handles directly in here + apply_random_seed_to_workflow(workflow_api) + apply_inputs_to_workflow(workflow_api, inputs) + + prompt = { + "prompt": workflow_api, + "client_id": "comfy_deploy_instance", #api.client_id + "prompt_id": prompt_id + } + + prompt_metadata[prompt_id] = SimplePrompt( + status_endpoint=data.get('status_endpoint'), + file_upload_endpoint=data.get('file_upload_endpoint'), + workflow_api=workflow_api + ) + + logfire.info("Begin prompt", prompt=prompt) + + try: + res = post_prompt(prompt) + except Exception as e: + error_type = type(e).__name__ + stack_trace_short = traceback.format_exc().strip().split('\n')[-2] + stack_trace = traceback.format_exc().strip() + logger.info(f"error: {error_type}, {e}") + logger.info(f"stack trace: {stack_trace_short}") + await update_run_with_output(prompt_id, { + "error": { + "error_type": error_type, + "stack_trace": stack_trace + } + }) + # When there are critical errors, the prompt is actually not run + await update_run(prompt_id, Status.FAILED) + # return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}") + # raise Exception("Prompt failed") + + status = 200 + + if "node_errors" in res and res["node_errors"]: + # Even tho there are node_errors it can still be run + status = 400 + await update_run_with_output(prompt_id, { + "error": { + **res + } + }) + + # When there are critical errors, the prompt is actually not run + if "error" in res: + await update_run(prompt_id, Status.FAILED) + # raise Exception("Prompt failed") + + return res + # return web.json_response(res, status=status) + +comfy_message_queues: Dict[str, asyncio.Queue] = {} + +@server.PromptServer.instance.routes.post('/comfyui-deploy/run/streaming') +async def stream_response(request): + response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'}) + await response.prepare(request) + + pending = True + data = await request.json() + + prompt_id = data.get("prompt_id") + comfy_message_queues[prompt_id] = asyncio.Queue() + + with logfire.span('Streaming Run'): + logfire.info('Streaming prompt') + + try: + result = await stream_prompt(data=data) + await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8')) + # await response.write(.encode('utf-8')) + await response.drain() # Ensure the buffer is flushed + + while pending: + if prompt_id in comfy_message_queues: + if not comfy_message_queues[prompt_id].empty(): + data = await comfy_message_queues[prompt_id].get() + + logfire.info(data["event"], data=json.dumps(data)) + # logger.info("listener", data) + await response.write(f"event: event_update\ndata: {json.dumps(data)}\n\n".encode('utf-8')) + await response.drain() # Ensure the buffer is flushed + + if data["event"] == "status": + if data["data"]["status"] in (Status.FAILED.value, Status.SUCCESS.value): + pending = False + + await asyncio.sleep(0.1) # Adjust the sleep duration as needed + except asyncio.CancelledError: + logfire.info("Streaming was cancelled") + raise + except Exception as e: + logfire.error("Streaming error", error=e) + finally: + # event_emitter.off("send_json", task) + await response.write_eof() + comfy_message_queues.pop(prompt_id, None) + return response def get_comfyui_path_from_file_path(file_path): file_path_parts = file_path.split("\\") if file_path_parts[0] == "input": - print("matching input") + logger.info("matching input") file_path = os.path.join(folder_paths.get_directory_by_type("input"), *file_path_parts[1:]) elif file_path_parts[0] == "models": - print("matching models") + logger.info("matching models") file_path = folder_paths.get_full_path(file_path_parts[1], os.path.join(*file_path_parts[2:])) - print(file_path) + logger.info(file_path) return file_path # Form ComfyUI Manager async def compute_sha256_checksum(filepath): - print("computing sha256 checksum") + logger.info("computing sha256 checksum") chunk_size = 1024 * 256 # Example: 256KB filepath = get_comfyui_path_from_file_path(filepath) """Compute the SHA256 checksum of a file, in chunks, asynchronously""" @@ -297,7 +433,7 @@ async def get_installed_models(request): file_list = folder_paths.get_filename_list(key) value_json_compatible = (value[0], list(value[1]), file_list) new_dict[key] = value_json_compatible - # print(new_dict) + # logger.info(new_dict) return web.json_response(new_dict) # This is start uploading the files to Comfy Deploy @@ -307,7 +443,7 @@ async def upload_file_endpoint(request): file_path = data.get("file_path") - print("Original file path", file_path) + logger.info("Original file path", file_path) file_path = get_comfyui_path_from_file_path(file_path) @@ -321,7 +457,7 @@ async def upload_file_endpoint(request): try: base = folder_paths.base_path file_path = os.path.join(base, file_path) - + if os.path.exists(file_path): file_size = os.path.getsize(file_path) file_extension = os.path.splitext(file_path)[1] @@ -378,11 +514,11 @@ async def upload_file_endpoint(request): return web.json_response({ "error": f"An error occurred while fetching data from {get_url}: {str(e)}" }, status=500) - + return web.json_response({ "error": f"File not uploaded" }, status=500) - + script_dir = os.path.dirname(os.path.abspath(__file__)) # Assuming the cache file is stored in the same directory as this script @@ -416,7 +552,7 @@ async def get_file_hash(request): return web.json_response({ "error": "file_path is required" }, status=400) - + try: base = folder_paths.base_path full_file_path = os.path.join(base, file_path) @@ -429,11 +565,11 @@ async def get_file_hash(request): file_hash = await compute_sha256_checksum(full_file_path) end_time = time.time() elapsed_time = end_time - start_time - print(f"Cache miss -> Execution time: {elapsed_time} seconds") + logger.info(f"Cache miss -> Execution time: {elapsed_time} seconds") # Update the in-memory cache file_hash_cache[full_file_path] = file_hash - + save_cache() return web.json_response({ @@ -443,12 +579,14 @@ async def get_file_hash(request): return web.json_response({ "error": str(e) }, status=500) - + async def update_realtime_run_status(realtime_id: str, status_endpoint: str, status: Status): body = { "run_id": realtime_id, "status": status.value, } + if (status_endpoint is None): + return # requests.post(status_endpoint, json=body) async with aiohttp.ClientSession() as session: async with session.post(status_endpoint, json=body) as response: @@ -466,34 +604,34 @@ async def websocket_handler(request): sid = uuid.uuid4().hex sockets[sid] = ws - + auth_token = request.rel_url.query.get('token', None) get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None) realtime_id = request.rel_url.query.get('realtime_id', None) status_endpoint = request.rel_url.query.get('status_endpoint', None) - + if auth_token is not None and get_workflow_endpoint_url is not None: async with aiohttp.ClientSession() as session: headers = {'Authorization': f'Bearer {auth_token}'} async with session.get(get_workflow_endpoint_url, headers=headers) as response: if response.status == 200: workflow = await response.json() - - print("Loaded workflow version ",workflow["version"]) - + + logger.info(f"Loaded workflow version ${workflow['version']}") + streaming_prompt_metadata[sid] = StreamingPrompt( - workflow_api=workflow["workflow_api"], + workflow_api=workflow["workflow_api"], auth_token=auth_token, inputs={}, status_endpoint=status_endpoint, file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), ) - + await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING) # await send("workflow_api", workflow_api, sid) else: error_message = await response.text() - print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") + logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") # await send("error", {"message": error_message}, sid) try: @@ -503,15 +641,15 @@ async def websocket_handler(request): # Make sure when its connected via client, the full log is not being sent if cd_enable_log and get_workflow_endpoint_url is None: await send_first_time_log(sid) - + async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: try: data = json.loads(msg.data) - print(data) + logger.info(data) event_type = data.get('event') if event_type == 'input': - print("Got input: ", data.get("inputs")) + logger.info(f"Got input: ${data.get('inputs')}") input = data.get('inputs') streaming_prompt_metadata[sid].inputs.update(input) elif event_type == 'queue_prompt': @@ -521,8 +659,8 @@ async def websocket_handler(request): # Handle other event types pass except json.JSONDecodeError: - print('Failed to decode JSON from message') - + logger.info('Failed to decode JSON from message') + if msg.type == aiohttp.WSMsgType.BINARY: data = msg.data event_type, = struct.unpack(" 0: - print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes)) + logger.info(f"have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}") return True - print("no pending upload") + logger.info("no pending upload") return False def mark_prompt_done(prompt_id): @@ -867,7 +1047,7 @@ def mark_prompt_done(prompt_id): """ if prompt_id in prompt_metadata: prompt_metadata[prompt_id].done = True - print("Prompt done") + logger.info("Prompt done") def is_prompt_done(prompt_id: str): """ @@ -899,8 +1079,8 @@ async def handle_error(prompt_id, data, e: Exception): } } await update_file_status(prompt_id, data, False, have_error=True) - print(body) - print(f"Error occurred while uploading file: {e}") + logger.info(body) + logger.info(f"Error occurred while uploading file: {e}") # Mark the current prompt requires upload, and block it from being marked as success async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None): @@ -913,11 +1093,11 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False, else: prompt_metadata[prompt_id].uploading_nodes.discard(node_id) - print(prompt_metadata[prompt_id].uploading_nodes) + logger.info(prompt_metadata[prompt_id].uploading_nodes) # Update the remote status if have_error: - update_run(prompt_id, Status.FAILED) + await update_run(prompt_id, Status.FAILED) await send("failed", { "prompt_id": prompt_id, }) @@ -926,15 +1106,15 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False, # if there are still nodes that are uploading, then we set the status to uploading if uploading: if prompt_metadata[prompt_id].status != Status.UPLOADING: - update_run(prompt_id, Status.UPLOADING) + await update_run(prompt_id, Status.UPLOADING) await send("uploading", { "prompt_id": prompt_id, }) - + # if there are no nodes that are uploading, then we set the status to success elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id): - update_run(prompt_id, Status.SUCCESS) - # print("Status: SUCCUSS") + await update_run(prompt_id, Status.SUCCESS) + # logger.info("Status: SUCCUSS") await send("success", { "prompt_id": prompt_id, }) @@ -945,7 +1125,7 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d # # Skipping temp files if item.get("type") == "temp": continue - + file_type = item.get(content_type_key, default_content_type) file_extension = os.path.splitext(item.get("filename"))[1] if file_extension in ['.jpg', '.jpeg']: @@ -954,7 +1134,7 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d file_type = 'image/png' elif file_extension == '.webp': file_type = 'image/webp' - + await upload_file( prompt_id, item.get("filename"), @@ -971,7 +1151,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T # This will also be mp4 await handle_upload(prompt_id, data, 'gifs', "format", "image/gif") await handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream") - + if have_upload: await update_file_status(prompt_id, data, False, node_id=node_id) except Exception as e: @@ -980,10 +1160,10 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T async def update_run_with_output(prompt_id, data, node_id=None): if prompt_id not in prompt_metadata: return - + if prompt_metadata[prompt_id].is_realtime is True: return - + status_endpoint = prompt_metadata[prompt_id].status_endpoint body = { @@ -997,7 +1177,7 @@ async def update_run_with_output(prompt_id, data, node_id=None): if have_upload_media: try: - print("\nhave_upload", have_upload_media, node_id) + logger.info(f"\nhave_upload {have_upload} {node_id}") if have_upload_media: await update_file_status(prompt_id, data, True, node_id=node_id) @@ -1008,7 +1188,11 @@ async def update_run_with_output(prompt_id, data, node_id=None): except Exception as e: await handle_error(prompt_id, data, e) - requests.post(status_endpoint, json=body) + # requests.post(status_endpoint, json=body) + if status_endpoint is not None: + async with aiohttp.ClientSession() as session: + async with session.post(status_endpoint, json=body) as response: + pass await send('outputs_uploaded', { "prompt_id": prompt_id diff --git a/globals.py b/globals.py index d608209..fd700a1 100644 --- a/globals.py +++ b/globals.py @@ -22,12 +22,13 @@ class StreamingPrompt(BaseModel): auth_token: str inputs: dict[str, Union[str, bytes, Image.Image]] running_prompt_ids: set[str] = set() - status_endpoint: str - file_upload_endpoint: str + status_endpoint: Optional[str] + file_upload_endpoint: Optional[str] class SimplePrompt(BaseModel): - status_endpoint: str - file_upload_endpoint: str + status_endpoint: Optional[str] + file_upload_endpoint: Optional[str] + workflow_api: dict status: Status = Status.NOT_STARTED progress: set = set() diff --git a/requirements.txt b/requirements.txt index db40ae4..cedfa7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ aiofiles pydantic opencv-python -imageio-ffmpeg \ No newline at end of file +imageio-ffmpeg +logfire \ No newline at end of file diff --git a/web-plugin/index.js b/web-plugin/index.js index a459bb8..95f325b 100644 --- a/web-plugin/index.js +++ b/web-plugin/index.js @@ -13,6 +13,78 @@ function sendEventToCD(event, data) { window.parent.postMessage(JSON.stringify(message), "*"); } +function dispatchAPIEventData(data) { + const msg = JSON.parse(data); + + // Custom parse error + if (msg.error) { + let message = msg.error.message; + if (msg.error.details) + message += ": " + msg.error.details; + for (const [nodeID, nodeError] of Object.entries( + msg.node_errors, + )) { + message += "\n" + nodeError.class_type + ":"; + for (const errorReason of nodeError.errors) { + message += + "\n - " + errorReason.message + ": " + errorReason.details; + } + } + + app.ui.dialog.show(message); + if (msg.node_errors) { + app.lastNodeErrors = msg.node_errors; + app.canvas.draw(true, true); + } + } + + switch (msg.event) { + case "error": + break; + case "status": + if (msg.data.sid) { + // this.clientId = msg.data.sid; + // window.name = this.clientId; // use window name so it isnt reused when duplicating tabs + // sessionStorage.setItem("clientId", this.clientId); // store in session storage so duplicate tab can load correct workflow + } + api.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); + break; + case "progress": + api.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); + break; + case "executing": + api.dispatchEvent( + new CustomEvent("executing", { detail: msg.data.node }), + ); + break; + case "executed": + api.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); + break; + case "execution_start": + api.dispatchEvent( + new CustomEvent("execution_start", { detail: msg.data }), + ); + break; + case "execution_error": + api.dispatchEvent( + new CustomEvent("execution_error", { detail: msg.data }), + ); + break; + case "execution_cached": + api.dispatchEvent( + new CustomEvent("execution_cached", { detail: msg.data }), + ); + break; + default: + api.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); + // default: + // if (this.#registered.has(msg.type)) { + // } else { + // throw new Error(`Unknown message type ${msg.type}`); + // } + } +} + /** @typedef {import('../../../web/types/comfy.js').ComfyExtension} ComfyExtension*/ /** @type {ComfyExtension} */ const ext = { @@ -33,11 +105,10 @@ const ext = { sendEventToCD("cd_plugin_onInit"); - app.queuePrompt = ((originalFunction) => - async () => { - // const prompt = await app.graphToPrompt(); - sendEventToCD("cd_plugin_onQueuePromptTrigger"); - })(app.queuePrompt); + app.queuePrompt = ((originalFunction) => async () => { + // const prompt = await app.graphToPrompt(); + sendEventToCD("cd_plugin_onQueuePromptTrigger"); + })(app.queuePrompt); // // Intercept the onkeydown event // window.addEventListener( @@ -208,7 +279,12 @@ const ext = { } else if (message.type === "queue_prompt") { const prompt = await app.graphToPrompt(); sendEventToCD("cd_plugin_onQueuePrompt", prompt); + } else if (message.type === "event") { + dispatchAPIEventData(message.data); } + // else if (message.type === "refresh") { + // sendEventToCD("cd_plugin_onRefresh"); + // } } catch (error) { // console.error("Error processing message:", error); }