This commit is contained in:
bennykok 2024-12-09 16:48:11 +08:00
parent 1d63b21643
commit 40ec37e58f

View File

@ -1224,11 +1224,36 @@ def format_table(headers, data):
result.append(separator)
return '\n'.join(result)
origin_func = server.PromptServer.send_sync
def swizzle_send_sync(self, event, data, sid=None):
# print(f"swizzle_send_sync, event: {event}, data: {data}")
global CURRENT_START_EXECUTION_DATA
if event == "execution_start":
global NODE_EXECUTION_TIMES
NODE_EXECUTION_TIMES = {} # Reset execution times at start
CURRENT_START_EXECUTION_DATA = dict(
start_perf_time=time.perf_counter(),
nodes_start_perf_time={},
nodes_start_vram={},
)
origin_func(self, event=event, data=data, sid=sid)
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
if data.get("node") is not None:
node_id = data.get("node")
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
time.perf_counter()
)
reset_peak_memory_record()
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
get_peak_memory()
)
server.PromptServer.send_sync = swizzle_send_sync
send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None):
# logger.info("INTERNAL:", event, data, sid)
prompt_id = data.get("prompt_id")
@ -1254,15 +1279,6 @@ async def send_json_override(self, event, data, sid=None):
if prompt_id in prompt_metadata:
prompt_metadata[prompt_id].start_time = time.perf_counter()
global CURRENT_START_EXECUTION_DATA
global NODE_EXECUTION_TIMES
NODE_EXECUTION_TIMES = {} # Reset execution times at start
CURRENT_START_EXECUTION_DATA = dict(
start_perf_time=time.perf_counter(),
nodes_start_perf_time={},
nodes_start_vram={},
)
await update_run(prompt_id, Status.RUNNING)
@ -1305,17 +1321,6 @@ async def send_json_override(self, event, data, sid=None):
logger.info(format_table(headers, table_data))
# print("========================\n")
else:
node_id = data.get("node")
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
time.perf_counter()
)
reset_peak_memory_record()
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
get_peak_memory()
)
# the last executing event is none, then the workflow is finished
if event == "executing" and data.get("node") is None:
mark_prompt_done(prompt_id=prompt_id)