Streaming support (#52)

* feat: add streaming endpoint

* fix: run issues

* feat(plugin): add dispatchAPIEventData

* fix(plugin): event

* fix: streaming event format

* fix: prompt error

* fix: node_error proxy

* chore(plugin): add log

* custom route

---------

Co-authored-by: nick <kobenkao@gmail.com>
This commit is contained in:
BennyKok 2024-07-11 20:03:41 -07:00 committed by GitHub
parent 716790e344
commit a2ac1adf01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 413 additions and 151 deletions

View File

@ -18,13 +18,47 @@ import threading
import hashlib
import aiohttp
import aiofiles
from typing import List, Union, Any, Optional
from typing import Dict, List, Union, Any, Optional
from PIL import Image
import copy
import struct
from logging import basicConfig, getLogger
import logfire
# if os.environ.get('LOGFIRE_TOKEN', None) is not None:
logfire.configure(
send_to_logfire="if-token-present"
)
# basicConfig(handlers=[logfire.LogfireLoggingHandler()])
logfire_handler = logfire.LogfireLoggingHandler()
logger = getLogger("comfy-deploy")
logger.addHandler(logfire_handler)
from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
class EventEmitter:
def __init__(self):
self.listeners = {}
def on(self, event, listener):
if event not in self.listeners:
self.listeners[event] = []
self.listeners[event].append(listener)
def off(self, event, listener):
if event in self.listeners:
self.listeners[event].remove(listener)
if not self.listeners[event]:
del self.listeners[event]
def emit(self, event, *args, **kwargs):
if event in self.listeners:
for listener in self.listeners[event]:
listener(*args, **kwargs)
# Create a global event emitter instance
event_emitter = EventEmitter()
api = None
api_task = None
@ -32,18 +66,18 @@ cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true'
cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true'
bypass_upload = os.environ.get('CD_BYPASS_UPLOAD', 'false').lower() == 'true'
print("CD_BYPASS_UPLOAD", bypass_upload)
logger.info(f"CD_BYPASS_UPLOAD {bypass_upload}")
def clear_current_prompt(sid):
prompt_server = server.PromptServer.instance
to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list
print("clearning out prompt: ", to_delete)
logger.info("clearning out prompt: ", to_delete)
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == id_to_delete
prompt_server.prompt_queue.delete_queue_item(delete_func)
print("deleted prompt: ", id_to_delete, prompt_server.prompt_queue.get_tasks_remaining())
logger.info("deleted prompt: ", id_to_delete, prompt_server.prompt_queue.get_tasks_remaining())
streaming_prompt_metadata[sid].running_prompt_ids.clear()
@ -84,7 +118,7 @@ def post_prompt(json_data):
}
return response
else:
print("invalid prompt:", valid[1])
logger.info("invalid prompt:", valid[1])
return {"error": valid[1], "node_errors": valid[3]}
else:
return {"error": "no prompt", "node_errors": []}
@ -158,11 +192,11 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
# Random seed
apply_random_seed_to_workflow(workflow_api)
print("getting inputs" , inputs.inputs)
logger.info("getting inputs" , inputs.inputs)
apply_inputs_to_workflow(workflow_api, inputs.inputs, sid=sid)
print(workflow_api)
logger.info(workflow_api)
prompt_id = str(uuid.uuid4())
@ -185,12 +219,11 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
error_type = type(e).__name__
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
stack_trace = traceback.format_exc().strip()
print(f"error: {error_type}, {e}")
print(f"stack trace: {stack_trace_short}")
logger.info(f"error: {error_type}, {e}")
logger.info(f"stack trace: {stack_trace_short}")
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
async def comfy_deploy_run(request):
prompt_server = server.PromptServer.instance
data = await request.json()
# In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
@ -221,8 +254,8 @@ async def comfy_deploy_run(request):
error_type = type(e).__name__
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
stack_trace = traceback.format_exc().strip()
print(f"error: {error_type}, {e}")
print(f"stack trace: {stack_trace_short}")
logger.info(f"error: {error_type}, {e}")
logger.info(f"stack trace: {stack_trace_short}")
await update_run_with_output(prompt_id, {
"error": {
"error_type": error_type,
@ -234,13 +267,6 @@ async def comfy_deploy_run(request):
return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}")
status = 200
# if "error" in res:
# status = 400
# await update_run_with_output(prompt_id, {
# "error": {
# **res
# }
# })
if "node_errors" in res and res["node_errors"]:
# Even tho there are node_errors it can still be run
@ -257,24 +283,134 @@ async def comfy_deploy_run(request):
return web.json_response(res, status=status)
async def stream_prompt(data):
# In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
workflow_api = data.get("workflow_api_raw")
# The prompt id generated from comfy deploy, can be None
prompt_id = data.get("prompt_id")
inputs = data.get("inputs")
# Now it handles directly in here
apply_random_seed_to_workflow(workflow_api)
apply_inputs_to_workflow(workflow_api, inputs)
prompt = {
"prompt": workflow_api,
"client_id": "comfy_deploy_instance", #api.client_id
"prompt_id": prompt_id
}
prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=data.get('status_endpoint'),
file_upload_endpoint=data.get('file_upload_endpoint'),
workflow_api=workflow_api
)
logfire.info("Begin prompt", prompt=prompt)
try:
res = post_prompt(prompt)
except Exception as e:
error_type = type(e).__name__
stack_trace_short = traceback.format_exc().strip().split('\n')[-2]
stack_trace = traceback.format_exc().strip()
logger.info(f"error: {error_type}, {e}")
logger.info(f"stack trace: {stack_trace_short}")
await update_run_with_output(prompt_id, {
"error": {
"error_type": error_type,
"stack_trace": stack_trace
}
})
# When there are critical errors, the prompt is actually not run
await update_run(prompt_id, Status.FAILED)
# return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}")
# raise Exception("Prompt failed")
status = 200
if "node_errors" in res and res["node_errors"]:
# Even tho there are node_errors it can still be run
status = 400
await update_run_with_output(prompt_id, {
"error": {
**res
}
})
# When there are critical errors, the prompt is actually not run
if "error" in res:
await update_run(prompt_id, Status.FAILED)
# raise Exception("Prompt failed")
return res
# return web.json_response(res, status=status)
comfy_message_queues: Dict[str, asyncio.Queue] = {}
@server.PromptServer.instance.routes.post('/comfyui-deploy/run/streaming')
async def stream_response(request):
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'})
await response.prepare(request)
pending = True
data = await request.json()
prompt_id = data.get("prompt_id")
comfy_message_queues[prompt_id] = asyncio.Queue()
with logfire.span('Streaming Run'):
logfire.info('Streaming prompt')
try:
result = await stream_prompt(data=data)
await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8'))
# await response.write(.encode('utf-8'))
await response.drain() # Ensure the buffer is flushed
while pending:
if prompt_id in comfy_message_queues:
if not comfy_message_queues[prompt_id].empty():
data = await comfy_message_queues[prompt_id].get()
logfire.info(data["event"], data=json.dumps(data))
# logger.info("listener", data)
await response.write(f"event: event_update\ndata: {json.dumps(data)}\n\n".encode('utf-8'))
await response.drain() # Ensure the buffer is flushed
if data["event"] == "status":
if data["data"]["status"] in (Status.FAILED.value, Status.SUCCESS.value):
pending = False
await asyncio.sleep(0.1) # Adjust the sleep duration as needed
except asyncio.CancelledError:
logfire.info("Streaming was cancelled")
raise
except Exception as e:
logfire.error("Streaming error", error=e)
finally:
# event_emitter.off("send_json", task)
await response.write_eof()
comfy_message_queues.pop(prompt_id, None)
return response
def get_comfyui_path_from_file_path(file_path):
file_path_parts = file_path.split("\\")
if file_path_parts[0] == "input":
print("matching input")
logger.info("matching input")
file_path = os.path.join(folder_paths.get_directory_by_type("input"), *file_path_parts[1:])
elif file_path_parts[0] == "models":
print("matching models")
logger.info("matching models")
file_path = folder_paths.get_full_path(file_path_parts[1], os.path.join(*file_path_parts[2:]))
print(file_path)
logger.info(file_path)
return file_path
# Form ComfyUI Manager
async def compute_sha256_checksum(filepath):
print("computing sha256 checksum")
logger.info("computing sha256 checksum")
chunk_size = 1024 * 256 # Example: 256KB
filepath = get_comfyui_path_from_file_path(filepath)
"""Compute the SHA256 checksum of a file, in chunks, asynchronously"""
@ -297,7 +433,7 @@ async def get_installed_models(request):
file_list = folder_paths.get_filename_list(key)
value_json_compatible = (value[0], list(value[1]), file_list)
new_dict[key] = value_json_compatible
# print(new_dict)
# logger.info(new_dict)
return web.json_response(new_dict)
# This is start uploading the files to Comfy Deploy
@ -307,7 +443,7 @@ async def upload_file_endpoint(request):
file_path = data.get("file_path")
print("Original file path", file_path)
logger.info("Original file path", file_path)
file_path = get_comfyui_path_from_file_path(file_path)
@ -429,7 +565,7 @@ async def get_file_hash(request):
file_hash = await compute_sha256_checksum(full_file_path)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Cache miss -> Execution time: {elapsed_time} seconds")
logger.info(f"Cache miss -> Execution time: {elapsed_time} seconds")
# Update the in-memory cache
file_hash_cache[full_file_path] = file_hash
@ -449,6 +585,8 @@ async def update_realtime_run_status(realtime_id: str, status_endpoint: str, sta
"run_id": realtime_id,
"status": status.value,
}
if (status_endpoint is None):
return
# requests.post(status_endpoint, json=body)
async with aiohttp.ClientSession() as session:
async with session.post(status_endpoint, json=body) as response:
@ -479,7 +617,7 @@ async def websocket_handler(request):
if response.status == 200:
workflow = await response.json()
print("Loaded workflow version ",workflow["version"])
logger.info(f"Loaded workflow version ${workflow['version']}")
streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"],
@ -493,7 +631,7 @@ async def websocket_handler(request):
# await send("workflow_api", workflow_api, sid)
else:
error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid)
try:
@ -508,10 +646,10 @@ async def websocket_handler(request):
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
print(data)
logger.info(data)
event_type = data.get('event')
if event_type == 'input':
print("Got input: ", data.get("inputs"))
logger.info(f"Got input: ${data.get('inputs')}")
input = data.get('inputs')
streaming_prompt_metadata[sid].inputs.update(input)
elif event_type == 'queue_prompt':
@ -521,7 +659,7 @@ async def websocket_handler(request):
# Handle other event types
pass
except json.JSONDecodeError:
print('Failed to decode JSON from message')
logger.info('Failed to decode JSON from message')
if msg.type == aiohttp.WSMsgType.BINARY:
data = msg.data
@ -530,9 +668,9 @@ async def websocket_handler(request):
image_type_code, = struct.unpack("<I", data[4:8])
input_id_bytes = data[8:32] # Extract the next 24 bytes for the input ID
input_id = input_id_bytes.decode('ascii').strip() # Decode the input ID from ASCII
print(event_type)
print(image_type_code)
print(input_id)
logger.info(event_type)
logger.info(image_type_code)
logger.info(input_id)
image_data = data[32:] # The rest is the image data
if image_type_code == 1:
image_type = "JPEG"
@ -541,7 +679,7 @@ async def websocket_handler(request):
elif image_type_code == 3:
image_type = "WEBP"
else:
print("Unknown image type code:", image_type_code)
logger.info(f"Unknown image type code: ${image_type_code}")
return
image = Image.open(BytesIO(image_data))
# Check if the input ID already exists and replace the input with the new one
@ -552,14 +690,14 @@ async def websocket_handler(request):
if hasattr(existing_image, 'close'):
existing_image.close()
except Exception as e:
print(f"Error closing previous image for input ID {input_id}: {e}")
logger.info(f"Error closing previous image for input ID {input_id}: {e}")
streaming_prompt_metadata[sid].inputs[input_id] = image
# clear_current_prompt(sid)
# send_prompt(sid, streaming_prompt_metadata[sid])
print(f"Received {image_type} image of size {image.size} with input ID {input_id}")
logger.info(f"Received {image_type} image of size {image.size} with input ID {input_id}")
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
logger.info('ws connection closed with exception %s' % ws.exception())
finally:
sockets.pop(sid, None)
@ -604,16 +742,16 @@ async def send(event, data, sid=None):
if not ws.closed: # Check if the WebSocket connection is open and not closing
await ws.send_json({ 'event': event, 'data': data })
except Exception as e:
print(f"Exception: {e}")
logger.info(f"Exception: {e}")
traceback.print_exc()
logging.basicConfig(level=logging.INFO)
prompt_server = server.PromptServer.instance
send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None):
# print("INTERNAL:", event, data, sid)
# logger.info("INTERNAL:", event, data, sid)
prompt_id = data.get('prompt_id')
target_sid = sid
@ -626,8 +764,19 @@ async def send_json_override(self, event, data, sid=None):
asyncio.create_task(self.send_json_original(event, data, sid))
])
if prompt_id in comfy_message_queues:
comfy_message_queues[prompt_id].put_nowait({
"event": event,
"data": data
})
# event_emitter.emit("send_json", {
# "event": event,
# "data": data
# })
if event == 'execution_start':
update_run(prompt_id, Status.RUNNING)
await update_run(prompt_id, Status.RUNNING)
if prompt_id in prompt_metadata:
prompt_metadata[prompt_id].start_time = time.perf_counter()
@ -636,12 +785,12 @@ async def send_json_override(self, event, data, sid=None):
if event == 'executing' and data.get('node') is None:
mark_prompt_done(prompt_id=prompt_id)
if not have_pending_upload(prompt_id):
update_run(prompt_id, Status.SUCCESS)
await update_run(prompt_id, Status.SUCCESS)
if prompt_id in prompt_metadata:
current_time = time.perf_counter()
if prompt_metadata[prompt_id].start_time is not None:
elapsed_time = current_time - prompt_metadata[prompt_id].start_time
print(f"Elapsed time: {elapsed_time} seconds")
logger.info(f"Elapsed time: {elapsed_time} seconds")
await send("elapsed_time", {
"prompt_id": prompt_id,
"elapsed_time": elapsed_time
@ -656,13 +805,13 @@ async def send_json_override(self, event, data, sid=None):
prompt_metadata[prompt_id].progress.add(node)
calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
# print("calculated_progress", calculated_progress)
# logger.info("calculated_progress", calculated_progress)
if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
return
prompt_metadata[prompt_id].last_updated_node = node
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("updating run live status", class_type)
logger.info(f"updating run live status {class_type}")
await send("live_status", {
"prompt_id": prompt_id,
"current_node": class_type,
@ -683,17 +832,17 @@ async def send_json_override(self, event, data, sid=None):
if event == 'execution_error':
# Careful this might not be fully awaited.
await update_run_with_output(prompt_id, data)
update_run(prompt_id, Status.FAILED)
await update_run(prompt_id, Status.FAILED)
# await update_run_with_output(prompt_id, data)
if event == 'executed' and 'node' in data and 'output' in data:
print("executed", data)
logger.info(f"executed {data}")
if prompt_id in prompt_metadata:
node = data.get('node')
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("executed", class_type)
logger.info(f"executed {class_type}")
if class_type == "PreviewImage":
print("skipping preview image")
logger.info("skipping preview image")
return
await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
@ -710,21 +859,36 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
if prompt_metadata[prompt_id].is_realtime is True:
return
print("progress", calculated_progress)
status_endpoint = prompt_metadata[prompt_id].status_endpoint
if (status_endpoint is None):
return
logger.info(f"progress {calculated_progress}")
body = {
"run_id": prompt_id,
"live_status": live_status,
"progress": calculated_progress
}
if prompt_id in comfy_message_queues:
comfy_message_queues[prompt_id].put_nowait({
"event": "live_status",
"data": {
"prompt_id": prompt_id,
"live_status": live_status,
"progress": calculated_progress
}
})
# requests.post(status_endpoint, json=body)
async with aiohttp.ClientSession() as session:
async with session.post(status_endpoint, json=body) as response:
pass
def update_run(prompt_id: str, status: Status):
async def update_run(prompt_id: str, status: Status):
global last_read_line_number
if prompt_id not in prompt_metadata:
@ -747,18 +911,22 @@ def update_run(prompt_id: str, status: Status):
"run_id": prompt_id,
"status": status.value,
}
print(f"Status: {status.value}")
logger.info(f"Status: {status.value}")
try:
requests.post(status_endpoint, json=body)
# requests.post(status_endpoint, json=body)
if (status_endpoint is not None):
async with aiohttp.ClientSession() as session:
async with session.post(status_endpoint, json=body) as response:
pass
if cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
try:
with open(comfyui_file_path, 'r') as log_file:
# log_data = log_file.read()
# Move to the last read line
all_log_data = log_file.read() # Read all log data
print("All log data before skipping:", all_log_data) # Log all data before skipping
# logger.info("All log data before skipping: ") # Log all data before skipping
log_file.seek(0) # Reset file pointer to the beginning
for _ in range(last_read_line_number):
@ -766,9 +934,9 @@ def update_run(prompt_id: str, status: Status):
log_data = log_file.read()
# Update the last read line number
last_read_line_number += log_data.count('\n')
print("last_read_line_number", last_read_line_number)
print("log_data", log_data)
print("log_data.count(n)", log_data.count('\n'))
# logger.info("last_read_line_number", last_read_line_number)
# logger.info("log_data", log_data)
# logger.info("log_data.count(n)", log_data.count('\n'))
body = {
"run_id": prompt_id,
@ -779,16 +947,28 @@ def update_run(prompt_id: str, status: Status):
}
]
}
requests.post(status_endpoint, json=body)
async with aiohttp.ClientSession() as session:
async with session.post(status_endpoint, json=body) as response:
pass
# requests.post(status_endpoint, json=body)
except Exception as log_error:
print(f"Error reading log file: {log_error}")
logger.info(f"Error reading log file: {log_error}")
except Exception as e:
error_type = type(e).__name__
stack_trace = traceback.format_exc().strip()
print(f"Error occurred while updating run: {e} {stack_trace}")
logger.info(f"Error occurred while updating run: {e} {stack_trace}")
finally:
prompt_metadata[prompt_id].status = status
if prompt_id in comfy_message_queues:
comfy_message_queues[prompt_id].put_nowait({
"event": "status",
"data": {
"prompt_id": prompt_id,
"status": status.value,
}
})
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
@ -806,7 +986,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
print(filename, "Upload failed: output_dir is None")
logger.info(f"{filename} Upload failed: output_dir is None")
return
if subfolder != None:
@ -818,7 +998,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
print("uploading file", file)
logger.info(f"uploading file {file}")
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
@ -831,7 +1011,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
start_time = time.time() # Start timing here
result = requests.get(target_url)
end_time = time.time() # End timing after the request is complete
print("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
ok = result.json()
start_time = time.time() # Start timing here
@ -846,16 +1026,16 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
# response = requests.put(ok.get("url"), headers=headers, data=data)
async with aiohttp.ClientSession() as session:
async with session.put(ok.get("url"), headers=headers, data=data) as response:
print("Upload file response", response.status)
logger.info(f"Upload file response status: {response.status}, status text: {response.reason}")
end_time = time.time() # End timing after the request is complete
print("Upload time: {:.2f} seconds".format(end_time - start_time))
logger.info("Upload time: {:.2f} seconds".format(end_time - start_time))
def have_pending_upload(prompt_id):
if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
print("have pending upload ", len(prompt_metadata[prompt_id].uploading_nodes))
logger.info(f"have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}")
return True
print("no pending upload")
logger.info("no pending upload")
return False
def mark_prompt_done(prompt_id):
@ -867,7 +1047,7 @@ def mark_prompt_done(prompt_id):
"""
if prompt_id in prompt_metadata:
prompt_metadata[prompt_id].done = True
print("Prompt done")
logger.info("Prompt done")
def is_prompt_done(prompt_id: str):
"""
@ -899,8 +1079,8 @@ async def handle_error(prompt_id, data, e: Exception):
}
}
await update_file_status(prompt_id, data, False, have_error=True)
print(body)
print(f"Error occurred while uploading file: {e}")
logger.info(body)
logger.info(f"Error occurred while uploading file: {e}")
# Mark the current prompt requires upload, and block it from being marked as success
async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None):
@ -913,11 +1093,11 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
else:
prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
print(prompt_metadata[prompt_id].uploading_nodes)
logger.info(prompt_metadata[prompt_id].uploading_nodes)
# Update the remote status
if have_error:
update_run(prompt_id, Status.FAILED)
await update_run(prompt_id, Status.FAILED)
await send("failed", {
"prompt_id": prompt_id,
})
@ -926,15 +1106,15 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
# if there are still nodes that are uploading, then we set the status to uploading
if uploading:
if prompt_metadata[prompt_id].status != Status.UPLOADING:
update_run(prompt_id, Status.UPLOADING)
await update_run(prompt_id, Status.UPLOADING)
await send("uploading", {
"prompt_id": prompt_id,
})
# if there are no nodes that are uploading, then we set the status to success
elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id):
update_run(prompt_id, Status.SUCCESS)
# print("Status: SUCCUSS")
await update_run(prompt_id, Status.SUCCESS)
# logger.info("Status: SUCCUSS")
await send("success", {
"prompt_id": prompt_id,
})
@ -997,7 +1177,7 @@ async def update_run_with_output(prompt_id, data, node_id=None):
if have_upload_media:
try:
print("\nhave_upload", have_upload_media, node_id)
logger.info(f"\nhave_upload {have_upload} {node_id}")
if have_upload_media:
await update_file_status(prompt_id, data, True, node_id=node_id)
@ -1008,7 +1188,11 @@ async def update_run_with_output(prompt_id, data, node_id=None):
except Exception as e:
await handle_error(prompt_id, data, e)
requests.post(status_endpoint, json=body)
# requests.post(status_endpoint, json=body)
if status_endpoint is not None:
async with aiohttp.ClientSession() as session:
async with session.post(status_endpoint, json=body) as response:
pass
await send('outputs_uploaded', {
"prompt_id": prompt_id

View File

@ -22,12 +22,13 @@ class StreamingPrompt(BaseModel):
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: str
file_upload_endpoint: str
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
class SimplePrompt(BaseModel):
status_endpoint: str
file_upload_endpoint: str
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()

View File

@ -2,3 +2,4 @@ aiofiles
pydantic
opencv-python
imageio-ffmpeg
logfire

View File

@ -13,6 +13,78 @@ function sendEventToCD(event, data) {
window.parent.postMessage(JSON.stringify(message), "*");
}
function dispatchAPIEventData(data) {
const msg = JSON.parse(data);
// Custom parse error
if (msg.error) {
let message = msg.error.message;
if (msg.error.details)
message += ": " + msg.error.details;
for (const [nodeID, nodeError] of Object.entries(
msg.node_errors,
)) {
message += "\n" + nodeError.class_type + ":";
for (const errorReason of nodeError.errors) {
message +=
"\n - " + errorReason.message + ": " + errorReason.details;
}
}
app.ui.dialog.show(message);
if (msg.node_errors) {
app.lastNodeErrors = msg.node_errors;
app.canvas.draw(true, true);
}
}
switch (msg.event) {
case "error":
break;
case "status":
if (msg.data.sid) {
// this.clientId = msg.data.sid;
// window.name = this.clientId; // use window name so it isnt reused when duplicating tabs
// sessionStorage.setItem("clientId", this.clientId); // store in session storage so duplicate tab can load correct workflow
}
api.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
api.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
api.dispatchEvent(
new CustomEvent("executing", { detail: msg.data.node }),
);
break;
case "executed":
api.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
api.dispatchEvent(
new CustomEvent("execution_start", { detail: msg.data }),
);
break;
case "execution_error":
api.dispatchEvent(
new CustomEvent("execution_error", { detail: msg.data }),
);
break;
case "execution_cached":
api.dispatchEvent(
new CustomEvent("execution_cached", { detail: msg.data }),
);
break;
default:
api.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
// default:
// if (this.#registered.has(msg.type)) {
// } else {
// throw new Error(`Unknown message type ${msg.type}`);
// }
}
}
/** @typedef {import('../../../web/types/comfy.js').ComfyExtension} ComfyExtension*/
/** @type {ComfyExtension} */
const ext = {
@ -33,8 +105,7 @@ const ext = {
sendEventToCD("cd_plugin_onInit");
app.queuePrompt = ((originalFunction) =>
async () => {
app.queuePrompt = ((originalFunction) => async () => {
// const prompt = await app.graphToPrompt();
sendEventToCD("cd_plugin_onQueuePromptTrigger");
})(app.queuePrompt);
@ -208,7 +279,12 @@ const ext = {
} else if (message.type === "queue_prompt") {
const prompt = await app.graphToPrompt();
sendEventToCD("cd_plugin_onQueuePrompt", prompt);
} else if (message.type === "event") {
dispatchAPIEventData(message.data);
}
// else if (message.type === "refresh") {
// sendEventToCD("cd_plugin_onRefresh");
// }
} catch (error) {
// console.error("Error processing message:", error);
}