diff --git a/comfy-nodes/input_websocket_image.py b/comfy-nodes/input_websocket_image.py new file mode 100644 index 0000000..2656e0a --- /dev/null +++ b/comfy-nodes/input_websocket_image.py @@ -0,0 +1,66 @@ +import folder_paths +from PIL import Image, ImageOps +import numpy as np +import torch +from server import PromptServer, BinaryEventTypes +import asyncio + +from globals import streaming_prompt_metadata, max_output_id_length + +class ComfyDeployWebscoketImageInput: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_id": ( + "STRING", + {"multiline": False, "default": "input_id"}, + ), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + }, + "optional": { + "default_value": ("IMAGE", ), + "client_id": ( + "STRING", + {"multiline": False, "default": ""}, + ), + } + } + + OUTPUT_NODE = True + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("images",) + + FUNCTION = "run" + + @classmethod + def VALIDATE_INPUTS(s, input_id): + try: + if len(input_id.encode('ascii')) > max_output_id_length: + raise ValueError(f"input_id size is greater than {max_output_id_length} bytes") + except UnicodeEncodeError: + raise ValueError("input_id is not ASCII encodable") + + return True + + def run(self, input_id, default_value, seed, client_id): + # print(streaming_prompt_metadata[client_id].inputs) + if client_id in streaming_prompt_metadata and input_id in streaming_prompt_metadata[client_id].inputs: + if isinstance(streaming_prompt_metadata[client_id].inputs[input_id], Image.Image): + print("Returning image from websocket input") + + image = streaming_prompt_metadata[client_id].inputs[input_id] + + image = ImageOps.exif_transpose(image) + image = image.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + + return [image] + + print("Returning default value") + return [default_value] + +NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageInput": ComfyDeployWebscoketImageInput} +NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageInput": "Image Websocket Input (ComfyDeploy)"} \ No newline at end of file diff --git a/custom_routes.py b/custom_routes.py index b705f33..b443ff3 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -1,3 +1,4 @@ +from io import BytesIO from aiohttp import web import os import requests @@ -21,29 +22,16 @@ import aiofiles from typing import List, Union, Any, Optional from PIL import Image import copy +import struct -from globals import sockets +from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel -from pydantic import BaseModel as PydanticBaseModel - -class BaseModel(PydanticBaseModel): - class Config: - arbitrary_types_allowed = True - class Status(Enum): NOT_STARTED = "not-started" RUNNING = "running" SUCCESS = "success" FAILED = "failed" UPLOADING = "uploading" - -class StreamingPrompt(BaseModel): - workflow_api: Any - auth_token: str - inputs: dict[str, Union[str, bytes, Image.Image]] - running_prompt_ids: set[str] = set() - status_endpoint: str - file_upload_endpoint: str class SimplePrompt(BaseModel): status_endpoint: str @@ -56,8 +44,6 @@ class SimplePrompt(BaseModel): done: bool = False is_realtime: bool = False, start_time: Optional[float] = None, - -streaming_prompt_metadata: dict[str, StreamingPrompt] = {} api = None api_task = None @@ -65,6 +51,17 @@ prompt_metadata: dict[str, SimplePrompt] = {} cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' +async def clear_current_prompt(sid): + prompt_server = server.PromptServer.instance + to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list + + for id_to_delete in to_delete: + print("clearning out prompt: ", id_to_delete) + delete_func = lambda a: a[1] == id_to_delete + prompt_server.prompt_queue.delete_queue_item(delete_func) + + streaming_prompt_metadata[sid].running_prompt_ids.clear() + def post_prompt(json_data): prompt_server = server.PromptServer.instance json_data = prompt_server.trigger_on_prompt(json_data) @@ -140,16 +137,24 @@ def send_prompt(sid: str, inputs: StreamingPrompt): # Loop through each of the inputs and replace them for key, value in workflow_api.items(): if 'inputs' in value: + if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): + value['inputs']["client_id"] = sid + if (value["class_type"] == "ComfyDeployWebscoketImageInput"): + value['inputs']["client_id"] = sid + if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs: new_value = inputs.inputs[value['inputs']['input_id']] + + # Lets skip it if its an image + if isinstance(new_value, Image.Image): + continue + value['inputs']["input_id"] = new_value # Fix for external text default value if (value["class_type"] == "ComfyUIDeployExternalText"): value['inputs']["default_value"] = new_value - - if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): - value['inputs']["client_id"] = sid + print(workflow_api) @@ -458,13 +463,49 @@ async def websocket_handler(request): print("Got input: ", data.get("inputs")) input = data.get('inputs') streaming_prompt_metadata[sid].inputs.update(input) + clear_current_prompt(sid) send_prompt(sid, streaming_prompt_metadata[sid]) else: # Handle other event types pass except json.JSONDecodeError: print('Failed to decode JSON from message') - + + if msg.type == aiohttp.WSMsgType.BINARY: + data = msg.data + event_type, = struct.unpack("