diff --git a/custom_routes.py b/custom_routes.py index 9c4cefb..c090488 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -54,6 +54,7 @@ class SimplePrompt(BaseModel): last_updated_node: Optional[str] = None, uploading_nodes: set = set() done: bool = False + is_realtime: bool = False, streaming_prompt_metadata: dict[str, StreamingPrompt] = {} @@ -126,14 +127,14 @@ def send_prompt(sid: str, inputs: StreamingPrompt): if 'inputs' in value: if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs: new_value = inputs.inputs[value['inputs']['input_id']] - value['inputs']["input_id"] = new_value; + value['inputs']["input_id"] = new_value # Fix for external text default value if (value["class_type"] == "ComfyUIDeployExternalText"): - value['inputs']["default_value"] = new_value; + value['inputs']["default_value"] = new_value if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): - value['inputs']["client_id"] = sid; + value['inputs']["client_id"] = sid print(workflow_api) @@ -151,7 +152,8 @@ def send_prompt(sid: str, inputs: StreamingPrompt): prompt_metadata[prompt_id] = SimplePrompt( status_endpoint=inputs.status_endpoint, file_upload_endpoint=inputs.file_upload_endpoint, - workflow_api=workflow_api + workflow_api=workflow_api, + is_realtime=True ) except Exception as e: error_type = type(e).__name__ @@ -413,7 +415,8 @@ async def websocket_handler(request): # Send initial state to the new client await send("status", { 'sid': sid }, sid) - if cd_enable_log: + # Make sure when its connected via client, the full log is not being sent + if cd_enable_log and get_workflow_endpoint_url is None: await send_first_time_log(sid) async for msg in ws: @@ -490,7 +493,7 @@ async def send_json_override(self, event, data, sid=None): # now we send everything await asyncio.wait([ - asyncio.create_task(send(event, data)), + asyncio.create_task(send(event, data, sid=sid)), asyncio.create_task(self.send_json_original(event, data, sid)) ]) @@ -549,6 +552,9 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl if prompt_id not in prompt_metadata: return + if prompt_metadata[prompt_id].is_realtime: + return + print("progress", calculated_progress) status_endpoint = prompt_metadata[prompt_id].status_endpoint @@ -568,6 +574,10 @@ def update_run(prompt_id: str, status: Status): if prompt_id not in prompt_metadata: return + + # if its realtime prompt we need to skip that. + if prompt_metadata[prompt_id].is_realtime: + return if (prompt_metadata[prompt_id].status != status): @@ -792,32 +802,37 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T await handle_error(prompt_id, data, e) async def update_run_with_output(prompt_id, data, node_id=None): - if prompt_id in prompt_metadata: - status_endpoint = prompt_metadata[prompt_id].status_endpoint + if prompt_id not in prompt_metadata: + return + + if prompt_metadata[prompt_id].is_realtime: + return + + status_endpoint = prompt_metadata[prompt_id].status_endpoint - body = { - "run_id": prompt_id, - "output_data": data - } + body = { + "run_id": prompt_id, + "output_data": data + } - try: - have_upload = 'images' in data or 'files' in data or 'gifs' in data - print("\nhave_upload", have_upload, node_id) + try: + have_upload = 'images' in data or 'files' in data or 'gifs' in data + print("\nhave_upload", have_upload, node_id) - if have_upload: - await update_file_status(prompt_id, data, True, node_id=node_id) + if have_upload: + await update_file_status(prompt_id, data, True, node_id=node_id) - asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload)) + asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload)) - except Exception as e: - await handle_error(prompt_id, data, e) - + except Exception as e: + await handle_error(prompt_id, data, e) + - requests.post(status_endpoint, json=body) + requests.post(status_endpoint, json=body) - await send('outputs_uploaded', { - "prompt_id": prompt_id - }) + await send('outputs_uploaded', { + "prompt_id": prompt_id + }) prompt_server.send_json_original = prompt_server.send_json prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)