fix
This commit is contained in:
		
							parent
							
								
									1d63b21643
								
							
						
					
					
						commit
						40ec37e58f
					
				@ -1224,11 +1224,36 @@ def format_table(headers, data):
 | 
				
			|||||||
    result.append(separator)
 | 
					    result.append(separator)
 | 
				
			||||||
    return '\n'.join(result)
 | 
					    return '\n'.join(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					origin_func = server.PromptServer.send_sync
 | 
				
			||||||
 | 
					def swizzle_send_sync(self, event, data, sid=None):
 | 
				
			||||||
 | 
					    # print(f"swizzle_send_sync, event: {event}, data: {data}")
 | 
				
			||||||
 | 
					    global CURRENT_START_EXECUTION_DATA
 | 
				
			||||||
 | 
					    if event == "execution_start":
 | 
				
			||||||
 | 
					        global NODE_EXECUTION_TIMES
 | 
				
			||||||
 | 
					        NODE_EXECUTION_TIMES = {}  # Reset execution times at start
 | 
				
			||||||
 | 
					        CURRENT_START_EXECUTION_DATA = dict(
 | 
				
			||||||
 | 
					            start_perf_time=time.perf_counter(),
 | 
				
			||||||
 | 
					            nodes_start_perf_time={},
 | 
				
			||||||
 | 
					            nodes_start_vram={},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    origin_func(self, event=event, data=data, sid=sid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if event == "executing" and data and CURRENT_START_EXECUTION_DATA:
 | 
				
			||||||
 | 
					        if data.get("node") is not None:
 | 
				
			||||||
 | 
					            node_id = data.get("node")
 | 
				
			||||||
 | 
					            CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
 | 
				
			||||||
 | 
					                time.perf_counter()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            reset_peak_memory_record()
 | 
				
			||||||
 | 
					            CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
 | 
				
			||||||
 | 
					                get_peak_memory()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					server.PromptServer.send_sync = swizzle_send_sync
 | 
				
			||||||
 | 
					
 | 
				
			||||||
send_json = prompt_server.send_json
 | 
					send_json = prompt_server.send_json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
async def send_json_override(self, event, data, sid=None):
 | 
					async def send_json_override(self, event, data, sid=None):
 | 
				
			||||||
    # logger.info("INTERNAL:", event, data, sid)
 | 
					    # logger.info("INTERNAL:", event, data, sid)
 | 
				
			||||||
    prompt_id = data.get("prompt_id")
 | 
					    prompt_id = data.get("prompt_id")
 | 
				
			||||||
@ -1254,15 +1279,6 @@ async def send_json_override(self, event, data, sid=None):
 | 
				
			|||||||
        if prompt_id in prompt_metadata:
 | 
					        if prompt_id in prompt_metadata:
 | 
				
			||||||
            prompt_metadata[prompt_id].start_time = time.perf_counter()
 | 
					            prompt_metadata[prompt_id].start_time = time.perf_counter()
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        global CURRENT_START_EXECUTION_DATA
 | 
					 | 
				
			||||||
        global NODE_EXECUTION_TIMES
 | 
					 | 
				
			||||||
        NODE_EXECUTION_TIMES = {}  # Reset execution times at start
 | 
					 | 
				
			||||||
        CURRENT_START_EXECUTION_DATA = dict(
 | 
					 | 
				
			||||||
            start_perf_time=time.perf_counter(),
 | 
					 | 
				
			||||||
            nodes_start_perf_time={},
 | 
					 | 
				
			||||||
            nodes_start_vram={},
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        await update_run(prompt_id, Status.RUNNING)
 | 
					        await update_run(prompt_id, Status.RUNNING)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
@ -1305,17 +1321,6 @@ async def send_json_override(self, event, data, sid=None):
 | 
				
			|||||||
            logger.info(format_table(headers, table_data))
 | 
					            logger.info(format_table(headers, table_data))
 | 
				
			||||||
            # print("========================\n")
 | 
					            # print("========================\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            node_id = data.get("node")
 | 
					 | 
				
			||||||
            CURRENT_START_EXECUTION_DATA["nodes_start_perf_time"][node_id] = (
 | 
					 | 
				
			||||||
                time.perf_counter()
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            reset_peak_memory_record()
 | 
					 | 
				
			||||||
            CURRENT_START_EXECUTION_DATA["nodes_start_vram"][node_id] = (
 | 
					 | 
				
			||||||
                get_peak_memory()
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # the last executing event is none, then the workflow is finished
 | 
					    # the last executing event is none, then the workflow is finished
 | 
				
			||||||
    if event == "executing" and data.get("node") is None:
 | 
					    if event == "executing" and data.get("node") is None:
 | 
				
			||||||
        mark_prompt_done(prompt_id=prompt_id)
 | 
					        mark_prompt_done(prompt_id=prompt_id)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user