Merge branch 'benny/auth_token' into public-main
This commit is contained in:
		
						commit
						5554c95f44
					
				@ -59,7 +59,7 @@ print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multipl
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
async def async_request_with_retry(method, url, disable_timeout=False, **kwargs):
 | 
			
		||||
async def async_request_with_retry(method, url, disable_timeout=False, token=None, **kwargs):
 | 
			
		||||
    global client_session
 | 
			
		||||
    await ensure_client_session()
 | 
			
		||||
    retry_delay = 1  # Start with 1 second delay
 | 
			
		||||
@ -72,6 +72,11 @@ async def async_request_with_retry(method, url, disable_timeout=False, **kwargs)
 | 
			
		||||
                timeout = ClientTimeout(total=None, connect=initial_timeout)
 | 
			
		||||
                kwargs['timeout'] = timeout
 | 
			
		||||
 | 
			
		||||
            if token is not None:
 | 
			
		||||
                if 'headers' not in kwargs:
 | 
			
		||||
                    kwargs['headers'] = {}
 | 
			
		||||
                kwargs['headers']['Authorization'] = f"Bearer {token}"
 | 
			
		||||
 | 
			
		||||
            request_start = time.time()
 | 
			
		||||
            async with client_session.request(method, url, **kwargs) as response:
 | 
			
		||||
                request_end = time.time()
 | 
			
		||||
@ -361,6 +366,14 @@ def send_prompt(sid: str, inputs: StreamingPrompt):
 | 
			
		||||
 | 
			
		||||
@server.PromptServer.instance.routes.post("/comfyui-deploy/run")
 | 
			
		||||
async def comfy_deploy_run(request):
 | 
			
		||||
    # Extract the bearer token from the Authorization header
 | 
			
		||||
    auth_header = request.headers.get('Authorization')
 | 
			
		||||
    token = None
 | 
			
		||||
    if auth_header:
 | 
			
		||||
        parts = auth_header.split()
 | 
			
		||||
        if len(parts) == 2 and parts[0].lower() == 'bearer':
 | 
			
		||||
            token = parts[1]
 | 
			
		||||
 | 
			
		||||
    data = await request.json()
 | 
			
		||||
 | 
			
		||||
    # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
 | 
			
		||||
@ -376,13 +389,14 @@ async def comfy_deploy_run(request):
 | 
			
		||||
    prompt = {
 | 
			
		||||
        "prompt": workflow_api,
 | 
			
		||||
        "client_id": "comfy_deploy_instance", #api.client_id
 | 
			
		||||
        "prompt_id": prompt_id
 | 
			
		||||
        "prompt_id": prompt_id,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    prompt_metadata[prompt_id] = SimplePrompt(
 | 
			
		||||
        status_endpoint=data.get('status_endpoint'),
 | 
			
		||||
        file_upload_endpoint=data.get('file_upload_endpoint'),
 | 
			
		||||
        workflow_api=workflow_api
 | 
			
		||||
        workflow_api=workflow_api,
 | 
			
		||||
        token=token
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -420,7 +434,7 @@ async def comfy_deploy_run(request):
 | 
			
		||||
 | 
			
		||||
    return web.json_response(res, status=status)
 | 
			
		||||
 | 
			
		||||
async def stream_prompt(data):
 | 
			
		||||
async def stream_prompt(data, token):
 | 
			
		||||
    # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky
 | 
			
		||||
    workflow_api = data.get("workflow_api_raw")
 | 
			
		||||
    # The prompt id generated from comfy deploy, can be None
 | 
			
		||||
@ -440,7 +454,8 @@ async def stream_prompt(data):
 | 
			
		||||
    prompt_metadata[prompt_id] = SimplePrompt(
 | 
			
		||||
        status_endpoint=data.get('status_endpoint'),
 | 
			
		||||
        file_upload_endpoint=data.get('file_upload_endpoint'),
 | 
			
		||||
        workflow_api=workflow_api
 | 
			
		||||
        workflow_api=workflow_api,
 | 
			
		||||
        token=token
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # log('info', "Begin prompt", prompt=prompt)
 | 
			
		||||
@ -489,6 +504,14 @@ comfy_message_queues: Dict[str, asyncio.Queue] = {}
 | 
			
		||||
async def stream_response(request):
 | 
			
		||||
    response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'})
 | 
			
		||||
    await response.prepare(request)
 | 
			
		||||
    
 | 
			
		||||
    # Extract the bearer token from the Authorization header
 | 
			
		||||
    auth_header = request.headers.get('Authorization')
 | 
			
		||||
    token = None
 | 
			
		||||
    if auth_header:
 | 
			
		||||
        parts = auth_header.split()
 | 
			
		||||
        if len(parts) == 2 and parts[0].lower() == 'bearer':
 | 
			
		||||
            token = parts[1]
 | 
			
		||||
 | 
			
		||||
    pending = True
 | 
			
		||||
    data = await request.json()
 | 
			
		||||
@ -500,7 +523,7 @@ async def stream_response(request):
 | 
			
		||||
        log('info', 'Streaming prompt')
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            result = await stream_prompt(data=data)
 | 
			
		||||
            result = await stream_prompt(data=data, token=token)
 | 
			
		||||
            await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8'))
 | 
			
		||||
            # await response.write(.encode('utf-8'))
 | 
			
		||||
            await response.drain()  # Ensure the buffer is flushed
 | 
			
		||||
@ -1045,6 +1068,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
			
		||||
    token = prompt_metadata[prompt_id].token
 | 
			
		||||
 | 
			
		||||
    if (status_endpoint is None):
 | 
			
		||||
        return
 | 
			
		||||
@ -1068,7 +1092,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    # requests.post(status_endpoint, json=body)
 | 
			
		||||
    await async_request_with_retry('POST', status_endpoint, json=body)
 | 
			
		||||
    await async_request_with_retry('POST', status_endpoint, token=token, json=body)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def update_run(prompt_id: str, status: Status):
 | 
			
		||||
@ -1099,7 +1123,8 @@ async def update_run(prompt_id: str, status: Status):
 | 
			
		||||
        try:
 | 
			
		||||
            # requests.post(status_endpoint, json=body)
 | 
			
		||||
            if (status_endpoint is not None):
 | 
			
		||||
                await async_request_with_retry('POST', status_endpoint, json=body)
 | 
			
		||||
                token = prompt_metadata[prompt_id].token
 | 
			
		||||
                await async_request_with_retry('POST', status_endpoint, token=token, json=body)
 | 
			
		||||
 | 
			
		||||
            if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
 | 
			
		||||
                try:
 | 
			
		||||
@ -1129,7 +1154,7 @@ async def update_run(prompt_id: str, status: Status):
 | 
			
		||||
                            ]
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        await async_request_with_retry('POST', status_endpoint, json=body)
 | 
			
		||||
                        await async_request_with_retry('POST', status_endpoint, token=token, json=body)
 | 
			
		||||
                        # requests.post(status_endpoint, json=body)
 | 
			
		||||
                except Exception as log_error:
 | 
			
		||||
                    logger.info(f"Error reading log file: {log_error}")
 | 
			
		||||
@ -1180,7 +1205,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
    logger.info(f"Uploading file {file}")
 | 
			
		||||
 | 
			
		||||
    file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
 | 
			
		||||
 | 
			
		||||
    token = prompt_metadata[prompt_id].token
 | 
			
		||||
    filename = quote(filename)
 | 
			
		||||
    prompt_id = quote(prompt_id)
 | 
			
		||||
    content_type = quote(content_type)
 | 
			
		||||
@ -1192,7 +1217,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
 | 
			
		||||
        start_time = time.time()  # Start timing here
 | 
			
		||||
        logger.info(f"Target URL: {target_url}")
 | 
			
		||||
        result = await async_request_with_retry("GET", target_url, disable_timeout=True)
 | 
			
		||||
        result = await async_request_with_retry("GET", target_url, disable_timeout=True, token=token)
 | 
			
		||||
        end_time = time.time()  # End timing after the request is complete
 | 
			
		||||
        logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
 | 
			
		||||
        ok = await result.json()
 | 
			
		||||
@ -1219,6 +1244,8 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
            if file_download_url is not None:
 | 
			
		||||
                item["url"] = file_download_url
 | 
			
		||||
            item["upload_duration"] = end_time - start_time
 | 
			
		||||
            if ok.get("is_public") is not None:
 | 
			
		||||
                item["is_public"] = ok.get("is_public")
 | 
			
		||||
 | 
			
		||||
def have_pending_upload(prompt_id):
 | 
			
		||||
    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
			
		||||
@ -1352,6 +1379,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
 | 
			
		||||
        await asyncio.gather(*upload_tasks)
 | 
			
		||||
 | 
			
		||||
        status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
			
		||||
        token = prompt_metadata[prompt_id].token
 | 
			
		||||
        if have_upload:
 | 
			
		||||
            if status_endpoint is not None:
 | 
			
		||||
                body = {
 | 
			
		||||
@ -1360,7 +1388,7 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
 | 
			
		||||
                    "node_meta": node_meta,
 | 
			
		||||
                }
 | 
			
		||||
                # pprint(body)
 | 
			
		||||
                await async_request_with_retry('POST', status_endpoint, json=body)
 | 
			
		||||
                await async_request_with_retry('POST', status_endpoint, token=token, json=body)
 | 
			
		||||
            await update_file_status(prompt_id, data, False, node_id=node_id)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        await handle_error(prompt_id, data, e)
 | 
			
		||||
@ -1398,7 +1426,8 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None):
 | 
			
		||||
            await handle_error(prompt_id, data, e)
 | 
			
		||||
    # requests.post(status_endpoint, json=body)
 | 
			
		||||
    elif status_endpoint is not None:
 | 
			
		||||
        await async_request_with_retry('POST', status_endpoint, json=body)
 | 
			
		||||
        token = prompt_metadata[prompt_id].token
 | 
			
		||||
        await async_request_with_retry('POST', status_endpoint, token=token, json=body)
 | 
			
		||||
 | 
			
		||||
    await send('outputs_uploaded', {
 | 
			
		||||
        "prompt_id": prompt_id
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,8 @@ class SimplePrompt(BaseModel):
 | 
			
		||||
    status_endpoint: Optional[str]
 | 
			
		||||
    file_upload_endpoint: Optional[str]
 | 
			
		||||
    
 | 
			
		||||
    token: Optional[str]
 | 
			
		||||
    
 | 
			
		||||
    workflow_api: dict
 | 
			
		||||
    status: Status = Status.NOT_STARTED
 | 
			
		||||
    progress: set = set()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user