refactor(plugin): add prompt_metadata types and refactor from dict to data model
This commit is contained in:
parent
9d0ded7ecc
commit
25e62af24c
168
custom_routes.py
168
custom_routes.py
@ -1,34 +1,24 @@
|
||||
|
||||
|
||||
from aiohttp import web
|
||||
import os
|
||||
import requests
|
||||
import folder_paths
|
||||
import json
|
||||
import numpy as np
|
||||
import server
|
||||
import re
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import time
|
||||
import execution
|
||||
import random
|
||||
import traceback
|
||||
import uuid
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from enum import Enum
|
||||
from urllib.parse import quote
|
||||
import threading
|
||||
import hashlib
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import concurrent.futures
|
||||
from typing import List, Union, Any
|
||||
from typing import List, Union, Any, Optional
|
||||
from PIL import Image
|
||||
import copy
|
||||
|
||||
@ -40,16 +30,36 @@ class BaseModel(PydanticBaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
class Status(Enum):
|
||||
NOT_STARTED = "not-started"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
UPLOADING = "uploading"
|
||||
|
||||
class StreamingPrompt(BaseModel):
|
||||
workflow_api: Any
|
||||
auth_token: str
|
||||
inputs: dict[str, Union[str, bytes, Image.Image]]
|
||||
running_prompt_ids: set[str] = set()
|
||||
status_endpoint: str
|
||||
file_upload_endpoint: str
|
||||
|
||||
class SimplePrompt(BaseModel):
|
||||
status_endpoint: str
|
||||
file_upload_endpoint: str
|
||||
workflow_api: dict
|
||||
status: Status = Status.NOT_STARTED
|
||||
progress: set = set()
|
||||
last_updated_node: Optional[str] = None,
|
||||
uploading_nodes: set = set()
|
||||
done: bool = False
|
||||
|
||||
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
|
||||
|
||||
api = None
|
||||
api_task = None
|
||||
prompt_metadata = {}
|
||||
prompt_metadata: dict[str, SimplePrompt] = {}
|
||||
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
|
||||
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
|
||||
|
||||
@ -137,6 +147,12 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
|
||||
|
||||
try:
|
||||
res = post_prompt(prompt)
|
||||
inputs.running_prompt_ids.add(prompt_id)
|
||||
prompt_metadata[prompt_id] = SimplePrompt(
|
||||
status_endpoint=inputs.status_endpoint,
|
||||
file_upload_endpoint=inputs.file_upload_endpoint,
|
||||
workflow_api=workflow_api
|
||||
)
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
|
||||
@ -164,11 +180,11 @@ async def comfy_deploy_run(request):
|
||||
"prompt_id": prompt_id
|
||||
}
|
||||
|
||||
prompt_metadata[prompt_id] = {
|
||||
'status_endpoint': data.get('status_endpoint'),
|
||||
'file_upload_endpoint': data.get('file_upload_endpoint'),
|
||||
'workflow_api': workflow_api
|
||||
}
|
||||
prompt_metadata[prompt_id] = SimplePrompt(
|
||||
status_endpoint=data.get('status_endpoint'),
|
||||
file_upload_endpoint=data.get('file_upload_endpoint'),
|
||||
workflow_api=workflow_api
|
||||
)
|
||||
|
||||
try:
|
||||
res = post_prompt(prompt)
|
||||
@ -244,7 +260,7 @@ async def compute_sha256_checksum(filepath):
|
||||
|
||||
# This is start uploading the files to Comfy Deploy
|
||||
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
|
||||
async def upload_file(request):
|
||||
async def upload_file_endpoint(request):
|
||||
data = await request.json()
|
||||
|
||||
file_path = data.get("file_path")
|
||||
@ -367,9 +383,10 @@ async def websocket_handler(request):
|
||||
|
||||
sockets[sid] = ws
|
||||
|
||||
auth_token = request.rel_url.query.get('token', '')
|
||||
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '')
|
||||
auth_token = request.rel_url.query.get('token', None)
|
||||
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None)
|
||||
|
||||
if auth_token is not None and get_workflow_endpoint_url is not None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {'Authorization': f'Bearer {auth_token}'}
|
||||
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
|
||||
@ -381,7 +398,9 @@ async def websocket_handler(request):
|
||||
streaming_prompt_metadata[sid] = StreamingPrompt(
|
||||
workflow_api=workflow["workflow_api"],
|
||||
auth_token=auth_token,
|
||||
inputs={}
|
||||
inputs={},
|
||||
status_endpoint=request.rel_url.query.get('status_endpoint', None),
|
||||
file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
|
||||
)
|
||||
|
||||
# await send("workflow_api", workflow_api, sid)
|
||||
@ -422,17 +441,29 @@ async def websocket_handler(request):
|
||||
|
||||
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')
|
||||
async def comfy_deploy_check_status(request):
|
||||
prompt_server = server.PromptServer.instance
|
||||
prompt_id = request.rel_url.query.get('prompt_id', None)
|
||||
if prompt_id in prompt_metadata and 'status' in prompt_metadata[prompt_id]:
|
||||
if prompt_id in prompt_metadata:
|
||||
return web.json_response({
|
||||
"status": prompt_metadata[prompt_id]['status'].value
|
||||
"status": prompt_metadata[prompt_id].status.value
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
"message": "prompt_id not found"
|
||||
})
|
||||
|
||||
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-ws-status')
|
||||
async def comfy_deploy_check_ws_status(request):
|
||||
client_id = request.rel_url.query.get('client_id', None)
|
||||
if client_id in streaming_prompt_metadata:
|
||||
remaining_queue = 0 # Initialize remaining queue count
|
||||
for prompt_id in streaming_prompt_metadata[client_id].running_prompt_ids[client_id]:
|
||||
prompt_status = prompt_metadata[prompt_id].status
|
||||
if prompt_status not in [Status.FAILED, Status.SUCCESS]:
|
||||
remaining_queue += 1 # Increment for each prompt still running
|
||||
return web.json_response({"remaining_queue": remaining_queue})
|
||||
else:
|
||||
return web.json_response({"message": "client_id not found"}, status=404)
|
||||
|
||||
async def send(event, data, sid=None):
|
||||
try:
|
||||
# message = {"event": event, "data": data}
|
||||
@ -451,8 +482,8 @@ async def send(event, data, sid=None):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
prompt_server = server.PromptServer.instance
|
||||
|
||||
send_json = prompt_server.send_json
|
||||
|
||||
async def send_json_override(self, event, data, sid=None):
|
||||
# print("INTERNAL:", event, data, sid)
|
||||
prompt_id = data.get('prompt_id')
|
||||
@ -476,28 +507,28 @@ async def send_json_override(self, event, data, sid=None):
|
||||
node = data.get('node')
|
||||
|
||||
if prompt_id in prompt_metadata:
|
||||
if 'progress' not in prompt_metadata[prompt_id]:
|
||||
prompt_metadata[prompt_id]["progress"] = set()
|
||||
# if 'progress' not in prompt_metadata[prompt_id]:
|
||||
# prompt_metadata[prompt_id]["progress"] = set()
|
||||
|
||||
prompt_metadata[prompt_id]["progress"].add(node)
|
||||
calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api'])
|
||||
prompt_metadata[prompt_id].progress.add(node)
|
||||
calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
|
||||
# print("calculated_progress", calculated_progress)
|
||||
|
||||
if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node:
|
||||
if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
|
||||
return
|
||||
prompt_metadata[prompt_id]['last_updated_node'] = node
|
||||
class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type']
|
||||
prompt_metadata[prompt_id].last_updated_node = node
|
||||
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
|
||||
print("updating run live status", class_type)
|
||||
await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
|
||||
|
||||
if event == 'execution_cached' and data.get('nodes') is not None:
|
||||
if prompt_id in prompt_metadata:
|
||||
if 'progress' not in prompt_metadata[prompt_id]:
|
||||
prompt_metadata[prompt_id]["progress"] = set()
|
||||
# if 'progress' not in prompt_metadata[prompt_id]:
|
||||
# prompt_metadata[prompt_id].progress = set()
|
||||
|
||||
if 'nodes' in data:
|
||||
for node in data.get('nodes', []):
|
||||
prompt_metadata[prompt_id]["progress"].add(node)
|
||||
prompt_metadata[prompt_id].progress.add(node)
|
||||
# prompt_metadata[prompt_id]["progress"].update(data.get('nodes'))
|
||||
|
||||
if event == 'execution_error':
|
||||
@ -511,13 +542,6 @@ async def send_json_override(self, event, data, sid=None):
|
||||
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
|
||||
# update_run_with_output(prompt_id, data.get('output'))
|
||||
|
||||
class Status(Enum):
|
||||
NOT_STARTED = "not-started"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
UPLOADING = "uploading"
|
||||
|
||||
# Global variable to keep track of the last read line number
|
||||
last_read_line_number = 0
|
||||
|
||||
@ -527,7 +551,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
|
||||
|
||||
print("progress", calculated_progress)
|
||||
|
||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
||||
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||
body = {
|
||||
"run_id": prompt_id,
|
||||
"live_status": live_status,
|
||||
@ -539,19 +563,19 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
|
||||
pass
|
||||
|
||||
|
||||
def update_run(prompt_id, status: Status):
|
||||
def update_run(prompt_id: str, status: Status):
|
||||
global last_read_line_number
|
||||
|
||||
if prompt_id not in prompt_metadata:
|
||||
return
|
||||
|
||||
if ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
|
||||
if (prompt_metadata[prompt_id].status != status):
|
||||
|
||||
# when the status is already failed, we don't want to update it to success
|
||||
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['status'] == Status.FAILED):
|
||||
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id].status == Status.FAILED):
|
||||
return
|
||||
|
||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
||||
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||
body = {
|
||||
"run_id": prompt_id,
|
||||
"status": status.value,
|
||||
@ -598,7 +622,7 @@ def update_run(prompt_id, status: Status):
|
||||
stack_trace = traceback.format_exc().strip()
|
||||
print(f"Error occurred while updating run: {e} {stack_trace}")
|
||||
finally:
|
||||
prompt_metadata[prompt_id]['status'] = status
|
||||
prompt_metadata[prompt_id].status = status
|
||||
|
||||
|
||||
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
|
||||
@ -630,7 +654,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
||||
|
||||
print("uploading file", file)
|
||||
|
||||
file_upload_endpoint = prompt_metadata[prompt_id]['file_upload_endpoint']
|
||||
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
|
||||
|
||||
filename = quote(filename)
|
||||
prompt_id = quote(prompt_id)
|
||||
@ -654,21 +678,35 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
||||
print("upload file response", response.status)
|
||||
|
||||
def have_pending_upload(prompt_id):
|
||||
if 'prompt_id' in prompt_metadata and 'uploading_nodes' in prompt_metadata[prompt_id] and len(prompt_metadata[prompt_id]['uploading_nodes']) > 0:
|
||||
print("have pending upload ", len(prompt_metadata[prompt_id]['uploading_nodes']))
|
||||
if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
|
||||
print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
|
||||
return True
|
||||
|
||||
print("no pending upload")
|
||||
return False
|
||||
|
||||
def mark_prompt_done(prompt_id):
|
||||
"""
|
||||
Mark the prompt as done in the prompt metadata.
|
||||
|
||||
Args:
|
||||
prompt_id (str): The ID of the prompt to mark as done.
|
||||
"""
|
||||
if prompt_id in prompt_metadata:
|
||||
prompt_metadata[prompt_id]["done"] = True
|
||||
prompt_metadata[prompt_id].done = True
|
||||
print("Prompt done")
|
||||
|
||||
def is_prompt_done(prompt_id):
|
||||
if prompt_id in prompt_metadata and "done" in prompt_metadata[prompt_id]:
|
||||
if prompt_metadata[prompt_id]["done"] == True:
|
||||
def is_prompt_done(prompt_id: str):
|
||||
"""
|
||||
Check if the prompt with the given ID is marked as done.
|
||||
|
||||
Args:
|
||||
prompt_id (str): The ID of the prompt to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is marked as done, False otherwise.
|
||||
"""
|
||||
if prompt_id in prompt_metadata and prompt_metadata[prompt_id].done:
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -692,17 +730,17 @@ async def handle_error(prompt_id, data, e: Exception):
|
||||
print(f"Error occurred while uploading file: {e}")
|
||||
|
||||
# Mark the current prompt requires upload, and block it from being marked as success
|
||||
async def update_file_status(prompt_id, data, uploading, have_error=False, node_id=None):
|
||||
if 'uploading_nodes' not in prompt_metadata[prompt_id]:
|
||||
prompt_metadata[prompt_id]['uploading_nodes'] = set()
|
||||
async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
|
||||
# if 'uploading_nodes' not in prompt_metadata[prompt_id]:
|
||||
# prompt_metadata[prompt_id]['uploading_nodes'] = set()
|
||||
|
||||
if node_id is not None:
|
||||
if uploading:
|
||||
prompt_metadata[prompt_id]['uploading_nodes'].add(node_id)
|
||||
prompt_metadata[prompt_id].uploading_nodes.add(node_id)
|
||||
else:
|
||||
prompt_metadata[prompt_id]['uploading_nodes'].discard(node_id)
|
||||
prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
|
||||
|
||||
print(prompt_metadata[prompt_id]['uploading_nodes'])
|
||||
print(prompt_metadata[prompt_id].uploading_nodes)
|
||||
# Update the remote status
|
||||
|
||||
if have_error:
|
||||
@ -714,7 +752,7 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
|
||||
|
||||
# if there are still nodes that are uploading, then we set the status to uploading
|
||||
if uploading:
|
||||
if prompt_metadata[prompt_id]['status'] != Status.UPLOADING:
|
||||
if prompt_metadata[prompt_id].status != Status.UPLOADING:
|
||||
update_run(prompt_id, Status.UPLOADING)
|
||||
await send("uploading", {
|
||||
"prompt_id": prompt_id,
|
||||
@ -728,7 +766,7 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
|
||||
"prompt_id": prompt_id,
|
||||
})
|
||||
|
||||
async def handle_upload(prompt_id, data, key, content_type_key, default_content_type):
|
||||
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
|
||||
items = data.get(key, [])
|
||||
for item in items:
|
||||
await upload_file(
|
||||
@ -741,7 +779,7 @@ async def handle_upload(prompt_id, data, key, content_type_key, default_content_
|
||||
|
||||
|
||||
# Upload files in the background
|
||||
async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
|
||||
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
|
||||
try:
|
||||
await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
|
||||
await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
|
||||
@ -755,7 +793,7 @@ async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
|
||||
|
||||
async def update_run_with_output(prompt_id, data, node_id=None):
|
||||
if prompt_id in prompt_metadata:
|
||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
||||
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||
|
||||
body = {
|
||||
"run_id": prompt_id,
|
||||
|
Loading…
x
Reference in New Issue
Block a user