fix!: skipping preview image as save node

This commit is contained in:
bennykok 2024-04-12 13:34:27 +08:00
parent e66712425d
commit 03d12e4099
2 changed files with 34 additions and 26 deletions

View File

@ -13,7 +13,6 @@ import traceback
import uuid import uuid
import asyncio import asyncio
import logging import logging
from enum import Enum
from urllib.parse import quote from urllib.parse import quote
import threading import threading
import hashlib import hashlib
@ -24,30 +23,11 @@ from PIL import Image
import copy import copy
import struct import struct
from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
api = None api = None
api_task = None api_task = None
prompt_metadata: dict[str, SimplePrompt] = {}
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
@ -653,6 +633,14 @@ async def send_json_override(self, event, data, sid=None):
# await update_run_with_output(prompt_id, data) # await update_run_with_output(prompt_id, data)
if event == 'executed' and 'node' in data and 'output' in data: if event == 'executed' and 'node' in data and 'output' in data:
print("executed", data)
if prompt_id in prompt_metadata:
node = data.get('node')
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("skipping preview image")
if class_type == "PreviewImage":
return
await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# update_run_with_output(prompt_id, data.get('output')) # update_run_with_output(prompt_id, data.get('output'))
@ -892,6 +880,9 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str): async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
items = data.get(key, []) items = data.get(key, [])
for item in items: for item in items:
# # Skipping temp files
# if item.get("type") == "temp":
# continue
await upload_file( await upload_file(
prompt_id, prompt_id,
item.get("filename"), item.get("filename"),
@ -900,7 +891,6 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d
content_type=item.get(content_type_key, default_content_type) content_type=item.get(content_type_key, default_content_type)
) )
# Upload files in the background # Upload files in the background
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True): async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
try: try:

View File

@ -1,16 +1,21 @@
import struct import struct
from enum import Enum
import aiohttp import aiohttp
from typing import List, Union, Any, Optional from typing import List, Union, Any, Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from io import BytesIO from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel): class BaseModel(PydanticBaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel): class StreamingPrompt(BaseModel):
workflow_api: Any workflow_api: Any
@ -19,8 +24,21 @@ class StreamingPrompt(BaseModel):
running_prompt_ids: set[str] = set() running_prompt_ids: set[str] = set()
status_endpoint: str status_endpoint: str
file_upload_endpoint: str file_upload_endpoint: str
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False,
start_time: Optional[float] = None,
sockets = dict() sockets = dict()
prompt_metadata: dict[str, SimplePrompt] = {}
streaming_prompt_metadata: dict[str, StreamingPrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes: class BinaryEventTypes: