feat(plugin): send live_status and elapsed_time

This commit is contained in:
bennykok 2024-02-25 22:47:54 -08:00
parent e87bb63c6f
commit 97096a9035

View File

@ -55,6 +55,7 @@ class SimplePrompt(BaseModel):
uploading_nodes: set = set() uploading_nodes: set = set()
done: bool = False done: bool = False
is_realtime: bool = False, is_realtime: bool = False,
start_time: Optional[float] = None,
streaming_prompt_metadata: dict[str, StreamingPrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
@ -500,11 +501,23 @@ async def send_json_override(self, event, data, sid=None):
if event == 'execution_start': if event == 'execution_start':
update_run(prompt_id, Status.RUNNING) update_run(prompt_id, Status.RUNNING)
if prompt_id in prompt_metadata:
prompt_metadata[prompt_id].start_time = time.perf_counter()
# the last executing event is none, then the workflow is finished # the last executing event is none, then the workflow is finished
if event == 'executing' and data.get('node') is None: if event == 'executing' and data.get('node') is None:
mark_prompt_done(prompt_id=prompt_id) mark_prompt_done(prompt_id=prompt_id)
if not have_pending_upload(prompt_id): if not have_pending_upload(prompt_id):
update_run(prompt_id, Status.SUCCESS) update_run(prompt_id, Status.SUCCESS)
if prompt_id in prompt_metadata:
current_time = time.perf_counter()
if prompt_metadata[prompt_id].start_time is not None:
elapsed_time = current_time - prompt_metadata[prompt_id].start_time
print(f"Elapsed time: {elapsed_time} seconds")
await send("elapsed_time", {
"prompt_id": prompt_id,
"elapsed_time": elapsed_time
}, sid=sid)
if event == 'executing' and data.get('node') is not None: if event == 'executing' and data.get('node') is not None:
node = data.get('node') node = data.get('node')
@ -522,6 +535,11 @@ async def send_json_override(self, event, data, sid=None):
prompt_metadata[prompt_id].last_updated_node = node prompt_metadata[prompt_id].last_updated_node = node
class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type'] class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
print("updating run live status", class_type) print("updating run live status", class_type)
await send("live_status", {
"prompt_id": prompt_id,
"current_node": class_type,
"progress": calculated_progress,
}, sid=sid)
await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress) await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress)
if event == 'execution_cached' and data.get('nodes') is not None: if event == 'execution_cached' and data.get('nodes') is not None: