Merge branch 'benny/auth_token' into public-main
This commit is contained in:
commit
5554c95f44
@ -59,7 +59,7 @@ print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multipl
|
||||
|
||||
import time
|
||||
|
||||
async def async_request_with_retry(method, url, disable_timeout=False, **kwargs):
|
||||
async def async_request_with_retry(method, url, disable_timeout=False, token=None, **kwargs):
|
||||
global client_session
|
||||
await ensure_client_session()
|
||||
retry_delay = 1 # Start with 1 second delay
|
||||
@ -72,6 +72,11 @@ async def async_request_with_retry(method, url, disable_timeout=False, **kwargs)
|
||||
timeout = ClientTimeout(total=None, connect=initial_timeout)
|
||||
kwargs['timeout'] = timeout
|
||||
|
||||
if token is not None:
|
||||
if 'headers' not in kwargs:
|
||||
kwargs['headers'] = {}
|
||||
kwargs['headers']['Authorization'] = f"Bearer {token}"
|
||||
|
||||
request_start = time.time()
|
||||
async with client_session.request(method, url, **kwargs) as response:
|
||||
request_end = time.time()
|
||||
@ -361,6 +366,14 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
|
||||
|
||||
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
|
||||
async def comfy_deploy_run(request):
|
||||
# Extract the bearer token from the Authorization header
|
||||
auth_header = request.headers.get('Authorization')
|
||||
token = None
|
||||
if auth_header:
|
||||
parts = auth_header.split()
|
||||
if len(parts) == 2 and parts[0].lower() == 'bearer':
|
||||
token = parts[1]
|
||||
|
||||
data = await request.json()
|
||||
|
||||
# In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
|
||||
@ -376,13 +389,14 @@ async def comfy_deploy_run(request):
|
||||
prompt = {
|
||||
"prompt": workflow_api,
|
||||
"client_id": "comfy_deploy_instance", #api.client_id
|
||||
"prompt_id": prompt_id
|
||||
"prompt_id": prompt_id,
|
||||
}
|
||||
|
||||
prompt_metadata[prompt_id] = SimplePrompt(
|
||||
status_endpoint=data.get('status_endpoint'),
|
||||
file_upload_endpoint=data.get('file_upload_endpoint'),
|
||||
workflow_api=workflow_api
|
||||
workflow_api=workflow_api,
|
||||
token=token
|
||||
)
|
||||
|
||||
try:
|
||||
@ -420,7 +434,7 @@ async def comfy_deploy_run(request):
|
||||
|
||||
return web.json_response(res, status=status)
|
||||
|
||||
async def stream_prompt(data):
|
||||
async def stream_prompt(data, token):
|
||||
# In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
|
||||
workflow_api = data.get("workflow_api_raw")
|
||||
# The prompt id generated from comfy deploy, can be None
|
||||
@ -440,7 +454,8 @@ async def stream_prompt(data):
|
||||
prompt_metadata[prompt_id] = SimplePrompt(
|
||||
status_endpoint=data.get('status_endpoint'),
|
||||
file_upload_endpoint=data.get('file_upload_endpoint'),
|
||||
workflow_api=workflow_api
|
||||
workflow_api=workflow_api,
|
||||
token=token
|
||||
)
|
||||
|
||||
# log('info', "Begin prompt", prompt=prompt)
|
||||
@ -489,6 +504,14 @@ comfy_message_queues: Dict[str, asyncio.Queue] = {}
|
||||
async def stream_response(request):
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'})
|
||||
await response.prepare(request)
|
||||
|
||||
# Extract the bearer token from the Authorization header
|
||||
auth_header = request.headers.get('Authorization')
|
||||
token = None
|
||||
if auth_header:
|
||||
parts = auth_header.split()
|
||||
if len(parts) == 2 and parts[0].lower() == 'bearer':
|
||||
token = parts[1]
|
||||
|
||||
pending = True
|
||||
data = await request.json()
|
||||
@ -500,7 +523,7 @@ async def stream_response(request):
|
||||
log('info', 'Streaming prompt')
|
||||
|
||||
try:
|
||||
result = await stream_prompt(data=data)
|
||||
result = await stream_prompt(data=data, token=token)
|
||||
await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8'))
|
||||
# await response.write(.encode('utf-8'))
|
||||
await response.drain() # Ensure the buffer is flushed
|
||||
@ -1045,6 +1068,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
|
||||
return
|
||||
|
||||
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||
token = prompt_metadata[prompt_id].token
|
||||
|
||||
if (status_endpoint is None):
|
||||
return
|
||||
@ -1068,7 +1092,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
|
||||
})
|
||||
|
||||
# requests.post(status_endpoint, json=body)
|
||||
await async_request_with_retry('POST', status_endpoint, json=body)
|
||||
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
|
||||
|
||||
|
||||
async def update_run(prompt_id: str, status: Status):
|
||||
@ -1099,7 +1123,8 @@ async def update_run(prompt_id: str, status: Status):
|
||||
try:
|
||||
# requests.post(status_endpoint, json=body)
|
||||
if (status_endpoint is not None):
|
||||
await async_request_with_retry('POST', status_endpoint, json=body)
|
||||
token = prompt_metadata[prompt_id].token
|
||||
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
|
||||
|
||||
if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
|
||||
try:
|
||||
@ -1129,7 +1154,7 @@ async def update_run(prompt_id: str, status: Status):
|
||||
]
|
||||
}
|
||||
|
||||
await async_request_with_retry('POST', status_endpoint, json=body)
|
||||
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
|
||||
# requests.post(status_endpoint, json=body)
|
||||
except Exception as log_error:
|
||||
logger.info(f"Error reading log file: {log_error}")
|
||||
@ -1180,7 +1205,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
||||
logger.info(f"Uploading file {file}")
|
||||
|
||||
file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
|
||||
|
||||
token = prompt_metadata[prompt_id].token
|
||||
filename = quote(filename)
|
||||
prompt_id = quote(prompt_id)
|
||||
content_type = quote(content_type)
|
||||
@ -1192,7 +1217,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
||||
|
||||
start_time = time.time() # Start timing here
|
||||
logger.info(f"Target URL: {target_url}")
|
||||
result = await async_request_with_retry("GET", target_url, disable_timeout=True)
|
||||
result = await async_request_with_retry("GET", target_url, disable_timeout=True, token=token)
|
||||
end_time = time.time() # End timing after the request is complete
|
||||
logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
|
||||
ok = await result.json()
|
||||
@ -1219,6 +1244,8 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
|
||||
if file_download_url is not None:
|
||||
item["url"] = file_download_url
|
||||
item["upload_duration"] = end_time - start_time
|
||||
if ok.get("is_public") is not None:
|
||||
item["is_public"] = ok.get("is_public")
|
||||
|
||||
def have_pending_upload(prompt_id):
|
||||
if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
|
||||
@ -1352,6 +1379,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
|
||||
await asyncio.gather(*upload_tasks)
|
||||
|
||||
status_endpoint = prompt_metadata[prompt_id].status_endpoint
|
||||
token = prompt_metadata[prompt_id].token
|
||||
if have_upload:
|
||||
if status_endpoint is not None:
|
||||
body = {
|
||||
@ -1360,7 +1388,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
|
||||
"node_meta": node_meta,
|
||||
}
|
||||
# pprint(body)
|
||||
await async_request_with_retry('POST', status_endpoint, json=body)
|
||||
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
|
||||
await update_file_status(prompt_id, data, False, node_id=node_id)
|
||||
except Exception as e:
|
||||
await handle_error(prompt_id, data, e)
|
||||
@ -1398,7 +1426,8 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None):
|
||||
await handle_error(prompt_id, data, e)
|
||||
# requests.post(status_endpoint, json=body)
|
||||
elif status_endpoint is not None:
|
||||
await async_request_with_retry('POST', status_endpoint, json=body)
|
||||
token = prompt_metadata[prompt_id].token
|
||||
await async_request_with_retry('POST', status_endpoint, token=token, json=body)
|
||||
|
||||
await send('outputs_uploaded', {
|
||||
"prompt_id": prompt_id
|
||||
|
@ -29,6 +29,8 @@ class SimplePrompt(BaseModel):
|
||||
status_endpoint: Optional[str]
|
||||
file_upload_endpoint: Optional[str]
|
||||
|
||||
token: Optional[str]
|
||||
|
||||
workflow_api: dict
|
||||
status: Status = Status.NOT_STARTED
|
||||
progress: set = set()
|
||||
|
Loading…
x
Reference in New Issue
Block a user