feat: add ws streaming input

This commit is contained in:
bennykok 2024-03-02 00:47:06 -08:00
parent 619a9728c0
commit 3df549c25c
3 changed files with 144 additions and 21 deletions

View File

@ -0,0 +1,66 @@
import folder_paths
from PIL import Image, ImageOps
import numpy as np
import torch
from server import PromptServer, BinaryEventTypes
import asyncio
from globals import streaming_prompt_metadata, max_output_id_length
class ComfyDeployWebscoketImageInput:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_id": (
"STRING",
{"multiline": False, "default": "input_id"},
),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
},
"optional": {
"default_value": ("IMAGE", ),
"client_id": (
"STRING",
{"multiline": False, "default": ""},
),
}
}
OUTPUT_NODE = True
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images",)
FUNCTION = "run"
@classmethod
def VALIDATE_INPUTS(s, input_id):
try:
if len(input_id.encode('ascii')) > max_output_id_length:
raise ValueError(f"input_id size is greater than {max_output_id_length} bytes")
except UnicodeEncodeError:
raise ValueError("input_id is not ASCII encodable")
return True
def run(self, input_id, default_value, seed, client_id):
# print(streaming_prompt_metadata[client_id].inputs)
if client_id in streaming_prompt_metadata and input_id in streaming_prompt_metadata[client_id].inputs:
if isinstance(streaming_prompt_metadata[client_id].inputs[input_id], Image.Image):
print("Returning image from websocket input")
image = streaming_prompt_metadata[client_id].inputs[input_id]
image = ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
return [image]
print("Returning default value")
return [default_value]
NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageInput": ComfyDeployWebscoketImageInput}
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageInput": "Image Websocket Input (ComfyDeploy)"}

View File

@ -1,3 +1,4 @@
from io import BytesIO
from aiohttp import web from aiohttp import web
import os import os
import requests import requests
@ -21,14 +22,9 @@ import aiofiles
from typing import List, Union, Any, Optional from typing import List, Union, Any, Optional
from PIL import Image from PIL import Image
import copy import copy
import struct
from globals import sockets from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum): class Status(Enum):
NOT_STARTED = "not-started" NOT_STARTED = "not-started"
@ -37,14 +33,6 @@ class Status(Enum):
FAILED = "failed" FAILED = "failed"
UPLOADING = "uploading" UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
class SimplePrompt(BaseModel): class SimplePrompt(BaseModel):
status_endpoint: str status_endpoint: str
file_upload_endpoint: str file_upload_endpoint: str
@ -57,14 +45,23 @@ class SimplePrompt(BaseModel):
is_realtime: bool = False, is_realtime: bool = False,
start_time: Optional[float] = None, start_time: Optional[float] = None,
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
api = None api = None
api_task = None api_task = None
prompt_metadata: dict[str, SimplePrompt] = {} prompt_metadata: dict[str, SimplePrompt] = {}
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
async def clear_current_prompt(sid):
prompt_server = server.PromptServer.instance
to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list
for id_to_delete in to_delete:
print("clearning out prompt: ", id_to_delete)
delete_func = lambda a: a[1] == id_to_delete
prompt_server.prompt_queue.delete_queue_item(delete_func)
streaming_prompt_metadata[sid].running_prompt_ids.clear()
def post_prompt(json_data): def post_prompt(json_data):
prompt_server = server.PromptServer.instance prompt_server = server.PromptServer.instance
json_data = prompt_server.trigger_on_prompt(json_data) json_data = prompt_server.trigger_on_prompt(json_data)
@ -140,16 +137,24 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
# Loop through each of the inputs and replace them # Loop through each of the inputs and replace them
for key, value in workflow_api.items(): for key, value in workflow_api.items():
if 'inputs' in value: if 'inputs' in value:
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
value['inputs']["client_id"] = sid
if (value["class_type"] == "ComfyDeployWebscoketImageInput"):
value['inputs']["client_id"] = sid
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs: if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
new_value = inputs.inputs[value['inputs']['input_id']] new_value = inputs.inputs[value['inputs']['input_id']]
# Lets skip it if its an image
if isinstance(new_value, Image.Image):
continue
value['inputs']["input_id"] = new_value value['inputs']["input_id"] = new_value
# Fix for external text default value # Fix for external text default value
if (value["class_type"] == "ComfyUIDeployExternalText"): if (value["class_type"] == "ComfyUIDeployExternalText"):
value['inputs']["default_value"] = new_value value['inputs']["default_value"] = new_value
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
value['inputs']["client_id"] = sid
print(workflow_api) print(workflow_api)
@ -458,6 +463,7 @@ async def websocket_handler(request):
print("Got input: ", data.get("inputs")) print("Got input: ", data.get("inputs"))
input = data.get('inputs') input = data.get('inputs')
streaming_prompt_metadata[sid].inputs.update(input) streaming_prompt_metadata[sid].inputs.update(input)
clear_current_prompt(sid)
send_prompt(sid, streaming_prompt_metadata[sid]) send_prompt(sid, streaming_prompt_metadata[sid])
else: else:
# Handle other event types # Handle other event types
@ -465,6 +471,41 @@ async def websocket_handler(request):
except json.JSONDecodeError: except json.JSONDecodeError:
print('Failed to decode JSON from message') print('Failed to decode JSON from message')
if msg.type == aiohttp.WSMsgType.BINARY:
data = msg.data
event_type, = struct.unpack("<I", data[:4])
if event_type == 0: # Image input
image_type_code, = struct.unpack("<I", data[4:8])
input_id_bytes = data[8:32] # Extract the next 24 bytes for the input ID
input_id = input_id_bytes.decode('ascii').strip() # Decode the input ID from ASCII
print(event_type)
print(image_type_code)
print(input_id)
image_data = data[32:] # The rest is the image data
if image_type_code == 1:
image_type = "JPEG"
elif image_type_code == 2:
image_type = "PNG"
elif image_type_code == 3:
image_type = "WEBP"
else:
print("Unknown image type code:", image_type_code)
return
image = Image.open(BytesIO(image_data))
# Check if the input ID already exists and replace the input with the new one
if input_id in streaming_prompt_metadata[sid].inputs:
# If the input exists, we assume it's an image and attempt to close it to free resources
try:
existing_image = streaming_prompt_metadata[sid].inputs[input_id]
if hasattr(existing_image, 'close'):
existing_image.close()
except Exception as e:
print(f"Error closing previous image for input ID {input_id}: {e}")
streaming_prompt_metadata[sid].inputs[input_id] = image
clear_current_prompt(sid)
send_prompt(sid, streaming_prompt_metadata[sid])
print(f"Received {image_type} image of size {image.size} with input ID {input_id}")
if msg.type == aiohttp.WSMsgType.ERROR: if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception()) print('ws connection closed with exception %s' % ws.exception())
finally: finally:

View File

@ -2,10 +2,26 @@ import struct
import aiohttp import aiohttp
from typing import List, Union, Any, Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from io import BytesIO from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
sockets = dict() sockets = dict()
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1