From 1e33435ae53fea4ed8c58918a81ee58f51ed90b9 Mon Sep 17 00:00:00 2001 From: bennykok Date: Mon, 9 Dec 2024 00:11:51 +0800 Subject: [PATCH] feat: add perf counter --- custom_routes.py | 9 ++- perf_counter.py | 149 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 perf_counter.py diff --git a/custom_routes.py b/custom_routes.py index 68edcd3..62c3408 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -26,6 +26,7 @@ import copy import struct from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError import atexit +import perf_counter # Global session client_session = None @@ -1227,7 +1228,8 @@ async def send_json_override(self, event, data, sid=None): "node_class": class_type, } if class_type == "PreviewImage": - logger.info("Skipping preview image") + pass + # logger.info("Skipping preview image") else: await update_run_with_output( prompt_id, @@ -1239,9 +1241,10 @@ async def send_json_override(self, event, data, sid=None): comfy_message_queues[prompt_id].put_nowait( {"event": "output_ready", "data": data} ) - logger.info(f"Executed {class_type} {data}") + # logger.info(f"Executed {class_type} {data}") else: - logger.info(f"Executed {data}") + pass + # logger.info(f"Executed {data}") # Global variable to keep track of the last read line number diff --git a/perf_counter.py b/perf_counter.py new file mode 100644 index 0000000..642bb7e --- /dev/null +++ b/perf_counter.py @@ -0,0 +1,149 @@ +# Reference from https://github.com/ty0x2333/ComfyUI-Dev-Utils/blob/main/nodes/execution_time.py + +import server +import torch +import time +import execution +from tabulate import tabulate +from model_management import get_torch_device +import psutil + +prompt_server = server.PromptServer.instance + +NODE_EXECUTION_TIMES = {} # New dictionary to store node execution times + +def get_peak_memory(): + device = get_torch_device() + if device.type == 'cuda': + return torch.cuda.max_memory_allocated(device) + elif device.type == 'mps': + # Return system memory usage for MPS devices + return psutil.Process().memory_info().rss + return 0 + + +def reset_peak_memory_record(): + device = get_torch_device() + if device.type == 'cuda': + torch.cuda.reset_max_memory_allocated(device) + # MPS doesn't need reset as we're not tracking its memory + + +def handle_execute(class_type, last_node_id, prompt_id, server, unique_id): + if not CURRENT_START_EXECUTION_DATA: + return + start_time = CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"].get(unique_id) + start_vram = CURRENT_START_EXECUTION_DATA["nodes_start_vram"].get(unique_id) + if start_time: + end_time = time.perf_counter() + execution_time = end_time - start_time + + end_vram = get_peak_memory() + vram_used = end_vram - start_vram + # print(f"end_vram - start_vram: {end_vram} - {start_vram} = {vram_used}") + NODE_EXECUTION_TIMES[unique_id] = { + "time": execution_time, + "class_type": class_type, + "vram_used": vram_used + } + # print(f"#{unique_id} [{class_type}]: {execution_time:.2f}s - vram {vram_used}b") + + +try: + origin_execute = execution.execute + + def swizzle_execute( + server, + dynprompt, + caches, + current_item, + extra_data, + executed, + prompt_id, + execution_list, + pending_subgraph_results, + ): + unique_id = current_item + class_type = dynprompt.get_node(unique_id)["class_type"] + last_node_id = server.last_node_id + result = origin_execute( + server, + dynprompt, + caches, + current_item, + extra_data, + executed, + prompt_id, + execution_list, + pending_subgraph_results, + ) + handle_execute(class_type, last_node_id, prompt_id, server, unique_id) + return result + + execution.execute = swizzle_execute +except Exception as e: + pass + +CURRENT_START_EXECUTION_DATA = None +origin_func = server.PromptServer.send_sync + + +def swizzle_send_sync(self, event, data, sid=None): + # print(f"swizzle_send_sync, event: {event}, data: {data}") + global CURRENT_START_EXECUTION_DATA + if event == "execution_start": + global NODE_EXECUTION_TIMES + NODE_EXECUTION_TIMES = {} # Reset execution times at start + CURRENT_START_EXECUTION_DATA = dict( + start_perf_time=time.perf_counter(), + nodes_start_perf_time={}, + nodes_start_vram={}, + ) + + origin_func(self, event=event, data=data, sid=sid) + + if event == "executing" and data and CURRENT_START_EXECUTION_DATA: + if data.get("node") is None: + start_perf_time = CURRENT_START_EXECUTION_DATA.get("start_perf_time") + new_data = data.copy() + if start_perf_time is not None: + execution_time = time.perf_counter() - start_perf_time + new_data["execution_time"] = int(execution_time * 1000) + + # Replace the print statements with tabulate + headers = ["Node ID", "Type", "Time (s)", "VRAM (GB)"] + table_data = [] + for node_id, node_data in NODE_EXECUTION_TIMES.items(): + vram_gb = node_data['vram_used'] / (1024**3) # Convert bytes to GB + table_data.append([ + f"#{node_id}", + node_data['class_type'], + f"{node_data['time']:.2f}", + f"{vram_gb:.2f}" + ]) + + # Add total execution time as the last row + table_data.append([ + "TOTAL", + "-", + f"{execution_time:.2f}", + "-" + ]) + + # print("\n=== Node Execution Times ===") + print(tabulate(table_data, headers=headers, tablefmt="grid")) + # print("========================\n") + + + else: + node_id = data.get("node") + CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = ( + time.perf_counter() + ) + reset_peak_memory_record() + CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = ( + get_peak_memory() + ) + + +server.PromptServer.send_sync = swizzle_send_sync diff --git a/requirements.txt b/requirements.txt index f1d3a61..c14ab49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ pydantic opencv-python imageio-ffmpeg brotli +tabulate # logfire \ No newline at end of file