refactor(plugin): add prompt_metadata types and refactor from dict to data model
This commit is contained in:
		
							parent
							
								
									9d0ded7ecc
								
							
						
					
					
						commit
						25e62af24c
					
				
							
								
								
									
										226
									
								
								custom_routes.py
									
									
									
									
									
								
							
							
						
						
									
										226
									
								
								custom_routes.py
									
									
									
									
									
								
							@ -1,34 +1,24 @@
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from aiohttp import web
 | 
			
		||||
import os
 | 
			
		||||
import requests
 | 
			
		||||
import folder_paths
 | 
			
		||||
import json
 | 
			
		||||
import numpy as np
 | 
			
		||||
import server
 | 
			
		||||
import re
 | 
			
		||||
import base64
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import io
 | 
			
		||||
import time
 | 
			
		||||
import execution
 | 
			
		||||
import random
 | 
			
		||||
import traceback
 | 
			
		||||
import uuid
 | 
			
		||||
import asyncio
 | 
			
		||||
import atexit
 | 
			
		||||
import logging
 | 
			
		||||
import sys
 | 
			
		||||
from logging.handlers import RotatingFileHandler
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from urllib.parse import quote
 | 
			
		||||
import threading
 | 
			
		||||
import hashlib
 | 
			
		||||
import aiohttp
 | 
			
		||||
import aiofiles
 | 
			
		||||
import concurrent.futures
 | 
			
		||||
from typing import List, Union, Any
 | 
			
		||||
from typing import List, Union, Any, Optional
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
@ -39,17 +29,37 @@ from pydantic import BaseModel as PydanticBaseModel
 | 
			
		||||
class BaseModel(PydanticBaseModel):
 | 
			
		||||
    class Config:
 | 
			
		||||
        arbitrary_types_allowed = True
 | 
			
		||||
        
 | 
			
		||||
class Status(Enum):
 | 
			
		||||
    NOT_STARTED = "not-started"
 | 
			
		||||
    RUNNING = "running"
 | 
			
		||||
    SUCCESS = "success"
 | 
			
		||||
    FAILED = "failed"
 | 
			
		||||
    UPLOADING = "uploading"
 | 
			
		||||
 | 
			
		||||
class StreamingPrompt(BaseModel):
 | 
			
		||||
    workflow_api: Any
 | 
			
		||||
    auth_token: str
 | 
			
		||||
    inputs: dict[str, Union[str, bytes, Image.Image]]
 | 
			
		||||
    running_prompt_ids: set[str] = set()
 | 
			
		||||
    status_endpoint: str
 | 
			
		||||
    file_upload_endpoint: str
 | 
			
		||||
    
 | 
			
		||||
class SimplePrompt(BaseModel):
 | 
			
		||||
    status_endpoint: str
 | 
			
		||||
    file_upload_endpoint: str
 | 
			
		||||
    workflow_api: dict
 | 
			
		||||
    status: Status = Status.NOT_STARTED
 | 
			
		||||
    progress: set = set()
 | 
			
		||||
    last_updated_node: Optional[str] = None,
 | 
			
		||||
    uploading_nodes: set = set()
 | 
			
		||||
    done: bool = False
 | 
			
		||||
    
 | 
			
		||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
 | 
			
		||||
 | 
			
		||||
api = None
 | 
			
		||||
api_task = None
 | 
			
		||||
prompt_metadata = {}
 | 
			
		||||
prompt_metadata: dict[str, SimplePrompt] = {}
 | 
			
		||||
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
 | 
			
		||||
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
 | 
			
		||||
 | 
			
		||||
@ -109,7 +119,7 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
			
		||||
        if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
 | 
			
		||||
            workflow_api[key]['inputs']['seed'] = randomSeed()
 | 
			
		||||
            
 | 
			
		||||
    print("getting inputs" ,inputs.inputs)
 | 
			
		||||
    print("getting inputs" , inputs.inputs)
 | 
			
		||||
            
 | 
			
		||||
    # Loop through each of the inputs and replace them
 | 
			
		||||
    for key, value in workflow_api.items():
 | 
			
		||||
@ -137,6 +147,12 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        res = post_prompt(prompt)
 | 
			
		||||
        inputs.running_prompt_ids.add(prompt_id)
 | 
			
		||||
        prompt_metadata[prompt_id] = SimplePrompt(
 | 
			
		||||
            status_endpoint=inputs.status_endpoint,
 | 
			
		||||
            file_upload_endpoint=inputs.file_upload_endpoint,
 | 
			
		||||
            workflow_api=workflow_api
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        error_type = type(e).__name__
 | 
			
		||||
        stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
 | 
			
		||||
@ -164,11 +180,11 @@ async def comfy_deploy_run(request):
 | 
			
		||||
        "prompt_id": prompt_id
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    prompt_metadata[prompt_id] = {
 | 
			
		||||
        'status_endpoint': data.get('status_endpoint'),
 | 
			
		||||
        'file_upload_endpoint': data.get('file_upload_endpoint'),
 | 
			
		||||
        'workflow_api': workflow_api
 | 
			
		||||
    }
 | 
			
		||||
    prompt_metadata[prompt_id] = SimplePrompt(
 | 
			
		||||
        status_endpoint=data.get('status_endpoint'),
 | 
			
		||||
        file_upload_endpoint=data.get('file_upload_endpoint'),
 | 
			
		||||
        workflow_api=workflow_api
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        res = post_prompt(prompt)
 | 
			
		||||
@ -244,7 +260,7 @@ async def compute_sha256_checksum(filepath):
 | 
			
		||||
 | 
			
		||||
# This is start uploading the files to Comfy Deploy
 | 
			
		||||
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
 | 
			
		||||
async def upload_file(request):
 | 
			
		||||
async def upload_file_endpoint(request):
 | 
			
		||||
    data = await request.json()
 | 
			
		||||
 | 
			
		||||
    file_path = data.get("file_path")
 | 
			
		||||
@ -367,28 +383,31 @@ async def websocket_handler(request):
 | 
			
		||||
 | 
			
		||||
    sockets[sid] = ws
 | 
			
		||||
    
 | 
			
		||||
    auth_token = request.rel_url.query.get('token', '')
 | 
			
		||||
    get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '')
 | 
			
		||||
    auth_token = request.rel_url.query.get('token', None)
 | 
			
		||||
    get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None)
 | 
			
		||||
    
 | 
			
		||||
    async with aiohttp.ClientSession() as session:
 | 
			
		||||
        headers = {'Authorization': f'Bearer {auth_token}'}
 | 
			
		||||
        async with session.get(get_workflow_endpoint_url, headers=headers) as response:
 | 
			
		||||
            if response.status == 200:
 | 
			
		||||
                workflow = await response.json()
 | 
			
		||||
                
 | 
			
		||||
                print("Loaded workflow version ",workflow["version"])
 | 
			
		||||
                
 | 
			
		||||
                streaming_prompt_metadata[sid] = StreamingPrompt(
 | 
			
		||||
                    workflow_api=workflow["workflow_api"], 
 | 
			
		||||
                    auth_token=auth_token,
 | 
			
		||||
                    inputs={}
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
                # await send("workflow_api", workflow_api, sid)
 | 
			
		||||
            else:
 | 
			
		||||
                error_message = await response.text()
 | 
			
		||||
                print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
 | 
			
		||||
                # await send("error", {"message": error_message}, sid)
 | 
			
		||||
    if auth_token is not None and get_workflow_endpoint_url is not None:
 | 
			
		||||
        async with aiohttp.ClientSession() as session:
 | 
			
		||||
            headers = {'Authorization': f'Bearer {auth_token}'}
 | 
			
		||||
            async with session.get(get_workflow_endpoint_url, headers=headers) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    workflow = await response.json()
 | 
			
		||||
                   
 | 
			
		||||
                    print("Loaded workflow version ",workflow["version"])
 | 
			
		||||
                   
 | 
			
		||||
                    streaming_prompt_metadata[sid] = StreamingPrompt(
 | 
			
		||||
                        workflow_api=workflow["workflow_api"], 
 | 
			
		||||
                        auth_token=auth_token,
 | 
			
		||||
                        inputs={},
 | 
			
		||||
                        status_endpoint=request.rel_url.query.get('status_endpoint', None),
 | 
			
		||||
                        file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
 | 
			
		||||
                    )
 | 
			
		||||
                    
 | 
			
		||||
                    # await send("workflow_api", workflow_api, sid)
 | 
			
		||||
                else:
 | 
			
		||||
                    error_message = await response.text()
 | 
			
		||||
                    print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
 | 
			
		||||
                    # await send("error", {"message": error_message}, sid)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        # Send initial state to the new client
 | 
			
		||||
@ -422,16 +441,28 @@ async def websocket_handler(request):
 | 
			
		||||
 | 
			
		||||
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')
 | 
			
		||||
async def comfy_deploy_check_status(request):
 | 
			
		||||
    prompt_server = server.PromptServer.instance
 | 
			
		||||
    prompt_id = request.rel_url.query.get('prompt_id', None)
 | 
			
		||||
    if prompt_id in prompt_metadata and 'status' in prompt_metadata[prompt_id]:
 | 
			
		||||
    if prompt_id in prompt_metadata:
 | 
			
		||||
        return web.json_response({
 | 
			
		||||
            "status": prompt_metadata[prompt_id]['status'].value
 | 
			
		||||
            "status": prompt_metadata[prompt_id].status.value
 | 
			
		||||
        })
 | 
			
		||||
    else:
 | 
			
		||||
        return web.json_response({
 | 
			
		||||
            "message": "prompt_id not found"
 | 
			
		||||
        })
 | 
			
		||||
        
 | 
			
		||||
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-ws-status')
 | 
			
		||||
async def comfy_deploy_check_ws_status(request):
 | 
			
		||||
    client_id = request.rel_url.query.get('client_id', None)
 | 
			
		||||
    if client_id in streaming_prompt_metadata:
 | 
			
		||||
        remaining_queue = 0  # Initialize remaining queue count
 | 
			
		||||
        for prompt_id in streaming_prompt_metadata[client_id].running_prompt_ids[client_id]:
 | 
			
		||||
            prompt_status = prompt_metadata[prompt_id].status
 | 
			
		||||
            if prompt_status not in [Status.FAILED, Status.SUCCESS]:
 | 
			
		||||
                remaining_queue += 1  # Increment for each prompt still running
 | 
			
		||||
        return web.json_response({"remaining_queue": remaining_queue})
 | 
			
		||||
    else:
 | 
			
		||||
        return web.json_response({"message": "client_id not found"}, status=404)
 | 
			
		||||
 | 
			
		||||
async def send(event, data, sid=None):
 | 
			
		||||
    try:
 | 
			
		||||
@ -451,8 +482,8 @@ async def send(event, data, sid=None):
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
prompt_server = server.PromptServer.instance
 | 
			
		||||
 | 
			
		||||
send_json = prompt_server.send_json
 | 
			
		||||
 | 
			
		||||
async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
    # print("INTERNAL:", event, data, sid)
 | 
			
		||||
    prompt_id = data.get('prompt_id')
 | 
			
		||||
@ -476,28 +507,28 @@ async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
        node = data.get('node')
 | 
			
		||||
        
 | 
			
		||||
        if prompt_id in prompt_metadata:
 | 
			
		||||
            if 'progress' not in prompt_metadata[prompt_id]:
 | 
			
		||||
                prompt_metadata[prompt_id]["progress"] = set()
 | 
			
		||||
                
 | 
			
		||||
            prompt_metadata[prompt_id]["progress"].add(node)
 | 
			
		||||
            calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api'])
 | 
			
		||||
            # if 'progress' not in prompt_metadata[prompt_id]:
 | 
			
		||||
            #     prompt_metadata[prompt_id]["progress"] = set()
 | 
			
		||||
              
 | 
			
		||||
            prompt_metadata[prompt_id].progress.add(node)
 | 
			
		||||
            calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
 | 
			
		||||
            # print("calculated_progress", calculated_progress)
 | 
			
		||||
            
 | 
			
		||||
            if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node:
 | 
			
		||||
         
 | 
			
		||||
            if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
 | 
			
		||||
                return
 | 
			
		||||
            prompt_metadata[prompt_id]['last_updated_node'] = node
 | 
			
		||||
            class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type']
 | 
			
		||||
            prompt_metadata[prompt_id].last_updated_node = node
 | 
			
		||||
            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
			
		||||
            print("updating run live status", class_type)
 | 
			
		||||
            await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
 | 
			
		||||
 | 
			
		||||
    if event == 'execution_cached' and data.get('nodes') is not None:
 | 
			
		||||
        if prompt_id in prompt_metadata:
 | 
			
		||||
            if 'progress' not in prompt_metadata[prompt_id]:
 | 
			
		||||
                prompt_metadata[prompt_id]["progress"] = set()
 | 
			
		||||
            
 | 
			
		||||
            # if 'progress' not in prompt_metadata[prompt_id]:
 | 
			
		||||
            #     prompt_metadata[prompt_id].progress = set()
 | 
			
		||||
           
 | 
			
		||||
            if 'nodes' in data:
 | 
			
		||||
                for node in data.get('nodes', []):
 | 
			
		||||
                    prompt_metadata[prompt_id]["progress"].add(node)
 | 
			
		||||
                    prompt_metadata[prompt_id].progress.add(node)
 | 
			
		||||
            # prompt_metadata[prompt_id]["progress"].update(data.get('nodes'))
 | 
			
		||||
 | 
			
		||||
    if event == 'execution_error':
 | 
			
		||||
@ -511,13 +542,6 @@ async def send_json_override(self, event, data, sid=None):
 | 
			
		||||
        # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
			
		||||
        # update_run_with_output(prompt_id, data.get('output'))
 | 
			
		||||
 | 
			
		||||
class Status(Enum):
 | 
			
		||||
    NOT_STARTED = "not-started"
 | 
			
		||||
    RUNNING = "running"
 | 
			
		||||
    SUCCESS = "success"
 | 
			
		||||
    FAILED = "failed"
 | 
			
		||||
    UPLOADING = "uploading"
 | 
			
		||||
 | 
			
		||||
# Global variable to keep track of the last read line number
 | 
			
		||||
last_read_line_number = 0
 | 
			
		||||
 | 
			
		||||
@ -527,7 +551,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
			
		||||
    
 | 
			
		||||
    print("progress", calculated_progress)
 | 
			
		||||
    
 | 
			
		||||
    status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
 | 
			
		||||
    status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
			
		||||
    body = {
 | 
			
		||||
        "run_id": prompt_id,
 | 
			
		||||
        "live_status": live_status,
 | 
			
		||||
@ -539,19 +563,19 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_run(prompt_id, status: Status):
 | 
			
		||||
def update_run(prompt_id: str, status: Status):
 | 
			
		||||
    global last_read_line_number
 | 
			
		||||
 | 
			
		||||
    if prompt_id not in prompt_metadata:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
 | 
			
		||||
    if (prompt_metadata[prompt_id].status != status):
 | 
			
		||||
 | 
			
		||||
        # when the status is already failed, we don't want to update it to success
 | 
			
		||||
        if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['status'] == Status.FAILED):
 | 
			
		||||
        if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id].status == Status.FAILED):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
 | 
			
		||||
        status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
			
		||||
        body = {
 | 
			
		||||
            "run_id": prompt_id,
 | 
			
		||||
            "status": status.value,
 | 
			
		||||
@ -598,7 +622,7 @@ def update_run(prompt_id, status: Status):
 | 
			
		||||
            stack_trace = traceback.format_exc().strip()
 | 
			
		||||
            print(f"Error occurred while updating run: {e} {stack_trace}")
 | 
			
		||||
        finally:
 | 
			
		||||
            prompt_metadata[prompt_id]['status'] = status
 | 
			
		||||
            prompt_metadata[prompt_id].status = status
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
 | 
			
		||||
@ -630,7 +654,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
 | 
			
		||||
    print("uploading file", file)
 | 
			
		||||
 | 
			
		||||
    file_upload_endpoint = prompt_metadata[prompt_id]['file_upload_endpoint']
 | 
			
		||||
    file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
 | 
			
		||||
 | 
			
		||||
    filename = quote(filename)
 | 
			
		||||
    prompt_id = quote(prompt_id)
 | 
			
		||||
@ -654,23 +678,37 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
                print("upload file response", response.status)
 | 
			
		||||
 | 
			
		||||
def have_pending_upload(prompt_id):
 | 
			
		||||
    if 'prompt_id' in prompt_metadata and 'uploading_nodes' in prompt_metadata[prompt_id] and len(prompt_metadata[prompt_id]['uploading_nodes']) > 0:
 | 
			
		||||
        print("have pending upload ", len(prompt_metadata[prompt_id]['uploading_nodes']))
 | 
			
		||||
    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
			
		||||
        print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    print("no pending upload")
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
def mark_prompt_done(prompt_id):
 | 
			
		||||
    """
 | 
			
		||||
    Mark the prompt as done in the prompt metadata.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        prompt_id (str): The ID of the prompt to mark as done.
 | 
			
		||||
    """
 | 
			
		||||
    if prompt_id in prompt_metadata:
 | 
			
		||||
        prompt_metadata[prompt_id]["done"] = True
 | 
			
		||||
        prompt_metadata[prompt_id].done = True
 | 
			
		||||
        print("Prompt done")
 | 
			
		||||
 | 
			
		||||
def is_prompt_done(prompt_id):
 | 
			
		||||
    if prompt_id in prompt_metadata and "done" in prompt_metadata[prompt_id]:
 | 
			
		||||
        if prompt_metadata[prompt_id]["done"] == True:
 | 
			
		||||
            return True
 | 
			
		||||
    
 | 
			
		||||
def is_prompt_done(prompt_id: str):
 | 
			
		||||
    """
 | 
			
		||||
    Check if the prompt with the given ID is marked as done.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        prompt_id (str): The ID of the prompt to check.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        bool: True if the prompt is marked as done, False otherwise.
 | 
			
		||||
    """
 | 
			
		||||
    if prompt_id in prompt_metadata and prompt_metadata[prompt_id].done:
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
# Use to handle upload error and send back to ComfyDeploy
 | 
			
		||||
@ -692,17 +730,17 @@ async def handle_error(prompt_id, data, e: Exception):
 | 
			
		||||
    print(f"Error occurred while uploading file: {e}")
 | 
			
		||||
 | 
			
		||||
# Mark the current prompt requires upload, and block it from being marked as success
 | 
			
		||||
async def update_file_status(prompt_id, data, uploading, have_error=False, node_id=None):
 | 
			
		||||
    if 'uploading_nodes' not in prompt_metadata[prompt_id]:
 | 
			
		||||
        prompt_metadata[prompt_id]['uploading_nodes'] = set()
 | 
			
		||||
async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
 | 
			
		||||
    # if 'uploading_nodes' not in prompt_metadata[prompt_id]:
 | 
			
		||||
    #     prompt_metadata[prompt_id]['uploading_nodes'] = set()
 | 
			
		||||
 | 
			
		||||
    if node_id is not None:
 | 
			
		||||
        if uploading:
 | 
			
		||||
            prompt_metadata[prompt_id]['uploading_nodes'].add(node_id)
 | 
			
		||||
            prompt_metadata[prompt_id].uploading_nodes.add(node_id)
 | 
			
		||||
        else:
 | 
			
		||||
            prompt_metadata[prompt_id]['uploading_nodes'].discard(node_id)
 | 
			
		||||
            prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
 | 
			
		||||
 | 
			
		||||
    print(prompt_metadata[prompt_id]['uploading_nodes'])
 | 
			
		||||
    print(prompt_metadata[prompt_id].uploading_nodes)
 | 
			
		||||
    # Update the remote status
 | 
			
		||||
 | 
			
		||||
    if have_error:
 | 
			
		||||
@ -714,12 +752,12 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
 | 
			
		||||
 | 
			
		||||
    # if there are still nodes that are uploading, then we set the status to uploading
 | 
			
		||||
    if uploading:
 | 
			
		||||
        if prompt_metadata[prompt_id]['status'] != Status.UPLOADING:
 | 
			
		||||
        if prompt_metadata[prompt_id].status != Status.UPLOADING:
 | 
			
		||||
            update_run(prompt_id, Status.UPLOADING)
 | 
			
		||||
            await send("uploading", {
 | 
			
		||||
                "prompt_id": prompt_id,
 | 
			
		||||
            })
 | 
			
		||||
    
 | 
			
		||||
   
 | 
			
		||||
    # if there are no nodes that are uploading, then we set the status to success
 | 
			
		||||
    elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id):
 | 
			
		||||
        update_run(prompt_id, Status.SUCCESS)
 | 
			
		||||
@ -728,20 +766,20 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
 | 
			
		||||
            "prompt_id": prompt_id,
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
async def handle_upload(prompt_id, data, key, content_type_key, default_content_type):
 | 
			
		||||
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
 | 
			
		||||
    items = data.get(key, [])
 | 
			
		||||
    for item in items:
 | 
			
		||||
        await upload_file(
 | 
			
		||||
            prompt_id, 
 | 
			
		||||
            item.get("filename"), 
 | 
			
		||||
            subfolder=item.get("subfolder"), 
 | 
			
		||||
            type=item.get("type"), 
 | 
			
		||||
            prompt_id,
 | 
			
		||||
            item.get("filename"),
 | 
			
		||||
            subfolder=item.get("subfolder"),
 | 
			
		||||
            type=item.get("type"),
 | 
			
		||||
            content_type=item.get(content_type_key, default_content_type)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Upload files in the background
 | 
			
		||||
async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
 | 
			
		||||
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
 | 
			
		||||
    try:
 | 
			
		||||
        await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
 | 
			
		||||
        await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
 | 
			
		||||
@ -755,7 +793,7 @@ async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
 | 
			
		||||
 | 
			
		||||
async def update_run_with_output(prompt_id, data, node_id=None):
 | 
			
		||||
    if prompt_id in prompt_metadata:
 | 
			
		||||
        status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
 | 
			
		||||
        status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
			
		||||
 | 
			
		||||
        body = {
 | 
			
		||||
            "run_id": prompt_id,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user