fix(plugin): output_id is also included in the binary data back

This commit is contained in:
bennykok 2024-02-29 11:39:54 -08:00
parent 32c6d1215b
commit 410d03cd2b
2 changed files with 28 additions and 3 deletions

View File

@ -5,7 +5,7 @@ import torch
from server import PromptServer, BinaryEventTypes from server import PromptServer, BinaryEventTypes
import asyncio import asyncio
from globals import send_image from globals import send_image, max_output_id_length
class ComfyDeployWebscoketImageOutput: class ComfyDeployWebscoketImageOutput:
@classmethod @classmethod
@ -37,6 +37,16 @@ class ComfyDeployWebscoketImageOutput:
FUNCTION = "run" FUNCTION = "run"
CATEGORY = "output" CATEGORY = "output"
@classmethod
def VALIDATE_INPUTS(s, output_id):
try:
if len(output_id.encode('ascii')) > max_output_id_length:
raise ValueError(f"output_id size is greater than {max_output_id_length} bytes")
except UnicodeEncodeError:
raise ValueError("output_id is not ASCII encodable")
return True
def run(self, output_id, images, file_type, quality, client_id): def run(self, output_id, images, file_type, quality, client_id):
prompt_server = PromptServer.instance prompt_server = PromptServer.instance
@ -50,7 +60,7 @@ class ComfyDeployWebscoketImageOutput:
array = 255.0 * tensor.cpu().numpy() array = 255.0 * tensor.cpu().numpy()
image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8)) image = Image.fromarray(np.clip(array, 0, 255).astype(np.uint8))
schedule_coroutine_blocking(send_image, [file_type, image, None, quality], client_id) schedule_coroutine_blocking(send_image, [file_type, image, None, quality], client_id, output_id)
print("Image sent") print("Image sent")
return {"ui": {}} return {"ui": {}}

View File

@ -10,8 +10,15 @@ sockets = dict()
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2 UNENCODED_PREVIEW_IMAGE = 2
max_output_id_length = 24
async def send_image(image_data, sid=None): async def send_image(image_data, sid=None, output_id:str = None):
max_length = max_output_id_length
output_id = output_id[:max_length]
padded_output_id = output_id.ljust(max_length, '\x00')
encoded_output_id = padded_output_id.encode('ascii', 'replace')
image_type = image_data[0] image_type = image_data[0]
image = image_data[1] image = image_data[1]
max_size = image_data[2] max_size = image_data[2]
@ -33,7 +40,15 @@ async def send_image(image_data, sid=None):
bytesIO = BytesIO() bytesIO = BytesIO()
header = struct.pack(">I", type_num) header = struct.pack(">I", type_num)
# 4 bytes for the type
bytesIO.write(header) bytesIO.write(header)
# 10 bytes for the output_id
position_before = bytesIO.tell()
bytesIO.write(encoded_output_id)
position_after = bytesIO.tell()
bytes_written = position_after - position_before
print(f"Bytes written: {bytes_written}")
image.save(bytesIO, format=image_type, quality=quality, compress_level=1) image.save(bytesIO, format=image_type, quality=quality, compress_level=1)
preview_bytes = bytesIO.getvalue() preview_bytes = bytesIO.getvalue()
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)