fix!: skipping preview image as save node

This commit is contained in:
bennykok 2024-04-12 13:34:27 +08:00
parent e66712425d
commit 03d12e4099
2 changed files with 34 additions and 26 deletions

View File

@ -13,7 +13,6 @@ import traceback
import uuid
import asyncio
import logging
from enum import Enum
from urllib.parse import quote
import threading
import hashlib
@ -24,30 +23,11 @@ from PIL import Image
import copy
import struct
from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
api = None
api_task = None
prompt_metadata: dict[str, SimplePrompt] = {}
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
@ -653,6 +633,14 @@ async def send_json_override(self, event, data, sid=None):
# await update_run_with_output(prompt_id, data)
if event == 'executed' and 'node' in data and 'output' in data:
print("executed", data)
if prompt_id in prompt_metadata:
node = data.get('node')
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("skipping preview image")
if class_type == "PreviewImage":
return
await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# update_run_with_output(prompt_id, data.get('output'))
@ -892,6 +880,9 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
items = data.get(key, [])
for item in items:
# # Skipping temp files
# if item.get("type") == "temp":
# continue
await upload_file(
prompt_id,
item.get("filename"),
@ -900,7 +891,6 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d
content_type=item.get(content_type_key, default_content_type)
)
# Upload files in the background
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
try:

View File

@ -1,16 +1,21 @@
import struct
from enum import Enum
import aiohttp
from typing import List, Union, Any, Optional
from PIL import Image, ImageOps
from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
@ -19,8 +24,21 @@ class StreamingPrompt(BaseModel):
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
sockets = dict()
prompt_metadata: dict[str, SimplePrompt] = {}
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes: