feat: add ws streaming input

This commit is contained in:
bennykok 2024-03-02 00:47:06 -08:00
parent 619a9728c0
commit 3df549c25c
3 changed files with 144 additions and 21 deletions

View File

@ -0,0 +1,66 @@
import folder_paths
from PIL import Image, ImageOps
import numpy as np
import torch
from server import PromptServer, BinaryEventTypes
import asyncio
from globals import streaming_prompt_metadata, max_output_id_length
class ComfyDeployWebscoketImageInput:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_id": (
"STRING",
{"multiline": False, "default": "input_id"},
),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
},
"optional": {
"default_value": ("IMAGE", ),
"client_id": (
"STRING",
{"multiline": False, "default": ""},
),
}
}
OUTPUT_NODE = True
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images",)
FUNCTION = "run"
@classmethod
def VALIDATE_INPUTS(s, input_id):
try:
if len(input_id.encode('ascii')) > max_output_id_length:
raise ValueError(f"input_id size is greater than {max_output_id_length} bytes")
except UnicodeEncodeError:
raise ValueError("input_id is not ASCII encodable")
return True
def run(self, input_id, default_value, seed, client_id):
# print(streaming_prompt_metadata[client_id].inputs)
if client_id in streaming_prompt_metadata and input_id in streaming_prompt_metadata[client_id].inputs:
if isinstance(streaming_prompt_metadata[client_id].inputs[input_id], Image.Image):
print("Returning image from websocket input")
image = streaming_prompt_metadata[client_id].inputs[input_id]
image = ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
return [image]
print("Returning default value")
return [default_value]
NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageInput": ComfyDeployWebscoketImageInput}
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageInput": "Image Websocket Input (ComfyDeploy)"}

View File

@ -1,3 +1,4 @@
from io import BytesIO
from aiohttp import web
import os
import requests
@ -21,29 +22,16 @@ import aiofiles
from typing import List, Union, Any, Optional
from PIL import Image
import copy
import struct
from globals import sockets
from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
class SimplePrompt(BaseModel):
status_endpoint: str
@ -56,8 +44,6 @@ class SimplePrompt(BaseModel):
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
api = None
api_task = None
@ -65,6 +51,17 @@ prompt_metadata: dict[str, SimplePrompt] = {}
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
async def clear_current_prompt(sid):
prompt_server = server.PromptServer.instance
to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list
for id_to_delete in to_delete:
print("clearning out prompt: ", id_to_delete)
delete_func = lambda a: a[1] == id_to_delete
prompt_server.prompt_queue.delete_queue_item(delete_func)
streaming_prompt_metadata[sid].running_prompt_ids.clear()
def post_prompt(json_data):
prompt_server = server.PromptServer.instance
json_data = prompt_server.trigger_on_prompt(json_data)
@ -140,16 +137,24 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
# Loop through each of the inputs and replace them
for key, value in workflow_api.items():
if 'inputs' in value:
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
value['inputs']["client_id"] = sid
if (value["class_type"] == "ComfyDeployWebscoketImageInput"):
value['inputs']["client_id"] = sid
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
new_value = inputs.inputs[value['inputs']['input_id']]
# Lets skip it if its an image
if isinstance(new_value, Image.Image):
continue
value['inputs']["input_id"] = new_value
# Fix for external text default value
if (value["class_type"] == "ComfyUIDeployExternalText"):
value['inputs']["default_value"] = new_value
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
value['inputs']["client_id"] = sid
print(workflow_api)
@ -458,13 +463,49 @@ async def websocket_handler(request):
print("Got input: ", data.get("inputs"))
input = data.get('inputs')
streaming_prompt_metadata[sid].inputs.update(input)
clear_current_prompt(sid)
send_prompt(sid, streaming_prompt_metadata[sid])
else:
# Handle other event types
pass
except json.JSONDecodeError:
print('Failed to decode JSON from message')
if msg.type == aiohttp.WSMsgType.BINARY:
data = msg.data
event_type, = struct.unpack("<I", data[:4])
if event_type == 0: # Image input
image_type_code, = struct.unpack("<I", data[4:8])
input_id_bytes = data[8:32] # Extract the next 24 bytes for the input ID
input_id = input_id_bytes.decode('ascii').strip() # Decode the input ID from ASCII
print(event_type)
print(image_type_code)
print(input_id)
image_data = data[32:] # The rest is the image data
if image_type_code == 1:
image_type = "JPEG"
elif image_type_code == 2:
image_type = "PNG"
elif image_type_code == 3:
image_type = "WEBP"
else:
print("Unknown image type code:", image_type_code)
return
image = Image.open(BytesIO(image_data))
# Check if the input ID already exists and replace the input with the new one
if input_id in streaming_prompt_metadata[sid].inputs:
# If the input exists, we assume it's an image and attempt to close it to free resources
try:
existing_image = streaming_prompt_metadata[sid].inputs[input_id]
if hasattr(existing_image, 'close'):
existing_image.close()
except Exception as e:
print(f"Error closing previous image for input ID {input_id}: {e}")
streaming_prompt_metadata[sid].inputs[input_id] = image
clear_current_prompt(sid)
send_prompt(sid, streaming_prompt_metadata[sid])
print(f"Received {image_type} image of size {image.size} with input ID {input_id}")
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:

View File

@ -2,10 +2,26 @@ import struct
import aiohttp
from typing import List, Union, Any, Optional
from PIL import Image, ImageOps
from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
sockets = dict()
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes:
PREVIEW_IMAGE = 1