feat: add streaming endpoint
This commit is contained in:
		
							parent
							
								
									9b24b12006
								
							
						
					
					
						commit
						af0fac7afc
					
				
							
								
								
									
										343
									
								
								custom_routes.py
									
									
									
									
									
								
							
							
						
						
									
										343
									
								
								custom_routes.py
									
									
									
									
									
								
							@ -18,13 +18,45 @@ import threading
 | 
			
		||||
import hashlib
 | 
			
		||||
import aiohttp
 | 
			
		||||
import aiofiles
 | 
			
		||||
from typing import List, Union, Any, Optional
 | 
			
		||||
from typing import Dict, List, Union, Any, Optional
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import copy
 | 
			
		||||
import struct
 | 
			
		||||
 | 
			
		||||
from logging import basicConfig, getLogger
 | 
			
		||||
import logfire
 | 
			
		||||
if os.environ.get('LOGFIRE_TOKEN', None) is not None:
 | 
			
		||||
    logfire.configure()
 | 
			
		||||
# basicConfig(handlers=[logfire.LogfireLoggingHandler()])
 | 
			
		||||
logfire_handler = logfire.LogfireLoggingHandler()
 | 
			
		||||
logger = getLogger("comfy-deploy")
 | 
			
		||||
logger.addHandler(logfire_handler)
 | 
			
		||||
 | 
			
		||||
from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
 | 
			
		||||
 | 
			
		||||
class EventEmitter:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.listeners = {}
 | 
			
		||||
 | 
			
		||||
    def on(self, event, listener):
 | 
			
		||||
        if event not in self.listeners:
 | 
			
		||||
            self.listeners[event] = []
 | 
			
		||||
        self.listeners[event].append(listener)
 | 
			
		||||
 | 
			
		||||
    def off(self, event, listener):
 | 
			
		||||
        if event in self.listeners:
 | 
			
		||||
            self.listeners[event].remove(listener)
 | 
			
		||||
            if not self.listeners[event]:
 | 
			
		||||
                del self.listeners[event]
 | 
			
		||||
 | 
			
		||||
    def emit(self, event, *args, **kwargs):
 | 
			
		||||
        if event in self.listeners:
 | 
			
		||||
            for listener in self.listeners[event]:
 | 
			
		||||
                listener(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
# Create a global event emitter instance
 | 
			
		||||
event_emitter = EventEmitter()
 | 
			
		||||
 | 
			
		||||
api = None
 | 
			
		||||
api_task = None
 | 
			
		||||
 | 
			
		||||
@ -32,18 +64,18 @@ cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
 | 
			
		||||
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
 | 
			
		||||
bypass_upload = os.environ.get('CD_BYPASS_UPLOAD', 'false').lower() == 'true'
 | 
			
		||||
 | 
			
		||||
print("CD_BYPASS_UPLOAD", bypass_upload)
 | 
			
		||||
logger.info(f"CD_BYPASS_UPLOAD {bypass_upload}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_current_prompt(sid):
 | 
			
		||||
    prompt_server = server.PromptServer.instance
 | 
			
		||||
    to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids)  # Convert set to list
 | 
			
		||||
 | 
			
		||||
    print("clearning out prompt: ", to_delete)
 | 
			
		||||
    logger.info("clearning out prompt: ", to_delete)
 | 
			
		||||
    for id_to_delete in to_delete:
 | 
			
		||||
        delete_func = lambda a: a[1] == id_to_delete
 | 
			
		||||
        prompt_server.prompt_queue.delete_queue_item(delete_func)
 | 
			
		||||
        print("deleted prompt: ", id_to_delete, prompt_server.prompt_queue.get_tasks_remaining())
 | 
			
		||||
        logger.info("deleted prompt: ", id_to_delete, prompt_server.prompt_queue.get_tasks_remaining())
 | 
			
		||||
 | 
			
		||||
    streaming_prompt_metadata[sid].running_prompt_ids.clear()
 | 
			
		||||
 | 
			
		||||
@ -84,7 +116,7 @@ def post_prompt(json_data):
 | 
			
		||||
            }
 | 
			
		||||
            return response
 | 
			
		||||
        else:
 | 
			
		||||
            print("invalid prompt:", valid[1])
 | 
			
		||||
            logger.info("invalid prompt:", valid[1])
 | 
			
		||||
            return {"error": valid[1], "node_errors": valid[3]}
 | 
			
		||||
    else:
 | 
			
		||||
        return {"error": "no prompt", "node_errors": []}
 | 
			
		||||
@ -158,11 +190,11 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
			
		||||
    # Random seed
 | 
			
		||||
    apply_random_seed_to_workflow(workflow_api)
 | 
			
		||||
 | 
			
		||||
    print("getting inputs" , inputs.inputs)
 | 
			
		||||
    logger.info("getting inputs" , inputs.inputs)
 | 
			
		||||
 | 
			
		||||
    apply_inputs_to_workflow(workflow_api, inputs.inputs, sid=sid)
 | 
			
		||||
 | 
			
		||||
    print(workflow_api)
 | 
			
		||||
    logger.info(workflow_api)
 | 
			
		||||
 | 
			
		||||
    prompt_id = str(uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
@ -185,12 +217,11 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
			
		||||
        error_type = type(e).__name__
 | 
			
		||||
        stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
 | 
			
		||||
        stack_trace = traceback.format_exc().strip()
 | 
			
		||||
        print(f"error: {error_type}, {e}")
 | 
			
		||||
        print(f"stack trace: {stack_trace_short}")
 | 
			
		||||
        logger.info(f"error: {error_type}, {e}")
 | 
			
		||||
        logger.info(f"stack trace: {stack_trace_short}")
 | 
			
		||||
 | 
			
		||||
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
 | 
			
		||||
async def comfy_deploy_run(request):
 | 
			
		||||
    prompt_server = server.PromptServer.instance
 | 
			
		||||
    data = await request.json()
 | 
			
		||||
 | 
			
		||||
    # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
 | 
			
		||||
@ -221,8 +252,8 @@ async def comfy_deploy_run(request):
 | 
			
		||||
        error_type = type(e).__name__
 | 
			
		||||
        stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
 | 
			
		||||
        stack_trace = traceback.format_exc().strip()
 | 
			
		||||
        print(f"error: {error_type}, {e}")
 | 
			
		||||
        print(f"stack trace: {stack_trace_short}")
 | 
			
		||||
        logger.info(f"error: {error_type}, {e}")
 | 
			
		||||
        logger.info(f"stack trace: {stack_trace_short}")
 | 
			
		||||
        await update_run_with_output(prompt_id, {
 | 
			
		||||
            "error": {
 | 
			
		||||
                "error_type": error_type,
 | 
			
		||||
@ -234,13 +265,6 @@ async def comfy_deploy_run(request):
 | 
			
		||||
        return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}")
 | 
			
		||||
 | 
			
		||||
    status = 200
 | 
			
		||||
    # if "error" in res:
 | 
			
		||||
    #     status = 400
 | 
			
		||||
    #     await update_run_with_output(prompt_id, {
 | 
			
		||||
    #         "error": {
 | 
			
		||||
    #             **res
 | 
			
		||||
    #         }
 | 
			
		||||
    #     })
 | 
			
		||||
 | 
			
		||||
    if "node_errors" in res and res["node_errors"]:
 | 
			
		||||
        # Even tho there are node_errors it can still be run
 | 
			
		||||
@ -257,24 +281,133 @@ async def comfy_deploy_run(request):
 | 
			
		||||
 | 
			
		||||
    return web.json_response(res, status=status)
 | 
			
		||||
 | 
			
		||||
async def stream_prompt(data):
 | 
			
		||||
    # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
 | 
			
		||||
    workflow_api = data.get("workflow_api_raw")
 | 
			
		||||
    # The prompt id generated from comfy deploy, can be None
 | 
			
		||||
    prompt_id = data.get("prompt_id")
 | 
			
		||||
    inputs = data.get("inputs")
 | 
			
		||||
 | 
			
		||||
    # Now it handles directly in here
 | 
			
		||||
    apply_random_seed_to_workflow(workflow_api)
 | 
			
		||||
    apply_inputs_to_workflow(workflow_api, inputs)
 | 
			
		||||
 | 
			
		||||
    prompt = {
 | 
			
		||||
        "prompt": workflow_api,
 | 
			
		||||
        "client_id": "comfy_deploy_instance", #api.client_id
 | 
			
		||||
        "prompt_id": prompt_id
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    prompt_metadata[prompt_id] = SimplePrompt(
 | 
			
		||||
        status_endpoint=data.get('status_endpoint'),
 | 
			
		||||
        file_upload_endpoint=data.get('file_upload_endpoint'),
 | 
			
		||||
        workflow_api=workflow_api
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    logfire.info("Begin prompt", prompt=prompt)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        res = post_prompt(prompt)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        error_type = type(e).__name__
 | 
			
		||||
        stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
 | 
			
		||||
        stack_trace = traceback.format_exc().strip()
 | 
			
		||||
        logger.info(f"error: {error_type}, {e}")
 | 
			
		||||
        logger.info(f"stack trace: {stack_trace_short}")
 | 
			
		||||
        await update_run_with_output(prompt_id, {
 | 
			
		||||
            "error": {
 | 
			
		||||
                "error_type": error_type,
 | 
			
		||||
                "stack_trace": stack_trace
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
         # When there are critical errors, the prompt is actually not run
 | 
			
		||||
        await update_run(prompt_id, Status.FAILED)
 | 
			
		||||
        # return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}")
 | 
			
		||||
        raise Exception("Prompt failed")
 | 
			
		||||
 | 
			
		||||
    status = 200
 | 
			
		||||
 | 
			
		||||
    if "node_errors" in res and res["node_errors"]:
 | 
			
		||||
        # Even tho there are node_errors it can still be run
 | 
			
		||||
        status = 400
 | 
			
		||||
        await update_run_with_output(prompt_id, {
 | 
			
		||||
            "error": {
 | 
			
		||||
                **res
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        # When there are critical errors, the prompt is actually not run
 | 
			
		||||
        if "error" in res:
 | 
			
		||||
            await update_run(prompt_id, Status.FAILED)
 | 
			
		||||
            raise Exception("Prompt failed")
 | 
			
		||||
 | 
			
		||||
    return res
 | 
			
		||||
    # return web.json_response(res, status=status)
 | 
			
		||||
 | 
			
		||||
comfy_message_queues: Dict[str, asyncio.Queue] = {}
 | 
			
		||||
 | 
			
		||||
@server.PromptServer.instance.routes.post('/comfyui-deploy/run/streaming')
 | 
			
		||||
async def stream_response(request):
 | 
			
		||||
    response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'})
 | 
			
		||||
    await response.prepare(request)
 | 
			
		||||
 | 
			
		||||
    pending = True
 | 
			
		||||
    data = await request.json()
 | 
			
		||||
 | 
			
		||||
    prompt_id = data.get("prompt_id")
 | 
			
		||||
    comfy_message_queues[prompt_id] = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
    with logfire.span('Streaming Run'):
 | 
			
		||||
        logfire.info('Streaming prompt')
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            result = await stream_prompt(data=data)
 | 
			
		||||
            await response.write(json.dumps(result).encode('utf-8'))
 | 
			
		||||
            await response.drain()  # Ensure the buffer is flushed
 | 
			
		||||
 | 
			
		||||
            while pending:
 | 
			
		||||
                if prompt_id in comfy_message_queues:
 | 
			
		||||
                    if not comfy_message_queues[prompt_id].empty():
 | 
			
		||||
                        data = await comfy_message_queues[prompt_id].get()
 | 
			
		||||
 | 
			
		||||
                        logfire.info(data["event"], data=json.dumps(data))
 | 
			
		||||
                        # logger.info("listener", data)
 | 
			
		||||
                        await response.write(json.dumps(data).encode('utf-8'))
 | 
			
		||||
                        await response.drain()  # Ensure the buffer is flushed
 | 
			
		||||
 | 
			
		||||
                        if data["event"] == "status":
 | 
			
		||||
                            if data["data"]["status"] in (Status.FAILED.value, Status.SUCCESS.value):
 | 
			
		||||
                                pending = False
 | 
			
		||||
 | 
			
		||||
                await asyncio.sleep(0.1)  # Adjust the sleep duration as needed
 | 
			
		||||
        except asyncio.CancelledError:
 | 
			
		||||
            logfire.info("Streaming was cancelled")
 | 
			
		||||
            raise
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logfire.error("Streaming error", error=e)
 | 
			
		||||
        finally:
 | 
			
		||||
            # event_emitter.off("send_json", task)
 | 
			
		||||
            await response.write_eof()
 | 
			
		||||
            comfy_message_queues.pop(prompt_id, None)
 | 
			
		||||
            return response
 | 
			
		||||
 | 
			
		||||
def get_comfyui_path_from_file_path(file_path):
 | 
			
		||||
    file_path_parts = file_path.split("\\")
 | 
			
		||||
 | 
			
		||||
    if file_path_parts[0] == "input":
 | 
			
		||||
        print("matching input")
 | 
			
		||||
        logger.info("matching input")
 | 
			
		||||
        file_path = os.path.join(folder_paths.get_directory_by_type("input"), *file_path_parts[1:])
 | 
			
		||||
    elif file_path_parts[0] == "models":
 | 
			
		||||
        print("matching models")
 | 
			
		||||
        logger.info("matching models")
 | 
			
		||||
        file_path = folder_paths.get_full_path(file_path_parts[1], os.path.join(*file_path_parts[2:]))
 | 
			
		||||
 | 
			
		||||
    print(file_path)
 | 
			
		||||
    logger.info(file_path)
 | 
			
		||||
 | 
			
		||||
    return file_path
 | 
			
		||||
 | 
			
		||||
# Form ComfyUI Manager
 | 
			
		||||
async def compute_sha256_checksum(filepath):
 | 
			
		||||
    print("computing sha256 checksum")
 | 
			
		||||
    logger.info("computing sha256 checksum")
 | 
			
		||||
    chunk_size = 1024 * 256  # Example: 256KB
 | 
			
		||||
    filepath = get_comfyui_path_from_file_path(filepath)
 | 
			
		||||
    """Compute the SHA256 checksum of a file, in chunks, asynchronously"""
 | 
			
		||||
@ -297,7 +430,7 @@ async def get_installed_models(request):
 | 
			
		||||
        file_list = folder_paths.get_filename_list(key)
 | 
			
		||||
        value_json_compatible = (value[0], list(value[1]), file_list)
 | 
			
		||||
        new_dict[key] = value_json_compatible
 | 
			
		||||
    # print(new_dict)
 | 
			
		||||
    # logger.info(new_dict)
 | 
			
		||||
    return web.json_response(new_dict)
 | 
			
		||||
 | 
			
		||||
# This is start uploading the files to Comfy Deploy
 | 
			
		||||
@ -307,7 +440,7 @@ async def upload_file_endpoint(request):
 | 
			
		||||
 | 
			
		||||
    file_path = data.get("file_path")
 | 
			
		||||
 | 
			
		||||
    print("Original file path", file_path)
 | 
			
		||||
    logger.info("Original file path", file_path)
 | 
			
		||||
 | 
			
		||||
    file_path = get_comfyui_path_from_file_path(file_path)
 | 
			
		||||
 | 
			
		||||
@ -429,7 +562,7 @@ async def get_file_hash(request):
 | 
			
		||||
            file_hash = await compute_sha256_checksum(full_file_path)
 | 
			
		||||
            end_time = time.time()
 | 
			
		||||
            elapsed_time = end_time - start_time
 | 
			
		||||
            print(f"Cache miss -> Execution time: {elapsed_time} seconds")
 | 
			
		||||
            logger.info(f"Cache miss -> Execution time: {elapsed_time} seconds")
 | 
			
		||||
 | 
			
		||||
            # Update the in-memory cache
 | 
			
		||||
            file_hash_cache[full_file_path] = file_hash
 | 
			
		||||
@ -449,6 +582,8 @@ async def update_realtime_run_status(realtime_id: str, status_endpoint: str, sta
 | 
			
		||||
        "run_id": realtime_id,
 | 
			
		||||
        "status": status.value,
 | 
			
		||||
    }
 | 
			
		||||
    if (status_endpoint is None):
 | 
			
		||||
        return
 | 
			
		||||
    # requests.post(status_endpoint, json=body)
 | 
			
		||||
    async with aiohttp.ClientSession() as session:
 | 
			
		||||
        async with session.post(status_endpoint, json=body) as response:
 | 
			
		||||
@ -479,7 +614,7 @@ async def websocket_handler(request):
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    workflow = await response.json()
 | 
			
		||||
 | 
			
		||||
                    print("Loaded workflow version ",workflow["version"])
 | 
			
		||||
                    logger.info(f"Loaded workflow version ${workflow['version']}")
 | 
			
		||||
 | 
			
		||||
                    streaming_prompt_metadata[sid] = StreamingPrompt(
 | 
			
		||||
                        workflow_api=workflow["workflow_api"],
 | 
			
		||||
@ -493,7 +628,7 @@ async def websocket_handler(request):
 | 
			
		||||
                    # await send("workflow_api", workflow_api, sid)
 | 
			
		||||
                else:
 | 
			
		||||
                    error_message = await response.text()
 | 
			
		||||
                    print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
 | 
			
		||||
                    logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
 | 
			
		||||
                    # await send("error", {"message": error_message}, sid)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -508,10 +643,10 @@ async def websocket_handler(request):
 | 
			
		||||
            if msg.type == aiohttp.WSMsgType.TEXT:
 | 
			
		||||
                try:
 | 
			
		||||
                    data = json.loads(msg.data)
 | 
			
		||||
                    print(data)
 | 
			
		||||
                    logger.info(data)
 | 
			
		||||
                    event_type = data.get('event')
 | 
			
		||||
                    if event_type == 'input':
 | 
			
		||||
                        print("Got input: ", data.get("inputs"))
 | 
			
		||||
                        logger.info(f"Got input: ${data.get('inputs')}")
 | 
			
		||||
                        input = data.get('inputs')
 | 
			
		||||
                        streaming_prompt_metadata[sid].inputs.update(input)
 | 
			
		||||
                    elif event_type == 'queue_prompt':
 | 
			
		||||
@ -521,7 +656,7 @@ async def websocket_handler(request):
 | 
			
		||||
                        # Handle other event types
 | 
			
		||||
                        pass
 | 
			
		||||
                except json.JSONDecodeError:
 | 
			
		||||
                    print('Failed to decode JSON from message')
 | 
			
		||||
                    logger.info('Failed to decode JSON from message')
 | 
			
		||||
 | 
			
		||||
            if msg.type == aiohttp.WSMsgType.BINARY:
 | 
			
		||||
                data = msg.data
 | 
			
		||||
@ -530,9 +665,9 @@ async def websocket_handler(request):
 | 
			
		||||
                    image_type_code, = struct.unpack("<I", data[4:8])
 | 
			
		||||
                    input_id_bytes = data[8:32]  # Extract the next 24 bytes for the input ID
 | 
			
		||||
                    input_id = input_id_bytes.decode('ascii').strip()  # Decode the input ID from ASCII
 | 
			
		||||
                    print(event_type)
 | 
			
		||||
                    print(image_type_code)
 | 
			
		||||
                    print(input_id)
 | 
			
		||||
                    logger.info(event_type)
 | 
			
		||||
                    logger.info(image_type_code)
 | 
			
		||||
                    logger.info(input_id)
 | 
			
		||||
                    image_data = data[32:]  # The rest is the image data
 | 
			
		||||
                    if image_type_code == 1:
 | 
			
		||||
                        image_type = "JPEG"
 | 
			
		||||
@ -541,7 +676,7 @@ async def websocket_handler(request):
 | 
			
		||||
                    elif image_type_code == 3:
 | 
			
		||||
                        image_type = "WEBP"
 | 
			
		||||
                    else:
 | 
			
		||||
                        print("Unknown image type code:", image_type_code)
 | 
			
		||||
                        logger.info(f"Unknown image type code: ${image_type_code}")
 | 
			
		||||
                        return
 | 
			
		||||
                    image = Image.open(BytesIO(image_data))
 | 
			
		||||
                    # Check if the input ID already exists and replace the input with the new one
 | 
			
		||||
@ -552,14 +687,14 @@ async def websocket_handler(request):
 | 
			
		||||
                            if hasattr(existing_image, 'close'):
 | 
			
		||||
                                existing_image.close()
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(f"Error closing previous image for input ID {input_id}: {e}")
 | 
			
		||||
                            logger.info(f"Error closing previous image for input ID {input_id}: {e}")
 | 
			
		||||
                    streaming_prompt_metadata[sid].inputs[input_id] = image
 | 
			
		||||
                    # clear_current_prompt(sid)
 | 
			
		||||
                    # send_prompt(sid, streaming_prompt_metadata[sid])
 | 
			
		||||
                    print(f"Received {image_type} image of size {image.size} with input ID {input_id}")
 | 
			
		||||
                    logger.info(f"Received {image_type} image of size {image.size} with input ID {input_id}")
 | 
			
		||||
 | 
			
		||||
            if msg.type == aiohttp.WSMsgType.ERROR:
 | 
			
		||||
                print('ws connection closed with exception %s' % ws.exception())
 | 
			
		||||
                logger.info('ws connection closed with exception %s' % ws.exception())
 | 
			
		||||
    finally:
 | 
			
		||||
        sockets.pop(sid, None)
 | 
			
		||||
 | 
			
		||||
@ -604,16 +739,16 @@ async def send(event, data, sid=None):
 | 
			
		||||
                if not ws.closed:  # Check if the WebSocket connection is open and not closing
 | 
			
		||||
                    await ws.send_json({ 'event': event, 'data': data })
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"Exception: {e}")
 | 
			
		||||
        logger.info(f"Exception: {e}")
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
prompt_server = server.PromptServer.instance
 | 
			
		||||
send_json = prompt_server.send_json
 | 
			
		||||
 | 
			
		||||
async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
    # print("INTERNAL:", event, data, sid)
 | 
			
		||||
    # logger.info("INTERNAL:", event, data, sid)
 | 
			
		||||
    prompt_id = data.get('prompt_id')
 | 
			
		||||
 | 
			
		||||
    target_sid = sid
 | 
			
		||||
@ -626,8 +761,19 @@ async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
        asyncio.create_task(self.send_json_original(event, data, sid))
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
    if prompt_id in comfy_message_queues:
 | 
			
		||||
        comfy_message_queues[prompt_id].put_nowait({
 | 
			
		||||
            "event": event,
 | 
			
		||||
            "data": data
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    # event_emitter.emit("send_json", {
 | 
			
		||||
    #     "event": event,
 | 
			
		||||
    #     "data": data
 | 
			
		||||
    # })
 | 
			
		||||
 | 
			
		||||
    if event == 'execution_start':
 | 
			
		||||
        update_run(prompt_id, Status.RUNNING)
 | 
			
		||||
        await update_run(prompt_id, Status.RUNNING)
 | 
			
		||||
 | 
			
		||||
        if prompt_id in prompt_metadata:
 | 
			
		||||
            prompt_metadata[prompt_id].start_time = time.perf_counter()
 | 
			
		||||
@ -636,12 +782,12 @@ async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
    if event == 'executing' and data.get('node') is None:
 | 
			
		||||
        mark_prompt_done(prompt_id=prompt_id)
 | 
			
		||||
        if not have_pending_upload(prompt_id):
 | 
			
		||||
            update_run(prompt_id, Status.SUCCESS)
 | 
			
		||||
            await update_run(prompt_id, Status.SUCCESS)
 | 
			
		||||
            if prompt_id in prompt_metadata:
 | 
			
		||||
                current_time = time.perf_counter()
 | 
			
		||||
                if prompt_metadata[prompt_id].start_time is not None:
 | 
			
		||||
                    elapsed_time = current_time - prompt_metadata[prompt_id].start_time
 | 
			
		||||
                    print(f"Elapsed time: {elapsed_time} seconds")
 | 
			
		||||
                    logger.info(f"Elapsed time: {elapsed_time} seconds")
 | 
			
		||||
                    await send("elapsed_time", {
 | 
			
		||||
                        "prompt_id": prompt_id,
 | 
			
		||||
                        "elapsed_time": elapsed_time
 | 
			
		||||
@ -656,13 +802,13 @@ async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
 | 
			
		||||
            prompt_metadata[prompt_id].progress.add(node)
 | 
			
		||||
            calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
 | 
			
		||||
            # print("calculated_progress", calculated_progress)
 | 
			
		||||
            # logger.info("calculated_progress", calculated_progress)
 | 
			
		||||
 | 
			
		||||
            if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
 | 
			
		||||
                return
 | 
			
		||||
            prompt_metadata[prompt_id].last_updated_node = node
 | 
			
		||||
            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
			
		||||
            print("updating run live status", class_type)
 | 
			
		||||
            logger.info(f"updating run live status {class_type}")
 | 
			
		||||
            await send("live_status", {
 | 
			
		||||
                "prompt_id": prompt_id,
 | 
			
		||||
                "current_node": class_type,
 | 
			
		||||
@ -683,17 +829,17 @@ async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
    if event == 'execution_error':
 | 
			
		||||
        # Careful this might not be fully awaited.
 | 
			
		||||
        await update_run_with_output(prompt_id, data)
 | 
			
		||||
        update_run(prompt_id, Status.FAILED)
 | 
			
		||||
        await update_run(prompt_id, Status.FAILED)
 | 
			
		||||
        # await update_run_with_output(prompt_id, data)
 | 
			
		||||
 | 
			
		||||
    if event == 'executed' and 'node' in data and 'output' in data:
 | 
			
		||||
        print("executed", data)
 | 
			
		||||
        logger.info(f"executed {data}")
 | 
			
		||||
        if prompt_id in prompt_metadata:
 | 
			
		||||
            node = data.get('node')
 | 
			
		||||
            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
			
		||||
            print("executed", class_type)
 | 
			
		||||
            logger.info(f"executed {class_type}")
 | 
			
		||||
            if class_type == "PreviewImage":
 | 
			
		||||
                print("skipping preview image")
 | 
			
		||||
                logger.info("skipping preview image")
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
			
		||||
@ -710,21 +856,36 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
			
		||||
    if prompt_metadata[prompt_id].is_realtime is True:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    print("progress", calculated_progress)
 | 
			
		||||
    
 | 
			
		||||
    status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
			
		||||
 | 
			
		||||
    if (status_endpoint is None):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    logger.info(f"progress {calculated_progress}")
 | 
			
		||||
 | 
			
		||||
    body = {
 | 
			
		||||
        "run_id": prompt_id,
 | 
			
		||||
        "live_status": live_status,
 | 
			
		||||
        "progress": calculated_progress
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if prompt_id in comfy_message_queues:
 | 
			
		||||
        comfy_message_queues[prompt_id].put_nowait({
 | 
			
		||||
            "event": "live_status",
 | 
			
		||||
            "data": {
 | 
			
		||||
                "prompt_id": prompt_id,
 | 
			
		||||
                "live_status": live_status,
 | 
			
		||||
                "progress": calculated_progress
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    # requests.post(status_endpoint, json=body)
 | 
			
		||||
    async with aiohttp.ClientSession() as session:
 | 
			
		||||
        async with session.post(status_endpoint, json=body) as response:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_run(prompt_id: str, status: Status):
 | 
			
		||||
async def update_run(prompt_id: str, status: Status):
 | 
			
		||||
    global last_read_line_number
 | 
			
		||||
 | 
			
		||||
    if prompt_id not in prompt_metadata:
 | 
			
		||||
@ -747,18 +908,22 @@ def update_run(prompt_id: str, status: Status):
 | 
			
		||||
            "run_id": prompt_id,
 | 
			
		||||
            "status": status.value,
 | 
			
		||||
        }
 | 
			
		||||
        print(f"Status: {status.value}")
 | 
			
		||||
        logger.info(f"Status: {status.value}")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            requests.post(status_endpoint, json=body)
 | 
			
		||||
            # requests.post(status_endpoint, json=body)
 | 
			
		||||
            if (status_endpoint is not None):
 | 
			
		||||
                async with aiohttp.ClientSession() as session:
 | 
			
		||||
                    async with session.post(status_endpoint, json=body) as response:
 | 
			
		||||
                        pass
 | 
			
		||||
 | 
			
		||||
            if cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
 | 
			
		||||
            if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
 | 
			
		||||
                try:
 | 
			
		||||
                    with open(comfyui_file_path, 'r') as log_file:
 | 
			
		||||
                        # log_data = log_file.read()
 | 
			
		||||
                        # Move to the last read line
 | 
			
		||||
                        all_log_data = log_file.read()  # Read all log data
 | 
			
		||||
                        print("All log data before skipping:", all_log_data)  # Log all data before skipping
 | 
			
		||||
                        # logger.info("All log data before skipping: ")  # Log all data before skipping
 | 
			
		||||
                        log_file.seek(0)  # Reset file pointer to the beginning
 | 
			
		||||
 | 
			
		||||
                        for _ in range(last_read_line_number):
 | 
			
		||||
@ -766,9 +931,9 @@ def update_run(prompt_id: str, status: Status):
 | 
			
		||||
                        log_data = log_file.read()
 | 
			
		||||
                        # Update the last read line number
 | 
			
		||||
                        last_read_line_number += log_data.count('\n')
 | 
			
		||||
                        print("last_read_line_number", last_read_line_number)
 | 
			
		||||
                        print("log_data", log_data)
 | 
			
		||||
                        print("log_data.count(n)", log_data.count('\n'))
 | 
			
		||||
                        # logger.info("last_read_line_number", last_read_line_number)
 | 
			
		||||
                        # logger.info("log_data", log_data)
 | 
			
		||||
                        # logger.info("log_data.count(n)", log_data.count('\n'))
 | 
			
		||||
 | 
			
		||||
                        body = {
 | 
			
		||||
                            "run_id": prompt_id,
 | 
			
		||||
@ -779,16 +944,28 @@ def update_run(prompt_id: str, status: Status):
 | 
			
		||||
                                }
 | 
			
		||||
                            ]
 | 
			
		||||
                        }
 | 
			
		||||
                        requests.post(status_endpoint, json=body)
 | 
			
		||||
 | 
			
		||||
                        async with aiohttp.ClientSession() as session:
 | 
			
		||||
                            async with session.post(status_endpoint, json=body) as response:
 | 
			
		||||
                                pass
 | 
			
		||||
                        # requests.post(status_endpoint, json=body)
 | 
			
		||||
                except Exception as log_error:
 | 
			
		||||
                    print(f"Error reading log file: {log_error}")
 | 
			
		||||
                    logger.info(f"Error reading log file: {log_error}")
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            error_type = type(e).__name__
 | 
			
		||||
            stack_trace = traceback.format_exc().strip()
 | 
			
		||||
            print(f"Error occurred while updating run: {e} {stack_trace}")
 | 
			
		||||
            logger.info(f"Error occurred while updating run: {e} {stack_trace}")
 | 
			
		||||
        finally:
 | 
			
		||||
            prompt_metadata[prompt_id].status = status
 | 
			
		||||
            if prompt_id in comfy_message_queues:
 | 
			
		||||
                comfy_message_queues[prompt_id].put_nowait({
 | 
			
		||||
                    "event": "status",
 | 
			
		||||
                    "data": {
 | 
			
		||||
                        "prompt_id": prompt_id,
 | 
			
		||||
                        "status": status.value,
 | 
			
		||||
                    }
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
 | 
			
		||||
@ -806,7 +983,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
        output_dir = folder_paths.get_directory_by_type(type)
 | 
			
		||||
 | 
			
		||||
    if output_dir is None:
 | 
			
		||||
        print(filename, "Upload failed: output_dir is None")
 | 
			
		||||
        logger.info(f"{filename} Upload failed: output_dir is None")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if subfolder != None:
 | 
			
		||||
@ -818,7 +995,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
    filename = os.path.basename(filename)
 | 
			
		||||
    file = os.path.join(output_dir, filename)
 | 
			
		||||
 | 
			
		||||
    print("uploading file", file)
 | 
			
		||||
    logger.info(f"uploading file {file}")
 | 
			
		||||
 | 
			
		||||
    file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
 | 
			
		||||
 | 
			
		||||
@ -831,7 +1008,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
    start_time = time.time()  # Start timing here
 | 
			
		||||
    result = requests.get(target_url)
 | 
			
		||||
    end_time = time.time()  # End timing after the request is complete
 | 
			
		||||
    print("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
 | 
			
		||||
    logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
 | 
			
		||||
    ok = result.json()
 | 
			
		||||
 | 
			
		||||
    start_time = time.time()  # Start timing here
 | 
			
		||||
@ -846,16 +1023,16 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
        # response = requests.put(ok.get("url"), headers=headers, data=data)
 | 
			
		||||
        async with aiohttp.ClientSession() as session:
 | 
			
		||||
            async with session.put(ok.get("url"), headers=headers, data=data) as response:
 | 
			
		||||
                print("Upload file response", response.status)
 | 
			
		||||
                logger.info(f"Upload file response {response.status}")
 | 
			
		||||
                end_time = time.time()  # End timing after the request is complete
 | 
			
		||||
                print("Upload time: {:.2f} seconds".format(end_time - start_time))
 | 
			
		||||
                logger.info("Upload time: {:.2f} seconds".format(end_time - start_time))
 | 
			
		||||
 | 
			
		||||
def have_pending_upload(prompt_id):
 | 
			
		||||
    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
			
		||||
        print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
 | 
			
		||||
        logger.info(f"have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}")
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    print("no pending upload")
 | 
			
		||||
    logger.info("no pending upload")
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
def mark_prompt_done(prompt_id):
 | 
			
		||||
@ -867,7 +1044,7 @@ def mark_prompt_done(prompt_id):
 | 
			
		||||
    """
 | 
			
		||||
    if prompt_id in prompt_metadata:
 | 
			
		||||
        prompt_metadata[prompt_id].done = True
 | 
			
		||||
        print("Prompt done")
 | 
			
		||||
        logger.info("Prompt done")
 | 
			
		||||
 | 
			
		||||
def is_prompt_done(prompt_id: str):
 | 
			
		||||
    """
 | 
			
		||||
@ -899,8 +1076,8 @@ async def handle_error(prompt_id, data, e: Exception):
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    await update_file_status(prompt_id, data, False, have_error=True)
 | 
			
		||||
    print(body)
 | 
			
		||||
    print(f"Error occurred while uploading file: {e}")
 | 
			
		||||
    logger.info(body)
 | 
			
		||||
    logger.info(f"Error occurred while uploading file: {e}")
 | 
			
		||||
 | 
			
		||||
# Mark the current prompt requires upload, and block it from being marked as success
 | 
			
		||||
async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
 | 
			
		||||
@ -913,11 +1090,11 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
 | 
			
		||||
        else:
 | 
			
		||||
            prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
 | 
			
		||||
 | 
			
		||||
    print(prompt_metadata[prompt_id].uploading_nodes)
 | 
			
		||||
    logger.info(prompt_metadata[prompt_id].uploading_nodes)
 | 
			
		||||
    # Update the remote status
 | 
			
		||||
 | 
			
		||||
    if have_error:
 | 
			
		||||
        update_run(prompt_id, Status.FAILED)
 | 
			
		||||
        await update_run(prompt_id, Status.FAILED)
 | 
			
		||||
        await send("failed", {
 | 
			
		||||
            "prompt_id": prompt_id,
 | 
			
		||||
        })
 | 
			
		||||
@ -926,15 +1103,15 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
 | 
			
		||||
    # if there are still nodes that are uploading, then we set the status to uploading
 | 
			
		||||
    if uploading:
 | 
			
		||||
        if prompt_metadata[prompt_id].status != Status.UPLOADING:
 | 
			
		||||
            update_run(prompt_id, Status.UPLOADING)
 | 
			
		||||
            await update_run(prompt_id, Status.UPLOADING)
 | 
			
		||||
            await send("uploading", {
 | 
			
		||||
                "prompt_id": prompt_id,
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
    # if there are no nodes that are uploading, then we set the status to success
 | 
			
		||||
    elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id):
 | 
			
		||||
        update_run(prompt_id, Status.SUCCESS)
 | 
			
		||||
        # print("Status: SUCCUSS")
 | 
			
		||||
        await update_run(prompt_id, Status.SUCCESS)
 | 
			
		||||
        # logger.info("Status: SUCCUSS")
 | 
			
		||||
        await send("success", {
 | 
			
		||||
            "prompt_id": prompt_id,
 | 
			
		||||
        })
 | 
			
		||||
@ -991,10 +1168,10 @@ async def update_run_with_output(prompt_id, data, node_id=None):
 | 
			
		||||
        "output_data": data
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if not bypass_upload:
 | 
			
		||||
    if not bypass_upload and prompt_metadata[prompt_id].file_upload_endpoint is not None:
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info(f"\nhave_upload {have_upload} {node_id}")
 | 
			
		||||
            have_upload = 'images' in data or 'files' in data or 'gifs' in data or 'mesh' in data
 | 
			
		||||
            print("\nhave_upload", have_upload, node_id)
 | 
			
		||||
 | 
			
		||||
            if have_upload:
 | 
			
		||||
                await update_file_status(prompt_id, data, True, node_id=node_id)
 | 
			
		||||
@ -1005,7 +1182,11 @@ async def update_run_with_output(prompt_id, data, node_id=None):
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            await handle_error(prompt_id, data, e)
 | 
			
		||||
 | 
			
		||||
    requests.post(status_endpoint, json=body)
 | 
			
		||||
    # requests.post(status_endpoint, json=body)
 | 
			
		||||
    if status_endpoint is not None:
 | 
			
		||||
        async with aiohttp.ClientSession() as session:
 | 
			
		||||
            async with session.post(status_endpoint, json=body) as response:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
    await send('outputs_uploaded', {
 | 
			
		||||
        "prompt_id": prompt_id
 | 
			
		||||
 | 
			
		||||
@ -22,12 +22,13 @@ class StreamingPrompt(BaseModel):
 | 
			
		||||
    auth_token: str
 | 
			
		||||
    inputs: dict[str, Union[str, bytes, Image.Image]]
 | 
			
		||||
    running_prompt_ids: set[str] = set()
 | 
			
		||||
    status_endpoint: str
 | 
			
		||||
    file_upload_endpoint: str
 | 
			
		||||
    status_endpoint: Optional[str]
 | 
			
		||||
    file_upload_endpoint: Optional[str]
 | 
			
		||||
    
 | 
			
		||||
class SimplePrompt(BaseModel):
 | 
			
		||||
    status_endpoint: str
 | 
			
		||||
    file_upload_endpoint: str
 | 
			
		||||
    status_endpoint: Optional[str]
 | 
			
		||||
    file_upload_endpoint: Optional[str]
 | 
			
		||||
    
 | 
			
		||||
    workflow_api: dict
 | 
			
		||||
    status: Status = Status.NOT_STARTED
 | 
			
		||||
    progress: set = set()
 | 
			
		||||
 | 
			
		||||
@ -2,3 +2,4 @@ aiofiles
 | 
			
		||||
pydantic
 | 
			
		||||
opencv-python
 | 
			
		||||
imageio-ffmpeg
 | 
			
		||||
logfire
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user