feat(plugin): add output ws image node

This commit is contained in:
bennykok 2024-02-23 13:38:00 -08:00
parent 4ce2c98ae9
commit ddbf6848a7
4 changed files with 236 additions and 57 deletions

View File

@ -0,0 +1,64 @@
import folder_paths
from PIL import Image, ImageOps
import numpy as np
import torch
from server import PromptServer, BinaryEventTypes
import asyncio
from globals import send_image
class ComfyDeployWebscoketImageOutput:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"output_id": (
"STRING",
{"multiline": False, "default": "output_id"},
),
"images": ("IMAGE", ),
},
"optional": {
"client_id": (
"STRING",
{"multiline": False, "default": ""},
),
}
# "hidden": {"client_id": "CLIENT_ID"},
}
OUTPUT_NODE = True
RETURN_TYPES = ()
RETURN_NAMES = ("text",)
FUNCTION = "run"
CATEGORY = "output"
def run(self, output_id, images, client_id):
results = []
prompt_server = PromptServer.instance
loop = prompt_server.loop
def schedule_coroutine_blocking(target, *args):
future = asyncio.run_coroutine_threadsafe(target(*args), loop)
return future.result() # This makes the call blocking
for tensor in images:
array = 255.0 * tensor.cpu().numpy()
image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8))
schedule_coroutine_blocking(send_image, ["PNG", image, None], client_id)
print("Image sent")
# loop.run_until_complete(send_image(["PNG", image, None], client_id))
results.append(
{"source": "websocket", "content-type": "image/png", "type": "output"}
)
return {"ui": {"images": results}}
NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageOutput": ComfyDeployWebscoketImageOutput}
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageOutput": "Image Websocket Output (ComfyDeploy)"}

View File

@ -1,3 +1,5 @@
from aiohttp import web
import os
import requests
@ -26,6 +28,24 @@ import hashlib
import aiohttp
import aiofiles
import concurrent.futures
from typing import List, Union, Any
from PIL import Image
import copy
from globals import sockets
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
api = None
api_task = None
@ -80,6 +100,50 @@ def randomSeed(num_digits=15):
range_end = (10**num_digits) - 1
return random.randint(range_start, range_end)
def send_prompt(sid: str, inputs: StreamingPrompt):
# workflow_api = inputs.workflow_api
workflow_api = copy.deepcopy(inputs.workflow_api)
# Random seed
for key in workflow_api:
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
workflow_api[key]['inputs']['seed'] = randomSeed()
print("getting inputs" ,inputs.inputs)
# Loop through each of the inputs and replace them
for key, value in workflow_api.items():
if 'inputs' in value:
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
new_value = inputs.inputs[value['inputs']['input_id']]
value['inputs']["input_id"] = new_value;
# Fix for external text default value
if (value["class_type"] == "ComfyUIDeployExternalText"):
value['inputs']["default_value"] = new_value;
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
value['inputs']["client_id"] = sid;
print(workflow_api)
prompt_id = str(uuid.uuid4())
prompt = {
"prompt": workflow_api,
"client_id": "comfy_deploy_instance", #api.client_id
"prompt_id": prompt_id
}
try:
res = post_prompt(prompt)
except Exception as e:
error_type = type(e).__name__
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
stack_trace = traceback.format_exc().strip()
print(f"error: {error_type}, {e}")
print(f"stack trace: {stack_trace_short}")
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
async def comfy_deploy_run(request):
prompt_server = server.PromptServer.instance
@ -148,7 +212,6 @@ async def comfy_deploy_run(request):
return web.json_response(res, status=status)
sockets = dict()
def get_comfyui_path_from_file_path(file_path):
file_path_parts = file_path.split("\\")
@ -179,60 +242,6 @@ async def compute_sha256_checksum(filepath):
sha256.update(chunk)
return sha256.hexdigest()
# def hash_chunk(start_end, filepath):
# """Hash a specific chunk of the file."""
# start, end = start_end
# sha256 = hashlib.sha256()
# with open(filepath, 'rb') as f:
# f.seek(start)
# chunk = f.read(end - start)
# sha256.update(chunk)
# return sha256.digest() # Return the digest of the chunk
# async def compute_sha256_checksum(filepath):
# file_size = os.path.getsize(filepath)
# parts = 1 # Or any other division based on file size or desired concurrency
# part_size = file_size // parts
# start_end_ranges = [(i * part_size, min((i + 1) * part_size, file_size)) for i in range(parts)]
# print(start_end_ranges, file_size)
# loop = asyncio.get_running_loop()
# # Use ThreadPoolExecutor to process chunks in parallel
# with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
# futures = [loop.run_in_executor(executor, hash_chunk, start_end, filepath) for start_end in start_end_ranges]
# chunk_hashes = await asyncio.gather(*futures)
# # Combine the hashes sequentially
# final_sha256 = hashlib.sha256()
# for chunk_hash in chunk_hashes:
# final_sha256.update(chunk_hash)
# return final_sha256.hexdigest()
# def hash_chunk(filepath):
# chunk_size = 1024 * 256 # 256KB per chunk
# sha256 = hashlib.sha256()
# with open(filepath, 'rb') as f:
# while True:
# chunk = f.read(chunk_size)
# if not chunk:
# break # End of file
# sha256.update(chunk)
# return sha256.hexdigest()
# async def compute_sha256_checksum(filepath):
# print("computing sha256 checksum")
# filepath = get_comfyui_path_from_file_path(filepath)
# loop = asyncio.get_running_loop()
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
# task = loop.run_in_executor(executor, hash_chunk, filepath)
# return await task
# This is start uploading the files to Comfy Deploy
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
async def upload_file(request):
@ -366,6 +375,48 @@ async def websocket_handler(request):
await send_first_time_log(sid)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
print(data)
event_type = data.get('event')
if event_type == 'workflow_endpoint':
_data = data.get('data')
get_workflow_endpoint_url = _data.get('get_workflow_endpoint_url')
auth_token = _data.get('auth_token')
async with aiohttp.ClientSession() as session:
headers = {'Authorization': f'Bearer {auth_token}'}
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
if response.status == 200:
workflow = await response.json()
print(workflow["version"])
streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"],
auth_token=auth_token,
inputs={}
)
# await send("workflow_api", workflow_api, sid)
else:
error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid)
pass
elif event_type == 'input':
print("Got input: ", data.get("inputs"))
input = data.get('inputs')
streaming_prompt_metadata[sid].inputs.update(input)
send_prompt(sid, streaming_prompt_metadata[sid])
else:
# Handle other event types
pass
except json.JSONDecodeError:
print('Failed to decode JSON from message')
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
@ -387,6 +438,7 @@ async def comfy_deploy_check_status(request):
async def send(event, data, sid=None):
try:
# message = {"event": event, "data": data}
if sid:
ws = sockets.get(sid)
if ws != None and not ws.closed: # Check if the WebSocket connection is open and not closing
@ -462,7 +514,6 @@ async def send_json_override(self, event, data, sid=None):
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# update_run_with_output(prompt_id, data.get('output'))
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"

63
globals.py Normal file
View File

@ -0,0 +1,63 @@
import struct
import aiohttp
from PIL import Image, ImageOps
from io import BytesIO
sockets = dict()
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_image(image_data, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
print("send error:", err)
def encode_bytes(event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(event, data, sid=None):
message = encode_bytes(event, data)
print("sending image to ", event, sid)
if sid is None:
_sockets = list(sockets.values())
for ws in _sockets:
await send_socket_catch_exception(ws.send_bytes, message)
elif sid in sockets:
await send_socket_catch_exception(sockets[sid].send_bytes, message)

View File

@ -1 +1,2 @@
aiofiles
aiofiles
pydantic