diff --git a/globals.py b/globals.py index ea705ee..e8eb46f 100644 --- a/globals.py +++ b/globals.py @@ -6,10 +6,12 @@ from PIL import Image, ImageOps from io import BytesIO from pydantic import BaseModel as PydanticBaseModel + class BaseModel(PydanticBaseModel): class Config: arbitrary_types_allowed = True - + + class Status(Enum): NOT_STARTED = "not-started" RUNNING = "running" @@ -17,6 +19,7 @@ class Status(Enum): FAILED = "failed" UPLOADING = "uploading" + class StreamingPrompt(BaseModel): workflow_api: Any auth_token: str @@ -25,44 +28,50 @@ class StreamingPrompt(BaseModel): status_endpoint: Optional[str] file_upload_endpoint: Optional[str] workflow: Any - + + class SimplePrompt(BaseModel): status_endpoint: Optional[str] file_upload_endpoint: Optional[str] - + token: Optional[str] - + workflow_api: dict status: Status = Status.NOT_STARTED progress: set = set() - last_updated_node: Optional[str] = None, + last_updated_node: Optional[str] = (None,) uploading_nodes: set = set() done: bool = False - is_realtime: bool = False, - start_time: Optional[float] = None, + is_realtime: bool = (False,) + start_time: Optional[float] = (None,) + gpu_event_id: Optional[str] = (None,) + sockets = dict() prompt_metadata: dict[str, SimplePrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {} + class BinaryEventTypes: PREVIEW_IMAGE = 1 UNENCODED_PREVIEW_IMAGE = 2 - + + max_output_id_length = 24 -async def send_image(image_data, sid=None, output_id:str = None): + +async def send_image(image_data, sid=None, output_id: str = None): max_length = max_output_id_length output_id = output_id[:max_length] - padded_output_id = output_id.ljust(max_length, '\x00') - encoded_output_id = padded_output_id.encode('ascii', 'replace') - + padded_output_id = output_id.ljust(max_length, "\x00") + encoded_output_id = padded_output_id.encode("ascii", "replace") + image_type = image_data[0] image = image_data[1] max_size = image_data[2] quality = image_data[3] if max_size is not None: - if hasattr(Image, 'Resampling'): + if hasattr(Image, "Resampling"): resampling = Image.Resampling.BILINEAR else: resampling = Image.ANTIALIAS @@ -86,17 +95,23 @@ async def send_image(image_data, sid=None, output_id:str = None): position_after = bytesIO.tell() bytes_written = position_after - position_before print(f"Bytes written: {bytes_written}") - + image.save(bytesIO, format=image_type, quality=quality, compress_level=1) preview_bytes = bytesIO.getvalue() await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) - + + async def send_socket_catch_exception(function, message): try: await function(message) - except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: + except ( + aiohttp.ClientError, + aiohttp.ClientPayloadError, + ConnectionResetError, + ) as err: print("send error:", err) + def encode_bytes(event, data): if not isinstance(event, int): raise RuntimeError(f"Binary event types must be integers, got {event}") @@ -106,9 +121,10 @@ def encode_bytes(event, data): message.extend(data) return message + async def send_bytes(event, data, sid=None): message = encode_bytes(event, data) - + print("sending image to ", event, sid) if sid is None: @@ -116,4 +132,4 @@ async def send_bytes(event, data, sid=None): for ws in _sockets: await send_socket_catch_exception(ws.send_bytes, message) elif sid in sockets: - await send_socket_catch_exception(sockets[sid].send_bytes, message) \ No newline at end of file + await send_socket_catch_exception(sockets[sid].send_bytes, message)