diff --git a/comfy-nodes/output_websocket_image.py b/comfy-nodes/output_websocket_image.py new file mode 100644 index 0000000..6277b0a --- /dev/null +++ b/comfy-nodes/output_websocket_image.py @@ -0,0 +1,64 @@ +import folder_paths +from PIL import Image, ImageOps +import numpy as np +import torch +from server import PromptServer, BinaryEventTypes +import asyncio + +from globals import send_image + +class ComfyDeployWebscoketImageOutput: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "output_id": ( + "STRING", + {"multiline": False, "default": "output_id"}, + ), + "images": ("IMAGE", ), + }, + "optional": { + "client_id": ( + "STRING", + {"multiline": False, "default": ""}, + ), + } + # "hidden": {"client_id": "CLIENT_ID"}, + } + + OUTPUT_NODE = True + + RETURN_TYPES = () + RETURN_NAMES = ("text",) + + FUNCTION = "run" + + CATEGORY = "output" + + def run(self, output_id, images, client_id): + results = [] + prompt_server = PromptServer.instance + loop = prompt_server.loop + + def schedule_coroutine_blocking(target, *args): + future = asyncio.run_coroutine_threadsafe(target(*args), loop) + return future.result() # This makes the call blocking + + for tensor in images: + array = 255.0 * tensor.cpu().numpy() + image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8)) + + schedule_coroutine_blocking(send_image, ["PNG", image, None], client_id) + print("Image sent") + # loop.run_until_complete(send_image(["PNG", image, None], client_id)) + results.append( + {"source": "websocket", "content-type": "image/png", "type": "output"} + ) + + return {"ui": {"images": results}} + + + +NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageOutput": ComfyDeployWebscoketImageOutput} +NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageOutput": "Image Websocket Output (ComfyDeploy)"} \ No newline at end of file diff --git a/custom_routes.py b/custom_routes.py index 80c54e3..a119b7c 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -1,3 +1,5 @@ + + from aiohttp import web import os import requests @@ -26,6 +28,24 @@ import hashlib import aiohttp import aiofiles import concurrent.futures +from typing import List, Union, Any +from PIL import Image +import copy + +from globals import sockets + +from pydantic import BaseModel as PydanticBaseModel + +class BaseModel(PydanticBaseModel): + class Config: + arbitrary_types_allowed = True + +class StreamingPrompt(BaseModel): + workflow_api: Any + auth_token: str + inputs: dict[str, Union[str, bytes, Image.Image]] + +streaming_prompt_metadata: dict[str, StreamingPrompt] = {} api = None api_task = None @@ -80,6 +100,50 @@ def randomSeed(num_digits=15): range_end = (10**num_digits) - 1 return random.randint(range_start, range_end) +def send_prompt(sid: str, inputs: StreamingPrompt): + # workflow_api = inputs.workflow_api + workflow_api = copy.deepcopy(inputs.workflow_api) + + # Random seed + for key in workflow_api: + if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']: + workflow_api[key]['inputs']['seed'] = randomSeed() + + print("getting inputs" ,inputs.inputs) + + # Loop through each of the inputs and replace them + for key, value in workflow_api.items(): + if 'inputs' in value: + if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs: + new_value = inputs.inputs[value['inputs']['input_id']] + value['inputs']["input_id"] = new_value; + + # Fix for external text default value + if (value["class_type"] == "ComfyUIDeployExternalText"): + value['inputs']["default_value"] = new_value; + + if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): + value['inputs']["client_id"] = sid; + + print(workflow_api) + + prompt_id = str(uuid.uuid4()) + + prompt = { + "prompt": workflow_api, + "client_id": "comfy_deploy_instance", #api.client_id + "prompt_id": prompt_id + } + + try: + res = post_prompt(prompt) + except Exception as e: + error_type = type(e).__name__ + stack_trace_short = traceback.format_exc().strip().split('\n')[-2] + stack_trace = traceback.format_exc().strip() + print(f"error: {error_type}, {e}") + print(f"stack trace: {stack_trace_short}") + @server.PromptServer.instance.routes.post("/comfyui-deploy/run") async def comfy_deploy_run(request): prompt_server = server.PromptServer.instance @@ -148,7 +212,6 @@ async def comfy_deploy_run(request): return web.json_response(res, status=status) -sockets = dict() def get_comfyui_path_from_file_path(file_path): file_path_parts = file_path.split("\\") @@ -179,60 +242,6 @@ async def compute_sha256_checksum(filepath): sha256.update(chunk) return sha256.hexdigest() -# def hash_chunk(start_end, filepath): -# """Hash a specific chunk of the file.""" -# start, end = start_end -# sha256 = hashlib.sha256() -# with open(filepath, 'rb') as f: -# f.seek(start) -# chunk = f.read(end - start) -# sha256.update(chunk) -# return sha256.digest() # Return the digest of the chunk - -# async def compute_sha256_checksum(filepath): -# file_size = os.path.getsize(filepath) -# parts = 1 # Or any other division based on file size or desired concurrency -# part_size = file_size // parts -# start_end_ranges = [(i * part_size, min((i + 1) * part_size, file_size)) for i in range(parts)] - -# print(start_end_ranges, file_size) - -# loop = asyncio.get_running_loop() - -# # Use ThreadPoolExecutor to process chunks in parallel -# with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: -# futures = [loop.run_in_executor(executor, hash_chunk, start_end, filepath) for start_end in start_end_ranges] -# chunk_hashes = await asyncio.gather(*futures) - -# # Combine the hashes sequentially -# final_sha256 = hashlib.sha256() -# for chunk_hash in chunk_hashes: -# final_sha256.update(chunk_hash) - -# return final_sha256.hexdigest() - -# def hash_chunk(filepath): -# chunk_size = 1024 * 256 # 256KB per chunk -# sha256 = hashlib.sha256() -# with open(filepath, 'rb') as f: -# while True: -# chunk = f.read(chunk_size) -# if not chunk: -# break # End of file -# sha256.update(chunk) -# return sha256.hexdigest() - -# async def compute_sha256_checksum(filepath): -# print("computing sha256 checksum") -# filepath = get_comfyui_path_from_file_path(filepath) - -# loop = asyncio.get_running_loop() - -# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: -# task = loop.run_in_executor(executor, hash_chunk, filepath) - -# return await task - # This is start uploading the files to Comfy Deploy @server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file') async def upload_file(request): @@ -366,6 +375,48 @@ async def websocket_handler(request): await send_first_time_log(sid) async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + print(data) + event_type = data.get('event') + if event_type == 'workflow_endpoint': + _data = data.get('data') + get_workflow_endpoint_url = _data.get('get_workflow_endpoint_url') + + auth_token = _data.get('auth_token') + + async with aiohttp.ClientSession() as session: + headers = {'Authorization': f'Bearer {auth_token}'} + async with session.get(get_workflow_endpoint_url, headers=headers) as response: + if response.status == 200: + workflow = await response.json() + + print(workflow["version"]) + + streaming_prompt_metadata[sid] = StreamingPrompt( + workflow_api=workflow["workflow_api"], + auth_token=auth_token, + inputs={} + ) + + # await send("workflow_api", workflow_api, sid) + else: + error_message = await response.text() + print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") + # await send("error", {"message": error_message}, sid) + pass + elif event_type == 'input': + print("Got input: ", data.get("inputs")) + input = data.get('inputs') + streaming_prompt_metadata[sid].inputs.update(input) + send_prompt(sid, streaming_prompt_metadata[sid]) + else: + # Handle other event types + pass + except json.JSONDecodeError: + print('Failed to decode JSON from message') + if msg.type == aiohttp.WSMsgType.ERROR: print('ws connection closed with exception %s' % ws.exception()) finally: @@ -387,6 +438,7 @@ async def comfy_deploy_check_status(request): async def send(event, data, sid=None): try: + # message = {"event": event, "data": data} if sid: ws = sockets.get(sid) if ws != None and not ws.closed: # Check if the WebSocket connection is open and not closing @@ -462,7 +514,6 @@ async def send_json_override(self, event, data, sid=None): # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # update_run_with_output(prompt_id, data.get('output')) - class Status(Enum): NOT_STARTED = "not-started" RUNNING = "running" diff --git a/globals.py b/globals.py new file mode 100644 index 0000000..6bde3bc --- /dev/null +++ b/globals.py @@ -0,0 +1,63 @@ +import struct + +import aiohttp + +from PIL import Image, ImageOps +from io import BytesIO + +sockets = dict() + +class BinaryEventTypes: + PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 + +async def send_image(image_data, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + type_num = 1 + if image_type == "JPEG": + type_num = 1 + elif image_type == "PNG": + type_num = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", type_num) + bytesIO.write(header) + image.save(bytesIO, format=image_type, quality=95, compress_level=1) + preview_bytes = bytesIO.getvalue() + await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + +async def send_socket_catch_exception(function, message): + try: + await function(message) + except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: + print("send error:", err) + +def encode_bytes(event, data): + if not isinstance(event, int): + raise RuntimeError(f"Binary event types must be integers, got {event}") + + packed = struct.pack(">I", event) + message = bytearray(packed) + message.extend(data) + return message + +async def send_bytes(event, data, sid=None): + message = encode_bytes(event, data) + + print("sending image to ", event, sid) + + if sid is None: + _sockets = list(sockets.values()) + for ws in _sockets: + await send_socket_catch_exception(ws.send_bytes, message) + elif sid in sockets: + await send_socket_catch_exception(sockets[sid].send_bytes, message) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f72db6f..d3deafb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -aiofiles \ No newline at end of file +aiofiles +pydantic \ No newline at end of file