refactor(plugin): add prompt_metadata types and refactor from dict to data model
This commit is contained in:
parent
9d0ded7ecc
commit
25e62af24c
226
custom_routes.py
226
custom_routes.py
@ -1,34 +1,24 @@
|
|||||||
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import server
|
import server
|
||||||
import re
|
|
||||||
import base64
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
|
||||||
import time
|
import time
|
||||||
import execution
|
import execution
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
import threading
|
import threading
|
||||||
import hashlib
|
import hashlib
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import concurrent.futures
|
from typing import List, Union, Any, Optional
|
||||||
from typing import List, Union, Any
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
@ -39,17 +29,37 @@ from pydantic import BaseModel as PydanticBaseModel
|
|||||||
class BaseModel(PydanticBaseModel):
|
class BaseModel(PydanticBaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
class Status(Enum):
|
||||||
|
NOT_STARTED = "not-started"
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILED = "failed"
|
||||||
|
UPLOADING = "uploading"
|
||||||
|
|
||||||
class StreamingPrompt(BaseModel):
|
class StreamingPrompt(BaseModel):
|
||||||
workflow_api: Any
|
workflow_api: Any
|
||||||
auth_token: str
|
auth_token: str
|
||||||
inputs: dict[str, Union[str, bytes, Image.Image]]
|
inputs: dict[str, Union[str, bytes, Image.Image]]
|
||||||
|
running_prompt_ids: set[str] = set()
|
||||||
|
status_endpoint: str
|
||||||
|
file_upload_endpoint: str
|
||||||
|
|
||||||
|
class SimplePrompt(BaseModel):
|
||||||
|
status_endpoint: str
|
||||||
|
file_upload_endpoint: str
|
||||||
|
workflow_api: dict
|
||||||
|
status: Status = Status.NOT_STARTED
|
||||||
|
progress: set = set()
|
||||||
|
last_updated_node: Optional[str] = None,
|
||||||
|
uploading_nodes: set = set()
|
||||||
|
done: bool = False
|
||||||
|
|
||||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
||||||
|
|
||||||
api = None
|
api = None
|
||||||
api_task = None
|
api_task = None
|
||||||
prompt_metadata = {}
|
prompt_metadata: dict[str, SimplePrompt] = {}
|
||||||
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
|
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
|
||||||
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
|
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
|
||||||
|
|
||||||
@ -109,7 +119,7 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
|
|||||||
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
|
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
|
||||||
workflow_api[key]['inputs']['seed'] = randomSeed()
|
workflow_api[key]['inputs']['seed'] = randomSeed()
|
||||||
|
|
||||||
print("getting inputs" ,inputs.inputs)
|
print("getting inputs" , inputs.inputs)
|
||||||
|
|
||||||
# Loop through each of the inputs and replace them
|
# Loop through each of the inputs and replace them
|
||||||
for key, value in workflow_api.items():
|
for key, value in workflow_api.items():
|
||||||
@ -137,6 +147,12 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
res = post_prompt(prompt)
|
res = post_prompt(prompt)
|
||||||
|
inputs.running_prompt_ids.add(prompt_id)
|
||||||
|
prompt_metadata[prompt_id] = SimplePrompt(
|
||||||
|
status_endpoint=inputs.status_endpoint,
|
||||||
|
file_upload_endpoint=inputs.file_upload_endpoint,
|
||||||
|
workflow_api=workflow_api
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_type = type(e).__name__
|
error_type = type(e).__name__
|
||||||
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
|
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
|
||||||
@ -164,11 +180,11 @@ async def comfy_deploy_run(request):
|
|||||||
"prompt_id": prompt_id
|
"prompt_id": prompt_id
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_metadata[prompt_id] = {
|
prompt_metadata[prompt_id] = SimplePrompt(
|
||||||
'status_endpoint': data.get('status_endpoint'),
|
status_endpoint=data.get('status_endpoint'),
|
||||||
'file_upload_endpoint': data.get('file_upload_endpoint'),
|
file_upload_endpoint=data.get('file_upload_endpoint'),
|
||||||
'workflow_api': workflow_api
|
workflow_api=workflow_api
|
||||||
}
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = post_prompt(prompt)
|
res = post_prompt(prompt)
|
||||||
@ -244,7 +260,7 @@ async def compute_sha256_checksum(filepath):
|
|||||||
|
|
||||||
# This is start uploading the files to Comfy Deploy
|
# This is start uploading the files to Comfy Deploy
|
||||||
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
|
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
|
||||||
async def upload_file(request):
|
async def upload_file_endpoint(request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
file_path = data.get("file_path")
|
file_path = data.get("file_path")
|
||||||
@ -367,28 +383,31 @@ async def websocket_handler(request):
|
|||||||
|
|
||||||
sockets[sid] = ws
|
sockets[sid] = ws
|
||||||
|
|
||||||
auth_token = request.rel_url.query.get('token', '')
|
auth_token = request.rel_url.query.get('token', None)
|
||||||
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '')
|
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
if auth_token is not None and get_workflow_endpoint_url is not None:
|
||||||
headers = {'Authorization': f'Bearer {auth_token}'}
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
|
headers = {'Authorization': f'Bearer {auth_token}'}
|
||||||
if response.status == 200:
|
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
|
||||||
workflow = await response.json()
|
if response.status == 200:
|
||||||
|
workflow = await response.json()
|
||||||
print("Loaded workflow version ",workflow["version"])
|
|
||||||
|
print("Loaded workflow version ",workflow["version"])
|
||||||
streaming_prompt_metadata[sid] = StreamingPrompt(
|
|
||||||
workflow_api=workflow["workflow_api"],
|
streaming_prompt_metadata[sid] = StreamingPrompt(
|
||||||
auth_token=auth_token,
|
workflow_api=workflow["workflow_api"],
|
||||||
inputs={}
|
auth_token=auth_token,
|
||||||
)
|
inputs={},
|
||||||
|
status_endpoint=request.rel_url.query.get('status_endpoint', None),
|
||||||
# await send("workflow_api", workflow_api, sid)
|
file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
|
||||||
else:
|
)
|
||||||
error_message = await response.text()
|
|
||||||
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
|
# await send("workflow_api", workflow_api, sid)
|
||||||
# await send("error", {"message": error_message}, sid)
|
else:
|
||||||
|
error_message = await response.text()
|
||||||
|
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
|
||||||
|
# await send("error", {"message": error_message}, sid)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Send initial state to the new client
|
# Send initial state to the new client
|
||||||
@ -422,16 +441,28 @@ async def websocket_handler(request):
|
|||||||
|
|
||||||
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')
|
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')
|
||||||
async def comfy_deploy_check_status(request):
|
async def comfy_deploy_check_status(request):
|
||||||
prompt_server = server.PromptServer.instance
|
|
||||||
prompt_id = request.rel_url.query.get('prompt_id', None)
|
prompt_id = request.rel_url.query.get('prompt_id', None)
|
||||||
if prompt_id in prompt_metadata and 'status' in prompt_metadata[prompt_id]:
|
if prompt_id in prompt_metadata:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": prompt_metadata[prompt_id]['status'].value
|
"status": prompt_metadata[prompt_id].status.value
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"message": "prompt_id not found"
|
"message": "prompt_id not found"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-ws-status')
|
||||||
|
async def comfy_deploy_check_ws_status(request):
|
||||||
|
client_id = request.rel_url.query.get('client_id', None)
|
||||||
|
if client_id in streaming_prompt_metadata:
|
||||||
|
remaining_queue = 0 # Initialize remaining queue count
|
||||||
|
for prompt_id in streaming_prompt_metadata[client_id].running_prompt_ids[client_id]:
|
||||||
|
prompt_status = prompt_metadata[prompt_id].status
|
||||||
|
if prompt_status not in [Status.FAILED, Status.SUCCESS]:
|
||||||
|
remaining_queue += 1 # Increment for each prompt still running
|
||||||
|
return web.json_response({"remaining_queue": remaining_queue})
|
||||||
|
else:
|
||||||
|
return web.json_response({"message": "client_id not found"}, status=404)
|
||||||
|
|
||||||
async def send(event, data, sid=None):
|
async def send(event, data, sid=None):
|
||||||
try:
|
try:
|
||||||
@ -451,8 +482,8 @@ async def send(event, data, sid=None):
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
prompt_server = server.PromptServer.instance
|
prompt_server = server.PromptServer.instance
|
||||||
|
|
||||||
send_json = prompt_server.send_json
|
send_json = prompt_server.send_json
|
||||||
|
|
||||||
async def send_json_override(self, event, data, sid=None):
|
async def send_json_override(self, event, data, sid=None):
|
||||||
# print("INTERNAL:", event, data, sid)
|
# print("INTERNAL:", event, data, sid)
|
||||||
prompt_id = data.get('prompt_id')
|
prompt_id = data.get('prompt_id')
|
||||||
@ -476,28 +507,28 @@ async def send_json_override(self, event, data, sid=None):
|
|||||||
node = data.get('node')
|
node = data.get('node')
|
||||||
|
|
||||||
if prompt_id in prompt_metadata:
|
if prompt_id in prompt_metadata:
|
||||||
if 'progress' not in prompt_metadata[prompt_id]:
|
# if 'progress' not in prompt_metadata[prompt_id]:
|
||||||
prompt_metadata[prompt_id]["progress"] = set()
|
# prompt_metadata[prompt_id]["progress"] = set()
|
||||||
|
|
||||||
prompt_metadata[prompt_id]["progress"].add(node)
|
prompt_metadata[prompt_id].progress.add(node)
|
||||||
calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api'])
|
calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
|
||||||
# print("calculated_progress", calculated_progress)
|
# print("calculated_progress", calculated_progress)
|
||||||
|
|
||||||
if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node:
|
if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
|
||||||
return
|
return
|
||||||
prompt_metadata[prompt_id]['last_updated_node'] = node
|
prompt_metadata[prompt_id].last_updated_node = node
|
||||||
class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type']
|
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
|
||||||
print("updating run live status", class_type)
|
print("updating run live status", class_type)
|
||||||
await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
|
await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
|
||||||
|
|
||||||
if event == 'execution_cached' and data.get('nodes') is not None:
|
if event == 'execution_cached' and data.get('nodes') is not None:
|
||||||
if prompt_id in prompt_metadata:
|
if prompt_id in prompt_metadata:
|
||||||
if 'progress' not in prompt_metadata[prompt_id]:
|
# if 'progress' not in prompt_metadata[prompt_id]:
|
||||||
prompt_metadata[prompt_id]["progress"] = set()
|
# prompt_metadata[prompt_id].progress = set()
|
||||||
|
|
||||||
if 'nodes' in data:
|
if 'nodes' in data:
|
||||||
for node in data.get('nodes', []):
|
for node in data.get('nodes', []):
|
||||||
prompt_metadata[prompt_id]["progress"].add(node)
|
prompt_metadata[prompt_id].progress.add(node)
|
||||||
# prompt_metadata[prompt_id]["progress"].update(data.get('nodes'))
|
# prompt_metadata[prompt_id]["progress"].update(data.get('nodes'))
|
||||||
|
|
||||||
if event == 'execution_error':
|
if event == 'execution_error':
|
||||||
@ -511,13 +542,6 @@ async def send_json_override(self, event, data, sid=None):
|
|||||||
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
|
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
|
||||||
# update_run_with_output(prompt_id, data.get('output'))
|
# update_run_with_output(prompt_id, data.get('output'))
|
||||||
|
|
||||||
class Status(Enum):
|
|
||||||
NOT_STARTED = "not-started"
|
|
||||||
RUNNING = "running"
|
|
||||||
SUCCESS = "success"
|
|
||||||
FAILED = "failed"
|
|
||||||
UPLOADING = "uploading"
|
|
||||||
|
|
||||||
# Global variable to keep track of the last read line number
|
# Global variable to keep track of the last read line number
|
||||||
last_read_line_number = 0
|
last_read_line_number = 0
|
||||||
|
|
||||||
@ -527,7 +551,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
|
|||||||
|
|
||||||
print("progress", calculated_progress)
|
print("progress", calculated_progress)
|
||||||
|
|
||||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||||
body = {
|
body = {
|
||||||
"run_id": prompt_id,
|
"run_id": prompt_id,
|
||||||
"live_status": live_status,
|
"live_status": live_status,
|
||||||
@ -539,19 +563,19 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def update_run(prompt_id, status: Status):
|
def update_run(prompt_id: str, status: Status):
|
||||||
global last_read_line_number
|
global last_read_line_number
|
||||||
|
|
||||||
if prompt_id not in prompt_metadata:
|
if prompt_id not in prompt_metadata:
|
||||||
return
|
return
|
||||||
|
|
||||||
if ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
|
if (prompt_metadata[prompt_id].status != status):
|
||||||
|
|
||||||
# when the status is already failed, we don't want to update it to success
|
# when the status is already failed, we don't want to update it to success
|
||||||
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['status'] == Status.FAILED):
|
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id].status == Status.FAILED):
|
||||||
return
|
return
|
||||||
|
|
||||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||||
body = {
|
body = {
|
||||||
"run_id": prompt_id,
|
"run_id": prompt_id,
|
||||||
"status": status.value,
|
"status": status.value,
|
||||||
@ -598,7 +622,7 @@ def update_run(prompt_id, status: Status):
|
|||||||
stack_trace = traceback.format_exc().strip()
|
stack_trace = traceback.format_exc().strip()
|
||||||
print(f"Error occurred while updating run: {e} {stack_trace}")
|
print(f"Error occurred while updating run: {e} {stack_trace}")
|
||||||
finally:
|
finally:
|
||||||
prompt_metadata[prompt_id]['status'] = status
|
prompt_metadata[prompt_id].status = status
|
||||||
|
|
||||||
|
|
||||||
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
|
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
|
||||||
@ -630,7 +654,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
|||||||
|
|
||||||
print("uploading file", file)
|
print("uploading file", file)
|
||||||
|
|
||||||
file_upload_endpoint = prompt_metadata[prompt_id]['file_upload_endpoint']
|
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
|
||||||
|
|
||||||
filename = quote(filename)
|
filename = quote(filename)
|
||||||
prompt_id = quote(prompt_id)
|
prompt_id = quote(prompt_id)
|
||||||
@ -654,23 +678,37 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
|||||||
print("upload file response", response.status)
|
print("upload file response", response.status)
|
||||||
|
|
||||||
def have_pending_upload(prompt_id):
|
def have_pending_upload(prompt_id):
|
||||||
if 'prompt_id' in prompt_metadata and 'uploading_nodes' in prompt_metadata[prompt_id] and len(prompt_metadata[prompt_id]['uploading_nodes']) > 0:
|
if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
|
||||||
print("have pending upload ", len(prompt_metadata[prompt_id]['uploading_nodes']))
|
print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print("no pending upload")
|
print("no pending upload")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mark_prompt_done(prompt_id):
|
def mark_prompt_done(prompt_id):
|
||||||
|
"""
|
||||||
|
Mark the prompt as done in the prompt metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id (str): The ID of the prompt to mark as done.
|
||||||
|
"""
|
||||||
if prompt_id in prompt_metadata:
|
if prompt_id in prompt_metadata:
|
||||||
prompt_metadata[prompt_id]["done"] = True
|
prompt_metadata[prompt_id].done = True
|
||||||
print("Prompt done")
|
print("Prompt done")
|
||||||
|
|
||||||
def is_prompt_done(prompt_id):
|
def is_prompt_done(prompt_id: str):
|
||||||
if prompt_id in prompt_metadata and "done" in prompt_metadata[prompt_id]:
|
"""
|
||||||
if prompt_metadata[prompt_id]["done"] == True:
|
Check if the prompt with the given ID is marked as done.
|
||||||
return True
|
|
||||||
|
Args:
|
||||||
|
prompt_id (str): The ID of the prompt to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the prompt is marked as done, False otherwise.
|
||||||
|
"""
|
||||||
|
if prompt_id in prompt_metadata and prompt_metadata[prompt_id].done:
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Use to handle upload error and send back to ComfyDeploy
|
# Use to handle upload error and send back to ComfyDeploy
|
||||||
@ -692,17 +730,17 @@ async def handle_error(prompt_id, data, e: Exception):
|
|||||||
print(f"Error occurred while uploading file: {e}")
|
print(f"Error occurred while uploading file: {e}")
|
||||||
|
|
||||||
# Mark the current prompt requires upload, and block it from being marked as success
|
# Mark the current prompt requires upload, and block it from being marked as success
|
||||||
async def update_file_status(prompt_id, data, uploading, have_error=False, node_id=None):
|
async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
|
||||||
if 'uploading_nodes' not in prompt_metadata[prompt_id]:
|
# if 'uploading_nodes' not in prompt_metadata[prompt_id]:
|
||||||
prompt_metadata[prompt_id]['uploading_nodes'] = set()
|
# prompt_metadata[prompt_id]['uploading_nodes'] = set()
|
||||||
|
|
||||||
if node_id is not None:
|
if node_id is not None:
|
||||||
if uploading:
|
if uploading:
|
||||||
prompt_metadata[prompt_id]['uploading_nodes'].add(node_id)
|
prompt_metadata[prompt_id].uploading_nodes.add(node_id)
|
||||||
else:
|
else:
|
||||||
prompt_metadata[prompt_id]['uploading_nodes'].discard(node_id)
|
prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
|
||||||
|
|
||||||
print(prompt_metadata[prompt_id]['uploading_nodes'])
|
print(prompt_metadata[prompt_id].uploading_nodes)
|
||||||
# Update the remote status
|
# Update the remote status
|
||||||
|
|
||||||
if have_error:
|
if have_error:
|
||||||
@ -714,12 +752,12 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
|
|||||||
|
|
||||||
# if there are still nodes that are uploading, then we set the status to uploading
|
# if there are still nodes that are uploading, then we set the status to uploading
|
||||||
if uploading:
|
if uploading:
|
||||||
if prompt_metadata[prompt_id]['status'] != Status.UPLOADING:
|
if prompt_metadata[prompt_id].status != Status.UPLOADING:
|
||||||
update_run(prompt_id, Status.UPLOADING)
|
update_run(prompt_id, Status.UPLOADING)
|
||||||
await send("uploading", {
|
await send("uploading", {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
})
|
})
|
||||||
|
|
||||||
# if there are no nodes that are uploading, then we set the status to success
|
# if there are no nodes that are uploading, then we set the status to success
|
||||||
elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id):
|
elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id):
|
||||||
update_run(prompt_id, Status.SUCCESS)
|
update_run(prompt_id, Status.SUCCESS)
|
||||||
@ -728,20 +766,20 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
|
|||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
})
|
})
|
||||||
|
|
||||||
async def handle_upload(prompt_id, data, key, content_type_key, default_content_type):
|
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
|
||||||
items = data.get(key, [])
|
items = data.get(key, [])
|
||||||
for item in items:
|
for item in items:
|
||||||
await upload_file(
|
await upload_file(
|
||||||
prompt_id,
|
prompt_id,
|
||||||
item.get("filename"),
|
item.get("filename"),
|
||||||
subfolder=item.get("subfolder"),
|
subfolder=item.get("subfolder"),
|
||||||
type=item.get("type"),
|
type=item.get("type"),
|
||||||
content_type=item.get(content_type_key, default_content_type)
|
content_type=item.get(content_type_key, default_content_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Upload files in the background
|
# Upload files in the background
|
||||||
async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
|
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
|
||||||
try:
|
try:
|
||||||
await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
|
await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
|
||||||
await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
|
await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
|
||||||
@ -755,7 +793,7 @@ async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
|
|||||||
|
|
||||||
async def update_run_with_output(prompt_id, data, node_id=None):
|
async def update_run_with_output(prompt_id, data, node_id=None):
|
||||||
if prompt_id in prompt_metadata:
|
if prompt_id in prompt_metadata:
|
||||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"run_id": prompt_id,
|
"run_id": prompt_id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user