From 03d12e4099b03e46b902354bca98979fe9e07829 Mon Sep 17 00:00:00 2001 From: bennykok Date: Fri, 12 Apr 2024 13:34:27 +0800 Subject: [PATCH] fix!: skipping preview image as save node --- custom_routes.py | 36 +++++++++++++----------------------- globals.py | 24 +++++++++++++++++++++--- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index 6858923..018bc4c 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -13,7 +13,6 @@ import traceback import uuid import asyncio import logging -from enum import Enum from urllib.parse import quote import threading import hashlib @@ -24,30 +23,11 @@ from PIL import Image import copy import struct -from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel - -class Status(Enum): - NOT_STARTED = "not-started" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - UPLOADING = "uploading" - -class SimplePrompt(BaseModel): - status_endpoint: str - file_upload_endpoint: str - workflow_api: dict - status: Status = Status.NOT_STARTED - progress: set = set() - last_updated_node: Optional[str] = None, - uploading_nodes: set = set() - done: bool = False - is_realtime: bool = False, - start_time: Optional[float] = None, +from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata api = None api_task = None -prompt_metadata: dict[str, SimplePrompt] = {} + cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' @@ -653,6 +633,14 @@ async def send_json_override(self, event, data, sid=None): # await update_run_with_output(prompt_id, data) if event == 'executed' and 'node' in data and 'output' in data: + print("executed", data) + if prompt_id in prompt_metadata: + node = data.get('node') + class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type'] + print("skipping preview image") + if class_type == "PreviewImage": + return + await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # update_run_with_output(prompt_id, data.get('output')) @@ -892,6 +880,9 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False, async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str): items = data.get(key, []) for item in items: + # # Skipping temp files + # if item.get("type") == "temp": + # continue await upload_file( prompt_id, item.get("filename"), @@ -900,7 +891,6 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d content_type=item.get(content_type_key, default_content_type) ) - # Upload files in the background async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True): try: diff --git a/globals.py b/globals.py index a5575df..d608209 100644 --- a/globals.py +++ b/globals.py @@ -1,16 +1,21 @@ import struct - +from enum import Enum import aiohttp - from typing import List, Union, Any, Optional from PIL import Image, ImageOps from io import BytesIO - from pydantic import BaseModel as PydanticBaseModel class BaseModel(PydanticBaseModel): class Config: arbitrary_types_allowed = True + +class Status(Enum): + NOT_STARTED = "not-started" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + UPLOADING = "uploading" class StreamingPrompt(BaseModel): workflow_api: Any @@ -19,8 +24,21 @@ class StreamingPrompt(BaseModel): running_prompt_ids: set[str] = set() status_endpoint: str file_upload_endpoint: str + +class SimplePrompt(BaseModel): + status_endpoint: str + file_upload_endpoint: str + workflow_api: dict + status: Status = Status.NOT_STARTED + progress: set = set() + last_updated_node: Optional[str] = None, + uploading_nodes: set = set() + done: bool = False + is_realtime: bool = False, + start_time: Optional[float] = None, sockets = dict() +prompt_metadata: dict[str, SimplePrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {} class BinaryEventTypes: