diff --git a/custom_routes.py b/custom_routes.py index 9d54be7..33ae803 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -361,13 +361,20 @@ async def send_json_override(self, event, data, sid=None): if event == 'executing' and data.get('node') is not None: node = data.get('node') - if 'prompt_id' in prompt_metadata: + if prompt_id in prompt_metadata: + if 'progress' not in prompt_metadata[prompt_id]: + prompt_metadata[prompt_id]["progress"] = set() + + prompt_metadata[prompt_id]["progress"].add(node) + calculated_progress = len(prompt_metadata[prompt_id]["progress"]) / len(prompt_metadata[prompt_id]['workflow_api']) + # print("calculated_progress", calculated_progress) + if 'last_updated_node' in prompt_metadata[prompt_id] and prompt_metadata[prompt_id]['last_updated_node'] == node: return prompt_metadata[prompt_id]['last_updated_node'] = node class_type = prompt_metadata[prompt_id]['workflow_api'][node]['class_type'] print("updating run live status", class_type) - await update_run_live_status(prompt_id, "Executing " + class_type) + await update_run_live_status(prompt_id, "Executing " + class_type, calculated_progress) if event == 'execution_error': # Careful this might not be fully awaited. @@ -391,18 +398,22 @@ class Status(Enum): # Global variable to keep track of the last read line number last_read_line_number = 0 -async def update_run_live_status(prompt_id, live_status): +async def update_run_live_status(prompt_id, live_status, calculated_progress: float): if prompt_id not in prompt_metadata: return + print("progress", calculated_progress) + status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] body = { "run_id": prompt_id, "live_status": live_status, + "progress": calculated_progress } # requests.post(status_endpoint, json=body) async with aiohttp.ClientSession() as session: - await session.post(status_endpoint, json=body) + async with session.post(status_endpoint, json=body) as response: + pass def update_run(prompt_id, status: Status):