feat(plugin): add output ws image node
This commit is contained in:
parent
4ce2c98ae9
commit
ddbf6848a7
64
comfy-nodes/output_websocket_image.py
Normal file
64
comfy-nodes/output_websocket_image.py
Normal file
@ -0,0 +1,64 @@
|
||||
import folder_paths
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import torch
|
||||
from server import PromptServer, BinaryEventTypes
|
||||
import asyncio
|
||||
|
||||
from globals import send_image
|
||||
|
||||
class ComfyDeployWebscoketImageOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"output_id": (
|
||||
"STRING",
|
||||
{"multiline": False, "default": "output_id"},
|
||||
),
|
||||
"images": ("IMAGE", ),
|
||||
},
|
||||
"optional": {
|
||||
"client_id": (
|
||||
"STRING",
|
||||
{"multiline": False, "default": ""},
|
||||
),
|
||||
}
|
||||
# "hidden": {"client_id": "CLIENT_ID"},
|
||||
}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
RETURN_TYPES = ()
|
||||
RETURN_NAMES = ("text",)
|
||||
|
||||
FUNCTION = "run"
|
||||
|
||||
CATEGORY = "output"
|
||||
|
||||
def run(self, output_id, images, client_id):
|
||||
results = []
|
||||
prompt_server = PromptServer.instance
|
||||
loop = prompt_server.loop
|
||||
|
||||
def schedule_coroutine_blocking(target, *args):
|
||||
future = asyncio.run_coroutine_threadsafe(target(*args), loop)
|
||||
return future.result() # This makes the call blocking
|
||||
|
||||
for tensor in images:
|
||||
array = 255.0 * tensor.cpu().numpy()
|
||||
image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8))
|
||||
|
||||
schedule_coroutine_blocking(send_image, ["PNG", image, None], client_id)
|
||||
print("Image sent")
|
||||
# loop.run_until_complete(send_image(["PNG", image, None], client_id))
|
||||
results.append(
|
||||
{"source": "websocket", "content-type": "image/png", "type": "output"}
|
||||
)
|
||||
|
||||
return {"ui": {"images": results}}
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageOutput": ComfyDeployWebscoketImageOutput}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageOutput": "Image Websocket Output (ComfyDeploy)"}
|
163
custom_routes.py
163
custom_routes.py
@ -1,3 +1,5 @@
|
||||
|
||||
|
||||
from aiohttp import web
|
||||
import os
|
||||
import requests
|
||||
@ -26,6 +28,24 @@ import hashlib
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import concurrent.futures
|
||||
from typing import List, Union, Any
|
||||
from PIL import Image
|
||||
import copy
|
||||
|
||||
from globals import sockets
|
||||
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
class BaseModel(PydanticBaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
class StreamingPrompt(BaseModel):
|
||||
workflow_api: Any
|
||||
auth_token: str
|
||||
inputs: dict[str, Union[str, bytes, Image.Image]]
|
||||
|
||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
||||
|
||||
api = None
|
||||
api_task = None
|
||||
@ -80,6 +100,50 @@ def randomSeed(num_digits=15):
|
||||
range_end = (10**num_digits) - 1
|
||||
return random.randint(range_start, range_end)
|
||||
|
||||
def send_prompt(sid: str, inputs: StreamingPrompt):
|
||||
# workflow_api = inputs.workflow_api
|
||||
workflow_api = copy.deepcopy(inputs.workflow_api)
|
||||
|
||||
# Random seed
|
||||
for key in workflow_api:
|
||||
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
|
||||
workflow_api[key]['inputs']['seed'] = randomSeed()
|
||||
|
||||
print("getting inputs" ,inputs.inputs)
|
||||
|
||||
# Loop through each of the inputs and replace them
|
||||
for key, value in workflow_api.items():
|
||||
if 'inputs' in value:
|
||||
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
|
||||
new_value = inputs.inputs[value['inputs']['input_id']]
|
||||
value['inputs']["input_id"] = new_value;
|
||||
|
||||
# Fix for external text default value
|
||||
if (value["class_type"] == "ComfyUIDeployExternalText"):
|
||||
value['inputs']["default_value"] = new_value;
|
||||
|
||||
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
|
||||
value['inputs']["client_id"] = sid;
|
||||
|
||||
print(workflow_api)
|
||||
|
||||
prompt_id = str(uuid.uuid4())
|
||||
|
||||
prompt = {
|
||||
"prompt": workflow_api,
|
||||
"client_id": "comfy_deploy_instance", #api.client_id
|
||||
"prompt_id": prompt_id
|
||||
}
|
||||
|
||||
try:
|
||||
res = post_prompt(prompt)
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
|
||||
stack_trace = traceback.format_exc().strip()
|
||||
print(f"error: {error_type}, {e}")
|
||||
print(f"stack trace: {stack_trace_short}")
|
||||
|
||||
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
|
||||
async def comfy_deploy_run(request):
|
||||
prompt_server = server.PromptServer.instance
|
||||
@ -148,7 +212,6 @@ async def comfy_deploy_run(request):
|
||||
|
||||
return web.json_response(res, status=status)
|
||||
|
||||
sockets = dict()
|
||||
|
||||
def get_comfyui_path_from_file_path(file_path):
|
||||
file_path_parts = file_path.split("\\")
|
||||
@ -179,60 +242,6 @@ async def compute_sha256_checksum(filepath):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
# def hash_chunk(start_end, filepath):
|
||||
# """Hash a specific chunk of the file."""
|
||||
# start, end = start_end
|
||||
# sha256 = hashlib.sha256()
|
||||
# with open(filepath, 'rb') as f:
|
||||
# f.seek(start)
|
||||
# chunk = f.read(end - start)
|
||||
# sha256.update(chunk)
|
||||
# return sha256.digest() # Return the digest of the chunk
|
||||
|
||||
# async def compute_sha256_checksum(filepath):
|
||||
# file_size = os.path.getsize(filepath)
|
||||
# parts = 1 # Or any other division based on file size or desired concurrency
|
||||
# part_size = file_size // parts
|
||||
# start_end_ranges = [(i * part_size, min((i + 1) * part_size, file_size)) for i in range(parts)]
|
||||
|
||||
# print(start_end_ranges, file_size)
|
||||
|
||||
# loop = asyncio.get_running_loop()
|
||||
|
||||
# # Use ThreadPoolExecutor to process chunks in parallel
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
# futures = [loop.run_in_executor(executor, hash_chunk, start_end, filepath) for start_end in start_end_ranges]
|
||||
# chunk_hashes = await asyncio.gather(*futures)
|
||||
|
||||
# # Combine the hashes sequentially
|
||||
# final_sha256 = hashlib.sha256()
|
||||
# for chunk_hash in chunk_hashes:
|
||||
# final_sha256.update(chunk_hash)
|
||||
|
||||
# return final_sha256.hexdigest()
|
||||
|
||||
# def hash_chunk(filepath):
|
||||
# chunk_size = 1024 * 256 # 256KB per chunk
|
||||
# sha256 = hashlib.sha256()
|
||||
# with open(filepath, 'rb') as f:
|
||||
# while True:
|
||||
# chunk = f.read(chunk_size)
|
||||
# if not chunk:
|
||||
# break # End of file
|
||||
# sha256.update(chunk)
|
||||
# return sha256.hexdigest()
|
||||
|
||||
# async def compute_sha256_checksum(filepath):
|
||||
# print("computing sha256 checksum")
|
||||
# filepath = get_comfyui_path_from_file_path(filepath)
|
||||
|
||||
# loop = asyncio.get_running_loop()
|
||||
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# task = loop.run_in_executor(executor, hash_chunk, filepath)
|
||||
|
||||
# return await task
|
||||
|
||||
# This is start uploading the files to Comfy Deploy
|
||||
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
|
||||
async def upload_file(request):
|
||||
@ -366,6 +375,48 @@ async def websocket_handler(request):
|
||||
await send_first_time_log(sid)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
print(data)
|
||||
event_type = data.get('event')
|
||||
if event_type == 'workflow_endpoint':
|
||||
_data = data.get('data')
|
||||
get_workflow_endpoint_url = _data.get('get_workflow_endpoint_url')
|
||||
|
||||
auth_token = _data.get('auth_token')
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {'Authorization': f'Bearer {auth_token}'}
|
||||
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
workflow = await response.json()
|
||||
|
||||
print(workflow["version"])
|
||||
|
||||
streaming_prompt_metadata[sid] = StreamingPrompt(
|
||||
workflow_api=workflow["workflow_api"],
|
||||
auth_token=auth_token,
|
||||
inputs={}
|
||||
)
|
||||
|
||||
# await send("workflow_api", workflow_api, sid)
|
||||
else:
|
||||
error_message = await response.text()
|
||||
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
|
||||
# await send("error", {"message": error_message}, sid)
|
||||
pass
|
||||
elif event_type == 'input':
|
||||
print("Got input: ", data.get("inputs"))
|
||||
input = data.get('inputs')
|
||||
streaming_prompt_metadata[sid].inputs.update(input)
|
||||
send_prompt(sid, streaming_prompt_metadata[sid])
|
||||
else:
|
||||
# Handle other event types
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
print('Failed to decode JSON from message')
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print('ws connection closed with exception %s' % ws.exception())
|
||||
finally:
|
||||
@ -387,6 +438,7 @@ async def comfy_deploy_check_status(request):
|
||||
|
||||
async def send(event, data, sid=None):
|
||||
try:
|
||||
# message = {"event": event, "data": data}
|
||||
if sid:
|
||||
ws = sockets.get(sid)
|
||||
if ws != None and not ws.closed: # Check if the WebSocket connection is open and not closing
|
||||
@ -462,7 +514,6 @@ async def send_json_override(self, event, data, sid=None):
|
||||
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
|
||||
# update_run_with_output(prompt_id, data.get('output'))
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
NOT_STARTED = "not-started"
|
||||
RUNNING = "running"
|
||||
|
63
globals.py
Normal file
63
globals.py
Normal file
@ -0,0 +1,63 @@
|
||||
import struct
|
||||
|
||||
import aiohttp
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
|
||||
sockets = dict()
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
|
||||
async def send_image(image_data, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.ANTIALIAS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
type_num = 1
|
||||
if image_type == "JPEG":
|
||||
type_num = 1
|
||||
elif image_type == "PNG":
|
||||
type_num = 2
|
||||
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", type_num)
|
||||
bytesIO.write(header)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||
print("send error:", err)
|
||||
|
||||
def encode_bytes(event, data):
|
||||
if not isinstance(event, int):
|
||||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||
|
||||
packed = struct.pack(">I", event)
|
||||
message = bytearray(packed)
|
||||
message.extend(data)
|
||||
return message
|
||||
|
||||
async def send_bytes(event, data, sid=None):
|
||||
message = encode_bytes(event, data)
|
||||
|
||||
print("sending image to ", event, sid)
|
||||
|
||||
if sid is None:
|
||||
_sockets = list(sockets.values())
|
||||
for ws in _sockets:
|
||||
await send_socket_catch_exception(ws.send_bytes, message)
|
||||
elif sid in sockets:
|
||||
await send_socket_catch_exception(sockets[sid].send_bytes, message)
|
@ -1 +1,2 @@
|
||||
aiofiles
|
||||
aiofiles
|
||||
pydantic
|
Loading…
x
Reference in New Issue
Block a user