refactor(plugin): add prompt_metadata types and refactor from dict to data model

This commit is contained in:
bennykok 2024-02-24 23:57:01 -08:00
parent 9d0ded7ecc
commit 25e62af24c

View File

@ -1,34 +1,24 @@
from aiohttp import web
import os
import requests
import folder_paths
import json
import numpy as np
import server
import re
import base64
from PIL import Image
import io
import time
import execution
import random
import traceback
import uuid
import asyncio
import atexit
import logging
import sys
from logging.handlers import RotatingFileHandler
from enum import Enum
from urllib.parse import quote
import threading
import hashlib
import aiohttp
import aiofiles
import concurrent.futures
from typing import List, Union, Any
from typing import List, Union, Any, Optional
from PIL import Image
import copy
@ -39,17 +29,37 @@ from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
api = None
api_task = None
prompt_metadata = {}
prompt_metadata: dict[str, SimplePrompt] = {}
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
@ -109,7 +119,7 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
workflow_api[key]['inputs']['seed'] = randomSeed()
print("getting inputs" ,inputs.inputs)
print("getting inputs" , inputs.inputs)
# Loop through each of the inputs and replace them
for key, value in workflow_api.items():
@ -137,6 +147,12 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
try:
res = post_prompt(prompt)
inputs.running_prompt_ids.add(prompt_id)
prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=inputs.status_endpoint,
file_upload_endpoint=inputs.file_upload_endpoint,
workflow_api=workflow_api
)
except Exception as e:
error_type = type(e).__name__
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
@ -164,11 +180,11 @@ async def comfy_deploy_run(request):
"prompt_id": prompt_id
}
prompt_metadata[prompt_id] = {
'status_endpoint': data.get('status_endpoint'),
'file_upload_endpoint': data.get('file_upload_endpoint'),
'workflow_api': workflow_api
}
prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=data.get('status_endpoint'),
file_upload_endpoint=data.get('file_upload_endpoint'),
workflow_api=workflow_api
)
try:
res = post_prompt(prompt)
@ -244,7 +260,7 @@ async def compute_sha256_checksum(filepath):
# This is start uploading the files to Comfy Deploy
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
async def upload_file(request):
async def upload_file_endpoint(request):
data = await request.json()
file_path = data.get("file_path")
@ -367,28 +383,31 @@ async def websocket_handler(request):
sockets[sid] = ws
auth_token = request.rel_url.query.get('token', '')
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '')
auth_token = request.rel_url.query.get('token', None)
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None)
async with aiohttp.ClientSession() as session:
headers = {'Authorization': f'Bearer {auth_token}'}
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
if response.status == 200:
workflow = await response.json()
print("Loaded workflow version ",workflow["version"])
streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"],
auth_token=auth_token,
inputs={}
)
# await send("workflow_api", workflow_api, sid)
else:
error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid)
if auth_token is not None and get_workflow_endpoint_url is not None:
async with aiohttp.ClientSession() as session:
headers = {'Authorization': f'Bearer {auth_token}'}
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
if response.status == 200:
workflow = await response.json()
print("Loaded workflow version ",workflow["version"])
streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"],
auth_token=auth_token,
inputs={},
status_endpoint=request.rel_url.query.get('status_endpoint', None),
file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
)
# await send("workflow_api", workflow_api, sid)
else:
error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid)
try:
# Send initial state to the new client
@ -422,16 +441,28 @@ async def websocket_handler(request):
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')
async def comfy_deploy_check_status(request):
prompt_server = server.PromptServer.instance
prompt_id = request.rel_url.query.get('prompt_id', None)
if prompt_id in prompt_metadata and 'status' in prompt_metadata[prompt_id]:
if prompt_id in prompt_metadata:
return web.json_response({
"status": prompt_metadata[prompt_id]['status'].value
"status": prompt_metadata[prompt_id].status.value
})
else:
return web.json_response({
"message": "prompt_id not found"
})
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-ws-status')
async def comfy_deploy_check_ws_status(request):
client_id = request.rel_url.query.get('client_id', None)
if client_id in streaming_prompt_metadata:
remaining_queue = 0 # Initialize remaining queue count
for prompt_id in streaming_prompt_metadata[client_id].running_prompt_ids[client_id]:
prompt_status = prompt_metadata[prompt_id].status
if prompt_status not in [Status.FAILED, Status.SUCCESS]:
remaining_queue += 1 # Increment for each prompt still running
return web.json_response({"remaining_queue": remaining_queue})
else:
return web.json_response({"message": "client_id not found"}, status=404)
async def send(event, data, sid=None):
try:
@ -451,8 +482,8 @@ async def send(event, data, sid=None):
logging.basicConfig(level=logging.INFO)
prompt_server = server.PromptServer.instance
send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None):
# print("INTERNAL:", event, data, sid)
prompt_id = data.get('prompt_id')
@ -476,28 +507,28 @@ async def send_json_override(self, event, data, sid=None):
node = data.get('node')
if prompt_id in prompt_metadata:
if 'progress' not in prompt_metadata[prompt_id]:
prompt_metadata[prompt_id]["progress"] = set()
prompt_metadata[prompt_id]["progress"].add(node)
calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api'])
# if 'progress' not in prompt_metadata[prompt_id]:
# prompt_metadata[prompt_id]["progress"] = set()
prompt_metadata[prompt_id].progress.add(node)
calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
# print("calculated_progress", calculated_progress)
if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node:
if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
return
prompt_metadata[prompt_id]['last_updated_node'] = node
class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type']
prompt_metadata[prompt_id].last_updated_node = node
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("updating run live status", class_type)
await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
if event == 'execution_cached' and data.get('nodes') is not None:
if prompt_id in prompt_metadata:
if 'progress' not in prompt_metadata[prompt_id]:
prompt_metadata[prompt_id]["progress"] = set()
# if 'progress' not in prompt_metadata[prompt_id]:
# prompt_metadata[prompt_id].progress = set()
if 'nodes' in data:
for node in data.get('nodes', []):
prompt_metadata[prompt_id]["progress"].add(node)
prompt_metadata[prompt_id].progress.add(node)
# prompt_metadata[prompt_id]["progress"].update(data.get('nodes'))
if event == 'execution_error':
@ -511,13 +542,6 @@ async def send_json_override(self, event, data, sid=None):
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# update_run_with_output(prompt_id, data.get('output'))
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
# Global variable to keep track of the last read line number
last_read_line_number = 0
@ -527,7 +551,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
print("progress", calculated_progress)
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = {
"run_id": prompt_id,
"live_status": live_status,
@ -539,19 +563,19 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
pass
def update_run(prompt_id, status: Status):
def update_run(prompt_id: str, status: Status):
global last_read_line_number
if prompt_id not in prompt_metadata:
return
if ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
if (prompt_metadata[prompt_id].status != status):
# when the status is already failed, we don't want to update it to success
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['status'] == Status.FAILED):
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id].status == Status.FAILED):
return
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = {
"run_id": prompt_id,
"status": status.value,
@ -598,7 +622,7 @@ def update_run(prompt_id, status: Status):
stack_trace = traceback.format_exc().strip()
print(f"Error occurred while updating run: {e} {stack_trace}")
finally:
prompt_metadata[prompt_id]['status'] = status
prompt_metadata[prompt_id].status = status
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
@ -630,7 +654,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
print("uploading file", file)
file_upload_endpoint = prompt_metadata[prompt_id]['file_upload_endpoint']
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
filename = quote(filename)
prompt_id = quote(prompt_id)
@ -654,23 +678,37 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
print("upload file response", response.status)
def have_pending_upload(prompt_id):
if 'prompt_id' in prompt_metadata and 'uploading_nodes' in prompt_metadata[prompt_id] and len(prompt_metadata[prompt_id]['uploading_nodes']) > 0:
print("have pending upload ", len(prompt_metadata[prompt_id]['uploading_nodes']))
if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
return True
print("no pending upload")
return False
def mark_prompt_done(prompt_id):
"""
Mark the prompt as done in the prompt metadata.
Args:
prompt_id (str): The ID of the prompt to mark as done.
"""
if prompt_id in prompt_metadata:
prompt_metadata[prompt_id]["done"] = True
prompt_metadata[prompt_id].done = True
print("Prompt done")
def is_prompt_done(prompt_id):
if prompt_id in prompt_metadata and "done" in prompt_metadata[prompt_id]:
if prompt_metadata[prompt_id]["done"] == True:
return True
def is_prompt_done(prompt_id: str):
"""
Check if the prompt with the given ID is marked as done.
Args:
prompt_id (str): The ID of the prompt to check.
Returns:
bool: True if the prompt is marked as done, False otherwise.
"""
if prompt_id in prompt_metadata and prompt_metadata[prompt_id].done:
return True
return False
# Use to handle upload error and send back to ComfyDeploy
@ -692,17 +730,17 @@ async def handle_error(prompt_id, data, e: Exception):
print(f"Error occurred while uploading file: {e}")
# Mark the current prompt requires upload, and block it from being marked as success
async def update_file_status(prompt_id, data, uploading, have_error=False, node_id=None):
if 'uploading_nodes' not in prompt_metadata[prompt_id]:
prompt_metadata[prompt_id]['uploading_nodes'] = set()
async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
# if 'uploading_nodes' not in prompt_metadata[prompt_id]:
# prompt_metadata[prompt_id]['uploading_nodes'] = set()
if node_id is not None:
if uploading:
prompt_metadata[prompt_id]['uploading_nodes'].add(node_id)
prompt_metadata[prompt_id].uploading_nodes.add(node_id)
else:
prompt_metadata[prompt_id]['uploading_nodes'].discard(node_id)
prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
print(prompt_metadata[prompt_id]['uploading_nodes'])
print(prompt_metadata[prompt_id].uploading_nodes)
# Update the remote status
if have_error:
@ -714,12 +752,12 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
# if there are still nodes that are uploading, then we set the status to uploading
if uploading:
if prompt_metadata[prompt_id]['status'] != Status.UPLOADING:
if prompt_metadata[prompt_id].status != Status.UPLOADING:
update_run(prompt_id, Status.UPLOADING)
await send("uploading", {
"prompt_id": prompt_id,
})
# if there are no nodes that are uploading, then we set the status to success
elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id):
update_run(prompt_id, Status.SUCCESS)
@ -728,20 +766,20 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
"prompt_id": prompt_id,
})
async def handle_upload(prompt_id, data, key, content_type_key, default_content_type):
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
items = data.get(key, [])
for item in items:
await upload_file(
prompt_id,
item.get("filename"),
subfolder=item.get("subfolder"),
type=item.get("type"),
prompt_id,
item.get("filename"),
subfolder=item.get("subfolder"),
type=item.get("type"),
content_type=item.get(content_type_key, default_content_type)
)
# Upload files in the background
async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
try:
await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
@ -755,7 +793,7 @@ async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
async def update_run_with_output(prompt_id, data, node_id=None):
if prompt_id in prompt_metadata:
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = {
"run_id": prompt_id,