feat(plugin): load workflow from ws url params

This commit is contained in:
bennykok 2024-02-24 13:28:56 -08:00
parent 45d37879c2
commit ec620dbc53

View File

@ -366,6 +366,29 @@ async def websocket_handler(request):
sid = uuid.uuid4().hex sid = uuid.uuid4().hex
sockets[sid] = ws sockets[sid] = ws
auth_token = request.rel_url.query.get('token', '')
get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', '')
async with aiohttp.ClientSession() as session:
headers = {'Authorization': f'Bearer {auth_token}'}
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
if response.status == 200:
workflow = await response.json()
print("Loaded workflow version ",workflow["version"])
streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"],
auth_token=auth_token,
inputs={}
)
# await send("workflow_api", workflow_api, sid)
else:
error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid)
try: try:
# Send initial state to the new client # Send initial state to the new client
@ -380,33 +403,7 @@ async def websocket_handler(request):
data = json.loads(msg.data) data = json.loads(msg.data)
print(data) print(data)
event_type = data.get('event') event_type = data.get('event')
if event_type == 'workflow_endpoint': if event_type == 'input':
_data = data.get('data')
get_workflow_endpoint_url = _data.get('get_workflow_endpoint_url')
auth_token = _data.get('auth_token')
async with aiohttp.ClientSession() as session:
headers = {'Authorization': f'Bearer {auth_token}'}
async with session.get(get_workflow_endpoint_url, headers=headers) as response:
if response.status == 200:
workflow = await response.json()
print(workflow["version"])
streaming_prompt_metadata[sid] = StreamingPrompt(
workflow_api=workflow["workflow_api"],
auth_token=auth_token,
inputs={}
)
# await send("workflow_api", workflow_api, sid)
else:
error_message = await response.text()
print(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
# await send("error", {"message": error_message}, sid)
pass
elif event_type == 'input':
print("Got input: ", data.get("inputs")) print("Got input: ", data.get("inputs"))
input = data.get('inputs') input = data.get('inputs')
streaming_prompt_metadata[sid].inputs.update(input) streaming_prompt_metadata[sid].inputs.update(input)