refactor(plugin): add prompt_metadata types and refactor from dict to data model

This commit is contained in:
bennykok 2024-02-24 23:57:01 -08:00
parent 9d0ded7ecc
commit 25e62af24c

View File

@ -1,34 +1,24 @@
from aiohttp import web from aiohttp import web
import os import os
import requests import requests
import folder_paths import folder_paths
import json import json
import numpy as np
import server import server
import re
import base64
from PIL import Image from PIL import Image
import io
import time import time
import execution import execution
import random import random
import traceback import traceback
import uuid import uuid
import asyncio import asyncio
import atexit
import logging import logging
import sys
from logging.handlers import RotatingFileHandler
from enum import Enum from enum import Enum
from urllib.parse import quote from urllib.parse import quote
import threading import threading
import hashlib import hashlib
import aiohttp import aiohttp
import aiofiles import aiofiles
import concurrent.futures from typing import List, Union, Any, Optional
from typing import List, Union, Any
from PIL import Image from PIL import Image
import copy import copy
@ -40,16 +30,36 @@ class BaseModel(PydanticBaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel): class StreamingPrompt(BaseModel):
workflow_api: Any workflow_api: Any
auth_token: str auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]] inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None,
uploading_nodes: set = set()
done: bool = False
streaming_prompt_metadata: dict[str, StreamingPrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
api = None api = None
api_task = None api_task = None
prompt_metadata = {} prompt_metadata: dict[str, SimplePrompt] = {}
cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
@ -109,7 +119,7 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']: if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
workflow_api[key]['inputs']['seed'] = randomSeed() workflow_api[key]['inputs']['seed'] = randomSeed()
print("getting inputs" ,inputs.inputs) print("getting inputs" , inputs.inputs)
# Loop through each of the inputs and replace them # Loop through each of the inputs and replace them
for key, value in workflow_api.items(): for key, value in workflow_api.items():
@ -137,6 +147,12 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
try: try:
res = post_prompt(prompt) res = post_prompt(prompt)
inputs.running_prompt_ids.add(prompt_id)
prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=inputs.status_endpoint,
file_upload_endpoint=inputs.file_upload_endpoint,
workflow_api=workflow_api
)
except Exception as e: except Exception as e:
error_type = type(e).__name__ error_type = type(e).__name__
stack_trace_short = traceback.format_exc().strip().split('\n')[-2] stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
@ -164,11 +180,11 @@ async def comfy_deploy_run(request):
"prompt_id": prompt_id "prompt_id": prompt_id
} }
prompt_metadata[prompt_id] = { prompt_metadata[prompt_id] = SimplePrompt(
'status_endpoint': data.get('status_endpoint'), status_endpoint=data.get('status_endpoint'),
'file_upload_endpoint': data.get('file_upload_endpoint'), file_upload_endpoint=data.get('file_upload_endpoint'),
'workflow_api': workflow_api workflow_api=workflow_api
} )
try: try:
res = post_prompt(prompt) res = post_prompt(prompt)
@ -244,7 +260,7 @@ async def compute_sha256_checksum(filepath):
# This is start uploading the files to Comfy Deploy # This is start uploading the files to Comfy Deploy
@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file') @server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file')
async def upload_file(request): async def upload_file_endpoint(request):
data = await request.json() data = await request.json()
file_path = data.get("file_path") file_path = data.get("file_path")
@ -367,28 +383,31 @@ async def websocket_handler(request):
sockets[sid] = ws sockets[sid] = ws
auth_token = request.rel_url.query.get('token', '') auth_token = request.rel_url.query.get('token', None)
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '') get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None)
async with aiohttp.ClientSession() as session: if auth_token is not None and get_workflow_endpoint_url is not None:
headers = {'Authorization': f'Bearer {auth_token}'} async with aiohttp.ClientSession() as session:
async with session.get(get_workflow_endpoint_url, headers=headers) as response: headers = {'Authorization': f'Bearer {auth_token}'}
if response.status == 200: async with session.get(get_workflow_endpoint_url, headers=headers) as response:
workflow = await response.json() if response.status == 200:
workflow = await response.json()
print("Loaded workflow version ",workflow["version"]) print("Loaded workflow version ",workflow["version"])
streaming_prompt_metadata[sid] = StreamingPrompt( streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"], workflow_api=workflow["workflow_api"],
auth_token=auth_token, auth_token=auth_token,
inputs={} inputs={},
) status_endpoint=request.rel_url.query.get('status_endpoint', None),
file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
)
# await send("workflow_api", workflow_api, sid) # await send("workflow_api", workflow_api, sid)
else: else:
error_message = await response.text() error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid) # await send("error", {"message": error_message}, sid)
try: try:
# Send initial state to the new client # Send initial state to the new client
@ -422,17 +441,29 @@ async def websocket_handler(request):
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status') @server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')
async def comfy_deploy_check_status(request): async def comfy_deploy_check_status(request):
prompt_server = server.PromptServer.instance
prompt_id = request.rel_url.query.get('prompt_id', None) prompt_id = request.rel_url.query.get('prompt_id', None)
if prompt_id in prompt_metadata and 'status' in prompt_metadata[prompt_id]: if prompt_id in prompt_metadata:
return web.json_response({ return web.json_response({
"status": prompt_metadata[prompt_id]['status'].value "status": prompt_metadata[prompt_id].status.value
}) })
else: else:
return web.json_response({ return web.json_response({
"message": "prompt_id not found" "message": "prompt_id not found"
}) })
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-ws-status')
async def comfy_deploy_check_ws_status(request):
client_id = request.rel_url.query.get('client_id', None)
if client_id in streaming_prompt_metadata:
remaining_queue = 0 # Initialize remaining queue count
for prompt_id in streaming_prompt_metadata[client_id].running_prompt_ids[client_id]:
prompt_status = prompt_metadata[prompt_id].status
if prompt_status not in [Status.FAILED, Status.SUCCESS]:
remaining_queue += 1 # Increment for each prompt still running
return web.json_response({"remaining_queue": remaining_queue})
else:
return web.json_response({"message": "client_id not found"}, status=404)
async def send(event, data, sid=None): async def send(event, data, sid=None):
try: try:
# message = {"event": event, "data": data} # message = {"event": event, "data": data}
@ -451,8 +482,8 @@ async def send(event, data, sid=None):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
prompt_server = server.PromptServer.instance prompt_server = server.PromptServer.instance
send_json = prompt_server.send_json send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None): async def send_json_override(self, event, data, sid=None):
# print("INTERNAL:", event, data, sid) # print("INTERNAL:", event, data, sid)
prompt_id = data.get('prompt_id') prompt_id = data.get('prompt_id')
@ -476,28 +507,28 @@ async def send_json_override(self, event, data, sid=None):
node = data.get('node') node = data.get('node')
if prompt_id in prompt_metadata: if prompt_id in prompt_metadata:
if 'progress' not in prompt_metadata[prompt_id]: # if 'progress' not in prompt_metadata[prompt_id]:
prompt_metadata[prompt_id]["progress"] = set() # prompt_metadata[prompt_id]["progress"] = set()
prompt_metadata[prompt_id]["progress"].add(node) prompt_metadata[prompt_id].progress.add(node)
calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api']) calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
# print("calculated_progress", calculated_progress) # print("calculated_progress", calculated_progress)
if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node: if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
return return
prompt_metadata[prompt_id]['last_updated_node'] = node prompt_metadata[prompt_id].last_updated_node = node
class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type'] class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("updating run live status", class_type) print("updating run live status", class_type)
await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress) await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
if event == 'execution_cached' and data.get('nodes') is not None: if event == 'execution_cached' and data.get('nodes') is not None:
if prompt_id in prompt_metadata: if prompt_id in prompt_metadata:
if 'progress' not in prompt_metadata[prompt_id]: # if 'progress' not in prompt_metadata[prompt_id]:
prompt_metadata[prompt_id]["progress"] = set() # prompt_metadata[prompt_id].progress = set()
if 'nodes' in data: if 'nodes' in data:
for node in data.get('nodes', []): for node in data.get('nodes', []):
prompt_metadata[prompt_id]["progress"].add(node) prompt_metadata[prompt_id].progress.add(node)
# prompt_metadata[prompt_id]["progress"].update(data.get('nodes')) # prompt_metadata[prompt_id]["progress"].update(data.get('nodes'))
if event == 'execution_error': if event == 'execution_error':
@ -511,13 +542,6 @@ async def send_json_override(self, event, data, sid=None):
# await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
# update_run_with_output(prompt_id, data.get('output')) # update_run_with_output(prompt_id, data.get('output'))
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
# Global variable to keep track of the last read line number # Global variable to keep track of the last read line number
last_read_line_number = 0 last_read_line_number = 0
@ -527,7 +551,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
print("progress", calculated_progress) print("progress", calculated_progress)
status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = { body = {
"run_id": prompt_id, "run_id": prompt_id,
"live_status": live_status, "live_status": live_status,
@ -539,19 +563,19 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
pass pass
def update_run(prompt_id, status: Status): def update_run(prompt_id: str, status: Status):
global last_read_line_number global last_read_line_number
if prompt_id not in prompt_metadata: if prompt_id not in prompt_metadata:
return return
if ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status): if (prompt_metadata[prompt_id].status != status):
# when the status is already failed, we don't want to update it to success # when the status is already failed, we don't want to update it to success
if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['status'] == Status.FAILED): if ('status' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id].status == Status.FAILED):
return return
status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = { body = {
"run_id": prompt_id, "run_id": prompt_id,
"status": status.value, "status": status.value,
@ -598,7 +622,7 @@ def update_run(prompt_id, status: Status):
stack_trace = traceback.format_exc().strip() stack_trace = traceback.format_exc().strip()
print(f"Error occurred while updating run: {e} {stack_trace}") print(f"Error occurred while updating run: {e} {stack_trace}")
finally: finally:
prompt_metadata[prompt_id]['status'] = status prompt_metadata[prompt_id].status = status
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"): async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
@ -630,7 +654,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
print("uploading file", file) print("uploading file", file)
file_upload_endpoint = prompt_metadata[prompt_id]['file_upload_endpoint'] file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
filename = quote(filename) filename = quote(filename)
prompt_id = quote(prompt_id) prompt_id = quote(prompt_id)
@ -654,22 +678,36 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
print("upload file response", response.status) print("upload file response", response.status)
def have_pending_upload(prompt_id): def have_pending_upload(prompt_id):
if 'prompt_id' in prompt_metadata and 'uploading_nodes' in prompt_metadata[prompt_id] and len(prompt_metadata[prompt_id]['uploading_nodes']) > 0: if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
print("have pending upload ", len(prompt_metadata[prompt_id]['uploading_nodes'])) print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
return True return True
print("no pending upload") print("no pending upload")
return False return False
def mark_prompt_done(prompt_id): def mark_prompt_done(prompt_id):
"""
Mark the prompt as done in the prompt metadata.
Args:
prompt_id (str): The ID of the prompt to mark as done.
"""
if prompt_id in prompt_metadata: if prompt_id in prompt_metadata:
prompt_metadata[prompt_id]["done"] = True prompt_metadata[prompt_id].done = True
print("Prompt done") print("Prompt done")
def is_prompt_done(prompt_id): def is_prompt_done(prompt_id: str):
if prompt_id in prompt_metadata and "done" in prompt_metadata[prompt_id]: """
if prompt_metadata[prompt_id]["done"] == True: Check if the prompt with the given ID is marked as done.
return True
Args:
prompt_id (str): The ID of the prompt to check.
Returns:
bool: True if the prompt is marked as done, False otherwise.
"""
if prompt_id in prompt_metadata and prompt_metadata[prompt_id].done:
return True
return False return False
@ -692,17 +730,17 @@ async def handle_error(prompt_id, data, e: Exception):
print(f"Error occurred while uploading file: {e}") print(f"Error occurred while uploading file: {e}")
# Mark the current prompt requires upload, and block it from being marked as success # Mark the current prompt requires upload, and block it from being marked as success
async def update_file_status(prompt_id, data, uploading, have_error=False, node_id=None): async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
if 'uploading_nodes' not in prompt_metadata[prompt_id]: # if 'uploading_nodes' not in prompt_metadata[prompt_id]:
prompt_metadata[prompt_id]['uploading_nodes'] = set() # prompt_metadata[prompt_id]['uploading_nodes'] = set()
if node_id is not None: if node_id is not None:
if uploading: if uploading:
prompt_metadata[prompt_id]['uploading_nodes'].add(node_id) prompt_metadata[prompt_id].uploading_nodes.add(node_id)
else: else:
prompt_metadata[prompt_id]['uploading_nodes'].discard(node_id) prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
print(prompt_metadata[prompt_id]['uploading_nodes']) print(prompt_metadata[prompt_id].uploading_nodes)
# Update the remote status # Update the remote status
if have_error: if have_error:
@ -714,7 +752,7 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
# if there are still nodes that are uploading, then we set the status to uploading # if there are still nodes that are uploading, then we set the status to uploading
if uploading: if uploading:
if prompt_metadata[prompt_id]['status'] != Status.UPLOADING: if prompt_metadata[prompt_id].status != Status.UPLOADING:
update_run(prompt_id, Status.UPLOADING) update_run(prompt_id, Status.UPLOADING)
await send("uploading", { await send("uploading", {
"prompt_id": prompt_id, "prompt_id": prompt_id,
@ -728,7 +766,7 @@ async def update_file_status(prompt_id, data, uploading, have_error=False, node_
"prompt_id": prompt_id, "prompt_id": prompt_id,
}) })
async def handle_upload(prompt_id, data, key, content_type_key, default_content_type): async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
items = data.get(key, []) items = data.get(key, [])
for item in items: for item in items:
await upload_file( await upload_file(
@ -741,7 +779,7 @@ async def handle_upload(prompt_id, data, key, content_type_key, default_content_
# Upload files in the background # Upload files in the background
async def upload_in_background(prompt_id, data, node_id=None, have_upload=True): async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
try: try:
await handle_upload(prompt_id, data, 'images', "content_type", "image/png") await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
await handle_upload(prompt_id, data, 'files', "content_type", "image/png") await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
@ -755,7 +793,7 @@ async def upload_in_background(prompt_id, data, node_id=None, have_upload=True):
async def update_run_with_output(prompt_id, data, node_id=None): async def update_run_with_output(prompt_id, data, node_id=None):
if prompt_id in prompt_metadata: if prompt_id in prompt_metadata:
status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = { body = {
"run_id": prompt_id, "run_id": prompt_id,