Retry logic when calling api (#57)
* fix: retry logic, bypass logfire, clean up log * fix: max_retries and retry_delay_multiplier, do not throw when pass the retry failed
This commit is contained in:
		
							parent
							
								
									7585d5049a
								
							
						
					
					
						commit
						8e12803ea1
					
				
							
								
								
									
										247
									
								
								custom_routes.py
									
									
									
									
									
								
							
							
						
						
									
										247
									
								
								custom_routes.py
									
									
									
									
									
								
							@ -22,17 +22,97 @@ from typing import Dict, List, Union, Any, Optional
 | 
				
			|||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import struct
 | 
					import struct
 | 
				
			||||||
 | 
					from aiohttp import ClientError
 | 
				
			||||||
 | 
					import atexit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Global session
 | 
				
			||||||
 | 
					client_session = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# def create_client_session():
 | 
				
			||||||
 | 
					#     global client_session
 | 
				
			||||||
 | 
					#     if client_session is None:
 | 
				
			||||||
 | 
					#         client_session = aiohttp.ClientSession()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					async def ensure_client_session():
 | 
				
			||||||
 | 
					    global client_session
 | 
				
			||||||
 | 
					    if client_session is None:
 | 
				
			||||||
 | 
					        client_session = aiohttp.ClientSession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def cleanup():
 | 
				
			||||||
 | 
					    global client_session
 | 
				
			||||||
 | 
					    if client_session:
 | 
				
			||||||
 | 
					        await client_session.close()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					def exit_handler():
 | 
				
			||||||
 | 
					    print("Exiting the application. Initiating cleanup...")
 | 
				
			||||||
 | 
					    loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					    loop.run_until_complete(cleanup())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					atexit.register(exit_handler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					max_retries = int(os.environ.get('MAX_RETRIES', '3'))
 | 
				
			||||||
 | 
					retry_delay_multiplier = float(os.environ.get('RETRY_DELAY_MULTIPLIER', '2'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def async_request_with_retry(method, url, **kwargs):
 | 
				
			||||||
 | 
					    global client_session
 | 
				
			||||||
 | 
					    await ensure_client_session()
 | 
				
			||||||
 | 
					    retry_delay = 1  # Start with 1 second delay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for attempt in range(max_retries):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            async with client_session.request(method, url, **kwargs) as response:
 | 
				
			||||||
 | 
					                response.raise_for_status()
 | 
				
			||||||
 | 
					                return response
 | 
				
			||||||
 | 
					        except ClientError as e:
 | 
				
			||||||
 | 
					            if attempt == max_retries - 1:
 | 
				
			||||||
 | 
					                logger.error(f"Request failed after {max_retries} attempts: {e}")
 | 
				
			||||||
 | 
					                # raise
 | 
				
			||||||
 | 
					            logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries}): {e}")
 | 
				
			||||||
 | 
					            await asyncio.sleep(retry_delay)
 | 
				
			||||||
 | 
					            retry_delay *= retry_delay_multiplier  # Exponential backoff
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from logging import basicConfig, getLogger
 | 
					from logging import basicConfig, getLogger
 | 
				
			||||||
import logfire
 | 
					
 | 
				
			||||||
# if os.environ.get('LOGFIRE_TOKEN', None) is not None:
 | 
					# Check for an environment variable to enable/disable Logfire
 | 
				
			||||||
logfire.configure(
 | 
					use_logfire = os.environ.get('USE_LOGFIRE', 'false').lower() == 'true'
 | 
				
			||||||
    send_to_logfire="if-token-present"
 | 
					
 | 
				
			||||||
)
 | 
					if use_logfire:
 | 
				
			||||||
# basicConfig(handlers=[logfire.LogfireLoggingHandler()])
 | 
					    try:
 | 
				
			||||||
logfire_handler = logfire.LogfireLoggingHandler()
 | 
					        import logfire
 | 
				
			||||||
logger = getLogger("comfy-deploy")
 | 
					        logfire.configure(
 | 
				
			||||||
logger.addHandler(logfire_handler)
 | 
					            send_to_logfire="if-token-present"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        logger = logfire
 | 
				
			||||||
 | 
					    except ImportError:
 | 
				
			||||||
 | 
					        print("Logfire not installed or disabled. Using standard Python logger.")
 | 
				
			||||||
 | 
					        use_logfire = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not use_logfire:
 | 
				
			||||||
 | 
					    # Use a standard Python logger when Logfire is disabled or not available
 | 
				
			||||||
 | 
					    logger = getLogger("comfy-deploy")
 | 
				
			||||||
 | 
					    basicConfig(level="INFO")  # You can adjust the logging level as needed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def log(level, message, **kwargs):
 | 
				
			||||||
 | 
					    if use_logfire:
 | 
				
			||||||
 | 
					        getattr(logger, level)(message, **kwargs)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        getattr(logger, level)(f"{message} {kwargs}")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					# For a span, you might need to create a context manager
 | 
				
			||||||
 | 
					from contextlib import contextmanager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@contextmanager
 | 
				
			||||||
 | 
					def log_span(name):
 | 
				
			||||||
 | 
					    if use_logfire:
 | 
				
			||||||
 | 
					        with logger.span(name):
 | 
				
			||||||
 | 
					            yield
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    #     logger.info(f"Start: {name}")
 | 
				
			||||||
 | 
					    #     yield
 | 
				
			||||||
 | 
					    #     logger.info(f"End: {name}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
 | 
					from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -306,7 +386,7 @@ async def stream_prompt(data):
 | 
				
			|||||||
        workflow_api=workflow_api
 | 
					        workflow_api=workflow_api
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logfire.info("Begin prompt", prompt=prompt)
 | 
					    log('info', "Begin prompt", prompt=prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        res = post_prompt(prompt)
 | 
					        res = post_prompt(prompt)
 | 
				
			||||||
@ -359,8 +439,8 @@ async def stream_response(request):
 | 
				
			|||||||
    prompt_id = data.get("prompt_id")
 | 
					    prompt_id = data.get("prompt_id")
 | 
				
			||||||
    comfy_message_queues[prompt_id] = asyncio.Queue()
 | 
					    comfy_message_queues[prompt_id] = asyncio.Queue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with logfire.span('Streaming Run'):
 | 
					    with log_span('Streaming Run'):
 | 
				
			||||||
        logfire.info('Streaming prompt')
 | 
					        log('info', 'Streaming prompt')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            result = await stream_prompt(data=data)
 | 
					            result = await stream_prompt(data=data)
 | 
				
			||||||
@ -373,7 +453,7 @@ async def stream_response(request):
 | 
				
			|||||||
                    if not comfy_message_queues[prompt_id].empty():
 | 
					                    if not comfy_message_queues[prompt_id].empty():
 | 
				
			||||||
                        data = await comfy_message_queues[prompt_id].get()
 | 
					                        data = await comfy_message_queues[prompt_id].get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        logfire.info(data["event"], data=json.dumps(data))
 | 
					                        log('info', data["event"], data=json.dumps(data))
 | 
				
			||||||
                        # logger.info("listener", data)
 | 
					                        # logger.info("listener", data)
 | 
				
			||||||
                        await response.write(f"event: event_update\ndata: {json.dumps(data)}\n\n".encode('utf-8'))
 | 
					                        await response.write(f"event: event_update\ndata: {json.dumps(data)}\n\n".encode('utf-8'))
 | 
				
			||||||
                        await response.drain()  # Ensure the buffer is flushed
 | 
					                        await response.drain()  # Ensure the buffer is flushed
 | 
				
			||||||
@ -384,10 +464,10 @@ async def stream_response(request):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                await asyncio.sleep(0.1)  # Adjust the sleep duration as needed
 | 
					                await asyncio.sleep(0.1)  # Adjust the sleep duration as needed
 | 
				
			||||||
        except asyncio.CancelledError:
 | 
					        except asyncio.CancelledError:
 | 
				
			||||||
            logfire.info("Streaming was cancelled")
 | 
					            log('info', "Streaming was cancelled")
 | 
				
			||||||
            raise
 | 
					            raise
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            logfire.error("Streaming error", error=e)
 | 
					            log('error', "Streaming error", error=e)
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            # event_emitter.off("send_json", task)
 | 
					            # event_emitter.off("send_json", task)
 | 
				
			||||||
            await response.write_eof()
 | 
					            await response.write_eof()
 | 
				
			||||||
@ -482,34 +562,33 @@ async def upload_file_endpoint(request):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    if get_url:
 | 
					    if get_url:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            async with aiohttp.ClientSession() as session:
 | 
					            headers = {'Authorization': f'Bearer {token}'}
 | 
				
			||||||
                headers = {'Authorization': f'Bearer {token}'}
 | 
					            params = {'file_size': file_size, 'type': file_type}
 | 
				
			||||||
                params = {'file_size': file_size, 'type': file_type}
 | 
					            response = await async_request_with_retry('GET', get_url, params=params, headers=headers)
 | 
				
			||||||
                async with session.get(get_url, params=params, headers=headers) as response:
 | 
					            if response.status == 200:
 | 
				
			||||||
                    if response.status == 200:
 | 
					                content = await response.json()
 | 
				
			||||||
                        content = await response.json()
 | 
					                upload_url = content["upload_url"]
 | 
				
			||||||
                        upload_url = content["upload_url"]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        with open(file_path, 'rb') as f:
 | 
					                with open(file_path, 'rb') as f:
 | 
				
			||||||
                            headers = {
 | 
					                    headers = {
 | 
				
			||||||
                                "Content-Type": file_type,
 | 
					                        "Content-Type": file_type,
 | 
				
			||||||
                                # "x-amz-acl": "public-read",
 | 
					                        # "x-amz-acl": "public-read",
 | 
				
			||||||
                                "Content-Length": str(file_size)
 | 
					                        "Content-Length": str(file_size)
 | 
				
			||||||
                            }
 | 
					                    }
 | 
				
			||||||
                            async with session.put(upload_url, data=f, headers=headers) as upload_response:
 | 
					                    upload_response = await async_request_with_retry('PUT', upload_url, data=f, headers=headers)
 | 
				
			||||||
                                if upload_response.status == 200:
 | 
					                    if upload_response.status == 200:
 | 
				
			||||||
                                    return web.json_response({
 | 
					                        return web.json_response({
 | 
				
			||||||
                                        "message": "File uploaded successfully",
 | 
					                            "message": "File uploaded successfully",
 | 
				
			||||||
                                        "download_url": content["download_url"]
 | 
					                            "download_url": content["download_url"]
 | 
				
			||||||
                                    })
 | 
					                        })
 | 
				
			||||||
                                else:
 | 
					 | 
				
			||||||
                                    return web.json_response({
 | 
					 | 
				
			||||||
                                        "error": f"Failed to upload file to {upload_url}. Status code: {upload_response.status}"
 | 
					 | 
				
			||||||
                                    }, status=upload_response.status)
 | 
					 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        return web.json_response({
 | 
					                        return web.json_response({
 | 
				
			||||||
                            "error": f"Failed to fetch data from {get_url}. Status code: {response.status}"
 | 
					                            "error": f"Failed to upload file to {upload_url}. Status code: {upload_response.status}"
 | 
				
			||||||
                        }, status=response.status)
 | 
					                        }, status=upload_response.status)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return web.json_response({
 | 
				
			||||||
 | 
					                    "error": f"Failed to fetch data from {get_url}. Status code: {response.status}"
 | 
				
			||||||
 | 
					                }, status=response.status)
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            return web.json_response({
 | 
					            return web.json_response({
 | 
				
			||||||
                "error": f"An error occurred while fetching data from {get_url}: {str(e)}"
 | 
					                "error": f"An error occurred while fetching data from {get_url}: {str(e)}"
 | 
				
			||||||
@ -588,9 +667,7 @@ async def update_realtime_run_status(realtime_id: str, status_endpoint: str, sta
 | 
				
			|||||||
    if (status_endpoint is None):
 | 
					    if (status_endpoint is None):
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    # requests.post(status_endpoint, json=body)
 | 
					    # requests.post(status_endpoint, json=body)
 | 
				
			||||||
    async with aiohttp.ClientSession() as session:
 | 
					    await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
        async with session.post(status_endpoint, json=body) as response:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@server.PromptServer.instance.routes.get('/comfyui-deploy/ws')
 | 
					@server.PromptServer.instance.routes.get('/comfyui-deploy/ws')
 | 
				
			||||||
async def websocket_handler(request):
 | 
					async def websocket_handler(request):
 | 
				
			||||||
@ -611,28 +688,27 @@ async def websocket_handler(request):
 | 
				
			|||||||
    status_endpoint = request.rel_url.query.get('status_endpoint', None)
 | 
					    status_endpoint = request.rel_url.query.get('status_endpoint', None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if auth_token is not None and get_workflow_endpoint_url is not None:
 | 
					    if auth_token is not None and get_workflow_endpoint_url is not None:
 | 
				
			||||||
        async with aiohttp.ClientSession() as session:
 | 
					        headers = {'Authorization': f'Bearer {auth_token}'}
 | 
				
			||||||
            headers = {'Authorization': f'Bearer {auth_token}'}
 | 
					        response = await async_request_with_retry('GET', get_workflow_endpoint_url, headers=headers)
 | 
				
			||||||
            async with session.get(get_workflow_endpoint_url, headers=headers) as response:
 | 
					        if response.status == 200:
 | 
				
			||||||
                if response.status == 200:
 | 
					            workflow = await response.json()
 | 
				
			||||||
                    workflow = await response.json()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    logger.info(f"Loaded workflow version ${workflow['version']}")
 | 
					            logger.info(f"Loaded workflow version ${workflow['version']}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    streaming_prompt_metadata[sid] = StreamingPrompt(
 | 
					            streaming_prompt_metadata[sid] = StreamingPrompt(
 | 
				
			||||||
                        workflow_api=workflow["workflow_api"],
 | 
					                workflow_api=workflow["workflow_api"],
 | 
				
			||||||
                        auth_token=auth_token,
 | 
					                auth_token=auth_token,
 | 
				
			||||||
                        inputs={},
 | 
					                inputs={},
 | 
				
			||||||
                        status_endpoint=status_endpoint,
 | 
					                status_endpoint=status_endpoint,
 | 
				
			||||||
                        file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
 | 
					                file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None),
 | 
				
			||||||
                    )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING)
 | 
					            await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING)
 | 
				
			||||||
                    # await send("workflow_api", workflow_api, sid)
 | 
					            # await send("workflow_api", workflow_api, sid)
 | 
				
			||||||
                else:
 | 
					        else:
 | 
				
			||||||
                    error_message = await response.text()
 | 
					            error_message = await response.text()
 | 
				
			||||||
                    logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
 | 
					            logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}")
 | 
				
			||||||
                    # await send("error", {"message": error_message}, sid)
 | 
					            # await send("error", {"message": error_message}, sid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        # Send initial state to the new client
 | 
					        # Send initial state to the new client
 | 
				
			||||||
@ -805,13 +881,14 @@ async def send_json_override(self, event, data, sid=None):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            prompt_metadata[prompt_id].progress.add(node)
 | 
					            prompt_metadata[prompt_id].progress.add(node)
 | 
				
			||||||
            calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
 | 
					            calculated_progress = len(prompt_metadata[prompt_id].progress) / len(prompt_metadata[prompt_id].workflow_api)
 | 
				
			||||||
 | 
					            calculated_progress = round(calculated_progress, 2)
 | 
				
			||||||
            # logger.info("calculated_progress", calculated_progress)
 | 
					            # logger.info("calculated_progress", calculated_progress)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
 | 
					            if prompt_metadata[prompt_id].last_updated_node is not None and prompt_metadata[prompt_id].last_updated_node == node:
 | 
				
			||||||
                return
 | 
					                return
 | 
				
			||||||
            prompt_metadata[prompt_id].last_updated_node = node
 | 
					            prompt_metadata[prompt_id].last_updated_node = node
 | 
				
			||||||
            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
					            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
				
			||||||
            logger.info(f"updating run live status {class_type}")
 | 
					            logger.info(f"At: {calculated_progress * 100}% - {class_type}")
 | 
				
			||||||
            await send("live_status", {
 | 
					            await send("live_status", {
 | 
				
			||||||
                "prompt_id": prompt_id,
 | 
					                "prompt_id": prompt_id,
 | 
				
			||||||
                "current_node": class_type,
 | 
					                "current_node": class_type,
 | 
				
			||||||
@ -836,14 +913,15 @@ async def send_json_override(self, event, data, sid=None):
 | 
				
			|||||||
        # await update_run_with_output(prompt_id, data)
 | 
					        # await update_run_with_output(prompt_id, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if event == 'executed' and 'node' in data and 'output' in data:
 | 
					    if event == 'executed' and 'node' in data and 'output' in data:
 | 
				
			||||||
        logger.info(f"executed {data}")
 | 
					 | 
				
			||||||
        if prompt_id in prompt_metadata:
 | 
					        if prompt_id in prompt_metadata:
 | 
				
			||||||
            node = data.get('node')
 | 
					            node = data.get('node')
 | 
				
			||||||
            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
					            class_type = prompt_metadata[prompt_id].workflow_api[node]['class_type']
 | 
				
			||||||
            logger.info(f"executed {class_type}")
 | 
					            logger.info(f"Executed {class_type} {data}")
 | 
				
			||||||
            if class_type == "PreviewImage":
 | 
					            if class_type == "PreviewImage":
 | 
				
			||||||
                logger.info("skipping preview image")
 | 
					                logger.info("Skipping preview image")
 | 
				
			||||||
                return
 | 
					                return
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            logger.info(f"Executed {data}")
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
					        await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
				
			||||||
        # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
					        # await update_run_with_output(prompt_id, data.get('output'), node_id=data.get('node'))
 | 
				
			||||||
@ -864,7 +942,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
				
			|||||||
    if (status_endpoint is None):
 | 
					    if (status_endpoint is None):
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info(f"progress {calculated_progress}")
 | 
					    # logger.info(f"progress {calculated_progress}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    body = {
 | 
					    body = {
 | 
				
			||||||
        "run_id": prompt_id,
 | 
					        "run_id": prompt_id,
 | 
				
			||||||
@ -883,9 +961,7 @@ async def update_run_live_status(prompt_id, live_status, calculated_progress: fl
 | 
				
			|||||||
        })
 | 
					        })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # requests.post(status_endpoint, json=body)
 | 
					    # requests.post(status_endpoint, json=body)
 | 
				
			||||||
    async with aiohttp.ClientSession() as session:
 | 
					    await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
        async with session.post(status_endpoint, json=body) as response:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def update_run(prompt_id: str, status: Status):
 | 
					async def update_run(prompt_id: str, status: Status):
 | 
				
			||||||
@ -916,9 +992,7 @@ async def update_run(prompt_id: str, status: Status):
 | 
				
			|||||||
        try:
 | 
					        try:
 | 
				
			||||||
            # requests.post(status_endpoint, json=body)
 | 
					            # requests.post(status_endpoint, json=body)
 | 
				
			||||||
            if (status_endpoint is not None):
 | 
					            if (status_endpoint is not None):
 | 
				
			||||||
                async with aiohttp.ClientSession() as session:
 | 
					                await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
                    async with session.post(status_endpoint, json=body) as response:
 | 
					 | 
				
			||||||
                        pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
 | 
					            if (status_endpoint is not None) and cd_enable_run_log and (status == Status.SUCCESS or status == Status.FAILED):
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
@ -948,9 +1022,7 @@ async def update_run(prompt_id: str, status: Status):
 | 
				
			|||||||
                            ]
 | 
					                            ]
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        async with aiohttp.ClientSession() as session:
 | 
					                        await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
                            async with session.post(status_endpoint, json=body) as response:
 | 
					 | 
				
			||||||
                                pass
 | 
					 | 
				
			||||||
                        # requests.post(status_endpoint, json=body)
 | 
					                        # requests.post(status_endpoint, json=body)
 | 
				
			||||||
                except Exception as log_error:
 | 
					                except Exception as log_error:
 | 
				
			||||||
                    logger.info(f"Error reading log file: {log_error}")
 | 
					                    logger.info(f"Error reading log file: {log_error}")
 | 
				
			||||||
@ -998,7 +1070,7 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
				
			|||||||
    filename = os.path.basename(filename)
 | 
					    filename = os.path.basename(filename)
 | 
				
			||||||
    file = os.path.join(output_dir, filename)
 | 
					    file = os.path.join(output_dir, filename)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info(f"uploading file {file}")
 | 
					    logger.info(f"Uploading file {file}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
 | 
					    file_upload_endpoint = prompt_metadata[prompt_id].file_upload_endpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1024,18 +1096,17 @@ async def upload_file(prompt_id, filename, subfolder=None, content_type="image/p
 | 
				
			|||||||
            "Content-Length": str(len(data)),
 | 
					            "Content-Length": str(len(data)),
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        # response = requests.put(ok.get("url"), headers=headers, data=data)
 | 
					        # response = requests.put(ok.get("url"), headers=headers, data=data)
 | 
				
			||||||
        async with aiohttp.ClientSession() as session:
 | 
					        response = await async_request_with_retry('PUT', ok.get("url"), headers=headers, data=data)
 | 
				
			||||||
            async with session.put(ok.get("url"), headers=headers, data=data) as response:
 | 
					        logger.info(f"Upload file response status: {response.status}, status text: {response.reason}")
 | 
				
			||||||
                logger.info(f"Upload file response status: {response.status}, status text: {response.reason}")
 | 
					        end_time = time.time()  # End timing after the request is complete
 | 
				
			||||||
                end_time = time.time()  # End timing after the request is complete
 | 
					        logger.info("Upload time: {:.2f} seconds".format(end_time - start_time))
 | 
				
			||||||
                logger.info("Upload time: {:.2f} seconds".format(end_time - start_time))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def have_pending_upload(prompt_id):
 | 
					def have_pending_upload(prompt_id):
 | 
				
			||||||
    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
					    if prompt_id in prompt_metadata and len(prompt_metadata[prompt_id].uploading_nodes) > 0:
 | 
				
			||||||
        logger.info(f"have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}")
 | 
					        logger.info(f"Have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}")
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("no pending upload")
 | 
					    logger.info("No pending upload")
 | 
				
			||||||
    return False
 | 
					    return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mark_prompt_done(prompt_id):
 | 
					def mark_prompt_done(prompt_id):
 | 
				
			||||||
@ -1093,7 +1164,7 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False,
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
 | 
					            prompt_metadata[prompt_id].uploading_nodes.discard(node_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info(prompt_metadata[prompt_id].uploading_nodes)
 | 
					    logger.info(f"Remaining uploads: {prompt_metadata[prompt_id].uploading_nodes}")
 | 
				
			||||||
    # Update the remote status
 | 
					    # Update the remote status
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if have_error:
 | 
					    if have_error:
 | 
				
			||||||
@ -1177,7 +1248,7 @@ async def update_run_with_output(prompt_id, data, node_id=None):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    if have_upload_media:
 | 
					    if have_upload_media:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            logger.info(f"\nhave_upload {have_upload_media} {node_id}")
 | 
					            logger.info(f"\nHave_upload {have_upload_media} Node Id: {node_id}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if have_upload_media:
 | 
					            if have_upload_media:
 | 
				
			||||||
                await update_file_status(prompt_id, data, True, node_id=node_id)
 | 
					                await update_file_status(prompt_id, data, True, node_id=node_id)
 | 
				
			||||||
@ -1190,9 +1261,7 @@ async def update_run_with_output(prompt_id, data, node_id=None):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # requests.post(status_endpoint, json=body)
 | 
					    # requests.post(status_endpoint, json=body)
 | 
				
			||||||
    if status_endpoint is not None:
 | 
					    if status_endpoint is not None:
 | 
				
			||||||
        async with aiohttp.ClientSession() as session:
 | 
					        await async_request_with_retry('POST', status_endpoint, json=body)
 | 
				
			||||||
            async with session.post(status_endpoint, json=body) as response:
 | 
					 | 
				
			||||||
                pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await send('outputs_uploaded', {
 | 
					    await send('outputs_uploaded', {
 | 
				
			||||||
        "prompt_id": prompt_id
 | 
					        "prompt_id": prompt_id
 | 
				
			||||||
 | 
				
			|||||||
@ -2,4 +2,4 @@ aiofiles
 | 
				
			|||||||
pydantic
 | 
					pydantic
 | 
				
			||||||
opencv-python
 | 
					opencv-python
 | 
				
			||||||
imageio-ffmpeg
 | 
					imageio-ffmpeg
 | 
				
			||||||
logfire
 | 
					# logfire
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user