feat: add perf counter

This commit is contained in:
bennykok 2024-12-09 00:11:51 +08:00
parent 32d574475c
commit 1e33435ae5
3 changed files with 156 additions and 3 deletions

View File

@ -26,6 +26,7 @@ import copy
import struct
from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError
import atexit
import perf_counter
# Global session
client_session = None
@ -1227,7 +1228,8 @@ async def send_json_override(self, event, data, sid=None):
"node_class": class_type,
}
if class_type == "PreviewImage":
logger.info("Skipping preview image")
pass
# logger.info("Skipping preview image")
else:
await update_run_with_output(
prompt_id,
@ -1239,9 +1241,10 @@ async def send_json_override(self, event, data, sid=None):
comfy_message_queues[prompt_id].put_nowait(
{"event": "output_ready", "data": data}
)
logger.info(f"Executed {class_type} {data}")
# logger.info(f"Executed {class_type} {data}")
else:
logger.info(f"Executed {data}")
pass
# logger.info(f"Executed {data}")
# Global variable to keep track of the last read line number

149
perf_counter.py Normal file
View File

@ -0,0 +1,149 @@
# Reference from https://github.com/ty0x2333/ComfyUI-Dev-Utils/blob/main/nodes/execution_time.py
import server
import torch
import time
import execution
from tabulate import tabulate
from model_management import get_torch_device
import psutil
prompt_server = server.PromptServer.instance
NODE_EXECUTION_TIMES = {} # New dictionary to store node execution times
def get_peak_memory():
device = get_torch_device()
if device.type == 'cuda':
return torch.cuda.max_memory_allocated(device)
elif device.type == 'mps':
# Return system memory usage for MPS devices
return psutil.Process().memory_info().rss
return 0
def reset_peak_memory_record():
device = get_torch_device()
if device.type == 'cuda':
torch.cuda.reset_max_memory_allocated(device)
# MPS doesn't need reset as we're not tracking its memory
def handle_execute(class_type, last_node_id, prompt_id, server, unique_id):
if not CURRENT_START_EXECUTION_DATA:
return
start_time = CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"].get(unique_id)
start_vram = CURRENT_START_EXECUTION_DATA["nodes_start_vram"].get(unique_id)
if start_time:
end_time = time.perf_counter()
execution_time = end_time - start_time
end_vram = get_peak_memory()
vram_used = end_vram - start_vram
# print(f"end_vram - start_vram: {end_vram} - {start_vram} = {vram_used}")
NODE_EXECUTION_TIMES[unique_id] = {
"time": execution_time,
"class_type": class_type,
"vram_used": vram_used
}
# print(f"#{unique_id} [{class_type}]: {execution_time:.2f}s - vram {vram_used}b")
try:
origin_execute = execution.execute
def swizzle_execute(
server,
dynprompt,
caches,
current_item,
extra_data,
executed,
prompt_id,
execution_list,
pending_subgraph_results,
):
unique_id = current_item
class_type = dynprompt.get_node(unique_id)["class_type"]
last_node_id = server.last_node_id
result = origin_execute(
server,
dynprompt,
caches,
current_item,
extra_data,
executed,
prompt_id,
execution_list,
pending_subgraph_results,
)
handle_execute(class_type, last_node_id, prompt_id, server, unique_id)
return result
execution.execute = swizzle_execute
except Exception as e:
pass
CURRENT_START_EXECUTION_DATA = None
origin_func = server.PromptServer.send_sync
def swizzle_send_sync(self, event, data, sid=None):
# print(f"swizzle_send_sync, event: {event}, data: {data}")
global CURRENT_START_EXECUTION_DATA
if event == "execution_start":
global NODE_EXECUTION_TIMES
NODE_EXECUTION_TIMES = {} # Reset execution times at start
CURRENT_START_EXECUTION_DATA = dict(
start_perf_time=time.perf_counter(),
nodes_start_perf_time={},
nodes_start_vram={},
)
origin_func(self, event=event, data=data, sid=sid)
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
if data.get("node") is None:
start_perf_time = CURRENT_START_EXECUTION_DATA.get("start_perf_time")
new_data = data.copy()
if start_perf_time is not None:
execution_time = time.perf_counter() - start_perf_time
new_data["execution_time"] = int(execution_time * 1000)
# Replace the print statements with tabulate
headers = ["Node ID", "Type", "Time (s)", "VRAM (GB)"]
table_data = []
for node_id, node_data in NODE_EXECUTION_TIMES.items():
vram_gb = node_data['vram_used'] / (1024**3) # Convert bytes to GB
table_data.append([
f"#{node_id}",
node_data['class_type'],
f"{node_data['time']:.2f}",
f"{vram_gb:.2f}"
])
# Add total execution time as the last row
table_data.append([
"TOTAL",
"-",
f"{execution_time:.2f}",
"-"
])
# print("\n=== Node Execution Times ===")
print(tabulate(table_data, headers=headers, tablefmt="grid"))
# print("========================\n")
else:
node_id = data.get("node")
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
time.perf_counter()
)
reset_peak_memory_record()
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
get_peak_memory()
)
server.PromptServer.send_sync = swizzle_send_sync

View File

@ -3,4 +3,5 @@ pydantic
opencv-python
imageio-ffmpeg
brotli
tabulate
# logfire