diff --git a/custom_routes.py b/custom_routes.py index a119b7c..e402c91 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -366,6 +366,29 @@ async def websocket_handler(request): sid = uuid.uuid4().hex sockets[sid] = ws + + auth_token = request.rel_url.query.get('token', '') + get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '') + + async with aiohttp.ClientSession() as session: + headers = {'Authorization': f'Bearer {auth_token}'} + async with session.get(get_workflow_endpoint_url, headers=headers) as response: + if response.status == 200: + workflow = await response.json() + + print("Loaded workflow version ",workflow["version"]) + + streaming_prompt_metadata[sid] = StreamingPrompt( + workflow_api=workflow["workflow_api"], + auth_token=auth_token, + inputs={} + ) + + # await send("workflow_api", workflow_api, sid) + else: + error_message = await response.text() + print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") + # await send("error", {"message": error_message}, sid) try: # Send initial state to the new client @@ -380,33 +403,7 @@ async def websocket_handler(request): data = json.loads(msg.data) print(data) event_type = data.get('event') - if event_type == 'workflow_endpoint': - _data = data.get('data') - get_workflow_endpoint_url = _data.get('get_workflow_endpoint_url') - - auth_token = _data.get('auth_token') - - async with aiohttp.ClientSession() as session: - headers = {'Authorization': f'Bearer {auth_token}'} - async with session.get(get_workflow_endpoint_url, headers=headers) as response: - if response.status == 200: - workflow = await response.json() - - print(workflow["version"]) - - streaming_prompt_metadata[sid] = StreamingPrompt( - workflow_api=workflow["workflow_api"], - auth_token=auth_token, - inputs={} - ) - - # await send("workflow_api", workflow_api, sid) - else: - error_message = await response.text() - print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") - # await send("error", {"message": error_message}, sid) - pass - elif event_type == 'input': + if event_type == 'input': print("Got input: ", data.get("inputs")) input = data.get('inputs') streaming_prompt_metadata[sid].inputs.update(input)