Compare commits
	
		
			10 Commits
		
	
	
		
			main
			...
			benny/uplo
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					2cfdd0b3f5 | ||
| 
						 | 
					9c06cc666c | ||
| 
						 | 
					710917d507 | ||
| 
						 | 
					16f4312c9e | ||
| 
						 | 
					20801f4d3f | ||
| 
						 | 
					ba2f942b29 | ||
| 
						 | 
					0dfa83f486 | ||
| 
						 | 
					5954092f25 | ||
| 
						 | 
					a2990ca833 | ||
| 
						 | 
					d673c4a00b | 
							
								
								
									
										100
									
								
								custom_routes.py
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								custom_routes.py
									
									
									
									
									
								
							@ -1,4 +1,5 @@
 | 
				
			|||||||
from io import BytesIO
 | 
					from io import BytesIO
 | 
				
			||||||
 | 
					from pprint import pprint
 | 
				
			||||||
from aiohttp import web
 | 
					from aiohttp import web
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import requests
 | 
					import requests
 | 
				
			||||||
@ -56,39 +57,60 @@ retry_delay_multiplier = float(os.environ.get('RETRY_DELAY_MULTIPLIER', '2'))
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}")
 | 
					print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def async_request_with_retry(method, url, disable_timeout=False, **kwargs):
 | 
					async def async_request_with_retry(method, url, disable_timeout=False, **kwargs):
 | 
				
			||||||
    global client_session
 | 
					    global client_session
 | 
				
			||||||
    await ensure_client_session()
 | 
					    await ensure_client_session()
 | 
				
			||||||
    # async with aiohttp.ClientSession() as client_session:
 | 
					 | 
				
			||||||
    retry_delay = 1  # Start with 1 second delay
 | 
					    retry_delay = 1  # Start with 1 second delay
 | 
				
			||||||
    initial_timeout = 5  # 5 seconds timeout for the initial connection
 | 
					    initial_timeout = 5  # 5 seconds timeout for the initial connection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    start_time = time.time()
 | 
				
			||||||
    for attempt in range(max_retries):
 | 
					    for attempt in range(max_retries):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            # Set a timeout for the initial connection
 | 
					 | 
				
			||||||
            if not disable_timeout:
 | 
					            if not disable_timeout:
 | 
				
			||||||
                timeout = ClientTimeout(total=None, connect=initial_timeout)
 | 
					                timeout = ClientTimeout(total=None, connect=initial_timeout)
 | 
				
			||||||
                kwargs['timeout'] = timeout
 | 
					                kwargs['timeout'] = timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            request_start = time.time()
 | 
				
			||||||
            async with client_session.request(method, url, **kwargs) as response:
 | 
					            async with client_session.request(method, url, **kwargs) as response:
 | 
				
			||||||
 | 
					                request_end = time.time()
 | 
				
			||||||
 | 
					                logger.info(f"Request attempt {attempt + 1} took {request_end - request_start:.2f} seconds")
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                if response.status != 200:
 | 
				
			||||||
 | 
					                    error_body = await response.text()
 | 
				
			||||||
 | 
					                    logger.error(f"Request failed with status {response.status} and body {error_body}")
 | 
				
			||||||
 | 
					                    # raise Exception(f"Request failed with status {response.status}")
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
                response.raise_for_status()
 | 
					                response.raise_for_status()
 | 
				
			||||||
                if method.upper() == 'GET':
 | 
					                if method.upper() == 'GET':
 | 
				
			||||||
                    await response.read()
 | 
					                    await response.read()
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                total_time = time.time() - start_time
 | 
				
			||||||
 | 
					                logger.info(f"Request succeeded after {total_time:.2f} seconds (attempt {attempt + 1}/{max_retries})")
 | 
				
			||||||
                return response
 | 
					                return response
 | 
				
			||||||
        except asyncio.TimeoutError:
 | 
					        except asyncio.TimeoutError:
 | 
				
			||||||
            logger.warning(f"Request timed out after {initial_timeout} seconds (attempt {attempt + 1}/{max_retries})")
 | 
					            logger.warning(f"Request timed out after {initial_timeout} seconds (attempt {attempt + 1}/{max_retries})")
 | 
				
			||||||
        except ClientError as e:
 | 
					        except ClientError as e:
 | 
				
			||||||
 | 
					            end_time = time.time()
 | 
				
			||||||
 | 
					            logger.error(f"Request failed (attempt {attempt + 1}/{max_retries}): {e}")
 | 
				
			||||||
 | 
					            logger.error(f"Time taken for failed attempt: {end_time - request_start:.2f} seconds")
 | 
				
			||||||
 | 
					            logger.error(f"Total time elapsed: {end_time - start_time:.2f} seconds")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Log the response body for ClientError as well
 | 
				
			||||||
 | 
					            if hasattr(e, 'response') and e.response is not None:
 | 
				
			||||||
 | 
					                error_body = await e.response.text()
 | 
				
			||||||
 | 
					                logger.error(f"Error response body: {error_body}")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
            if attempt == max_retries - 1:
 | 
					            if attempt == max_retries - 1:
 | 
				
			||||||
                logger.error(f"Request failed after {max_retries} attempts: {e}")
 | 
					                logger.error(f"Request failed after {max_retries} attempts: {e}")
 | 
				
			||||||
                # raise
 | 
					                raise
 | 
				
			||||||
            logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries}): {e}")
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Wait before retrying
 | 
					 | 
				
			||||||
        await asyncio.sleep(retry_delay)
 | 
					        await asyncio.sleep(retry_delay)
 | 
				
			||||||
        retry_delay *= retry_delay_multiplier  # Exponential backoff
 | 
					        retry_delay *= retry_delay_multiplier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # If all retries fail, raise an exception
 | 
					    total_time = time.time() - start_time
 | 
				
			||||||
    raise Exception(f"Request failed after {max_retries} attempts")
 | 
					    raise Exception(f"Request failed after {max_retries} attempts and {total_time:.2f} seconds")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from logging import basicConfig, getLogger
 | 
					from logging import basicConfig, getLogger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -607,9 +629,10 @@ async def upload_file_endpoint(request):
 | 
				
			|||||||
                with open(file_path, 'rb') as f:
 | 
					                with open(file_path, 'rb') as f:
 | 
				
			||||||
                    headers = {
 | 
					                    headers = {
 | 
				
			||||||
                        "Content-Type": file_type,
 | 
					                        "Content-Type": file_type,
 | 
				
			||||||
                        # "x-amz-acl": "public-read",
 | 
					                        # "Content-Length": str(file_size)
 | 
				
			||||||
                        "Content-Length": str(file_size)
 | 
					 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					                    if content.get('include_acl') is True:
 | 
				
			||||||
 | 
					                        headers["x-amz-acl"] = "public-read"
 | 
				
			||||||
                    upload_response = await async_request_with_retry('PUT', upload_url, data=f, headers=headers)
 | 
					                    upload_response = await async_request_with_retry('PUT', upload_url, data=f, headers=headers)
 | 
				
			||||||
                    if upload_response.status == 200:
 | 
					                    if upload_response.status == 200:
 | 
				
			||||||
                        return web.json_response({
 | 
					                        return web.json_response({
 | 
				
			||||||
@ -1127,7 +1150,7 @@ async def update_run(prompt_id: str, status: Status):
 | 
				
			|||||||
                })
 | 
					                })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output"):
 | 
					async def upload_file(prompt_id, filename, subfolder=None, content_type="image/png", type="output", item=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Uploads file to S3 bucket using S3 client object
 | 
					    Uploads file to S3 bucket using S3 client object
 | 
				
			||||||
    :return: None
 | 
					    :return: None
 | 
				
			||||||
@ -1162,29 +1185,41 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
				
			|||||||
    prompt_id = quote(prompt_id)
 | 
					    prompt_id = quote(prompt_id)
 | 
				
			||||||
    content_type = quote(content_type)
 | 
					    content_type = quote(content_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    target_url = f"{file_upload_endpoint}?file_name={filename}&run_id={prompt_id}&type={content_type}&version=v2"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    start_time = time.time()  # Start timing here
 | 
					 | 
				
			||||||
    result = await async_request_with_retry("GET", target_url, disable_timeout=True)
 | 
					 | 
				
			||||||
    end_time = time.time()  # End timing after the request is complete
 | 
					 | 
				
			||||||
    logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
 | 
					 | 
				
			||||||
    ok = await result.json()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    start_time = time.time()  # Start timing here
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async with aiofiles.open(file, 'rb') as f:
 | 
					    async with aiofiles.open(file, 'rb') as f:
 | 
				
			||||||
        data = await f.read()
 | 
					        data = await f.read()
 | 
				
			||||||
 | 
					        size = str(len(data))
 | 
				
			||||||
 | 
					        target_url = f"{file_upload_endpoint}?file_name={filename}&run_id={prompt_id}&type={content_type}&version=v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        start_time = time.time()  # Start timing here
 | 
				
			||||||
 | 
					        logger.info(f"Target URL: {target_url}")
 | 
				
			||||||
 | 
					        result = await async_request_with_retry("GET", target_url, disable_timeout=True)
 | 
				
			||||||
 | 
					        end_time = time.time()  # End timing after the request is complete
 | 
				
			||||||
 | 
					        logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
 | 
				
			||||||
 | 
					        ok = await result.json()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        logger.info(f"Result: {ok}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        start_time = time.time()  # Start timing here
 | 
				
			||||||
        headers = {
 | 
					        headers = {
 | 
				
			||||||
            # "x-amz-acl": "public-read",
 | 
					 | 
				
			||||||
            "Content-Type": content_type,
 | 
					            "Content-Type": content_type,
 | 
				
			||||||
            "Content-Length": str(len(data)),
 | 
					            # "Content-Length": size,
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if ok.get('include_acl') is True:
 | 
				
			||||||
 | 
					            headers["x-amz-acl"] = "public-read"
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        # response = requests.put(ok.get("url"), headers=headers, data=data)
 | 
					        # response = requests.put(ok.get("url"), headers=headers, data=data)
 | 
				
			||||||
        response = await async_request_with_retry('PUT', ok.get("url"), headers=headers, data=data)
 | 
					        response = await async_request_with_retry('PUT', ok.get("url"), headers=headers, data=data)
 | 
				
			||||||
        logger.info(f"Upload file response status: {response.status}, status text: {response.reason}")
 | 
					        logger.info(f"Upload file response status: {response.status}, status text: {response.reason}")
 | 
				
			||||||
        end_time = time.time()  # End timing after the request is complete
 | 
					        end_time = time.time()  # End timing after the request is complete
 | 
				
			||||||
        logger.info("Upload time: {:.2f} seconds".format(end_time - start_time))
 | 
					        logger.info("Upload time: {:.2f} seconds".format(end_time - start_time))
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        if item is not None:
 | 
				
			||||||
 | 
					            file_download_url = ok.get("download_url")
 | 
				
			||||||
 | 
					            if file_download_url is not None:
 | 
				
			||||||
 | 
					                item["url"] = file_download_url
 | 
				
			||||||
 | 
					            item["upload_duration"] = end_time - start_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def have_pending_upload(prompt_id):
 | 
					def have_pending_upload(prompt_id):
 | 
				
			||||||
    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
					    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
				
			||||||
        logger.info(f"Have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}")
 | 
					        logger.info(f"Have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}")
 | 
				
			||||||
@ -1297,14 +1332,15 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d
 | 
				
			|||||||
            item.get("filename"),
 | 
					            item.get("filename"),
 | 
				
			||||||
            subfolder=item.get("subfolder"),
 | 
					            subfolder=item.get("subfolder"),
 | 
				
			||||||
            type=item.get("type"),
 | 
					            type=item.get("type"),
 | 
				
			||||||
            content_type=file_type
 | 
					            content_type=file_type,
 | 
				
			||||||
 | 
					            item=item
 | 
				
			||||||
        ))
 | 
					        ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Execute all upload tasks concurrently
 | 
					    # Execute all upload tasks concurrently
 | 
				
			||||||
    await asyncio.gather(*upload_tasks)
 | 
					    await asyncio.gather(*upload_tasks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Upload files in the background
 | 
					# Upload files in the background
 | 
				
			||||||
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
 | 
					async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True, node_meta=None):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        upload_tasks = [
 | 
					        upload_tasks = [
 | 
				
			||||||
            handle_upload(prompt_id, data, 'images', "content_type", "image/png"),
 | 
					            handle_upload(prompt_id, data, 'images', "content_type", "image/png"),
 | 
				
			||||||
@ -1315,7 +1351,16 @@ async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=T
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        await asyncio.gather(*upload_tasks)
 | 
					        await asyncio.gather(*upload_tasks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        status_endpoint = prompt_metadata[prompt_id].status_endpoint
 | 
				
			||||||
        if have_upload:
 | 
					        if have_upload:
 | 
				
			||||||
 | 
					            if status_endpoint is not None:
 | 
				
			||||||
 | 
					                body = {
 | 
				
			||||||
 | 
					                    "run_id": prompt_id,
 | 
				
			||||||
 | 
					                    "output_data": data,
 | 
				
			||||||
 | 
					                    "node_meta": node_meta,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                # pprint(body)
 | 
				
			||||||
 | 
					                await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
            await update_file_status(prompt_id, data, False, node_id=node_id)
 | 
					            await update_file_status(prompt_id, data, False, node_id=node_id)
 | 
				
			||||||
    except Exception as e:
 | 
					    except Exception as e:
 | 
				
			||||||
        await handle_error(prompt_id, data, e)
 | 
					        await handle_error(prompt_id, data, e)
 | 
				
			||||||
@ -1346,14 +1391,13 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None):
 | 
				
			|||||||
            if have_upload_media:
 | 
					            if have_upload_media:
 | 
				
			||||||
                await update_file_status(prompt_id, data, True, node_id=node_id)
 | 
					                await update_file_status(prompt_id, data, True, node_id=node_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload_media))
 | 
					            asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload_media, node_meta=node_meta))
 | 
				
			||||||
            # await upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload)
 | 
					            # await upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            await handle_error(prompt_id, data, e)
 | 
					            await handle_error(prompt_id, data, e)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # requests.post(status_endpoint, json=body)
 | 
					    # requests.post(status_endpoint, json=body)
 | 
				
			||||||
    if status_endpoint is not None:
 | 
					    elif status_endpoint is not None:
 | 
				
			||||||
        await async_request_with_retry('POST', status_endpoint, json=body)
 | 
					        await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await send('outputs_uploaded', {
 | 
					    await send('outputs_uploaded', {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user