fix: pydantic type simpleprompt
This commit is contained in:
parent
492b81c340
commit
8882f4983c
32
globals.py
32
globals.py
@ -6,10 +6,12 @@ from PIL import Image, ImageOps
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(PydanticBaseModel):
|
class BaseModel(PydanticBaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class Status(Enum):
|
class Status(Enum):
|
||||||
NOT_STARTED = "not-started"
|
NOT_STARTED = "not-started"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
@ -17,6 +19,7 @@ class Status(Enum):
|
|||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
UPLOADING = "uploading"
|
UPLOADING = "uploading"
|
||||||
|
|
||||||
|
|
||||||
class StreamingPrompt(BaseModel):
|
class StreamingPrompt(BaseModel):
|
||||||
workflow_api: Any
|
workflow_api: Any
|
||||||
auth_token: str
|
auth_token: str
|
||||||
@ -26,6 +29,7 @@ class StreamingPrompt(BaseModel):
|
|||||||
file_upload_endpoint: Optional[str]
|
file_upload_endpoint: Optional[str]
|
||||||
workflow: Any
|
workflow: Any
|
||||||
|
|
||||||
|
|
||||||
class SimplePrompt(BaseModel):
|
class SimplePrompt(BaseModel):
|
||||||
status_endpoint: Optional[str]
|
status_endpoint: Optional[str]
|
||||||
file_upload_endpoint: Optional[str]
|
file_upload_endpoint: Optional[str]
|
||||||
@ -35,34 +39,39 @@ class SimplePrompt(BaseModel):
|
|||||||
workflow_api: dict
|
workflow_api: dict
|
||||||
status: Status = Status.NOT_STARTED
|
status: Status = Status.NOT_STARTED
|
||||||
progress: set = set()
|
progress: set = set()
|
||||||
last_updated_node: Optional[str] = None,
|
last_updated_node: Optional[str] = (None,)
|
||||||
uploading_nodes: set = set()
|
uploading_nodes: set = set()
|
||||||
done: bool = False
|
done: bool = False
|
||||||
is_realtime: bool = False,
|
is_realtime: bool = (False,)
|
||||||
start_time: Optional[float] = None,
|
start_time: Optional[float] = (None,)
|
||||||
|
gpu_event_id: Optional[str] = (None,)
|
||||||
|
|
||||||
|
|
||||||
sockets = dict()
|
sockets = dict()
|
||||||
prompt_metadata: dict[str, SimplePrompt] = {}
|
prompt_metadata: dict[str, SimplePrompt] = {}
|
||||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
|
|
||||||
|
|
||||||
max_output_id_length = 24
|
max_output_id_length = 24
|
||||||
|
|
||||||
async def send_image(image_data, sid=None, output_id:str = None):
|
|
||||||
|
async def send_image(image_data, sid=None, output_id: str = None):
|
||||||
max_length = max_output_id_length
|
max_length = max_output_id_length
|
||||||
output_id = output_id[:max_length]
|
output_id = output_id[:max_length]
|
||||||
padded_output_id = output_id.ljust(max_length, '\x00')
|
padded_output_id = output_id.ljust(max_length, "\x00")
|
||||||
encoded_output_id = padded_output_id.encode('ascii', 'replace')
|
encoded_output_id = padded_output_id.encode("ascii", "replace")
|
||||||
|
|
||||||
image_type = image_data[0]
|
image_type = image_data[0]
|
||||||
image = image_data[1]
|
image = image_data[1]
|
||||||
max_size = image_data[2]
|
max_size = image_data[2]
|
||||||
quality = image_data[3]
|
quality = image_data[3]
|
||||||
if max_size is not None:
|
if max_size is not None:
|
||||||
if hasattr(Image, 'Resampling'):
|
if hasattr(Image, "Resampling"):
|
||||||
resampling = Image.Resampling.BILINEAR
|
resampling = Image.Resampling.BILINEAR
|
||||||
else:
|
else:
|
||||||
resampling = Image.ANTIALIAS
|
resampling = Image.ANTIALIAS
|
||||||
@ -91,12 +100,18 @@ async def send_image(image_data, sid=None, output_id:str = None):
|
|||||||
preview_bytes = bytesIO.getvalue()
|
preview_bytes = bytesIO.getvalue()
|
||||||
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||||
|
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (
|
||||||
|
aiohttp.ClientError,
|
||||||
|
aiohttp.ClientPayloadError,
|
||||||
|
ConnectionResetError,
|
||||||
|
) as err:
|
||||||
print("send error:", err)
|
print("send error:", err)
|
||||||
|
|
||||||
|
|
||||||
def encode_bytes(event, data):
|
def encode_bytes(event, data):
|
||||||
if not isinstance(event, int):
|
if not isinstance(event, int):
|
||||||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||||
@ -106,6 +121,7 @@ def encode_bytes(event, data):
|
|||||||
message.extend(data)
|
message.extend(data)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
async def send_bytes(event, data, sid=None):
|
async def send_bytes(event, data, sid=None):
|
||||||
message = encode_bytes(event, data)
|
message = encode_bytes(event, data)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user