diff --git a/custom_routes.py b/custom_routes.py index e402c91..b320b5b 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -1,34 +1,24 @@ - - from aiohttp import web import os import requests import folder_paths import json -import numpy as np import server -import re -import base64 from PIL import Image -import io import time import execution import random import traceback import uuid import asyncio -import atexit import logging -import sys -from logging.handlers import RotatingFileHandler from enum import Enum from urllib.parse import quote import threading import hashlib import aiohttp import aiofiles -import concurrent.futures -from typing import List, Union, Any +from typing import List, Union, Any, Optional from PIL import Image import copy @@ -39,17 +29,37 @@ from pydantic import BaseModel as PydanticBaseModel class BaseModel(PydanticBaseModel): class Config: arbitrary_types_allowed = True + +class Status(Enum): + NOT_STARTED = "not-started" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + UPLOADING = "uploading" class StreamingPrompt(BaseModel): workflow_api: Any auth_token: str inputs: dict[str, Union[str, bytes, Image.Image]] + running_prompt_ids: set[str] = set() + status_endpoint: str + file_upload_endpoint: str + +class SimplePrompt(BaseModel): + status_endpoint: str + file_upload_endpoint: str + workflow_api: dict + status: Status = Status.NOT_STARTED + progress: set = set() + last_updated_node: Optional[str] = None, + uploading_nodes: set = set() + done: bool = False streaming_prompt_metadata: dict[str, StreamingPrompt] = {} api = None api_task = None -prompt_metadata = {} +prompt_metadata: dict[str, SimplePrompt] = {} cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' @@ -109,7 +119,7 @@ def send_prompt(sid: str, inputs: StreamingPrompt): if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']: workflow_api[key]['inputs']['seed'] = randomSeed() - print("getting inputs" ,inputs.inputs) + print("getting inputs" , inputs.inputs) # Loop through each of the inputs and replace them for key, value in workflow_api.items(): @@ -137,6 +147,12 @@ def send_prompt(sid: str, inputs: StreamingPrompt): try: res = post_prompt(prompt) + inputs.running_prompt_ids.add(prompt_id) + prompt_metadata[prompt_id] = SimplePrompt( + status_endpoint=inputs.status_endpoint, + file_upload_endpoint=inputs.file_upload_endpoint, + workflow_api=workflow_api + ) except Exception as e: error_type = type(e).__name__ stack_trace_short = traceback.format_exc().strip().split('\n')[-2] @@ -164,11 +180,11 @@ async def comfy_deploy_run(request): "prompt_id": prompt_id } - prompt_metadata[prompt_id] = { - 'status_endpoint': data.get('status_endpoint'), - 'file_upload_endpoint': data.get('file_upload_endpoint'), - 'workflow_api': workflow_api - } + prompt_metadata[prompt_id] = SimplePrompt( + status_endpoint=data.get('status_endpoint'), + file_upload_endpoint=data.get('file_upload_endpoint'), + workflow_api=workflow_api + ) try: res = post_prompt(prompt) @@ -244,7 +260,7 @@ async def compute_sha256_checksum(filepath): # This is start uploading the files to Comfy Deploy @server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file') -async def upload_file(request): +async def upload_file_endpoint(request): data = await request.json() file_path = data.get("file_path") @@ -367,28 +383,31 @@ async def websocket_handler(request): sockets[sid] = ws - auth_token = request.rel_url.query.get('token', '') - get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '') + auth_token = request.rel_url.query.get('token', None) + get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None) - async with aiohttp.ClientSession() as session: - headers = {'Authorization': f'Bearer {auth_token}'} - async with session.get(get_workflow_endpoint_url, headers=headers) as response: - if response.status == 200: - workflow = await response.json() - - print("Loaded workflow version ",workflow["version"]) - - streaming_prompt_metadata[sid] = StreamingPrompt( - workflow_api=workflow["workflow_api"], - auth_token=auth_token, - inputs={} - ) - - # await send("workflow_api", workflow_api, sid) - else: - error_message = await response.text() - print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") - # await send("error", {"message": error_message}, sid) + if auth_token is not None and get_workflow_endpoint_url is not None: + async with aiohttp.ClientSession() as session: + headers = {'Authorization': f'Bearer {auth_token}'} + async with session.get(get_workflow_endpoint_url, headers=headers) as response: + if response.status == 200: + workflow = await response.json() + + print("Loaded workflow version ",workflow["version"]) + + streaming_prompt_metadata[sid] = StreamingPrompt( + workflow_api=workflow["workflow_api"], + auth_token=auth_token, + inputs={}, + status_endpoint=request.rel_url.query.get('status_endpoint', None), + file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), + ) + + # await send("workflow_api", workflow_api, sid) + else: + error_message = await response.text() + print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") + # await send("error", {"message": error_message}, sid) try: # Send initial state to the new client @@ -422,16 +441,28 @@ async def websocket_handler(request): @server.PromptServer.instance.routes.get('/comfyui-deploy/check-status') async def comfy_deploy_check_status(request): - prompt_server = server.PromptServer.instance prompt_id = request.rel_url.query.get('prompt_id', None) - if prompt_id in prompt_metadata and 'status' in prompt_metadata[prompt_id]: + if prompt_id in prompt_metadata: return web.json_response({ - "status": prompt_metadata[prompt_id]['status'].value + "status": prompt_metadata[prompt_id].status.value }) else: return web.json_response({ "message": "prompt_id not found" }) + +@server.PromptServer.instance.routes.get('/comfyui-deploy/check-ws-status') +async def comfy_deploy_check_ws_status(request): + client_id = request.rel_url.query.get('client_id', None) + if client_id in streaming_prompt_metadata: + remaining_queue = 0 # Initialize remaining queue count + for prompt_id in streaming_prompt_metadata[client_id].running_prompt_ids[client_id]: + prompt_status = prompt_metadata[prompt_id].status + if prompt_status not in [Status.FAILED, Status.SUCCESS]: + remaining_queue += 1 # Increment for each prompt still running + return web.json_response({"remaining_queue": remaining_queue}) + else: + return web.json_response({"message": "client_id not found"}, status=404) async def send(event, data, sid=None): try: @@ -451,8 +482,8 @@ async def send(event, data, sid=None): logging.basicConfig(level=logging.INFO) prompt_server = server.PromptServer.instance - send_json = prompt_server.send_json + async def send_json_override(self, event, data, sid=None): # print("INTERNAL:", event, data, sid) prompt_id = data.get('prompt_id') @@ -476,28 +507,28 @@ async def send_json_override(self, event, data, sid=None): node = data.get('node') if prompt_id in prompt_metadata: - if 'progress' not in prompt_metadata[prompt_id]: - prompt_metadata[prompt_id]["progress"] = set() - - prompt_metadata[prompt_id]["progress"].add(node) - calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api']) + # if 'progress' not in prompt_metadata[prompt_id]: + # prompt_metadata[prompt_id]["progress"] = set() + + prompt_metadata[prompt_id].progress.add(node) + calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api) # print("calculated_progress", calculated_progress) - - if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node: + + if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node: return - prompt_metadata[prompt_id]['last_updated_node'] = node - class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type'] + prompt_metadata[prompt_id].last_updated_node = node + class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type'] print("updating run live status", class_type) await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress) if event == 'execution_cached' and data.get('nodes') is not None: if prompt_id in prompt_metadata: - if 'progress' not in prompt_metadata[prompt_id]: - prompt_metadata[prompt_id]["progress"] = set() - + # if 'progress' not in prompt_metadata[prompt_id]: + # prompt_metadata[prompt_id].progress = set() + if 'nodes' in data: for node in data.get('nodes', []): - prompt_metadata[prompt_id]["progress"].add(node) + prompt_metadata[prompt_id].progress.add(node) # prompt_metadata[prompt_id]["progress"].update(data.get('nodes')) if event == 'execution_error': @@ -511,13 +542,6 @@ async def send_json_override(self, event, data, sid=None): # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # update_run_with_output(prompt_id, data.get('output')) -class Status(Enum): - NOT_STARTED = "not-started" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - UPLOADING = "uploading" - # Global variable to keep track of the last read line number last_read_line_number = 0 @@ -527,7 +551,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl print("progress", calculated_progress) - status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] + status_endpoint = prompt_metadata[prompt_id].status_endpoint body = { "run_id": prompt_id, "live_status": live_status, @@ -539,19 +563,19 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl pass -def update_run(prompt_id, status: Status): +def update_run(prompt_id: str, status: Status): global last_read_line_number if prompt_id not in prompt_metadata: return - if ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status): + if (prompt_metadata[prompt_id].status != status): # when the status is already failed, we don't want to update it to success - if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['status'] == Status.FAILED): + if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id].status == Status.FAILED): return - status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] + status_endpoint = prompt_metadata[prompt_id].status_endpoint body = { "run_id": prompt_id, "status": status.value, @@ -598,7 +622,7 @@ def update_run(prompt_id, status: Status): stack_trace = traceback.format_exc().strip() print(f"Error occurred while updating run: {e} {stack_trace}") finally: - prompt_metadata[prompt_id]['status'] = status + prompt_metadata[prompt_id].status = status async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"): @@ -630,7 +654,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p print("uploading file", file) - file_upload_endpoint = prompt_metadata[prompt_id]['file_upload_endpoint'] + file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint filename = quote(filename) prompt_id = quote(prompt_id) @@ -654,23 +678,37 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p print("upload file response", response.status) def have_pending_upload(prompt_id): - if 'prompt_id' in prompt_metadata and 'uploading_nodes' in prompt_metadata[prompt_id] and len(prompt_metadata[prompt_id]['uploading_nodes']) > 0: - print("have pending upload ", len(prompt_metadata[prompt_id]['uploading_nodes'])) + if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0: + print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes)) return True print("no pending upload") return False def mark_prompt_done(prompt_id): + """ + Mark the prompt as done in the prompt metadata. + + Args: + prompt_id (str): The ID of the prompt to mark as done. + """ if prompt_id in prompt_metadata: - prompt_metadata[prompt_id]["done"] = True + prompt_metadata[prompt_id].done = True print("Prompt done") -def is_prompt_done(prompt_id): - if prompt_id in prompt_metadata and "done" in prompt_metadata[prompt_id]: - if prompt_metadata[prompt_id]["done"] == True: - return True - +def is_prompt_done(prompt_id: str): + """ + Check if the prompt with the given ID is marked as done. + + Args: + prompt_id (str): The ID of the prompt to check. + + Returns: + bool: True if the prompt is marked as done, False otherwise. + """ + if prompt_id in prompt_metadata and prompt_metadata[prompt_id].done: + return True + return False # Use to handle upload error and send back to ComfyDeploy @@ -692,17 +730,17 @@ async def handle_error(prompt_id, data, e: Exception): print(f"Error occurred while uploading file: {e}") # Mark the current prompt requires upload, and block it from being marked as success -async def update_file_status(prompt_id, data, uploading, have_error=False, node_id=None): - if 'uploading_nodes' not in prompt_metadata[prompt_id]: - prompt_metadata[prompt_id]['uploading_nodes'] = set() +async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None): + # if 'uploading_nodes' not in prompt_metadata[prompt_id]: + # prompt_metadata[prompt_id]['uploading_nodes'] = set() if node_id is not None: if uploading: - prompt_metadata[prompt_id]['uploading_nodes'].add(node_id) + prompt_metadata[prompt_id].uploading_nodes.add(node_id) else: - prompt_metadata[prompt_id]['uploading_nodes'].discard(node_id) + prompt_metadata[prompt_id].uploading_nodes.discard(node_id) - print(prompt_metadata[prompt_id]['uploading_nodes']) + print(prompt_metadata[prompt_id].uploading_nodes) # Update the remote status if have_error: @@ -714,12 +752,12 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_ # if there are still nodes that are uploading, then we set the status to uploading if uploading: - if prompt_metadata[prompt_id]['status'] != Status.UPLOADING: + if prompt_metadata[prompt_id].status != Status.UPLOADING: update_run(prompt_id, Status.UPLOADING) await send("uploading", { "prompt_id": prompt_id, }) - + # if there are no nodes that are uploading, then we set the status to success elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id): update_run(prompt_id, Status.SUCCESS) @@ -728,20 +766,20 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_ "prompt_id": prompt_id, }) -async def handle_upload(prompt_id, data, key, content_type_key, default_content_type): +async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str): items = data.get(key, []) for item in items: await upload_file( - prompt_id, - item.get("filename"), - subfolder=item.get("subfolder"), - type=item.get("type"), + prompt_id, + item.get("filename"), + subfolder=item.get("subfolder"), + type=item.get("type"), content_type=item.get(content_type_key, default_content_type) ) # Upload files in the background -async def upload_in_background(prompt_id, data, node_id=None, have_upload=True): +async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True): try: await handle_upload(prompt_id, data, 'images', "content_type", "image/png") await handle_upload(prompt_id, data, 'files', "content_type", "image/png") @@ -755,7 +793,7 @@ async def upload_in_background(prompt_id, data, node_id=None, have_upload=True): async def update_run_with_output(prompt_id, data, node_id=None): if prompt_id in prompt_metadata: - status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] + status_endpoint = prompt_metadata[prompt_id].status_endpoint body = { "run_id": prompt_id,