From 8e12803ea1d49375c358b6eeafc5cc16a6416197 Mon Sep 17 00:00:00 2001 From: BennyKok Date: Thu, 1 Aug 2024 20:43:21 -0700 Subject: [PATCH] Retry logic when calling api (#57) * fix: retry logic, bypass logfire, clean up log * fix: max_retries and retry_delay_multiplier, do not throw when pass the retry failed --- custom_routes.py | 249 ++++++++++++++++++++++++++++++----------------- requirements.txt | 2 +- 2 files changed, 160 insertions(+), 91 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index 915a5fc..99207f8 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -22,17 +22,97 @@ from typing import Dict, List, Union, Any, Optional from PIL import Image import copy import struct +from aiohttp import ClientError +import atexit + +# Global session +client_session = None + +# def create_client_session(): +# global client_session +# if client_session is None: +# client_session = aiohttp.ClientSession() + +async def ensure_client_session(): + global client_session + if client_session is None: + client_session = aiohttp.ClientSession() + +async def cleanup(): + global client_session + if client_session: + await client_session.close() + +def exit_handler(): + print("Exiting the application. Initiating cleanup...") + loop = asyncio.get_event_loop() + loop.run_until_complete(cleanup()) + +atexit.register(exit_handler) + +max_retries = int(os.environ.get('MAX_RETRIES', '3')) +retry_delay_multiplier = float(os.environ.get('RETRY_DELAY_MULTIPLIER', '2')) + +print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}") + +async def async_request_with_retry(method, url, **kwargs): + global client_session + await ensure_client_session() + retry_delay = 1 # Start with 1 second delay + + for attempt in range(max_retries): + try: + async with client_session.request(method, url, **kwargs) as response: + response.raise_for_status() + return response + except ClientError as e: + if attempt == max_retries - 1: + logger.error(f"Request failed after {max_retries} attempts: {e}") + # raise + logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries}): {e}") + await asyncio.sleep(retry_delay) + retry_delay *= retry_delay_multiplier # Exponential backoff from logging import basicConfig, getLogger -import logfire -# if os.environ.get('LOGFIRE_TOKEN', None) is not None: -logfire.configure( - send_to_logfire="if-token-present" -) -# basicConfig(handlers=[logfire.LogfireLoggingHandler()]) -logfire_handler = logfire.LogfireLoggingHandler() -logger = getLogger("comfy-deploy") -logger.addHandler(logfire_handler) + +# Check for an environment variable to enable/disable Logfire +use_logfire = os.environ.get('USE_LOGFIRE', 'false').lower() == 'true' + +if use_logfire: + try: + import logfire + logfire.configure( + send_to_logfire="if-token-present" + ) + logger = logfire + except ImportError: + print("Logfire not installed or disabled. Using standard Python logger.") + use_logfire = False + +if not use_logfire: + # Use a standard Python logger when Logfire is disabled or not available + logger = getLogger("comfy-deploy") + basicConfig(level="INFO") # You can adjust the logging level as needed + +def log(level, message, **kwargs): + if use_logfire: + getattr(logger, level)(message, **kwargs) + else: + getattr(logger, level)(f"{message} {kwargs}") + +# For a span, you might need to create a context manager +from contextlib import contextmanager + +@contextmanager +def log_span(name): + if use_logfire: + with logger.span(name): + yield + # else: + # logger.info(f"Start: {name}") + # yield + # logger.info(f"End: {name}") + from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata @@ -306,7 +386,7 @@ async def stream_prompt(data): workflow_api=workflow_api ) - logfire.info("Begin prompt", prompt=prompt) + log('info', "Begin prompt", prompt=prompt) try: res = post_prompt(prompt) @@ -359,8 +439,8 @@ async def stream_response(request): prompt_id = data.get("prompt_id") comfy_message_queues[prompt_id] = asyncio.Queue() - with logfire.span('Streaming Run'): - logfire.info('Streaming prompt') + with log_span('Streaming Run'): + log('info', 'Streaming prompt') try: result = await stream_prompt(data=data) @@ -373,7 +453,7 @@ async def stream_response(request): if not comfy_message_queues[prompt_id].empty(): data = await comfy_message_queues[prompt_id].get() - logfire.info(data["event"], data=json.dumps(data)) + log('info', data["event"], data=json.dumps(data)) # logger.info("listener", data) await response.write(f"event: event_update\ndata: {json.dumps(data)}\n\n".encode('utf-8')) await response.drain() # Ensure the buffer is flushed @@ -384,10 +464,10 @@ async def stream_response(request): await asyncio.sleep(0.1) # Adjust the sleep duration as needed except asyncio.CancelledError: - logfire.info("Streaming was cancelled") + log('info', "Streaming was cancelled") raise except Exception as e: - logfire.error("Streaming error", error=e) + log('error', "Streaming error", error=e) finally: # event_emitter.off("send_json", task) await response.write_eof() @@ -482,34 +562,33 @@ async def upload_file_endpoint(request): if get_url: try: - async with aiohttp.ClientSession() as session: - headers = {'Authorization': f'Bearer {token}'} - params = {'file_size': file_size, 'type': file_type} - async with session.get(get_url, params=params, headers=headers) as response: - if response.status == 200: - content = await response.json() - upload_url = content["upload_url"] + headers = {'Authorization': f'Bearer {token}'} + params = {'file_size': file_size, 'type': file_type} + response = await async_request_with_retry('GET', get_url, params=params, headers=headers) + if response.status == 200: + content = await response.json() + upload_url = content["upload_url"] - with open(file_path, 'rb') as f: - headers = { - "Content-Type": file_type, - # "x-amz-acl": "public-read", - "Content-Length": str(file_size) - } - async with session.put(upload_url, data=f, headers=headers) as upload_response: - if upload_response.status == 200: - return web.json_response({ - "message": "File uploaded successfully", - "download_url": content["download_url"] - }) - else: - return web.json_response({ - "error": f"Failed to upload file to {upload_url}. Status code: {upload_response.status}" - }, status=upload_response.status) + with open(file_path, 'rb') as f: + headers = { + "Content-Type": file_type, + # "x-amz-acl": "public-read", + "Content-Length": str(file_size) + } + upload_response = await async_request_with_retry('PUT', upload_url, data=f, headers=headers) + if upload_response.status == 200: + return web.json_response({ + "message": "File uploaded successfully", + "download_url": content["download_url"] + }) else: return web.json_response({ - "error": f"Failed to fetch data from {get_url}. Status code: {response.status}" - }, status=response.status) + "error": f"Failed to upload file to {upload_url}. Status code: {upload_response.status}" + }, status=upload_response.status) + else: + return web.json_response({ + "error": f"Failed to fetch data from {get_url}. Status code: {response.status}" + }, status=response.status) except Exception as e: return web.json_response({ "error": f"An error occurred while fetching data from {get_url}: {str(e)}" @@ -588,9 +667,7 @@ async def update_realtime_run_status(realtime_id: str, status_endpoint: str, sta if (status_endpoint is None): return # requests.post(status_endpoint, json=body) - async with aiohttp.ClientSession() as session: - async with session.post(status_endpoint, json=body) as response: - pass + await async_request_with_retry('POST', status_endpoint, json=body) @server.PromptServer.instance.routes.get('/comfyui-deploy/ws') async def websocket_handler(request): @@ -611,28 +688,27 @@ async def websocket_handler(request): status_endpoint = request.rel_url.query.get('status_endpoint', None) if auth_token is not None and get_workflow_endpoint_url is not None: - async with aiohttp.ClientSession() as session: - headers = {'Authorization': f'Bearer {auth_token}'} - async with session.get(get_workflow_endpoint_url, headers=headers) as response: - if response.status == 200: - workflow = await response.json() + headers = {'Authorization': f'Bearer {auth_token}'} + response = await async_request_with_retry('GET', get_workflow_endpoint_url, headers=headers) + if response.status == 200: + workflow = await response.json() - logger.info(f"Loaded workflow version ${workflow['version']}") + logger.info(f"Loaded workflow version ${workflow['version']}") - streaming_prompt_metadata[sid] = StreamingPrompt( - workflow_api=workflow["workflow_api"], - auth_token=auth_token, - inputs={}, - status_endpoint=status_endpoint, - file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), - ) + streaming_prompt_metadata[sid] = StreamingPrompt( + workflow_api=workflow["workflow_api"], + auth_token=auth_token, + inputs={}, + status_endpoint=status_endpoint, + file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), + ) - await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING) - # await send("workflow_api", workflow_api, sid) - else: - error_message = await response.text() - logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") - # await send("error", {"message": error_message}, sid) + await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING) + # await send("workflow_api", workflow_api, sid) + else: + error_message = await response.text() + logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") + # await send("error", {"message": error_message}, sid) try: # Send initial state to the new client @@ -805,13 +881,14 @@ async def send_json_override(self, event, data, sid=None): prompt_metadata[prompt_id].progress.add(node) calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api) + calculated_progress = round(calculated_progress, 2) # logger.info("calculated_progress", calculated_progress) if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node: return prompt_metadata[prompt_id].last_updated_node = node class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type'] - logger.info(f"updating run live status {class_type}") + logger.info(f"At: {calculated_progress * 100}% - {class_type}") await send("live_status", { "prompt_id": prompt_id, "current_node": class_type, @@ -836,15 +913,16 @@ async def send_json_override(self, event, data, sid=None): # await update_run_with_output(prompt_id, data) if event == 'executed' and 'node' in data and 'output' in data: - logger.info(f"executed {data}") if prompt_id in prompt_metadata: node = data.get('node') class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type'] - logger.info(f"executed {class_type}") + logger.info(f"Executed {class_type} {data}") if class_type == "PreviewImage": - logger.info("skipping preview image") + logger.info("Skipping preview image") return - + else: + logger.info(f"Executed {data}") + await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node')) # update_run_with_output(prompt_id, data.get('output')) @@ -864,7 +942,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl if (status_endpoint is None): return - logger.info(f"progress {calculated_progress}") + # logger.info(f"progress {calculated_progress}") body = { "run_id": prompt_id, @@ -883,9 +961,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl }) # requests.post(status_endpoint, json=body) - async with aiohttp.ClientSession() as session: - async with session.post(status_endpoint, json=body) as response: - pass + await async_request_with_retry('POST', status_endpoint, json=body) async def update_run(prompt_id: str, status: Status): @@ -916,9 +992,7 @@ async def update_run(prompt_id: str, status: Status): try: # requests.post(status_endpoint, json=body) if (status_endpoint is not None): - async with aiohttp.ClientSession() as session: - async with session.post(status_endpoint, json=body) as response: - pass + await async_request_with_retry('POST', status_endpoint, json=body) if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED): try: @@ -948,9 +1022,7 @@ async def update_run(prompt_id: str, status: Status): ] } - async with aiohttp.ClientSession() as session: - async with session.post(status_endpoint, json=body) as response: - pass + await async_request_with_retry('POST', status_endpoint, json=body) # requests.post(status_endpoint, json=body) except Exception as log_error: logger.info(f"Error reading log file: {log_error}") @@ -998,7 +1070,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p filename = os.path.basename(filename) file = os.path.join(output_dir, filename) - logger.info(f"uploading file {file}") + logger.info(f"Uploading file {file}") file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint @@ -1024,18 +1096,17 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p "Content-Length": str(len(data)), } # response = requests.put(ok.get("url"), headers=headers, data=data) - async with aiohttp.ClientSession() as session: - async with session.put(ok.get("url"), headers=headers, data=data) as response: - logger.info(f"Upload file response status: {response.status}, status text: {response.reason}") - end_time = time.time() # End timing after the request is complete - logger.info("Upload time: {:.2f} seconds".format(end_time - start_time)) + response = await async_request_with_retry('PUT', ok.get("url"), headers=headers, data=data) + logger.info(f"Upload file response status: {response.status}, status text: {response.reason}") + end_time = time.time() # End timing after the request is complete + logger.info("Upload time: {:.2f} seconds".format(end_time - start_time)) def have_pending_upload(prompt_id): if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0: - logger.info(f"have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}") + logger.info(f"Have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}") return True - logger.info("no pending upload") + logger.info("No pending upload") return False def mark_prompt_done(prompt_id): @@ -1093,7 +1164,7 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False, else: prompt_metadata[prompt_id].uploading_nodes.discard(node_id) - logger.info(prompt_metadata[prompt_id].uploading_nodes) + logger.info(f"Remaining uploads: {prompt_metadata[prompt_id].uploading_nodes}") # Update the remote status if have_error: @@ -1177,7 +1248,7 @@ async def update_run_with_output(prompt_id, data, node_id=None): if have_upload_media: try: - logger.info(f"\nhave_upload {have_upload_media} {node_id}") + logger.info(f"\nHave_upload {have_upload_media} Node Id: {node_id}") if have_upload_media: await update_file_status(prompt_id, data, True, node_id=node_id) @@ -1190,9 +1261,7 @@ async def update_run_with_output(prompt_id, data, node_id=None): # requests.post(status_endpoint, json=body) if status_endpoint is not None: - async with aiohttp.ClientSession() as session: - async with session.post(status_endpoint, json=body) as response: - pass + await async_request_with_retry('POST', status_endpoint, json=body) await send('outputs_uploaded', { "prompt_id": prompt_id diff --git a/requirements.txt b/requirements.txt index cedfa7c..b5a5491 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ aiofiles pydantic opencv-python imageio-ffmpeg -logfire \ No newline at end of file +# logfire \ No newline at end of file