Merge branch 'benny/auth_token' into public-main

This commit is contained in:
bennykok 2024-09-12 14:14:16 -07:00
commit 5554c95f44
2 changed files with 44 additions and 13 deletions

View File

@ -59,7 +59,7 @@ print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multipl
import time
async def async_request_with_retry(method, url, disable_timeout=False, **kwargs):
async def async_request_with_retry(method, url, disable_timeout=False, token=None, **kwargs):
global client_session
await ensure_client_session()
retry_delay = 1 # Start with 1 second delay
@ -72,6 +72,11 @@ async def async_request_with_retry(method, url, disable_timeout=False, **kwargs)
timeout = ClientTimeout(total=None, connect=initial_timeout)
kwargs['timeout'] = timeout
if token is not None:
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['Authorization'] = f"Bearer {token}"
request_start = time.time()
async with client_session.request(method, url, **kwargs) as response:
request_end = time.time()
@ -361,6 +366,14 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
async def comfy_deploy_run(request):
# Extract the bearer token from the Authorization header
auth_header = request.headers.get('Authorization')
token = None
if auth_header:
parts = auth_header.split()
if len(parts) == 2 and parts[0].lower() == 'bearer':
token = parts[1]
data = await request.json()
# In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
@ -376,13 +389,14 @@ async def comfy_deploy_run(request):
prompt = {
"prompt": workflow_api,
"client_id": "comfy_deploy_instance", #api.client_id
"prompt_id": prompt_id
"prompt_id": prompt_id,
}
prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=data.get('status_endpoint'),
file_upload_endpoint=data.get('file_upload_endpoint'),
workflow_api=workflow_api
workflow_api=workflow_api,
token=token
)
try:
@ -420,7 +434,7 @@ async def comfy_deploy_run(request):
return web.json_response(res, status=status)
async def stream_prompt(data):
async def stream_prompt(data, token):
# In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
workflow_api = data.get("workflow_api_raw")
# The prompt id generated from comfy deploy, can be None
@ -440,7 +454,8 @@ async def stream_prompt(data):
prompt_metadata[prompt_id] = SimplePrompt(
status_endpoint=data.get('status_endpoint'),
file_upload_endpoint=data.get('file_upload_endpoint'),
workflow_api=workflow_api
workflow_api=workflow_api,
token=token
)
# log('info', "Begin prompt", prompt=prompt)
@ -489,6 +504,14 @@ comfy_message_queues: Dict[str, asyncio.Queue] = {}
async def stream_response(request):
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'})
await response.prepare(request)
# Extract the bearer token from the Authorization header
auth_header = request.headers.get('Authorization')
token = None
if auth_header:
parts = auth_header.split()
if len(parts) == 2 and parts[0].lower() == 'bearer':
token = parts[1]
pending = True
data = await request.json()
@ -500,7 +523,7 @@ async def stream_response(request):
log('info', 'Streaming prompt')
try:
result = await stream_prompt(data=data)
result = await stream_prompt(data=data, token=token)
await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8'))
# await response.write(.encode('utf-8'))
await response.drain() # Ensure the buffer is flushed
@ -1045,6 +1068,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
return
status_endpoint = prompt_metadata[prompt_id].status_endpoint
token = prompt_metadata[prompt_id].token
if (status_endpoint is None):
return
@ -1068,7 +1092,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
})
# requests.post(status_endpoint, json=body)
await async_request_with_retry('POST', status_endpoint, json=body)
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
async def update_run(prompt_id: str, status: Status):
@ -1099,7 +1123,8 @@ async def update_run(prompt_id: str, status: Status):
try:
# requests.post(status_endpoint, json=body)
if (status_endpoint is not None):
await async_request_with_retry('POST', status_endpoint, json=body)
token = prompt_metadata[prompt_id].token
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
try:
@ -1129,7 +1154,7 @@ async def update_run(prompt_id: str, status: Status):
]
}
await async_request_with_retry('POST', status_endpoint, json=body)
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
# requests.post(status_endpoint, json=body)
except Exception as log_error:
logger.info(f"Error reading log file: {log_error}")
@ -1180,7 +1205,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
logger.info(f"Uploading file {file}")
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
token = prompt_metadata[prompt_id].token
filename = quote(filename)
prompt_id = quote(prompt_id)
content_type = quote(content_type)
@ -1192,7 +1217,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
start_time = time.time() # Start timing here
logger.info(f"Target URL: {target_url}")
result = await async_request_with_retry("GET", target_url, disable_timeout=True)
result = await async_request_with_retry("GET", target_url, disable_timeout=True, token=token)
end_time = time.time() # End timing after the request is complete
logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
ok = await result.json()
@ -1219,6 +1244,8 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
if file_download_url is not None:
item["url"] = file_download_url
item["upload_duration"] = end_time - start_time
if ok.get("is_public") is not None:
item["is_public"] = ok.get("is_public")
def have_pending_upload(prompt_id):
if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
@ -1352,6 +1379,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
await asyncio.gather(*upload_tasks)
status_endpoint = prompt_metadata[prompt_id].status_endpoint
token = prompt_metadata[prompt_id].token
if have_upload:
if status_endpoint is not None:
body = {
@ -1360,7 +1388,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
"node_meta": node_meta,
}
# pprint(body)
await async_request_with_retry('POST', status_endpoint, json=body)
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
await update_file_status(prompt_id, data, False, node_id=node_id)
except Exception as e:
await handle_error(prompt_id, data, e)
@ -1398,7 +1426,8 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None):
await handle_error(prompt_id, data, e)
# requests.post(status_endpoint, json=body)
elif status_endpoint is not None:
await async_request_with_retry('POST', status_endpoint, json=body)
token = prompt_metadata[prompt_id].token
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
await send('outputs_uploaded', {
"prompt_id": prompt_id

View File

@ -29,6 +29,8 @@ class SimplePrompt(BaseModel):
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
token: Optional[str]
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()