feat: add perf counter
This commit is contained in:
parent
32d574475c
commit
1e33435ae5
@ -26,6 +26,7 @@ import copy
|
||||
import struct
|
||||
from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError
|
||||
import atexit
|
||||
import perf_counter
|
||||
|
||||
# Global session
|
||||
client_session = None
|
||||
@ -1227,7 +1228,8 @@ async def send_json_override(self, event, data, sid=None):
|
||||
"node_class": class_type,
|
||||
}
|
||||
if class_type == "PreviewImage":
|
||||
logger.info("Skipping preview image")
|
||||
pass
|
||||
# logger.info("Skipping preview image")
|
||||
else:
|
||||
await update_run_with_output(
|
||||
prompt_id,
|
||||
@ -1239,9 +1241,10 @@ async def send_json_override(self, event, data, sid=None):
|
||||
comfy_message_queues[prompt_id].put_nowait(
|
||||
{"event": "output_ready", "data": data}
|
||||
)
|
||||
logger.info(f"Executed {class_type} {data}")
|
||||
# logger.info(f"Executed {class_type} {data}")
|
||||
else:
|
||||
logger.info(f"Executed {data}")
|
||||
pass
|
||||
# logger.info(f"Executed {data}")
|
||||
|
||||
|
||||
# Global variable to keep track of the last read line number
|
||||
|
149
perf_counter.py
Normal file
149
perf_counter.py
Normal file
@ -0,0 +1,149 @@
|
||||
# Reference from https://github.com/ty0x2333/ComfyUI-Dev-Utils/blob/main/nodes/execution_time.py
|
||||
|
||||
import server
|
||||
import torch
|
||||
import time
|
||||
import execution
|
||||
from tabulate import tabulate
|
||||
from model_management import get_torch_device
|
||||
import psutil
|
||||
|
||||
prompt_server = server.PromptServer.instance
|
||||
|
||||
NODE_EXECUTION_TIMES = {} # New dictionary to store node execution times
|
||||
|
||||
def get_peak_memory():
|
||||
device = get_torch_device()
|
||||
if device.type == 'cuda':
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
elif device.type == 'mps':
|
||||
# Return system memory usage for MPS devices
|
||||
return psutil.Process().memory_info().rss
|
||||
return 0
|
||||
|
||||
|
||||
def reset_peak_memory_record():
|
||||
device = get_torch_device()
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.reset_max_memory_allocated(device)
|
||||
# MPS doesn't need reset as we're not tracking its memory
|
||||
|
||||
|
||||
def handle_execute(class_type, last_node_id, prompt_id, server, unique_id):
|
||||
if not CURRENT_START_EXECUTION_DATA:
|
||||
return
|
||||
start_time = CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"].get(unique_id)
|
||||
start_vram = CURRENT_START_EXECUTION_DATA["nodes_start_vram"].get(unique_id)
|
||||
if start_time:
|
||||
end_time = time.perf_counter()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
end_vram = get_peak_memory()
|
||||
vram_used = end_vram - start_vram
|
||||
# print(f"end_vram - start_vram: {end_vram} - {start_vram} = {vram_used}")
|
||||
NODE_EXECUTION_TIMES[unique_id] = {
|
||||
"time": execution_time,
|
||||
"class_type": class_type,
|
||||
"vram_used": vram_used
|
||||
}
|
||||
# print(f"#{unique_id} [{class_type}]: {execution_time:.2f}s - vram {vram_used}b")
|
||||
|
||||
|
||||
try:
|
||||
origin_execute = execution.execute
|
||||
|
||||
def swizzle_execute(
|
||||
server,
|
||||
dynprompt,
|
||||
caches,
|
||||
current_item,
|
||||
extra_data,
|
||||
executed,
|
||||
prompt_id,
|
||||
execution_list,
|
||||
pending_subgraph_results,
|
||||
):
|
||||
unique_id = current_item
|
||||
class_type = dynprompt.get_node(unique_id)["class_type"]
|
||||
last_node_id = server.last_node_id
|
||||
result = origin_execute(
|
||||
server,
|
||||
dynprompt,
|
||||
caches,
|
||||
current_item,
|
||||
extra_data,
|
||||
executed,
|
||||
prompt_id,
|
||||
execution_list,
|
||||
pending_subgraph_results,
|
||||
)
|
||||
handle_execute(class_type, last_node_id, prompt_id, server, unique_id)
|
||||
return result
|
||||
|
||||
execution.execute = swizzle_execute
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
CURRENT_START_EXECUTION_DATA = None
|
||||
origin_func = server.PromptServer.send_sync
|
||||
|
||||
|
||||
def swizzle_send_sync(self, event, data, sid=None):
|
||||
# print(f"swizzle_send_sync, event: {event}, data: {data}")
|
||||
global CURRENT_START_EXECUTION_DATA
|
||||
if event == "execution_start":
|
||||
global NODE_EXECUTION_TIMES
|
||||
NODE_EXECUTION_TIMES = {} # Reset execution times at start
|
||||
CURRENT_START_EXECUTION_DATA = dict(
|
||||
start_perf_time=time.perf_counter(),
|
||||
nodes_start_perf_time={},
|
||||
nodes_start_vram={},
|
||||
)
|
||||
|
||||
origin_func(self, event=event, data=data, sid=sid)
|
||||
|
||||
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
|
||||
if data.get("node") is None:
|
||||
start_perf_time = CURRENT_START_EXECUTION_DATA.get("start_perf_time")
|
||||
new_data = data.copy()
|
||||
if start_perf_time is not None:
|
||||
execution_time = time.perf_counter() - start_perf_time
|
||||
new_data["execution_time"] = int(execution_time * 1000)
|
||||
|
||||
# Replace the print statements with tabulate
|
||||
headers = ["Node ID", "Type", "Time (s)", "VRAM (GB)"]
|
||||
table_data = []
|
||||
for node_id, node_data in NODE_EXECUTION_TIMES.items():
|
||||
vram_gb = node_data['vram_used'] / (1024**3) # Convert bytes to GB
|
||||
table_data.append([
|
||||
f"#{node_id}",
|
||||
node_data['class_type'],
|
||||
f"{node_data['time']:.2f}",
|
||||
f"{vram_gb:.2f}"
|
||||
])
|
||||
|
||||
# Add total execution time as the last row
|
||||
table_data.append([
|
||||
"TOTAL",
|
||||
"-",
|
||||
f"{execution_time:.2f}",
|
||||
"-"
|
||||
])
|
||||
|
||||
# print("\n=== Node Execution Times ===")
|
||||
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
||||
# print("========================\n")
|
||||
|
||||
|
||||
else:
|
||||
node_id = data.get("node")
|
||||
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
|
||||
time.perf_counter()
|
||||
)
|
||||
reset_peak_memory_record()
|
||||
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
|
||||
get_peak_memory()
|
||||
)
|
||||
|
||||
|
||||
server.PromptServer.send_sync = swizzle_send_sync
|
@ -3,4 +3,5 @@ pydantic
|
||||
opencv-python
|
||||
imageio-ffmpeg
|
||||
brotli
|
||||
tabulate
|
||||
# logfire
|
Loading…
x
Reference in New Issue
Block a user