fix: pydantic type simpleprompt
This commit is contained in:
		
							parent
							
								
									492b81c340
								
							
						
					
					
						commit
						8882f4983c
					
				
							
								
								
									
										32
									
								
								globals.py
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								globals.py
									
									
									
									
									
								
							@ -6,10 +6,12 @@ from PIL import Image, ImageOps
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from pydantic import BaseModel as PydanticBaseModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseModel(PydanticBaseModel):
 | 
			
		||||
    class Config:
 | 
			
		||||
        arbitrary_types_allowed = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Status(Enum):
 | 
			
		||||
    NOT_STARTED = "not-started"
 | 
			
		||||
    RUNNING = "running"
 | 
			
		||||
@ -17,6 +19,7 @@ class Status(Enum):
 | 
			
		||||
    FAILED = "failed"
 | 
			
		||||
    UPLOADING = "uploading"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StreamingPrompt(BaseModel):
 | 
			
		||||
    workflow_api: Any
 | 
			
		||||
    auth_token: str
 | 
			
		||||
@ -26,6 +29,7 @@ class StreamingPrompt(BaseModel):
 | 
			
		||||
    file_upload_endpoint: Optional[str]
 | 
			
		||||
    workflow: Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SimplePrompt(BaseModel):
 | 
			
		||||
    status_endpoint: Optional[str]
 | 
			
		||||
    file_upload_endpoint: Optional[str]
 | 
			
		||||
@ -35,34 +39,39 @@ class SimplePrompt(BaseModel):
 | 
			
		||||
    workflow_api: dict
 | 
			
		||||
    status: Status = Status.NOT_STARTED
 | 
			
		||||
    progress: set = set()
 | 
			
		||||
    last_updated_node: Optional[str] = None,
 | 
			
		||||
    last_updated_node: Optional[str] = (None,)
 | 
			
		||||
    uploading_nodes: set = set()
 | 
			
		||||
    done: bool = False
 | 
			
		||||
    is_realtime: bool = False,
 | 
			
		||||
    start_time: Optional[float] = None,
 | 
			
		||||
    is_realtime: bool = (False,)
 | 
			
		||||
    start_time: Optional[float] = (None,)
 | 
			
		||||
    gpu_event_id: Optional[str] = (None,)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
sockets = dict()
 | 
			
		||||
prompt_metadata: dict[str, SimplePrompt] = {}
 | 
			
		||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BinaryEventTypes:
 | 
			
		||||
    PREVIEW_IMAGE = 1
 | 
			
		||||
    UNENCODED_PREVIEW_IMAGE = 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
max_output_id_length = 24
 | 
			
		||||
 | 
			
		||||
async def send_image(image_data, sid=None, output_id:str = None):
 | 
			
		||||
 | 
			
		||||
async def send_image(image_data, sid=None, output_id: str = None):
 | 
			
		||||
    max_length = max_output_id_length
 | 
			
		||||
    output_id = output_id[:max_length]
 | 
			
		||||
    padded_output_id = output_id.ljust(max_length, '\x00')
 | 
			
		||||
    encoded_output_id = padded_output_id.encode('ascii', 'replace')
 | 
			
		||||
    padded_output_id = output_id.ljust(max_length, "\x00")
 | 
			
		||||
    encoded_output_id = padded_output_id.encode("ascii", "replace")
 | 
			
		||||
 | 
			
		||||
    image_type = image_data[0]
 | 
			
		||||
    image = image_data[1]
 | 
			
		||||
    max_size = image_data[2]
 | 
			
		||||
    quality = image_data[3]
 | 
			
		||||
    if max_size is not None:
 | 
			
		||||
        if hasattr(Image, 'Resampling'):
 | 
			
		||||
        if hasattr(Image, "Resampling"):
 | 
			
		||||
            resampling = Image.Resampling.BILINEAR
 | 
			
		||||
        else:
 | 
			
		||||
            resampling = Image.ANTIALIAS
 | 
			
		||||
@ -91,12 +100,18 @@ async def send_image(image_data, sid=None, output_id:str = None):
 | 
			
		||||
    preview_bytes = bytesIO.getvalue()
 | 
			
		||||
    await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def send_socket_catch_exception(function, message):
 | 
			
		||||
    try:
 | 
			
		||||
        await function(message)
 | 
			
		||||
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
 | 
			
		||||
    except (
 | 
			
		||||
        aiohttp.ClientError,
 | 
			
		||||
        aiohttp.ClientPayloadError,
 | 
			
		||||
        ConnectionResetError,
 | 
			
		||||
    ) as err:
 | 
			
		||||
        print("send error:", err)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def encode_bytes(event, data):
 | 
			
		||||
    if not isinstance(event, int):
 | 
			
		||||
        raise RuntimeError(f"Binary event types must be integers, got {event}")
 | 
			
		||||
@ -106,6 +121,7 @@ def encode_bytes(event, data):
 | 
			
		||||
    message.extend(data)
 | 
			
		||||
    return message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def send_bytes(event, data, sid=None):
 | 
			
		||||
    message = encode_bytes(event, data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user