feat: add perf counter
This commit is contained in:
parent
32d574475c
commit
1e33435ae5
@ -26,6 +26,7 @@ import copy
|
|||||||
import struct
|
import struct
|
||||||
from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError
|
from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError
|
||||||
import atexit
|
import atexit
|
||||||
|
import perf_counter
|
||||||
|
|
||||||
# Global session
|
# Global session
|
||||||
client_session = None
|
client_session = None
|
||||||
@ -1227,7 +1228,8 @@ async def send_json_override(self, event, data, sid=None):
|
|||||||
"node_class": class_type,
|
"node_class": class_type,
|
||||||
}
|
}
|
||||||
if class_type == "PreviewImage":
|
if class_type == "PreviewImage":
|
||||||
logger.info("Skipping preview image")
|
pass
|
||||||
|
# logger.info("Skipping preview image")
|
||||||
else:
|
else:
|
||||||
await update_run_with_output(
|
await update_run_with_output(
|
||||||
prompt_id,
|
prompt_id,
|
||||||
@ -1239,9 +1241,10 @@ async def send_json_override(self, event, data, sid=None):
|
|||||||
comfy_message_queues[prompt_id].put_nowait(
|
comfy_message_queues[prompt_id].put_nowait(
|
||||||
{"event": "output_ready", "data": data}
|
{"event": "output_ready", "data": data}
|
||||||
)
|
)
|
||||||
logger.info(f"Executed {class_type} {data}")
|
# logger.info(f"Executed {class_type} {data}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Executed {data}")
|
pass
|
||||||
|
# logger.info(f"Executed {data}")
|
||||||
|
|
||||||
|
|
||||||
# Global variable to keep track of the last read line number
|
# Global variable to keep track of the last read line number
|
||||||
|
149
perf_counter.py
Normal file
149
perf_counter.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# Reference from https://github.com/ty0x2333/ComfyUI-Dev-Utils/blob/main/nodes/execution_time.py
|
||||||
|
|
||||||
|
import server
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import execution
|
||||||
|
from tabulate import tabulate
|
||||||
|
from model_management import get_torch_device
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
prompt_server = server.PromptServer.instance
|
||||||
|
|
||||||
|
NODE_EXECUTION_TIMES = {} # New dictionary to store node execution times
|
||||||
|
|
||||||
|
def get_peak_memory():
|
||||||
|
device = get_torch_device()
|
||||||
|
if device.type == 'cuda':
|
||||||
|
return torch.cuda.max_memory_allocated(device)
|
||||||
|
elif device.type == 'mps':
|
||||||
|
# Return system memory usage for MPS devices
|
||||||
|
return psutil.Process().memory_info().rss
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def reset_peak_memory_record():
|
||||||
|
device = get_torch_device()
|
||||||
|
if device.type == 'cuda':
|
||||||
|
torch.cuda.reset_max_memory_allocated(device)
|
||||||
|
# MPS doesn't need reset as we're not tracking its memory
|
||||||
|
|
||||||
|
|
||||||
|
def handle_execute(class_type, last_node_id, prompt_id, server, unique_id):
|
||||||
|
if not CURRENT_START_EXECUTION_DATA:
|
||||||
|
return
|
||||||
|
start_time = CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"].get(unique_id)
|
||||||
|
start_vram = CURRENT_START_EXECUTION_DATA["nodes_start_vram"].get(unique_id)
|
||||||
|
if start_time:
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
execution_time = end_time - start_time
|
||||||
|
|
||||||
|
end_vram = get_peak_memory()
|
||||||
|
vram_used = end_vram - start_vram
|
||||||
|
# print(f"end_vram - start_vram: {end_vram} - {start_vram} = {vram_used}")
|
||||||
|
NODE_EXECUTION_TIMES[unique_id] = {
|
||||||
|
"time": execution_time,
|
||||||
|
"class_type": class_type,
|
||||||
|
"vram_used": vram_used
|
||||||
|
}
|
||||||
|
# print(f"#{unique_id} [{class_type}]: {execution_time:.2f}s - vram {vram_used}b")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
origin_execute = execution.execute
|
||||||
|
|
||||||
|
def swizzle_execute(
|
||||||
|
server,
|
||||||
|
dynprompt,
|
||||||
|
caches,
|
||||||
|
current_item,
|
||||||
|
extra_data,
|
||||||
|
executed,
|
||||||
|
prompt_id,
|
||||||
|
execution_list,
|
||||||
|
pending_subgraph_results,
|
||||||
|
):
|
||||||
|
unique_id = current_item
|
||||||
|
class_type = dynprompt.get_node(unique_id)["class_type"]
|
||||||
|
last_node_id = server.last_node_id
|
||||||
|
result = origin_execute(
|
||||||
|
server,
|
||||||
|
dynprompt,
|
||||||
|
caches,
|
||||||
|
current_item,
|
||||||
|
extra_data,
|
||||||
|
executed,
|
||||||
|
prompt_id,
|
||||||
|
execution_list,
|
||||||
|
pending_subgraph_results,
|
||||||
|
)
|
||||||
|
handle_execute(class_type, last_node_id, prompt_id, server, unique_id)
|
||||||
|
return result
|
||||||
|
|
||||||
|
execution.execute = swizzle_execute
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
CURRENT_START_EXECUTION_DATA = None
|
||||||
|
origin_func = server.PromptServer.send_sync
|
||||||
|
|
||||||
|
|
||||||
|
def swizzle_send_sync(self, event, data, sid=None):
|
||||||
|
# print(f"swizzle_send_sync, event: {event}, data: {data}")
|
||||||
|
global CURRENT_START_EXECUTION_DATA
|
||||||
|
if event == "execution_start":
|
||||||
|
global NODE_EXECUTION_TIMES
|
||||||
|
NODE_EXECUTION_TIMES = {} # Reset execution times at start
|
||||||
|
CURRENT_START_EXECUTION_DATA = dict(
|
||||||
|
start_perf_time=time.perf_counter(),
|
||||||
|
nodes_start_perf_time={},
|
||||||
|
nodes_start_vram={},
|
||||||
|
)
|
||||||
|
|
||||||
|
origin_func(self, event=event, data=data, sid=sid)
|
||||||
|
|
||||||
|
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
|
||||||
|
if data.get("node") is None:
|
||||||
|
start_perf_time = CURRENT_START_EXECUTION_DATA.get("start_perf_time")
|
||||||
|
new_data = data.copy()
|
||||||
|
if start_perf_time is not None:
|
||||||
|
execution_time = time.perf_counter() - start_perf_time
|
||||||
|
new_data["execution_time"] = int(execution_time * 1000)
|
||||||
|
|
||||||
|
# Replace the print statements with tabulate
|
||||||
|
headers = ["Node ID", "Type", "Time (s)", "VRAM (GB)"]
|
||||||
|
table_data = []
|
||||||
|
for node_id, node_data in NODE_EXECUTION_TIMES.items():
|
||||||
|
vram_gb = node_data['vram_used'] / (1024**3) # Convert bytes to GB
|
||||||
|
table_data.append([
|
||||||
|
f"#{node_id}",
|
||||||
|
node_data['class_type'],
|
||||||
|
f"{node_data['time']:.2f}",
|
||||||
|
f"{vram_gb:.2f}"
|
||||||
|
])
|
||||||
|
|
||||||
|
# Add total execution time as the last row
|
||||||
|
table_data.append([
|
||||||
|
"TOTAL",
|
||||||
|
"-",
|
||||||
|
f"{execution_time:.2f}",
|
||||||
|
"-"
|
||||||
|
])
|
||||||
|
|
||||||
|
# print("\n=== Node Execution Times ===")
|
||||||
|
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
||||||
|
# print("========================\n")
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
node_id = data.get("node")
|
||||||
|
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
|
||||||
|
time.perf_counter()
|
||||||
|
)
|
||||||
|
reset_peak_memory_record()
|
||||||
|
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
|
||||||
|
get_peak_memory()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
server.PromptServer.send_sync = swizzle_send_sync
|
@ -3,4 +3,5 @@ pydantic
|
|||||||
opencv-python
|
opencv-python
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
brotli
|
brotli
|
||||||
|
tabulate
|
||||||
# logfire
|
# logfire
|
Loading…
x
Reference in New Issue
Block a user