import struct from enum import Enum import aiohttp from typing import List, Union, Any, Optional from PIL import Image, ImageOps from io import BytesIO from pydantic import BaseModel as PydanticBaseModel class BaseModel(PydanticBaseModel): class Config: arbitrary_types_allowed = True class Status(Enum): NOT_STARTED = "not-started" RUNNING = "running" SUCCESS = "success" FAILED = "failed" UPLOADING = "uploading" class StreamingPrompt(BaseModel): workflow_api: Any auth_token: str inputs: dict[str, Union[str, bytes, Image.Image]] running_prompt_ids: set[str] = set() status_endpoint: Optional[str] file_upload_endpoint: Optional[str] workflow: Any gpu_event_id: Optional[str] = None class SimplePrompt(BaseModel): status_endpoint: Optional[str] file_upload_endpoint: Optional[str] token: Optional[str] workflow_api: dict status: Status = Status.NOT_STARTED progress: set = set() last_updated_node: Optional[str] = None uploading_nodes: set = set() done: bool = False is_realtime: bool = False start_time: Optional[float] = None gpu_event_id: Optional[str] = None sockets = dict() prompt_metadata: dict[str, SimplePrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {} class BinaryEventTypes: PREVIEW_IMAGE = 1 UNENCODED_PREVIEW_IMAGE = 2 max_output_id_length = 24 async def send_image(image_data, sid=None, output_id: str = None): max_length = max_output_id_length output_id = output_id[:max_length] padded_output_id = output_id.ljust(max_length, "\x00") encoded_output_id = padded_output_id.encode("ascii", "replace") image_type = image_data[0] image = image_data[1] max_size = image_data[2] quality = image_data[3] if max_size is not None: if hasattr(Image, "Resampling"): resampling = Image.Resampling.BILINEAR else: resampling = Image.ANTIALIAS image = ImageOps.contain(image, (max_size, max_size), resampling) type_num = 1 if image_type == "JPEG": type_num = 1 elif image_type == "PNG": type_num = 2 elif image_type == "WEBP": type_num = 3 bytesIO = BytesIO() header = struct.pack(">I", type_num) # 4 bytes for the type bytesIO.write(header) # 10 bytes for the output_id position_before = bytesIO.tell() bytesIO.write(encoded_output_id) position_after = bytesIO.tell() bytes_written = position_after - position_before print(f"Bytes written: {bytes_written}") image.save(bytesIO, format=image_type, quality=quality, compress_level=1) preview_bytes = bytesIO.getvalue() await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) async def send_socket_catch_exception(function, message): try: await function(message) except ( aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, ) as err: print("send error:", err) def encode_bytes(event, data): if not isinstance(event, int): raise RuntimeError(f"Binary event types must be integers, got {event}") packed = struct.pack(">I", event) message = bytearray(packed) message.extend(data) return message async def send_bytes(event, data, sid=None): message = encode_bytes(event, data) print("sending image to ", event, sid) if sid is None: _sockets = list(sockets.values()) for ws in _sockets: await send_socket_catch_exception(ws.send_bytes, message) elif sid in sockets: await send_socket_catch_exception(sockets[sid].send_bytes, message)