fix: pydantic type simpleprompt
This commit is contained in:
		
							parent
							
								
									492b81c340
								
							
						
					
					
						commit
						8882f4983c
					
				
							
								
								
									
										52
									
								
								globals.py
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								globals.py
									
									
									
									
									
								
							@ -6,10 +6,12 @@ from PIL import Image, ImageOps
 | 
				
			|||||||
from io import BytesIO
 | 
					from io import BytesIO
 | 
				
			||||||
from pydantic import BaseModel as PydanticBaseModel
 | 
					from pydantic import BaseModel as PydanticBaseModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseModel(PydanticBaseModel):
 | 
					class BaseModel(PydanticBaseModel):
 | 
				
			||||||
    class Config:
 | 
					    class Config:
 | 
				
			||||||
        arbitrary_types_allowed = True
 | 
					        arbitrary_types_allowed = True
 | 
				
			||||||
        
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Status(Enum):
 | 
					class Status(Enum):
 | 
				
			||||||
    NOT_STARTED = "not-started"
 | 
					    NOT_STARTED = "not-started"
 | 
				
			||||||
    RUNNING = "running"
 | 
					    RUNNING = "running"
 | 
				
			||||||
@ -17,6 +19,7 @@ class Status(Enum):
 | 
				
			|||||||
    FAILED = "failed"
 | 
					    FAILED = "failed"
 | 
				
			||||||
    UPLOADING = "uploading"
 | 
					    UPLOADING = "uploading"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StreamingPrompt(BaseModel):
 | 
					class StreamingPrompt(BaseModel):
 | 
				
			||||||
    workflow_api: Any
 | 
					    workflow_api: Any
 | 
				
			||||||
    auth_token: str
 | 
					    auth_token: str
 | 
				
			||||||
@ -25,44 +28,50 @@ class StreamingPrompt(BaseModel):
 | 
				
			|||||||
    status_endpoint: Optional[str]
 | 
					    status_endpoint: Optional[str]
 | 
				
			||||||
    file_upload_endpoint: Optional[str]
 | 
					    file_upload_endpoint: Optional[str]
 | 
				
			||||||
    workflow: Any
 | 
					    workflow: Any
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SimplePrompt(BaseModel):
 | 
					class SimplePrompt(BaseModel):
 | 
				
			||||||
    status_endpoint: Optional[str]
 | 
					    status_endpoint: Optional[str]
 | 
				
			||||||
    file_upload_endpoint: Optional[str]
 | 
					    file_upload_endpoint: Optional[str]
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
    token: Optional[str]
 | 
					    token: Optional[str]
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
    workflow_api: dict
 | 
					    workflow_api: dict
 | 
				
			||||||
    status: Status = Status.NOT_STARTED
 | 
					    status: Status = Status.NOT_STARTED
 | 
				
			||||||
    progress: set = set()
 | 
					    progress: set = set()
 | 
				
			||||||
    last_updated_node: Optional[str] = None,
 | 
					    last_updated_node: Optional[str] = (None,)
 | 
				
			||||||
    uploading_nodes: set = set()
 | 
					    uploading_nodes: set = set()
 | 
				
			||||||
    done: bool = False
 | 
					    done: bool = False
 | 
				
			||||||
    is_realtime: bool = False,
 | 
					    is_realtime: bool = (False,)
 | 
				
			||||||
    start_time: Optional[float] = None,
 | 
					    start_time: Optional[float] = (None,)
 | 
				
			||||||
 | 
					    gpu_event_id: Optional[str] = (None,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sockets = dict()
 | 
					sockets = dict()
 | 
				
			||||||
prompt_metadata: dict[str, SimplePrompt] = {}
 | 
					prompt_metadata: dict[str, SimplePrompt] = {}
 | 
				
			||||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
 | 
					streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BinaryEventTypes:
 | 
					class BinaryEventTypes:
 | 
				
			||||||
    PREVIEW_IMAGE = 1
 | 
					    PREVIEW_IMAGE = 1
 | 
				
			||||||
    UNENCODED_PREVIEW_IMAGE = 2
 | 
					    UNENCODED_PREVIEW_IMAGE = 2
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
max_output_id_length = 24
 | 
					max_output_id_length = 24
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def send_image(image_data, sid=None, output_id:str = None):
 | 
					
 | 
				
			||||||
 | 
					async def send_image(image_data, sid=None, output_id: str = None):
 | 
				
			||||||
    max_length = max_output_id_length
 | 
					    max_length = max_output_id_length
 | 
				
			||||||
    output_id = output_id[:max_length]
 | 
					    output_id = output_id[:max_length]
 | 
				
			||||||
    padded_output_id = output_id.ljust(max_length, '\x00')
 | 
					    padded_output_id = output_id.ljust(max_length, "\x00")
 | 
				
			||||||
    encoded_output_id = padded_output_id.encode('ascii', 'replace')
 | 
					    encoded_output_id = padded_output_id.encode("ascii", "replace")
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
    image_type = image_data[0]
 | 
					    image_type = image_data[0]
 | 
				
			||||||
    image = image_data[1]
 | 
					    image = image_data[1]
 | 
				
			||||||
    max_size = image_data[2]
 | 
					    max_size = image_data[2]
 | 
				
			||||||
    quality = image_data[3]
 | 
					    quality = image_data[3]
 | 
				
			||||||
    if max_size is not None:
 | 
					    if max_size is not None:
 | 
				
			||||||
        if hasattr(Image, 'Resampling'):
 | 
					        if hasattr(Image, "Resampling"):
 | 
				
			||||||
            resampling = Image.Resampling.BILINEAR
 | 
					            resampling = Image.Resampling.BILINEAR
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            resampling = Image.ANTIALIAS
 | 
					            resampling = Image.ANTIALIAS
 | 
				
			||||||
@ -86,17 +95,23 @@ async def send_image(image_data, sid=None, output_id:str = None):
 | 
				
			|||||||
    position_after = bytesIO.tell()
 | 
					    position_after = bytesIO.tell()
 | 
				
			||||||
    bytes_written = position_after - position_before
 | 
					    bytes_written = position_after - position_before
 | 
				
			||||||
    print(f"Bytes written: {bytes_written}")
 | 
					    print(f"Bytes written: {bytes_written}")
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
    image.save(bytesIO, format=image_type, quality=quality, compress_level=1)
 | 
					    image.save(bytesIO, format=image_type, quality=quality, compress_level=1)
 | 
				
			||||||
    preview_bytes = bytesIO.getvalue()
 | 
					    preview_bytes = bytesIO.getvalue()
 | 
				
			||||||
    await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
 | 
					    await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
 | 
				
			||||||
        
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def send_socket_catch_exception(function, message):
 | 
					async def send_socket_catch_exception(function, message):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        await function(message)
 | 
					        await function(message)
 | 
				
			||||||
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
 | 
					    except (
 | 
				
			||||||
 | 
					        aiohttp.ClientError,
 | 
				
			||||||
 | 
					        aiohttp.ClientPayloadError,
 | 
				
			||||||
 | 
					        ConnectionResetError,
 | 
				
			||||||
 | 
					    ) as err:
 | 
				
			||||||
        print("send error:", err)
 | 
					        print("send error:", err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def encode_bytes(event, data):
 | 
					def encode_bytes(event, data):
 | 
				
			||||||
    if not isinstance(event, int):
 | 
					    if not isinstance(event, int):
 | 
				
			||||||
        raise RuntimeError(f"Binary event types must be integers, got {event}")
 | 
					        raise RuntimeError(f"Binary event types must be integers, got {event}")
 | 
				
			||||||
@ -106,9 +121,10 @@ def encode_bytes(event, data):
 | 
				
			|||||||
    message.extend(data)
 | 
					    message.extend(data)
 | 
				
			||||||
    return message
 | 
					    return message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def send_bytes(event, data, sid=None):
 | 
					async def send_bytes(event, data, sid=None):
 | 
				
			||||||
    message = encode_bytes(event, data)
 | 
					    message = encode_bytes(event, data)
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
    print("sending image to ", event, sid)
 | 
					    print("sending image to ", event, sid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if sid is None:
 | 
					    if sid is None:
 | 
				
			||||||
@ -116,4 +132,4 @@ async def send_bytes(event, data, sid=None):
 | 
				
			|||||||
        for ws in _sockets:
 | 
					        for ws in _sockets:
 | 
				
			||||||
            await send_socket_catch_exception(ws.send_bytes, message)
 | 
					            await send_socket_catch_exception(ws.send_bytes, message)
 | 
				
			||||||
    elif sid in sockets:
 | 
					    elif sid in sockets:
 | 
				
			||||||
        await send_socket_catch_exception(sockets[sid].send_bytes, message)
 | 
					        await send_socket_catch_exception(sockets[sid].send_bytes, message)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user