From 08d631d1eb83b2dc9cdc483cdec0975c6cfca9b6 Mon Sep 17 00:00:00 2001 From: bennykok Date: Sun, 18 Aug 2024 19:09:30 -0700 Subject: [PATCH] feat: async file upload for the same node --- custom_routes.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index d726774..42f5a25 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -56,20 +56,24 @@ retry_delay_multiplier = float(os.environ.get('RETRY_DELAY_MULTIPLIER', '2')) print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}") -async def async_request_with_retry(method, url, **kwargs): +async def async_request_with_retry(method, url, disable_timeout=False, **kwargs): global client_session await ensure_client_session() + # async with aiohttp.ClientSession() as client_session: retry_delay = 1 # Start with 1 second delay initial_timeout = 5 # 5 seconds timeout for the initial connection for attempt in range(max_retries): try: # Set a timeout for the initial connection - timeout = ClientTimeout(total=None, connect=initial_timeout) - kwargs['timeout'] = timeout + if not disable_timeout: + timeout = ClientTimeout(total=None, connect=initial_timeout) + kwargs['timeout'] = timeout async with client_session.request(method, url, **kwargs) as response: response.raise_for_status() + if method.upper() == 'GET': + await response.read() return response except asyncio.TimeoutError: logger.warning(f"Request timed out after {initial_timeout} seconds (attempt {attempt + 1}/{max_retries})") @@ -1150,18 +1154,18 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p prompt_id = quote(prompt_id) content_type = quote(content_type) - target_url = f"{file_upload_endpoint}?file_name={filename}&run_id={prompt_id}&type={content_type}" + target_url = f"{file_upload_endpoint}?file_name={filename}&run_id={prompt_id}&type={content_type}&version=v2" start_time = time.time() # Start timing here - result = requests.get(target_url) + result = await async_request_with_retry("GET", target_url, disable_timeout=True) end_time = time.time() # End timing after the request is complete logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time)) - ok = result.json() + ok = await result.json() start_time = time.time() # Start timing here - with open(file, 'rb') as f: - data = f.read() + async with aiofiles.open(file, 'rb') as f: + data = await f.read() headers = { # "x-amz-acl": "public-read", "Content-Type": content_type, @@ -1264,8 +1268,10 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False, async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str): items = data.get(key, []) + upload_tasks = [] + for item in items: - # # Skipping temp files + # Skipping temp files if item.get("type") == "temp": continue @@ -1278,22 +1284,28 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d elif file_extension == '.webp': file_type = 'image/webp' - await upload_file( + upload_tasks.append(upload_file( prompt_id, item.get("filename"), subfolder=item.get("subfolder"), type=item.get("type"), content_type=file_type - ) + )) + + # Execute all upload tasks concurrently + await asyncio.gather(*upload_tasks) # Upload files in the background async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True): try: - await handle_upload(prompt_id, data, 'images', "content_type", "image/png") - await handle_upload(prompt_id, data, 'files', "content_type", "image/png") - # This will also be mp4 - await handle_upload(prompt_id, data, 'gifs', "format", "image/gif") - await handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream") + upload_tasks = [ + handle_upload(prompt_id, data, 'images', "content_type", "image/png"), + handle_upload(prompt_id, data, 'files', "content_type", "image/png"), + handle_upload(prompt_id, data, 'gifs', "format", "image/gif"), + handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream") + ] + + await asyncio.gather(*upload_tasks) if have_upload: await update_file_status(prompt_id, data, False, node_id=node_id)