feat: add ws streaming input
This commit is contained in:
parent
619a9728c0
commit
3df549c25c
66
comfy-nodes/input_websocket_image.py
Normal file
66
comfy-nodes/input_websocket_image.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import folder_paths
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from server import PromptServer, BinaryEventTypes
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from globals import streaming_prompt_metadata, max_output_id_length
|
||||||
|
|
||||||
|
class ComfyDeployWebscoketImageInput:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"input_id": (
|
||||||
|
"STRING",
|
||||||
|
{"multiline": False, "default": "input_id"},
|
||||||
|
),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"default_value": ("IMAGE", ),
|
||||||
|
"client_id": (
|
||||||
|
"STRING",
|
||||||
|
{"multiline": False, "default": ""},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", )
|
||||||
|
RETURN_NAMES = ("images",)
|
||||||
|
|
||||||
|
FUNCTION = "run"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, input_id):
|
||||||
|
try:
|
||||||
|
if len(input_id.encode('ascii')) > max_output_id_length:
|
||||||
|
raise ValueError(f"input_id size is greater than {max_output_id_length} bytes")
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
raise ValueError("input_id is not ASCII encodable")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def run(self, input_id, default_value, seed, client_id):
|
||||||
|
# print(streaming_prompt_metadata[client_id].inputs)
|
||||||
|
if client_id in streaming_prompt_metadata and input_id in streaming_prompt_metadata[client_id].inputs:
|
||||||
|
if isinstance(streaming_prompt_metadata[client_id].inputs[input_id], Image.Image):
|
||||||
|
print("Returning image from websocket input")
|
||||||
|
|
||||||
|
image = streaming_prompt_metadata[client_id].inputs[input_id]
|
||||||
|
|
||||||
|
image = ImageOps.exif_transpose(image)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = torch.from_numpy(image)[None,]
|
||||||
|
|
||||||
|
return [image]
|
||||||
|
|
||||||
|
print("Returning default value")
|
||||||
|
return [default_value]
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageInput": ComfyDeployWebscoketImageInput}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageInput": "Image Websocket Input (ComfyDeploy)"}
|
@ -1,3 +1,4 @@
|
|||||||
|
from io import BytesIO
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
@ -21,14 +22,9 @@ import aiofiles
|
|||||||
from typing import List, Union, Any, Optional
|
from typing import List, Union, Any, Optional
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import copy
|
import copy
|
||||||
|
import struct
|
||||||
|
|
||||||
from globals import sockets
|
from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel
|
||||||
|
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
|
||||||
|
|
||||||
class BaseModel(PydanticBaseModel):
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
class Status(Enum):
|
class Status(Enum):
|
||||||
NOT_STARTED = "not-started"
|
NOT_STARTED = "not-started"
|
||||||
@ -37,14 +33,6 @@ class Status(Enum):
|
|||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
UPLOADING = "uploading"
|
UPLOADING = "uploading"
|
||||||
|
|
||||||
class StreamingPrompt(BaseModel):
|
|
||||||
workflow_api: Any
|
|
||||||
auth_token: str
|
|
||||||
inputs: dict[str, Union[str, bytes, Image.Image]]
|
|
||||||
running_prompt_ids: set[str] = set()
|
|
||||||
status_endpoint: str
|
|
||||||
file_upload_endpoint: str
|
|
||||||
|
|
||||||
class SimplePrompt(BaseModel):
|
class SimplePrompt(BaseModel):
|
||||||
status_endpoint: str
|
status_endpoint: str
|
||||||
file_upload_endpoint: str
|
file_upload_endpoint: str
|
||||||
@ -57,14 +45,23 @@ class SimplePrompt(BaseModel):
|
|||||||
is_realtime: bool = False,
|
is_realtime: bool = False,
|
||||||
start_time: Optional[float] = None,
|
start_time: Optional[float] = None,
|
||||||
|
|
||||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
|
||||||
|
|
||||||
api = None
|
api = None
|
||||||
api_task = None
|
api_task = None
|
||||||
prompt_metadata: dict[str, SimplePrompt] = {}
|
prompt_metadata: dict[str, SimplePrompt] = {}
|
||||||
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
|
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
|
||||||
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
|
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
async def clear_current_prompt(sid):
|
||||||
|
prompt_server = server.PromptServer.instance
|
||||||
|
to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list
|
||||||
|
|
||||||
|
for id_to_delete in to_delete:
|
||||||
|
print("clearning out prompt: ", id_to_delete)
|
||||||
|
delete_func = lambda a: a[1] == id_to_delete
|
||||||
|
prompt_server.prompt_queue.delete_queue_item(delete_func)
|
||||||
|
|
||||||
|
streaming_prompt_metadata[sid].running_prompt_ids.clear()
|
||||||
|
|
||||||
def post_prompt(json_data):
|
def post_prompt(json_data):
|
||||||
prompt_server = server.PromptServer.instance
|
prompt_server = server.PromptServer.instance
|
||||||
json_data = prompt_server.trigger_on_prompt(json_data)
|
json_data = prompt_server.trigger_on_prompt(json_data)
|
||||||
@ -140,16 +137,24 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
|
|||||||
# Loop through each of the inputs and replace them
|
# Loop through each of the inputs and replace them
|
||||||
for key, value in workflow_api.items():
|
for key, value in workflow_api.items():
|
||||||
if 'inputs' in value:
|
if 'inputs' in value:
|
||||||
|
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
|
||||||
|
value['inputs']["client_id"] = sid
|
||||||
|
if (value["class_type"] == "ComfyDeployWebscoketImageInput"):
|
||||||
|
value['inputs']["client_id"] = sid
|
||||||
|
|
||||||
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
|
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
|
||||||
new_value = inputs.inputs[value['inputs']['input_id']]
|
new_value = inputs.inputs[value['inputs']['input_id']]
|
||||||
|
|
||||||
|
# Lets skip it if its an image
|
||||||
|
if isinstance(new_value, Image.Image):
|
||||||
|
continue
|
||||||
|
|
||||||
value['inputs']["input_id"] = new_value
|
value['inputs']["input_id"] = new_value
|
||||||
|
|
||||||
# Fix for external text default value
|
# Fix for external text default value
|
||||||
if (value["class_type"] == "ComfyUIDeployExternalText"):
|
if (value["class_type"] == "ComfyUIDeployExternalText"):
|
||||||
value['inputs']["default_value"] = new_value
|
value['inputs']["default_value"] = new_value
|
||||||
|
|
||||||
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
|
|
||||||
value['inputs']["client_id"] = sid
|
|
||||||
|
|
||||||
print(workflow_api)
|
print(workflow_api)
|
||||||
|
|
||||||
@ -458,6 +463,7 @@ async def websocket_handler(request):
|
|||||||
print("Got input: ", data.get("inputs"))
|
print("Got input: ", data.get("inputs"))
|
||||||
input = data.get('inputs')
|
input = data.get('inputs')
|
||||||
streaming_prompt_metadata[sid].inputs.update(input)
|
streaming_prompt_metadata[sid].inputs.update(input)
|
||||||
|
clear_current_prompt(sid)
|
||||||
send_prompt(sid, streaming_prompt_metadata[sid])
|
send_prompt(sid, streaming_prompt_metadata[sid])
|
||||||
else:
|
else:
|
||||||
# Handle other event types
|
# Handle other event types
|
||||||
@ -465,6 +471,41 @@ async def websocket_handler(request):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print('Failed to decode JSON from message')
|
print('Failed to decode JSON from message')
|
||||||
|
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
data = msg.data
|
||||||
|
event_type, = struct.unpack("<I", data[:4])
|
||||||
|
if event_type == 0: # Image input
|
||||||
|
image_type_code, = struct.unpack("<I", data[4:8])
|
||||||
|
input_id_bytes = data[8:32] # Extract the next 24 bytes for the input ID
|
||||||
|
input_id = input_id_bytes.decode('ascii').strip() # Decode the input ID from ASCII
|
||||||
|
print(event_type)
|
||||||
|
print(image_type_code)
|
||||||
|
print(input_id)
|
||||||
|
image_data = data[32:] # The rest is the image data
|
||||||
|
if image_type_code == 1:
|
||||||
|
image_type = "JPEG"
|
||||||
|
elif image_type_code == 2:
|
||||||
|
image_type = "PNG"
|
||||||
|
elif image_type_code == 3:
|
||||||
|
image_type = "WEBP"
|
||||||
|
else:
|
||||||
|
print("Unknown image type code:", image_type_code)
|
||||||
|
return
|
||||||
|
image = Image.open(BytesIO(image_data))
|
||||||
|
# Check if the input ID already exists and replace the input with the new one
|
||||||
|
if input_id in streaming_prompt_metadata[sid].inputs:
|
||||||
|
# If the input exists, we assume it's an image and attempt to close it to free resources
|
||||||
|
try:
|
||||||
|
existing_image = streaming_prompt_metadata[sid].inputs[input_id]
|
||||||
|
if hasattr(existing_image, 'close'):
|
||||||
|
existing_image.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing previous image for input ID {input_id}: {e}")
|
||||||
|
streaming_prompt_metadata[sid].inputs[input_id] = image
|
||||||
|
clear_current_prompt(sid)
|
||||||
|
send_prompt(sid, streaming_prompt_metadata[sid])
|
||||||
|
print(f"Received {image_type} image of size {image.size} with input ID {input_id}")
|
||||||
|
|
||||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
print('ws connection closed with exception %s' % ws.exception())
|
print('ws connection closed with exception %s' % ws.exception())
|
||||||
finally:
|
finally:
|
||||||
|
16
globals.py
16
globals.py
@ -2,10 +2,26 @@ import struct
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from typing import List, Union, Any, Optional
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
class BaseModel(PydanticBaseModel):
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
class StreamingPrompt(BaseModel):
|
||||||
|
workflow_api: Any
|
||||||
|
auth_token: str
|
||||||
|
inputs: dict[str, Union[str, bytes, Image.Image]]
|
||||||
|
running_prompt_ids: set[str] = set()
|
||||||
|
status_endpoint: str
|
||||||
|
file_upload_endpoint: str
|
||||||
|
|
||||||
sockets = dict()
|
sockets = dict()
|
||||||
|
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user