fix: pydantic type simpleprompt

This commit is contained in:
nick 2024-10-04 19:14:37 -07:00
parent 492b81c340
commit 8882f4983c

View File

@ -6,10 +6,12 @@ from PIL import Image, ImageOps
from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
@ -17,6 +19,7 @@ class Status(Enum):
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
@ -26,6 +29,7 @@ class StreamingPrompt(BaseModel):
file_upload_endpoint: Optional[str]
workflow: Any
class SimplePrompt(BaseModel):
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
@ -35,34 +39,39 @@ class SimplePrompt(BaseModel):
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
last_updated_node: Optional[str] = (None,)
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
is_realtime: bool = (False,)
start_time: Optional[float] = (None,)
gpu_event_id: Optional[str] = (None,)
sockets = dict()
prompt_metadata: dict[str, SimplePrompt] = {}
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
max_output_id_length = 24
async def send_image(image_data, sid=None, output_id: str = None):
max_length = max_output_id_length
output_id = output_id[:max_length]
padded_output_id = output_id.ljust(max_length, '\x00')
encoded_output_id = padded_output_id.encode('ascii', 'replace')
padded_output_id = output_id.ljust(max_length, "\x00")
encoded_output_id = padded_output_id.encode("ascii", "replace")
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
quality = image_data[3]
if max_size is not None:
if hasattr(Image, 'Resampling'):
if hasattr(Image, "Resampling"):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
@ -91,12 +100,18 @@ async def send_image(image_data, sid=None, output_id:str = None):
preview_bytes = bytesIO.getvalue()
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
except (
aiohttp.ClientError,
aiohttp.ClientPayloadError,
ConnectionResetError,
) as err:
print("send error:", err)
def encode_bytes(event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
@ -106,6 +121,7 @@ def encode_bytes(event, data):
message.extend(data)
return message
async def send_bytes(event, data, sid=None):
message = encode_bytes(event, data)