Merge branch 'benny/async-upload-file' into public-main
This commit is contained in:
		
						commit
						894d8e1503
					
				@ -56,20 +56,24 @@ retry_delay_multiplier = float(os.environ.get('RETRY_DELAY_MULTIPLIER', '2'))
 | 
			
		||||
 | 
			
		||||
print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}")
 | 
			
		||||
 | 
			
		||||
async def async_request_with_retry(method, url, **kwargs):
 | 
			
		||||
async def async_request_with_retry(method, url, disable_timeout=False, **kwargs):
 | 
			
		||||
    global client_session
 | 
			
		||||
    await ensure_client_session()
 | 
			
		||||
    # async with aiohttp.ClientSession() as client_session:
 | 
			
		||||
    retry_delay = 1  # Start with 1 second delay
 | 
			
		||||
    initial_timeout = 5  # 5 seconds timeout for the initial connection
 | 
			
		||||
 | 
			
		||||
    for attempt in range(max_retries):
 | 
			
		||||
        try:
 | 
			
		||||
            # Set a timeout for the initial connection
 | 
			
		||||
            timeout = ClientTimeout(total=None, connect=initial_timeout)
 | 
			
		||||
            kwargs['timeout'] = timeout
 | 
			
		||||
            if not disable_timeout:
 | 
			
		||||
                timeout = ClientTimeout(total=None, connect=initial_timeout)
 | 
			
		||||
                kwargs['timeout'] = timeout
 | 
			
		||||
 | 
			
		||||
            async with client_session.request(method, url, **kwargs) as response:
 | 
			
		||||
                response.raise_for_status()
 | 
			
		||||
                if method.upper() == 'GET':
 | 
			
		||||
                    await response.read()
 | 
			
		||||
                return response
 | 
			
		||||
        except asyncio.TimeoutError:
 | 
			
		||||
            logger.warning(f"Request timed out after {initial_timeout} seconds (attempt {attempt + 1}/{max_retries})")
 | 
			
		||||
@ -1150,18 +1154,18 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
			
		||||
    prompt_id = quote(prompt_id)
 | 
			
		||||
    content_type = quote(content_type)
 | 
			
		||||
 | 
			
		||||
    target_url = f"{file_upload_endpoint}?file_name={filename}&run_id={prompt_id}&type={content_type}"
 | 
			
		||||
    target_url = f"{file_upload_endpoint}?file_name={filename}&run_id={prompt_id}&type={content_type}&version=v2"
 | 
			
		||||
 | 
			
		||||
    start_time = time.time()  # Start timing here
 | 
			
		||||
    result = requests.get(target_url)
 | 
			
		||||
    result = await async_request_with_retry("GET", target_url, disable_timeout=True)
 | 
			
		||||
    end_time = time.time()  # End timing after the request is complete
 | 
			
		||||
    logger.info("Time taken for getting file upload endpoint: {:.2f} seconds".format(end_time - start_time))
 | 
			
		||||
    ok = result.json()
 | 
			
		||||
    ok = await result.json()
 | 
			
		||||
 | 
			
		||||
    start_time = time.time()  # Start timing here
 | 
			
		||||
 | 
			
		||||
    with open(file, 'rb') as f:
 | 
			
		||||
        data = f.read()
 | 
			
		||||
    async with aiofiles.open(file, 'rb') as f:
 | 
			
		||||
        data = await f.read()
 | 
			
		||||
        headers = {
 | 
			
		||||
            # "x-amz-acl": "public-read",
 | 
			
		||||
            "Content-Type": content_type,
 | 
			
		||||
@ -1264,8 +1268,10 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
 | 
			
		||||
 | 
			
		||||
async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str):
 | 
			
		||||
    items = data.get(key, [])
 | 
			
		||||
    upload_tasks = []
 | 
			
		||||
 | 
			
		||||
    for item in items:
 | 
			
		||||
        # # Skipping temp files
 | 
			
		||||
        # Skipping temp files
 | 
			
		||||
        if item.get("type") == "temp":
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
@ -1278,22 +1284,28 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d
 | 
			
		||||
        elif file_extension == '.webp':
 | 
			
		||||
            file_type = 'image/webp'
 | 
			
		||||
 | 
			
		||||
        await upload_file(
 | 
			
		||||
        upload_tasks.append(upload_file(
 | 
			
		||||
            prompt_id,
 | 
			
		||||
            item.get("filename"),
 | 
			
		||||
            subfolder=item.get("subfolder"),
 | 
			
		||||
            type=item.get("type"),
 | 
			
		||||
            content_type=file_type
 | 
			
		||||
        )
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    # Execute all upload tasks concurrently
 | 
			
		||||
    await asyncio.gather(*upload_tasks)
 | 
			
		||||
 | 
			
		||||
# Upload files in the background
 | 
			
		||||
async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True):
 | 
			
		||||
    try:
 | 
			
		||||
        await handle_upload(prompt_id, data, 'images', "content_type", "image/png")
 | 
			
		||||
        await handle_upload(prompt_id, data, 'files', "content_type", "image/png")
 | 
			
		||||
        # This will also be mp4
 | 
			
		||||
        await handle_upload(prompt_id, data, 'gifs', "format", "image/gif")
 | 
			
		||||
        await handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream")
 | 
			
		||||
        upload_tasks = [
 | 
			
		||||
            handle_upload(prompt_id, data, 'images', "content_type", "image/png"),
 | 
			
		||||
            handle_upload(prompt_id, data, 'files', "content_type", "image/png"),
 | 
			
		||||
            handle_upload(prompt_id, data, 'gifs', "format", "image/gif"),
 | 
			
		||||
            handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream")
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        await asyncio.gather(*upload_tasks)
 | 
			
		||||
 | 
			
		||||
        if have_upload:
 | 
			
		||||
            await update_file_status(prompt_id, data, False, node_id=node_id)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user