feat(plugin): update run status for ws request

This commit is contained in:
bennykok 2024-02-27 19:44:49 -08:00
parent 542b72bde5
commit 2d59fd2b1b

View File

@ -373,6 +373,16 @@ async def get_file_hash(request):
"error": str(e) "error": str(e)
}, status=500) }, status=500)
async def update_realtime_run_status(realtime_id: str, status_endpoint: str, status: Status):
body = {
"run_id": realtime_id,
"status": status.value,
}
# requests.post(status_endpoint, json=body)
async with aiohttp.ClientSession() as session:
async with session.post(status_endpoint, json=body) as response:
pass
@server.PromptServer.instance.routes.get('/comfyui-deploy/ws') @server.PromptServer.instance.routes.get('/comfyui-deploy/ws')
async def websocket_handler(request): async def websocket_handler(request):
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
@ -388,6 +398,8 @@ async def websocket_handler(request):
auth_token = request.rel_url.query.get('token', None) auth_token = request.rel_url.query.get('token', None)
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None) get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None)
realtime_id = request.rel_url.query.get('realtime_id', None)
status_endpoint = request.rel_url.query.get('status_endpoint', None)
if auth_token is not None and get_workflow_endpoint_url is not None: if auth_token is not None and get_workflow_endpoint_url is not None:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -402,10 +414,11 @@ async def websocket_handler(request):
workflow_api=workflow["workflow_api"], workflow_api=workflow["workflow_api"],
auth_token=auth_token, auth_token=auth_token,
inputs={}, inputs={},
status_endpoint=request.rel_url.query.get('status_endpoint', None), status_endpoint=status_endpoint,
file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
) )
await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING)
# await send("workflow_api", workflow_api, sid) # await send("workflow_api", workflow_api, sid)
else: else:
error_message = await response.text() error_message = await response.text()
@ -441,6 +454,9 @@ async def websocket_handler(request):
print('ws connection closed with exception %s' % ws.exception()) print('ws connection closed with exception %s' % ws.exception())
finally: finally:
sockets.pop(sid, None) sockets.pop(sid, None)
if realtime_id is not None:
await update_realtime_run_status(realtime_id, status_endpoint, Status.SUCCESS)
return ws return ws
@server.PromptServer.instance.routes.get('/comfyui-deploy/check-status') @server.PromptServer.instance.routes.get('/comfyui-deploy/check-status')