fix(plugin): remove file upload + status update from is_realtime prompt

This commit is contained in:
bennykok 2024-02-25 17:25:02 -08:00
parent cc31840d41
commit a643fa0999

View File

@ -54,6 +54,7 @@ class SimplePrompt(BaseModel):
last_updated_node: Optional[str] = None, last_updated_node: Optional[str] = None,
uploading_nodes: set = set() uploading_nodes: set = set()
done: bool = False done: bool = False
is_realtime: bool = False,
streaming_prompt_metadata: dict[str, StreamingPrompt] = {} streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
@ -126,14 +127,14 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
if 'inputs' in value: if 'inputs' in value:
if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs: if "input_id" in value['inputs'] and value['inputs']['input_id'] in inputs.inputs:
new_value = inputs.inputs[value['inputs']['input_id']] new_value = inputs.inputs[value['inputs']['input_id']]
value['inputs']["input_id"] = new_value; value['inputs']["input_id"] = new_value
# Fix for external text default value # Fix for external text default value
if (value["class_type"] == "ComfyUIDeployExternalText"): if (value["class_type"] == "ComfyUIDeployExternalText"):
value['inputs']["default_value"] = new_value; value['inputs']["default_value"] = new_value
if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): if (value["class_type"] == "ComfyDeployWebscoketImageOutput"):
value['inputs']["client_id"] = sid; value['inputs']["client_id"] = sid
print(workflow_api) print(workflow_api)
@ -151,7 +152,8 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
prompt_metadata[prompt_id] = SimplePrompt( prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=inputs.status_endpoint, status_endpoint=inputs.status_endpoint,
file_upload_endpoint=inputs.file_upload_endpoint, file_upload_endpoint=inputs.file_upload_endpoint,
workflow_api=workflow_api workflow_api=workflow_api,
is_realtime=True
) )
except Exception as e: except Exception as e:
error_type = type(e).__name__ error_type = type(e).__name__
@ -413,7 +415,8 @@ async def websocket_handler(request):
# Send initial state to the new client # Send initial state to the new client
await send("status", { 'sid': sid }, sid) await send("status", { 'sid': sid }, sid)
if cd_enable_log: # Make sure when its connected via client, the full log is not being sent
if cd_enable_log and get_workflow_endpoint_url is None:
await send_first_time_log(sid) await send_first_time_log(sid)
async for msg in ws: async for msg in ws:
@ -490,7 +493,7 @@ async def send_json_override(self, event, data, sid=None):
# now we send everything # now we send everything
await asyncio.wait([ await asyncio.wait([
asyncio.create_task(send(event, data)), asyncio.create_task(send(event, data, sid=sid)),
asyncio.create_task(self.send_json_original(event, data, sid)) asyncio.create_task(self.send_json_original(event, data, sid))
]) ])
@ -549,6 +552,9 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
if prompt_id not in prompt_metadata: if prompt_id not in prompt_metadata:
return return
if prompt_metadata[prompt_id].is_realtime:
return
print("progress", calculated_progress) print("progress", calculated_progress)
status_endpoint = prompt_metadata[prompt_id].status_endpoint status_endpoint = prompt_metadata[prompt_id].status_endpoint
@ -568,6 +574,10 @@ def update_run(prompt_id: str, status: Status):
if prompt_id not in prompt_metadata: if prompt_id not in prompt_metadata:
return return
# if its realtime prompt we need to skip that.
if prompt_metadata[prompt_id].is_realtime:
return
if (prompt_metadata[prompt_id].status != status): if (prompt_metadata[prompt_id].status != status):
@ -792,32 +802,37 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
await handle_error(prompt_id, data, e) await handle_error(prompt_id, data, e)
async def update_run_with_output(prompt_id, data, node_id=None): async def update_run_with_output(prompt_id, data, node_id=None):
if prompt_id in prompt_metadata: if prompt_id not in prompt_metadata:
status_endpoint = prompt_metadata[prompt_id].status_endpoint return
if prompt_metadata[prompt_id].is_realtime:
return
status_endpoint = prompt_metadata[prompt_id].status_endpoint
body = { body = {
"run_id": prompt_id, "run_id": prompt_id,
"output_data": data "output_data": data
} }
try: try:
have_upload = 'images' in data or 'files' in data or 'gifs' in data have_upload = 'images' in data or 'files' in data or 'gifs' in data
print("\nhave_upload", have_upload, node_id) print("\nhave_upload", have_upload, node_id)
if have_upload: if have_upload:
await update_file_status(prompt_id, data, True, node_id=node_id) await update_file_status(prompt_id, data, True, node_id=node_id)
asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload)) asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload))
except Exception as e: except Exception as e:
await handle_error(prompt_id, data, e) await handle_error(prompt_id, data, e)
requests.post(status_endpoint, json=body) requests.post(status_endpoint, json=body)
await send('outputs_uploaded', { await send('outputs_uploaded', {
"prompt_id": prompt_id "prompt_id": prompt_id
}) })
prompt_server.send_json_original = prompt_server.send_json prompt_server.send_json_original = prompt_server.send_json
prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer) prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)