feat(plugin): add output ws image node
This commit is contained in:
parent
4ce2c98ae9
commit
ddbf6848a7
64
comfy-nodes/output_websocket_image.py
Normal file
64
comfy-nodes/output_websocket_image.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import folder_paths
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from server import PromptServer, BinaryEventTypes
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from globals import send_image
|
||||||
|
|
||||||
|
class ComfyDeployWebscoketImageOutput:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"output_id": (
|
||||||
|
"STRING",
|
||||||
|
{"multiline": False, "default": "output_id"},
|
||||||
|
),
|
||||||
|
"images": ("IMAGE", ),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"client_id": (
|
||||||
|
"STRING",
|
||||||
|
{"multiline": False, "default": ""},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
# "hidden": {"client_id": "CLIENT_ID"},
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
RETURN_NAMES = ("text",)
|
||||||
|
|
||||||
|
FUNCTION = "run"
|
||||||
|
|
||||||
|
CATEGORY = "output"
|
||||||
|
|
||||||
|
def run(self, output_id, images, client_id):
|
||||||
|
results = []
|
||||||
|
prompt_server = PromptServer.instance
|
||||||
|
loop = prompt_server.loop
|
||||||
|
|
||||||
|
def schedule_coroutine_blocking(target, *args):
|
||||||
|
future = asyncio.run_coroutine_threadsafe(target(*args), loop)
|
||||||
|
return future.result() # This makes the call blocking
|
||||||
|
|
||||||
|
for tensor in images:
|
||||||
|
array = 255.0 * tensor.cpu().numpy()
|
||||||
|
image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
|
schedule_coroutine_blocking(send_image, ["PNG", image, None], client_id)
|
||||||
|
print("Image sent")
|
||||||
|
# loop.run_until_complete(send_image(["PNG", image, None], client_id))
|
||||||
|
results.append(
|
||||||
|
{"source": "websocket", "content-type": "image/png", "type": "output"}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"ui": {"images": results}}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {"ComfyDeployWebscoketImageOutput": ComfyDeployWebscoketImageOutput}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {"ComfyDeployWebscoketImageOutput": "Image Websocket Output (ComfyDeploy)"}
|
163
custom_routes.py
163
custom_routes.py
@ -1,3 +1,5 @@
|
|||||||
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
@ -26,6 +28,24 @@ import hashlib
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
from typing import List, Union, Any
|
||||||
|
from PIL import Image
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from globals import sockets
|
||||||
|
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
class BaseModel(PydanticBaseModel):
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
class StreamingPrompt(BaseModel):
|
||||||
|
workflow_api: Any
|
||||||
|
auth_token: str
|
||||||
|
inputs: dict[str, Union[str, bytes, Image.Image]]
|
||||||
|
|
||||||
|
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
||||||
|
|
||||||
api = None
|
api = None
|
||||||
api_task = None
|
api_task = None
|
||||||
@ -80,6 +100,50 @@ def randomSeed(num_digits=15):
|
|||||||
range_end = (10**num_digits) - 1
|
range_end = (10**num_digits) - 1
|
||||||
return random.randint(range_start, range_end)
|
return random.randint(range_start, range_end)
|
||||||
|
|
||||||
|
def send_prompt(sid: str, inputs: StreamingPrompt):
|
||||||
|
# workflow_api = inputs.workflow_api
|
||||||
|
workflow_api = copy.deepcopy(inputs.workflow_api)
|
||||||
|
|
||||||
|
# Random seed
|
||||||
|
for key in workflow_api:
|
||||||
|
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
|
||||||
|
workflow_api[key]['inputs']['seed'] = randomSeed()
|
||||||
|
|
||||||
|
print("getting inputs" ,inputs.inputs)
|
||||||
|
|
||||||
|
# Loop through each of the inputs and replace them
|
||||||
|
for key, value in workflow_api.items():
|
||||||
|
if 'inputs' in value:
|
||||||
|
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
|
||||||
|
new_value = inputs.inputs[value['inputs']['input_id']]
|
||||||
|
value['inputs']["input_id"] = new_value;
|
||||||
|
|
||||||
|
# Fix for external text default value
|
||||||
|
if (value["class_type"] == "ComfyUIDeployExternalText"):
|
||||||
|
value['inputs']["default_value"] = new_value;
|
||||||
|
|
||||||
|
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
|
||||||
|
value['inputs']["client_id"] = sid;
|
||||||
|
|
||||||
|
print(workflow_api)
|
||||||
|
|
||||||
|
prompt_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
"prompt": workflow_api,
|
||||||
|
"client_id": "comfy_deploy_instance", #api.client_id
|
||||||
|
"prompt_id": prompt_id
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = post_prompt(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
error_type = type(e).__name__
|
||||||
|
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
|
||||||
|
stack_trace = traceback.format_exc().strip()
|
||||||
|
print(f"error: {error_type}, {e}")
|
||||||
|
print(f"stack trace: {stack_trace_short}")
|
||||||
|
|
||||||
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
|
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
|
||||||
async def comfy_deploy_run(request):
|
async def comfy_deploy_run(request):
|
||||||
prompt_server = server.PromptServer.instance
|
prompt_server = server.PromptServer.instance
|
||||||
@ -148,7 +212,6 @@ async def comfy_deploy_run(request):
|
|||||||
|
|
||||||
return web.json_response(res, status=status)
|
return web.json_response(res, status=status)
|
||||||
|
|
||||||
sockets = dict()
|
|
||||||
|
|
||||||
def get_comfyui_path_from_file_path(file_path):
|
def get_comfyui_path_from_file_path(file_path):
|
||||||
file_path_parts = file_path.split("\\")
|
file_path_parts = file_path.split("\\")
|
||||||
@ -179,60 +242,6 @@ async def compute_sha256_checksum(filepath):
|
|||||||
sha256.update(chunk)
|
sha256.update(chunk)
|
||||||
return sha256.hexdigest()
|
return sha256.hexdigest()
|
||||||
|
|
||||||
# def hash_chunk(start_end, filepath):
|
|
||||||
# """Hash a specific chunk of the file."""
|
|
||||||
# start, end = start_end
|
|
||||||
# sha256 = hashlib.sha256()
|
|
||||||
# with open(filepath, 'rb') as f:
|
|
||||||
# f.seek(start)
|
|
||||||
# chunk = f.read(end - start)
|
|
||||||
# sha256.update(chunk)
|
|
||||||
# return sha256.digest() # Return the digest of the chunk
|
|
||||||
|
|
||||||
# async def compute_sha256_checksum(filepath):
|
|
||||||
# file_size = os.path.getsize(filepath)
|
|
||||||
# parts = 1 # Or any other division based on file size or desired concurrency
|
|
||||||
# part_size = file_size // parts
|
|
||||||
# start_end_ranges = [(i * part_size, min((i + 1) * part_size, file_size)) for i in range(parts)]
|
|
||||||
|
|
||||||
# print(start_end_ranges, file_size)
|
|
||||||
|
|
||||||
# loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
# # Use ThreadPoolExecutor to process chunks in parallel
|
|
||||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
|
||||||
# futures = [loop.run_in_executor(executor, hash_chunk, start_end, filepath) for start_end in start_end_ranges]
|
|
||||||
# chunk_hashes = await asyncio.gather(*futures)
|
|
||||||
|
|
||||||
# # Combine the hashes sequentially
|
|
||||||
# final_sha256 = hashlib.sha256()
|
|
||||||
# for chunk_hash in chunk_hashes:
|
|
||||||
# final_sha256.update(chunk_hash)
|
|
||||||
|
|
||||||
# return final_sha256.hexdigest()
|
|
||||||
|
|
||||||
# def hash_chunk(filepath):
|
|
||||||
# chunk_size = 1024 * 256 # 256KB per chunk
|
|
||||||
# sha256 = hashlib.sha256()
|
|
||||||
# with open(filepath, 'rb') as f:
|
|
||||||
# while True:
|
|
||||||
# chunk = f.read(chunk_size)
|
|
||||||
# if not chunk:
|
|
||||||
# break # End of file
|
|
||||||
# sha256.update(chunk)
|
|
||||||
# return sha256.hexdigest()
|
|
||||||
|
|
||||||
# async def compute_sha256_checksum(filepath):
|
|
||||||
# print("computing sha256 checksum")
|
|
||||||
# filepath = get_comfyui_path_from_file_path(filepath)
|
|
||||||
|
|
||||||
# loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
||||||
# task = loop.run_in_executor(executor, hash_chunk, filepath)
|
|
||||||
|
|
||||||
# return await task
|
|
||||||
|
|
||||||
# This is start uploading the files to Comfy Deploy
|
# This is start uploading the files to Comfy Deploy
|
||||||
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
|
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
|
||||||
async def upload_file(request):
|
async def upload_file(request):
|
||||||
@ -366,6 +375,48 @@ async def websocket_handler(request):
|
|||||||
await send_first_time_log(sid)
|
await send_first_time_log(sid)
|
||||||
|
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
try:
|
||||||
|
data = json.loads(msg.data)
|
||||||
|
print(data)
|
||||||
|
event_type = data.get('event')
|
||||||
|
if event_type == 'workflow_endpoint':
|
||||||
|
_data = data.get('data')
|
||||||
|
get_workflow_endpoint_url = _data.get('get_workflow_endpoint_url')
|
||||||
|
|
||||||
|
auth_token = _data.get('auth_token')
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
headers = {'Authorization': f'Bearer {auth_token}'}
|
||||||
|
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
workflow = await response.json()
|
||||||
|
|
||||||
|
print(workflow["version"])
|
||||||
|
|
||||||
|
streaming_prompt_metadata[sid] = StreamingPrompt(
|
||||||
|
workflow_api=workflow["workflow_api"],
|
||||||
|
auth_token=auth_token,
|
||||||
|
inputs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# await send("workflow_api", workflow_api, sid)
|
||||||
|
else:
|
||||||
|
error_message = await response.text()
|
||||||
|
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
|
||||||
|
# await send("error", {"message": error_message}, sid)
|
||||||
|
pass
|
||||||
|
elif event_type == 'input':
|
||||||
|
print("Got input: ", data.get("inputs"))
|
||||||
|
input = data.get('inputs')
|
||||||
|
streaming_prompt_metadata[sid].inputs.update(input)
|
||||||
|
send_prompt(sid, streaming_prompt_metadata[sid])
|
||||||
|
else:
|
||||||
|
# Handle other event types
|
||||||
|
pass
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print('Failed to decode JSON from message')
|
||||||
|
|
||||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
print('ws connection closed with exception %s' % ws.exception())
|
print('ws connection closed with exception %s' % ws.exception())
|
||||||
finally:
|
finally:
|
||||||
@ -387,6 +438,7 @@ async def comfy_deploy_check_status(request):
|
|||||||
|
|
||||||
async def send(event, data, sid=None):
|
async def send(event, data, sid=None):
|
||||||
try:
|
try:
|
||||||
|
# message = {"event": event, "data": data}
|
||||||
if sid:
|
if sid:
|
||||||
ws = sockets.get(sid)
|
ws = sockets.get(sid)
|
||||||
if ws != None and not ws.closed: # Check if the WebSocket connection is open and not closing
|
if ws != None and not ws.closed: # Check if the WebSocket connection is open and not closing
|
||||||
@ -462,7 +514,6 @@ async def send_json_override(self, event, data, sid=None):
|
|||||||
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
|
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
|
||||||
# update_run_with_output(prompt_id, data.get('output'))
|
# update_run_with_output(prompt_id, data.get('output'))
|
||||||
|
|
||||||
|
|
||||||
class Status(Enum):
|
class Status(Enum):
|
||||||
NOT_STARTED = "not-started"
|
NOT_STARTED = "not-started"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
|
63
globals.py
Normal file
63
globals.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import struct
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
sockets = dict()
|
||||||
|
|
||||||
|
class BinaryEventTypes:
|
||||||
|
PREVIEW_IMAGE = 1
|
||||||
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
|
|
||||||
|
async def send_image(image_data, sid=None):
|
||||||
|
image_type = image_data[0]
|
||||||
|
image = image_data[1]
|
||||||
|
max_size = image_data[2]
|
||||||
|
if max_size is not None:
|
||||||
|
if hasattr(Image, 'Resampling'):
|
||||||
|
resampling = Image.Resampling.BILINEAR
|
||||||
|
else:
|
||||||
|
resampling = Image.ANTIALIAS
|
||||||
|
|
||||||
|
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||||
|
type_num = 1
|
||||||
|
if image_type == "JPEG":
|
||||||
|
type_num = 1
|
||||||
|
elif image_type == "PNG":
|
||||||
|
type_num = 2
|
||||||
|
|
||||||
|
bytesIO = BytesIO()
|
||||||
|
header = struct.pack(">I", type_num)
|
||||||
|
bytesIO.write(header)
|
||||||
|
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||||
|
preview_bytes = bytesIO.getvalue()
|
||||||
|
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||||
|
|
||||||
|
async def send_socket_catch_exception(function, message):
|
||||||
|
try:
|
||||||
|
await function(message)
|
||||||
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
|
print("send error:", err)
|
||||||
|
|
||||||
|
def encode_bytes(event, data):
|
||||||
|
if not isinstance(event, int):
|
||||||
|
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||||
|
|
||||||
|
packed = struct.pack(">I", event)
|
||||||
|
message = bytearray(packed)
|
||||||
|
message.extend(data)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def send_bytes(event, data, sid=None):
|
||||||
|
message = encode_bytes(event, data)
|
||||||
|
|
||||||
|
print("sending image to ", event, sid)
|
||||||
|
|
||||||
|
if sid is None:
|
||||||
|
_sockets = list(sockets.values())
|
||||||
|
for ws in _sockets:
|
||||||
|
await send_socket_catch_exception(ws.send_bytes, message)
|
||||||
|
elif sid in sockets:
|
||||||
|
await send_socket_catch_exception(sockets[sid].send_bytes, message)
|
@ -1 +1,2 @@
|
|||||||
aiofiles
|
aiofiles
|
||||||
|
pydantic
|
Loading…
x
Reference in New Issue
Block a user