From 40ec37e58ff4a30c556703e18ffc01de1add5d44 Mon Sep 17 00:00:00 2001 From: bennykok Date: Mon, 9 Dec 2024 16:48:11 +0800 Subject: [PATCH] fix --- custom_routes.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index a3918a8..4913151 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -1224,11 +1224,36 @@ def format_table(headers, data): result.append(separator) return '\n'.join(result) +origin_func = server.PromptServer.send_sync +def swizzle_send_sync(self, event, data, sid=None): + # print(f"swizzle_send_sync, event: {event}, data: {data}") + global CURRENT_START_EXECUTION_DATA + if event == "execution_start": + global NODE_EXECUTION_TIMES + NODE_EXECUTION_TIMES = {} # Reset execution times at start + CURRENT_START_EXECUTION_DATA = dict( + start_perf_time=time.perf_counter(), + nodes_start_perf_time={}, + nodes_start_vram={}, + ) + origin_func(self, event=event, data=data, sid=sid) + + if event == "executing" and data and CURRENT_START_EXECUTION_DATA: + if data.get("node") is not None: + node_id = data.get("node") + CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = ( + time.perf_counter() + ) + reset_peak_memory_record() + CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = ( + get_peak_memory() + ) + +server.PromptServer.send_sync = swizzle_send_sync send_json = prompt_server.send_json - async def send_json_override(self, event, data, sid=None): # logger.info("INTERNAL:", event, data, sid) prompt_id = data.get("prompt_id") @@ -1254,15 +1279,6 @@ async def send_json_override(self, event, data, sid=None): if prompt_id in prompt_metadata: prompt_metadata[prompt_id].start_time = time.perf_counter() - global CURRENT_START_EXECUTION_DATA - global NODE_EXECUTION_TIMES - NODE_EXECUTION_TIMES = {} # Reset execution times at start - CURRENT_START_EXECUTION_DATA = dict( - start_perf_time=time.perf_counter(), - nodes_start_perf_time={}, - nodes_start_vram={}, - ) - await update_run(prompt_id, Status.RUNNING) @@ -1304,17 +1320,6 @@ async def send_json_override(self, event, data, sid=None): logger.info("Printing Node Execution Times") logger.info(format_table(headers, table_data)) # print("========================\n") - - - else: - node_id = data.get("node") - CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = ( - time.perf_counter() - ) - reset_peak_memory_record() - CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = ( - get_peak_memory() - ) # the last executing event is none, then the workflow is finished if event == "executing" and data.get("node") is None: