fix!: skipping preview image as save node
This commit is contained in:
		
							parent
							
								
									e66712425d
								
							
						
					
					
						commit
						03d12e4099
					
				@ -13,7 +13,6 @@ import traceback
 | 
				
			|||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from enum import Enum
 | 
					 | 
				
			||||||
from urllib.parse import quote
 | 
					from urllib.parse import quote
 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
import hashlib
 | 
					import hashlib
 | 
				
			||||||
@ -24,30 +23,11 @@ from PIL import Image
 | 
				
			|||||||
import copy
 | 
					import copy
 | 
				
			||||||
import struct
 | 
					import struct
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from globals import StreamingPrompt, sockets, streaming_prompt_metadata, BaseModel
 | 
					from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
 | 
				
			||||||
 | 
					 | 
				
			||||||
class Status(Enum):
 | 
					 | 
				
			||||||
    NOT_STARTED = "not-started"
 | 
					 | 
				
			||||||
    RUNNING = "running"
 | 
					 | 
				
			||||||
    SUCCESS = "success"
 | 
					 | 
				
			||||||
    FAILED = "failed"
 | 
					 | 
				
			||||||
    UPLOADING = "uploading"
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
class SimplePrompt(BaseModel):
 | 
					 | 
				
			||||||
    status_endpoint: str
 | 
					 | 
				
			||||||
    file_upload_endpoint: str
 | 
					 | 
				
			||||||
    workflow_api: dict
 | 
					 | 
				
			||||||
    status: Status = Status.NOT_STARTED
 | 
					 | 
				
			||||||
    progress: set = set()
 | 
					 | 
				
			||||||
    last_updated_node: Optional[str] = None,
 | 
					 | 
				
			||||||
    uploading_nodes: set = set()
 | 
					 | 
				
			||||||
    done: bool = False
 | 
					 | 
				
			||||||
    is_realtime: bool = False,
 | 
					 | 
				
			||||||
    start_time: Optional[float] = None,
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
api = None
 | 
					api = None
 | 
				
			||||||
api_task = None
 | 
					api_task = None
 | 
				
			||||||
prompt_metadata: dict[str, SimplePrompt] = {}
 | 
					
 | 
				
			||||||
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
 | 
					cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
 | 
				
			||||||
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
 | 
					cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -653,6 +633,14 @@ async def send_json_override(self, event, data, sid=None):
 | 
				
			|||||||
        # await update_run_with_output(prompt_id, data)
 | 
					        # await update_run_with_output(prompt_id, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if event == 'executed' and 'node' in data and 'output' in data:
 | 
					    if event == 'executed' and 'node' in data and 'output' in data:
 | 
				
			||||||
 | 
					        print("executed", data)
 | 
				
			||||||
 | 
					        if prompt_id in prompt_metadata:
 | 
				
			||||||
 | 
					            node = data.get('node')
 | 
				
			||||||
 | 
					            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
				
			||||||
 | 
					            print("skipping preview image")
 | 
				
			||||||
 | 
					            if class_type == "PreviewImage":
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
        await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
					        await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
				
			||||||
        # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
					        # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
				
			||||||
        # update_run_with_output(prompt_id, data.get('output'))
 | 
					        # update_run_with_output(prompt_id, data.get('output'))
 | 
				
			||||||
@ -892,6 +880,9 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
 | 
				
			|||||||
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
 | 
					async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
 | 
				
			||||||
    items = data.get(key, [])
 | 
					    items = data.get(key, [])
 | 
				
			||||||
    for item in items:
 | 
					    for item in items:
 | 
				
			||||||
 | 
					        # # Skipping temp files
 | 
				
			||||||
 | 
					        # if item.get("type") == "temp":
 | 
				
			||||||
 | 
					        #     continue
 | 
				
			||||||
        await upload_file(
 | 
					        await upload_file(
 | 
				
			||||||
            prompt_id,
 | 
					            prompt_id,
 | 
				
			||||||
            item.get("filename"),
 | 
					            item.get("filename"),
 | 
				
			||||||
@ -900,7 +891,6 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d
 | 
				
			|||||||
            content_type=item.get(content_type_key, default_content_type)
 | 
					            content_type=item.get(content_type_key, default_content_type)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
# Upload files in the background
 | 
					# Upload files in the background
 | 
				
			||||||
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
 | 
					async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										24
									
								
								globals.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								globals.py
									
									
									
									
									
								
							@ -1,16 +1,21 @@
 | 
				
			|||||||
import struct
 | 
					import struct
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
import aiohttp
 | 
					import aiohttp
 | 
				
			||||||
 | 
					 | 
				
			||||||
from typing import List, Union, Any, Optional
 | 
					from typing import List, Union, Any, Optional
 | 
				
			||||||
from PIL import Image, ImageOps
 | 
					from PIL import Image, ImageOps
 | 
				
			||||||
from io import BytesIO
 | 
					from io import BytesIO
 | 
				
			||||||
 | 
					 | 
				
			||||||
from pydantic import BaseModel as PydanticBaseModel
 | 
					from pydantic import BaseModel as PydanticBaseModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseModel(PydanticBaseModel):
 | 
					class BaseModel(PydanticBaseModel):
 | 
				
			||||||
    class Config:
 | 
					    class Config:
 | 
				
			||||||
        arbitrary_types_allowed = True
 | 
					        arbitrary_types_allowed = True
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					class Status(Enum):
 | 
				
			||||||
 | 
					    NOT_STARTED = "not-started"
 | 
				
			||||||
 | 
					    RUNNING = "running"
 | 
				
			||||||
 | 
					    SUCCESS = "success"
 | 
				
			||||||
 | 
					    FAILED = "failed"
 | 
				
			||||||
 | 
					    UPLOADING = "uploading"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StreamingPrompt(BaseModel):
 | 
					class StreamingPrompt(BaseModel):
 | 
				
			||||||
    workflow_api: Any
 | 
					    workflow_api: Any
 | 
				
			||||||
@ -19,8 +24,21 @@ class StreamingPrompt(BaseModel):
 | 
				
			|||||||
    running_prompt_ids: set[str] = set()
 | 
					    running_prompt_ids: set[str] = set()
 | 
				
			||||||
    status_endpoint: str
 | 
					    status_endpoint: str
 | 
				
			||||||
    file_upload_endpoint: str
 | 
					    file_upload_endpoint: str
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					class SimplePrompt(BaseModel):
 | 
				
			||||||
 | 
					    status_endpoint: str
 | 
				
			||||||
 | 
					    file_upload_endpoint: str
 | 
				
			||||||
 | 
					    workflow_api: dict
 | 
				
			||||||
 | 
					    status: Status = Status.NOT_STARTED
 | 
				
			||||||
 | 
					    progress: set = set()
 | 
				
			||||||
 | 
					    last_updated_node: Optional[str] = None,
 | 
				
			||||||
 | 
					    uploading_nodes: set = set()
 | 
				
			||||||
 | 
					    done: bool = False
 | 
				
			||||||
 | 
					    is_realtime: bool = False,
 | 
				
			||||||
 | 
					    start_time: Optional[float] = None,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sockets = dict()
 | 
					sockets = dict()
 | 
				
			||||||
 | 
					prompt_metadata: dict[str, SimplePrompt] = {}
 | 
				
			||||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
 | 
					streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BinaryEventTypes:
 | 
					class BinaryEventTypes:
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user