fix
This commit is contained in:
parent
1d63b21643
commit
40ec37e58f
@ -1224,11 +1224,36 @@ def format_table(headers, data):
|
||||
result.append(separator)
|
||||
return '\n'.join(result)
|
||||
|
||||
origin_func = server.PromptServer.send_sync
|
||||
def swizzle_send_sync(self, event, data, sid=None):
|
||||
# print(f"swizzle_send_sync, event: {event}, data: {data}")
|
||||
global CURRENT_START_EXECUTION_DATA
|
||||
if event == "execution_start":
|
||||
global NODE_EXECUTION_TIMES
|
||||
NODE_EXECUTION_TIMES = {} # Reset execution times at start
|
||||
CURRENT_START_EXECUTION_DATA = dict(
|
||||
start_perf_time=time.perf_counter(),
|
||||
nodes_start_perf_time={},
|
||||
nodes_start_vram={},
|
||||
)
|
||||
|
||||
origin_func(self, event=event, data=data, sid=sid)
|
||||
|
||||
if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
|
||||
if data.get("node") is not None:
|
||||
node_id = data.get("node")
|
||||
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
|
||||
time.perf_counter()
|
||||
)
|
||||
reset_peak_memory_record()
|
||||
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
|
||||
get_peak_memory()
|
||||
)
|
||||
|
||||
server.PromptServer.send_sync = swizzle_send_sync
|
||||
|
||||
send_json = prompt_server.send_json
|
||||
|
||||
|
||||
async def send_json_override(self, event, data, sid=None):
|
||||
# logger.info("INTERNAL:", event, data, sid)
|
||||
prompt_id = data.get("prompt_id")
|
||||
@ -1254,15 +1279,6 @@ async def send_json_override(self, event, data, sid=None):
|
||||
if prompt_id in prompt_metadata:
|
||||
prompt_metadata[prompt_id].start_time = time.perf_counter()
|
||||
|
||||
global CURRENT_START_EXECUTION_DATA
|
||||
global NODE_EXECUTION_TIMES
|
||||
NODE_EXECUTION_TIMES = {} # Reset execution times at start
|
||||
CURRENT_START_EXECUTION_DATA = dict(
|
||||
start_perf_time=time.perf_counter(),
|
||||
nodes_start_perf_time={},
|
||||
nodes_start_vram={},
|
||||
)
|
||||
|
||||
await update_run(prompt_id, Status.RUNNING)
|
||||
|
||||
|
||||
@ -1304,17 +1320,6 @@ async def send_json_override(self, event, data, sid=None):
|
||||
logger.info("Printing Node Execution Times")
|
||||
logger.info(format_table(headers, table_data))
|
||||
# print("========================\n")
|
||||
|
||||
|
||||
else:
|
||||
node_id = data.get("node")
|
||||
CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
|
||||
time.perf_counter()
|
||||
)
|
||||
reset_peak_memory_record()
|
||||
CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
|
||||
get_peak_memory()
|
||||
)
|
||||
|
||||
# the last executing event is none, then the workflow is finished
|
||||
if event == "executing" and data.get("node") is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user