diff --git a/custom_routes.py b/custom_routes.py index 9af7b80..4f4ff7c 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -59,7 +59,7 @@ print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multipl import time -async def async_request_with_retry(method, url, disable_timeout=False, **kwargs): +async def async_request_with_retry(method, url, disable_timeout=False, token=None, **kwargs): global client_session await ensure_client_session() retry_delay = 1 # Start with 1 second delay @@ -72,6 +72,11 @@ async def async_request_with_retry(method, url, disable_timeout=False, **kwargs) timeout = ClientTimeout(total=None, connect=initial_timeout) kwargs['timeout'] = timeout + if token is not None: + if 'headers' not in kwargs: + kwargs['headers'] = {} + kwargs['headers']['Authorization'] = f"Bearer {token}" + request_start = time.time() async with client_session.request(method, url, **kwargs) as response: request_end = time.time() @@ -361,6 +366,14 @@ def send_prompt(sid: str, inputs: StreamingPrompt): @server.PromptServer.instance.routes.post("/comfyui-deploy/run") async def comfy_deploy_run(request): + # Extract the bearer token from the Authorization header + auth_header = request.headers.get('Authorization') + token = None + if auth_header: + parts = auth_header.split() + if len(parts) == 2 and parts[0].lower() == 'bearer': + token = parts[1] + data = await request.json() # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky @@ -376,13 +389,14 @@ async def comfy_deploy_run(request): prompt = { "prompt": workflow_api, "client_id": "comfy_deploy_instance", #api.client_id - "prompt_id": prompt_id + "prompt_id": prompt_id, } prompt_metadata[prompt_id] = SimplePrompt( status_endpoint=data.get('status_endpoint'), file_upload_endpoint=data.get('file_upload_endpoint'), - workflow_api=workflow_api + workflow_api=workflow_api, + token=token ) try: @@ -420,7 +434,7 @@ async def comfy_deploy_run(request): return web.json_response(res, status=status) -async def stream_prompt(data): +async def stream_prompt(data, token): # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky workflow_api = data.get("workflow_api_raw") # The prompt id generated from comfy deploy, can be None @@ -440,7 +454,8 @@ async def stream_prompt(data): prompt_metadata[prompt_id] = SimplePrompt( status_endpoint=data.get('status_endpoint'), file_upload_endpoint=data.get('file_upload_endpoint'), - workflow_api=workflow_api + workflow_api=workflow_api, + token=token ) # log('info', "Begin prompt", prompt=prompt) @@ -489,6 +504,14 @@ comfy_message_queues: Dict[str, asyncio.Queue] = {} async def stream_response(request): response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'}) await response.prepare(request) + + # Extract the bearer token from the Authorization header + auth_header = request.headers.get('Authorization') + token = None + if auth_header: + parts = auth_header.split() + if len(parts) == 2 and parts[0].lower() == 'bearer': + token = parts[1] pending = True data = await request.json() @@ -500,7 +523,7 @@ async def stream_response(request): log('info', 'Streaming prompt') try: - result = await stream_prompt(data=data) + result = await stream_prompt(data=data, token=token) await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8')) # await response.write(.encode('utf-8')) await response.drain() # Ensure the buffer is flushed @@ -1045,6 +1068,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl return status_endpoint = prompt_metadata[prompt_id].status_endpoint + token = prompt_metadata[prompt_id].token if (status_endpoint is None): return @@ -1068,7 +1092,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl }) # requests.post(status_endpoint, json=body) - await async_request_with_retry('POST', status_endpoint, json=body) + await async_request_with_retry('POST', status_endpoint, token=token, json=body) async def update_run(prompt_id: str, status: Status): @@ -1099,7 +1123,8 @@ async def update_run(prompt_id: str, status: Status): try: # requests.post(status_endpoint, json=body) if (status_endpoint is not None): - await async_request_with_retry('POST', status_endpoint, json=body) + token = prompt_metadata[prompt_id].token + await async_request_with_retry('POST', status_endpoint, token=token, json=body) if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED): try: @@ -1129,7 +1154,7 @@ async def update_run(prompt_id: str, status: Status): ] } - await async_request_with_retry('POST', status_endpoint, json=body) + await async_request_with_retry('POST', status_endpoint, token=token, json=body) # requests.post(status_endpoint, json=body) except Exception as log_error: logger.info(f"Error reading log file: {log_error}") @@ -1180,7 +1205,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p logger.info(f"Uploading file {file}") file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint - + token = prompt_metadata[prompt_id].token filename = quote(filename) prompt_id = quote(prompt_id) content_type = quote(content_type) @@ -1192,7 +1217,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p start_time = time.time() # Start timing here logger.info(f"Target URL: {target_url}") - result = await async_request_with_retry("GET", target_url, disable_timeout=True) + result = await async_request_with_retry("GET", target_url, disable_timeout=True, token=token) end_time = time.time() # End timing after the request is complete logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time)) ok = await result.json() @@ -1352,6 +1377,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T await asyncio.gather(*upload_tasks) status_endpoint = prompt_metadata[prompt_id].status_endpoint + token = prompt_metadata[prompt_id].token if have_upload: if status_endpoint is not None: body = { @@ -1360,7 +1386,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T "node_meta": node_meta, } # pprint(body) - await async_request_with_retry('POST', status_endpoint, json=body) + await async_request_with_retry('POST', status_endpoint, token=token, json=body) await update_file_status(prompt_id, data, False, node_id=node_id) except Exception as e: await handle_error(prompt_id, data, e) @@ -1398,7 +1424,8 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None): await handle_error(prompt_id, data, e) # requests.post(status_endpoint, json=body) elif status_endpoint is not None: - await async_request_with_retry('POST', status_endpoint, json=body) + token = prompt_metadata[prompt_id].token + await async_request_with_retry('POST', status_endpoint, token=token, json=body) await send('outputs_uploaded', { "prompt_id": prompt_id diff --git a/globals.py b/globals.py index fd700a1..3ee2658 100644 --- a/globals.py +++ b/globals.py @@ -29,6 +29,8 @@ class SimplePrompt(BaseModel): status_endpoint: Optional[str] file_upload_endpoint: Optional[str] + token: Optional[str] + workflow_api: dict status: Status = Status.NOT_STARTED progress: set = set()