diff --git a/custom_routes.py b/custom_routes.py index 6af88a3..b440f2e 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -34,32 +34,39 @@ client_session = None # global client_session # if client_session is None: # client_session = aiohttp.ClientSession() - + + async def ensure_client_session(): global client_session if client_session is None: client_session = aiohttp.ClientSession() + async def cleanup(): global client_session if client_session: await client_session.close() - + + def exit_handler(): print("Exiting the application. Initiating cleanup...") loop = asyncio.get_event_loop() loop.run_until_complete(cleanup()) + atexit.register(exit_handler) -max_retries = int(os.environ.get('MAX_RETRIES', '5')) -retry_delay_multiplier = float(os.environ.get('RETRY_DELAY_MULTIPLIER', '2')) +max_retries = int(os.environ.get("MAX_RETRIES", "5")) +retry_delay_multiplier = float(os.environ.get("RETRY_DELAY_MULTIPLIER", "2")) print(f"max_retries: {max_retries}, retry_delay_multiplier: {retry_delay_multiplier}") import time -async def async_request_with_retry(method, url, disable_timeout=False, token=None, **kwargs): + +async def async_request_with_retry( + method, url, disable_timeout=False, token=None, **kwargs +): global client_session await ensure_client_session() retry_delay = 1 # Start with 1 second delay @@ -70,64 +77,76 @@ async def async_request_with_retry(method, url, disable_timeout=False, token=Non try: if not disable_timeout: timeout = ClientTimeout(total=None, connect=initial_timeout) - kwargs['timeout'] = timeout + kwargs["timeout"] = timeout if token is not None: - if 'headers' not in kwargs: - kwargs['headers'] = {} - kwargs['headers']['Authorization'] = f"Bearer {token}" + if "headers" not in kwargs: + kwargs["headers"] = {} + kwargs["headers"]["Authorization"] = f"Bearer {token}" request_start = time.time() async with client_session.request(method, url, **kwargs) as response: request_end = time.time() - logger.info(f"Request attempt {attempt + 1} took {request_end - request_start:.2f} seconds") - + logger.info( + f"Request attempt {attempt + 1} took {request_end - request_start:.2f} seconds" + ) + if response.status != 200: error_body = await response.text() - logger.error(f"Request failed with status {response.status} and body {error_body}") + logger.error( + f"Request failed with status {response.status} and body {error_body}" + ) # raise Exception(f"Request failed with status {response.status}") - + response.raise_for_status() - if method.upper() == 'GET': + if method.upper() == "GET": await response.read() - + total_time = time.time() - start_time - logger.info(f"Request succeeded after {total_time:.2f} seconds (attempt {attempt + 1}/{max_retries})") + logger.info( + f"Request succeeded after {total_time:.2f} seconds (attempt {attempt + 1}/{max_retries})" + ) return response except asyncio.TimeoutError: - logger.warning(f"Request timed out after {initial_timeout} seconds (attempt {attempt + 1}/{max_retries})") + logger.warning( + f"Request timed out after {initial_timeout} seconds (attempt {attempt + 1}/{max_retries})" + ) except ClientError as e: end_time = time.time() logger.error(f"Request failed (attempt {attempt + 1}/{max_retries}): {e}") - logger.error(f"Time taken for failed attempt: {end_time - request_start:.2f} seconds") + logger.error( + f"Time taken for failed attempt: {end_time - request_start:.2f} seconds" + ) logger.error(f"Total time elapsed: {end_time - start_time:.2f} seconds") - + # Log the response body for ClientError as well - if hasattr(e, 'response') and e.response is not None: + if hasattr(e, "response") and e.response is not None: error_body = await e.response.text() logger.error(f"Error response body: {error_body}") - + if attempt == max_retries - 1: logger.error(f"Request failed after {max_retries} attempts: {e}") raise - + await asyncio.sleep(retry_delay) retry_delay *= retry_delay_multiplier total_time = time.time() - start_time - raise Exception(f"Request failed after {max_retries} attempts and {total_time:.2f} seconds") + raise Exception( + f"Request failed after {max_retries} attempts and {total_time:.2f} seconds" + ) + from logging import basicConfig, getLogger # Check for an environment variable to enable/disable Logfire -use_logfire = os.environ.get('USE_LOGFIRE', 'false').lower() == 'true' +use_logfire = os.environ.get("USE_LOGFIRE", "false").lower() == "true" if use_logfire: try: import logfire - logfire.configure( - send_to_logfire="if-token-present" - ) + + logfire.configure(send_to_logfire="if-token-present") logger = logfire except ImportError: print("Logfire not installed or disabled. Using standard Python logger.") @@ -138,15 +157,18 @@ if not use_logfire: logger = getLogger("comfy-deploy") basicConfig(level="INFO") # You can adjust the logging level as needed + def log(level, message, **kwargs): if use_logfire: getattr(logger, level)(message, **kwargs) else: getattr(logger, level)(f"{message} {kwargs}") - + + # For a span, you might need to create a context manager from contextlib import contextmanager + @contextmanager def log_span(name): if use_logfire: @@ -159,7 +181,15 @@ def log_span(name): # logger.info(f"End: {name}") -from globals import StreamingPrompt, Status, sockets, SimplePrompt, streaming_prompt_metadata, prompt_metadata +from globals import ( + StreamingPrompt, + Status, + sockets, + SimplePrompt, + streaming_prompt_metadata, + prompt_metadata, +) + class EventEmitter: def __init__(self): @@ -181,31 +211,37 @@ class EventEmitter: for listener in self.listeners[event]: listener(*args, **kwargs) + # Create a global event emitter instance event_emitter = EventEmitter() api = None api_task = None -cd_enable_log = os.environ.get('CD_ENABLE_LOG', 'false').lower() == 'true' -cd_enable_run_log = os.environ.get('CD_ENABLE_RUN_LOG', 'false').lower() == 'true' -bypass_upload = os.environ.get('CD_BYPASS_UPLOAD', 'false').lower() == 'true' +cd_enable_log = os.environ.get("CD_ENABLE_LOG", "false").lower() == "true" +cd_enable_run_log = os.environ.get("CD_ENABLE_RUN_LOG", "false").lower() == "true" +bypass_upload = os.environ.get("CD_BYPASS_UPLOAD", "false").lower() == "true" logger.info(f"CD_BYPASS_UPLOAD {bypass_upload}") def clear_current_prompt(sid): prompt_server = server.PromptServer.instance - to_delete = list(streaming_prompt_metadata[sid].running_prompt_ids) # Convert set to list + to_delete = list( + streaming_prompt_metadata[sid].running_prompt_ids + ) # Convert set to list logger.info(f"clearing out prompt: {to_delete}") for id_to_delete in to_delete: delete_func = lambda a: a[1] == id_to_delete prompt_server.prompt_queue.delete_queue_item(delete_func) - logger.info(f"deleted prompt: {id_to_delete}, remaining tasks: {prompt_server.prompt_queue.get_tasks_remaining()}") + logger.info( + f"deleted prompt: {id_to_delete}, remaining tasks: {prompt_server.prompt_queue.get_tasks_remaining()}" + ) streaming_prompt_metadata[sid].running_prompt_ids.clear() + def post_prompt(json_data): prompt_server = server.PromptServer.instance json_data = prompt_server.trigger_on_prompt(json_data) @@ -248,11 +284,13 @@ def post_prompt(json_data): else: return {"error": "no prompt", "node_errors": []} + def randomSeed(num_digits=15): range_start = 10 ** (num_digits - 1) range_end = (10**num_digits) - 1 return random.randint(range_start, range_end) + def apply_random_seed_to_workflow(workflow_api): """ Applies a random seed to each element in the workflow_api that has a 'seed' input. @@ -261,62 +299,76 @@ def apply_random_seed_to_workflow(workflow_api): workflow_api (dict): The workflow API dictionary to modify. """ for key in workflow_api: - if 'inputs' in workflow_api[key]: - if 'seed' in workflow_api[key]['inputs']: - if isinstance(workflow_api[key]['inputs']['seed'], list): + if "inputs" in workflow_api[key]: + if "seed" in workflow_api[key]["inputs"]: + if isinstance(workflow_api[key]["inputs"]["seed"], list): continue - if workflow_api[key]['class_type'] == "PromptExpansion": - workflow_api[key]['inputs']['seed'] = randomSeed(8) - logger.info(f"Applied random seed {workflow_api[key]['inputs']['seed']} to PromptExpansion") + if workflow_api[key]["class_type"] == "PromptExpansion": + workflow_api[key]["inputs"]["seed"] = randomSeed(8) + logger.info( + f"Applied random seed {workflow_api[key]['inputs']['seed']} to PromptExpansion" + ) continue - workflow_api[key]['inputs']['seed'] = randomSeed() - logger.info(f"Applied random seed {workflow_api[key]['inputs']['seed']} to {workflow_api[key]['class_type']}") + workflow_api[key]["inputs"]["seed"] = randomSeed() + logger.info( + f"Applied random seed {workflow_api[key]['inputs']['seed']} to {workflow_api[key]['class_type']}" + ) - if 'noise_seed' in workflow_api[key]['inputs']: - if workflow_api[key]['class_type'] == "RandomNoise": - workflow_api[key]['inputs']['noise_seed'] = randomSeed() - logger.info(f"Applied random noise_seed {workflow_api[key]['inputs']['noise_seed']} to RandomNoise") + if "noise_seed" in workflow_api[key]["inputs"]: + if workflow_api[key]["class_type"] == "RandomNoise": + workflow_api[key]["inputs"]["noise_seed"] = randomSeed() + logger.info( + f"Applied random noise_seed {workflow_api[key]['inputs']['noise_seed']} to RandomNoise" + ) continue - if workflow_api[key]['class_type'] == "KSamplerAdvanced": - workflow_api[key]['inputs']['noise_seed'] = randomSeed() - logger.info(f"Applied random noise_seed {workflow_api[key]['inputs']['noise_seed']} to KSamplerAdvanced") + if workflow_api[key]["class_type"] == "KSamplerAdvanced": + workflow_api[key]["inputs"]["noise_seed"] = randomSeed() + logger.info( + f"Applied random noise_seed {workflow_api[key]['inputs']['noise_seed']} to KSamplerAdvanced" + ) continue - if workflow_api[key]['class_type'] == "SamplerCustom": - workflow_api[key]['inputs']['noise_seed'] = randomSeed() - logger.info(f"Applied random noise_seed {workflow_api[key]['inputs']['noise_seed']} to SamplerCustom") + if workflow_api[key]["class_type"] == "SamplerCustom": + workflow_api[key]["inputs"]["noise_seed"] = randomSeed() + logger.info( + f"Applied random noise_seed {workflow_api[key]['inputs']['noise_seed']} to SamplerCustom" + ) continue + def apply_inputs_to_workflow(workflow_api: Any, inputs: Any, sid: str = None): # Loop through each of the inputs and replace them for key, value in workflow_api.items(): - if 'inputs' in value: - + if "inputs" in value: # Support websocket if sid is not None: - if (value["class_type"] == "ComfyDeployWebscoketImageOutput"): - value['inputs']["client_id"] = sid - if (value["class_type"] == "ComfyDeployWebscoketImageInput"): - value['inputs']["client_id"] = sid + if value["class_type"] == "ComfyDeployWebscoketImageOutput": + value["inputs"]["client_id"] = sid + if value["class_type"] == "ComfyDeployWebscoketImageInput": + value["inputs"]["client_id"] = sid - if "input_id" in value['inputs'] and inputs is not None and value['inputs']['input_id'] in inputs: - new_value = inputs[value['inputs']['input_id']] + if ( + "input_id" in value["inputs"] + and inputs is not None + and value["inputs"]["input_id"] in inputs + ): + new_value = inputs[value["inputs"]["input_id"]] # Lets skip it if its an image if isinstance(new_value, Image.Image): continue # Backward compactibility - value['inputs']["input_id"] = new_value + value["inputs"]["input_id"] = new_value # Fix for external text default value - if (value["class_type"] == "ComfyUIDeployExternalText"): - value['inputs']["default_value"] = new_value + if value["class_type"] == "ComfyUIDeployExternalText": + value["inputs"]["default_value"] = new_value - if (value["class_type"] == "ComfyUIDeployExternalCheckpoint"): - value['inputs']["default_value"] = new_value + if value["class_type"] == "ComfyUIDeployExternalCheckpoint": + value["inputs"]["default_value"] = new_value - if (value["class_type"] == "ComfyUIDeployExternalImageBatch"): - value['inputs']["images"] = new_value + if value["class_type"] == "ComfyUIDeployExternalImageBatch": + value["inputs"]["images"] = new_value if value["class_type"] == "ComfyUIDeployExternalLora": value["inputs"]["lora_url"] = new_value @@ -327,6 +379,7 @@ def apply_inputs_to_workflow(workflow_api: Any, inputs: Any, sid: str = None): if value["class_type"] == "ComfyUIDeployExternalBoolean": value["inputs"]["default_value"] = new_value + def send_prompt(sid: str, inputs: StreamingPrompt): # workflow_api = inputs.workflow_api workflow_api = copy.deepcopy(inputs.workflow_api) @@ -334,7 +387,7 @@ def send_prompt(sid: str, inputs: StreamingPrompt): # Random seed apply_random_seed_to_workflow(workflow_api) - logger.info("getting inputs" , inputs.inputs) + logger.info("getting inputs", inputs.inputs) apply_inputs_to_workflow(workflow_api, inputs.inputs, sid=sid) @@ -344,8 +397,8 @@ def send_prompt(sid: str, inputs: StreamingPrompt): prompt = { "prompt": workflow_api, - "client_id": sid, #"comfy_deploy_instance", #api.client_id - "prompt_id": prompt_id + "client_id": sid, # "comfy_deploy_instance", #api.client_id + "prompt_id": prompt_id, } try: @@ -355,23 +408,24 @@ def send_prompt(sid: str, inputs: StreamingPrompt): status_endpoint=inputs.status_endpoint, file_upload_endpoint=inputs.file_upload_endpoint, workflow_api=workflow_api, - is_realtime=True + is_realtime=True, ) except Exception as e: error_type = type(e).__name__ - stack_trace_short = traceback.format_exc().strip().split('\n')[-2] + stack_trace_short = traceback.format_exc().strip().split("\n")[-2] stack_trace = traceback.format_exc().strip() logger.info(f"error: {error_type}, {e}") logger.info(f"stack trace: {stack_trace_short}") + @server.PromptServer.instance.routes.post("/comfyui-deploy/run") async def comfy_deploy_run(request): # Extract the bearer token from the Authorization header - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") token = None if auth_header: parts = auth_header.split() - if len(parts) == 2 and parts[0].lower() == 'bearer': + if len(parts) == 2 and parts[0].lower() == "bearer": token = parts[1] data = await request.json() @@ -381,6 +435,7 @@ async def comfy_deploy_run(request): # The prompt id generated from comfy deploy, can be None prompt_id = data.get("prompt_id") inputs = data.get("inputs") + gpu_event_id = data.get("gpu_event_id", None) # Now it handles directly in here apply_random_seed_to_workflow(workflow_api) @@ -388,45 +443,49 @@ async def comfy_deploy_run(request): prompt = { "prompt": workflow_api, - "client_id": "comfy_deploy_instance", #api.client_id + "client_id": "comfy_deploy_instance", # api.client_id "prompt_id": prompt_id, } prompt_metadata[prompt_id] = SimplePrompt( - status_endpoint=data.get('status_endpoint'), - file_upload_endpoint=data.get('file_upload_endpoint'), + status_endpoint=data.get("status_endpoint"), + file_upload_endpoint=data.get("file_upload_endpoint"), workflow_api=workflow_api, - token=token + token=token, + gpu_event_id=gpu_event_id, ) try: res = post_prompt(prompt) except Exception as e: error_type = type(e).__name__ - stack_trace_short = traceback.format_exc().strip().split('\n')[-2] + stack_trace_short = traceback.format_exc().strip().split("\n")[-2] stack_trace = traceback.format_exc().strip() logger.info(f"error: {error_type}, {e}") logger.info(f"stack trace: {stack_trace_short}") - await update_run_with_output(prompt_id, { - "error": { - "error_type": error_type, - "stack_trace": stack_trace - } - }) - # When there are critical errors, the prompt is actually not run + await update_run_with_output( + prompt_id, + {"error": {"error_type": error_type, "stack_trace": stack_trace}}, + gpu_event_id=gpu_event_id, + ) + # When there are critical errors, the prompt is actually not run await update_run(prompt_id, Status.FAILED) - return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}") + return web.Response( + status=500, reason=f"{error_type}: {e}, {stack_trace_short}" + ) status = 200 - if "node_errors" in res and res["node_errors"] is not None and len(res["node_errors"]) > 0: + if ( + "node_errors" in res + and res["node_errors"] is not None + and len(res["node_errors"]) > 0 + ): # Even tho there are node_errors it can still be run status = 400 - await update_run_with_output(prompt_id, { - "error": { - **res - } - }) + await update_run_with_output( + prompt_id, {"error": {**res}}, gpu_event_id=gpu_event_id + ) # When there are critical errors, the prompt is actually not run if "error" in res: @@ -434,6 +493,7 @@ async def comfy_deploy_run(request): return web.json_response(res, status=status) + async def stream_prompt(data, token): # In older version, we use workflow_api, but this has inputs already swapped in nextjs frontend, which is tricky workflow_api = data.get("workflow_api_raw") @@ -447,15 +507,15 @@ async def stream_prompt(data, token): prompt = { "prompt": workflow_api, - "client_id": "comfy_deploy_instance", #api.client_id - "prompt_id": prompt_id + "client_id": "comfy_deploy_instance", # api.client_id + "prompt_id": prompt_id, } prompt_metadata[prompt_id] = SimplePrompt( - status_endpoint=data.get('status_endpoint'), - file_upload_endpoint=data.get('file_upload_endpoint'), + status_endpoint=data.get("status_endpoint"), + file_upload_endpoint=data.get("file_upload_endpoint"), workflow_api=workflow_api, - token=token + token=token, ) # log('info', "Begin prompt", prompt=prompt) @@ -464,31 +524,28 @@ async def stream_prompt(data, token): res = post_prompt(prompt) except Exception as e: error_type = type(e).__name__ - stack_trace_short = traceback.format_exc().strip().split('\n')[-2] + stack_trace_short = traceback.format_exc().strip().split("\n")[-2] stack_trace = traceback.format_exc().strip() logger.info(f"error: {error_type}, {e}") logger.info(f"stack trace: {stack_trace_short}") - await update_run_with_output(prompt_id, { - "error": { - "error_type": error_type, - "stack_trace": stack_trace - } - }) - # When there are critical errors, the prompt is actually not run + await update_run_with_output( + prompt_id, {"error": {"error_type": error_type, "stack_trace": stack_trace}} + ) + # When there are critical errors, the prompt is actually not run await update_run(prompt_id, Status.FAILED) # return web.Response(status=500, reason=f"{error_type}: {e}, {stack_trace_short}") # raise Exception("Prompt failed") status = 200 - if "node_errors" in res and res["node_errors"] is not None and len(res["node_errors"]) > 0: + if ( + "node_errors" in res + and res["node_errors"] is not None + and len(res["node_errors"]) > 0 + ): # Even tho there are node_errors it can still be run status = 400 - await update_run_with_output(prompt_id, { - "error": { - **res - } - }) + await update_run_with_output(prompt_id, {"error": {**res}}) # When there are critical errors, the prompt is actually not run if "error" in res: @@ -498,19 +555,23 @@ async def stream_prompt(data, token): return res # return web.json_response(res, status=status) + comfy_message_queues: Dict[str, asyncio.Queue] = {} -@server.PromptServer.instance.routes.post('/comfyui-deploy/run/streaming') + +@server.PromptServer.instance.routes.post("/comfyui-deploy/run/streaming") async def stream_response(request): - response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/event-stream'}) + response = web.StreamResponse( + status=200, reason="OK", headers={"Content-Type": "text/event-stream"} + ) await response.prepare(request) - + # Extract the bearer token from the Authorization header - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") token = None if auth_header: parts = auth_header.split() - if len(parts) == 2 and parts[0].lower() == 'bearer': + if len(parts) == 2 and parts[0].lower() == "bearer": token = parts[1] pending = True @@ -519,12 +580,14 @@ async def stream_response(request): prompt_id = data.get("prompt_id") comfy_message_queues[prompt_id] = asyncio.Queue() - with log_span('Streaming Run'): - log('info', 'Streaming prompt') + with log_span("Streaming Run"): + log("info", "Streaming prompt") try: result = await stream_prompt(data=data, token=token) - await response.write(f"event: event_update\ndata: {json.dumps(result)}\n\n".encode('utf-8')) + await response.write( + f"event: event_update\ndata: {json.dumps(result)}\n\n".encode("utf-8") + ) # await response.write(.encode('utf-8')) await response.drain() # Ensure the buffer is flushed @@ -535,39 +598,52 @@ async def stream_response(request): # log('info', data["event"], data=json.dumps(data)) # logger.info("listener", data) - await response.write(f"event: event_update\ndata: {json.dumps(data)}\n\n".encode('utf-8')) + await response.write( + f"event: event_update\ndata: {json.dumps(data)}\n\n".encode( + "utf-8" + ) + ) await response.drain() # Ensure the buffer is flushed if data["event"] == "status": - if data["data"]["status"] in (Status.FAILED.value, Status.SUCCESS.value): + if data["data"]["status"] in ( + Status.FAILED.value, + Status.SUCCESS.value, + ): pending = False await asyncio.sleep(0.1) # Adjust the sleep duration as needed except asyncio.CancelledError: - log('info', "Streaming was cancelled") + log("info", "Streaming was cancelled") raise except Exception as e: - log('error', "Streaming error", error=e) + log("error", "Streaming error", error=e) finally: # event_emitter.off("send_json", task) await response.write_eof() comfy_message_queues.pop(prompt_id, None) return response + def get_comfyui_path_from_file_path(file_path): file_path_parts = file_path.split("\\") if file_path_parts[0] == "input": logger.info("matching input") - file_path = os.path.join(folder_paths.get_directory_by_type("input"), *file_path_parts[1:]) + file_path = os.path.join( + folder_paths.get_directory_by_type("input"), *file_path_parts[1:] + ) elif file_path_parts[0] == "models": logger.info("matching models") - file_path = folder_paths.get_full_path(file_path_parts[1], os.path.join(*file_path_parts[2:])) + file_path = folder_paths.get_full_path( + file_path_parts[1], os.path.join(*file_path_parts[2:]) + ) logger.info(file_path) return file_path + # Form ComfyUI Manager async def compute_sha256_checksum(filepath): logger.info("computing sha256 checksum") @@ -575,7 +651,7 @@ async def compute_sha256_checksum(filepath): filepath = get_comfyui_path_from_file_path(filepath) """Compute the SHA256 checksum of a file, in chunks, asynchronously""" sha256 = hashlib.sha256() - async with aiofiles.open(filepath, 'rb') as f: + async with aiofiles.open(filepath, "rb") as f: while True: chunk = await f.read(chunk_size) if not chunk: @@ -583,7 +659,8 @@ async def compute_sha256_checksum(filepath): sha256.update(chunk) return sha256.hexdigest() -@server.PromptServer.instance.routes.get('/comfyui-deploy/models') + +@server.PromptServer.instance.routes.get("/comfyui-deploy/models") async def get_installed_models(request): # Directly return the list of paths as JSON new_dict = {} @@ -596,8 +673,9 @@ async def get_installed_models(request): # logger.info(new_dict) return web.json_response(new_dict) + # This is start uploading the files to Comfy Deploy -@server.PromptServer.instance.routes.post('/comfyui-deploy/upload-file') +@server.PromptServer.instance.routes.post("/comfyui-deploy/upload-file") async def upload_file_endpoint(request): data = await request.json() @@ -622,96 +700,111 @@ async def upload_file_endpoint(request): file_size = os.path.getsize(file_path) file_extension = os.path.splitext(file_path)[1] - if file_extension in ['.jpg', '.jpeg']: - file_type = 'image/jpeg' - elif file_extension == '.png': - file_type = 'image/png' - elif file_extension == '.webp': - file_type = 'image/webp' + if file_extension in [".jpg", ".jpeg"]: + file_type = "image/jpeg" + elif file_extension == ".png": + file_type = "image/png" + elif file_extension == ".webp": + file_type = "image/webp" else: - file_type = 'application/octet-stream' # Default to binary file type if unknown + file_type = ( + "application/octet-stream" # Default to binary file type if unknown + ) else: - return web.json_response({ - "error": f"File not found: {file_path}" - }, status=404) + return web.json_response( + {"error": f"File not found: {file_path}"}, status=404 + ) except Exception as e: - return web.json_response({ - "error": str(e) - }, status=500) + return web.json_response({"error": str(e)}, status=500) if get_url: try: - headers = {'Authorization': f'Bearer {token}'} - params = {'file_size': file_size, 'type': file_type} - response = await async_request_with_retry('GET', get_url, params=params, headers=headers) + headers = {"Authorization": f"Bearer {token}"} + params = {"file_size": file_size, "type": file_type} + response = await async_request_with_retry( + "GET", get_url, params=params, headers=headers + ) if response.status == 200: content = await response.json() upload_url = content["upload_url"] - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: headers = { "Content-Type": file_type, # "Content-Length": str(file_size) } - if content.get('include_acl') is True: + if content.get("include_acl") is True: headers["x-amz-acl"] = "public-read" - upload_response = await async_request_with_retry('PUT', upload_url, data=f, headers=headers) + upload_response = await async_request_with_retry( + "PUT", upload_url, data=f, headers=headers + ) if upload_response.status == 200: - return web.json_response({ - "message": "File uploaded successfully", - "download_url": content["download_url"] - }) + return web.json_response( + { + "message": "File uploaded successfully", + "download_url": content["download_url"], + } + ) else: - return web.json_response({ - "error": f"Failed to upload file to {upload_url}. Status code: {upload_response.status}" - }, status=upload_response.status) + return web.json_response( + { + "error": f"Failed to upload file to {upload_url}. Status code: {upload_response.status}" + }, + status=upload_response.status, + ) else: - return web.json_response({ - "error": f"Failed to fetch data from {get_url}. Status code: {response.status}" - }, status=response.status) + return web.json_response( + { + "error": f"Failed to fetch data from {get_url}. Status code: {response.status}" + }, + status=response.status, + ) except Exception as e: - return web.json_response({ - "error": f"An error occurred while fetching data from {get_url}: {str(e)}" - }, status=500) + return web.json_response( + { + "error": f"An error occurred while fetching data from {get_url}: {str(e)}" + }, + status=500, + ) - return web.json_response({ - "error": f"File not uploaded" - }, status=500) + return web.json_response({"error": f"File not uploaded"}, status=500) script_dir = os.path.dirname(os.path.abspath(__file__)) # Assuming the cache file is stored in the same directory as this script -CACHE_FILE_PATH = script_dir + '/file-hash-cache.json' +CACHE_FILE_PATH = script_dir + "/file-hash-cache.json" # Global in-memory cache file_hash_cache = {} + # Load cache from disk at startup def load_cache(): global file_hash_cache try: - with open(CACHE_FILE_PATH, 'r') as cache_file: + with open(CACHE_FILE_PATH, "r") as cache_file: file_hash_cache = json.load(cache_file) except (FileNotFoundError, json.JSONDecodeError): file_hash_cache = {} + # Save cache to disk def save_cache(): - with open(CACHE_FILE_PATH, 'w') as cache_file: + with open(CACHE_FILE_PATH, "w") as cache_file: json.dump(file_hash_cache, cache_file) + # Initialize cache on application start load_cache() -@server.PromptServer.instance.routes.get('/comfyui-deploy/get-file-hash') + +@server.PromptServer.instance.routes.get("/comfyui-deploy/get-file-hash") async def get_file_hash(request): - file_path = request.rel_url.query.get('file_path', '') + file_path = request.rel_url.query.get("file_path", "") if not file_path: - return web.json_response({ - "error": "file_path is required" - }, status=400) + return web.json_response({"error": "file_path is required"}, status=400) try: base = folder_paths.base_path @@ -732,29 +825,33 @@ async def get_file_hash(request): save_cache() - return web.json_response({ - "file_hash": file_hash - }) + return web.json_response({"file_hash": file_hash}) except Exception as e: - return web.json_response({ - "error": str(e) - }, status=500) + return web.json_response({"error": str(e)}, status=500) -async def update_realtime_run_status(realtime_id: str, status_endpoint: str, status: Status): + +async def update_realtime_run_status( + realtime_id: str, + status_endpoint: str, + status: Status, + gpu_event_id: str | None = None, +): body = { "run_id": realtime_id, "status": status.value, + "gpu_event_id": gpu_event_id, } - if (status_endpoint is None): + if status_endpoint is None: return # requests.post(status_endpoint, json=body) - await async_request_with_retry('POST', status_endpoint, json=body) + await async_request_with_retry("POST", status_endpoint, json=body) -@server.PromptServer.instance.routes.get('/comfyui-deploy/ws') + +@server.PromptServer.instance.routes.get("/comfyui-deploy/ws") async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) - sid = request.rel_url.query.get('clientId', '') + sid = request.rel_url.query.get("clientId", "") if sid: # Reusing existing session, remove old sockets.pop(sid, None) @@ -763,14 +860,16 @@ async def websocket_handler(request): sockets[sid] = ws - auth_token = request.rel_url.query.get('token', None) - get_workflow_endpoint_url = request.rel_url.query.get('workflow_endpoint', None) - realtime_id = request.rel_url.query.get('realtime_id', None) - status_endpoint = request.rel_url.query.get('status_endpoint', None) + auth_token = request.rel_url.query.get("token", None) + get_workflow_endpoint_url = request.rel_url.query.get("workflow_endpoint", None) + realtime_id = request.rel_url.query.get("realtime_id", None) + status_endpoint = request.rel_url.query.get("status_endpoint", None) if auth_token is not None and get_workflow_endpoint_url is not None: - headers = {'Authorization': f'Bearer {auth_token}'} - response = await async_request_with_retry('GET', get_workflow_endpoint_url, headers=headers) + headers = {"Authorization": f"Bearer {auth_token}"} + response = await async_request_with_retry( + "GET", get_workflow_endpoint_url, headers=headers + ) if response.status == 200: workflow = await response.json() @@ -781,19 +880,25 @@ async def websocket_handler(request): auth_token=auth_token, inputs={}, status_endpoint=status_endpoint, - file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), + file_upload_endpoint=request.rel_url.query.get( + "file_upload_endpoint", None + ), ) - await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING) + await update_realtime_run_status( + realtime_id, status_endpoint, Status.RUNNING + ) # await send("workflow_api", workflow_api, sid) else: error_message = await response.text() - logger.info(f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}") + logger.info( + f"Failed to fetch workflow endpoint. Status: {response.status}, Error: {error_message}" + ) # await send("error", {"message": error_message}, sid) try: # Send initial state to the new client - await send("status", { 'sid': sid }, sid) + await send("status", {"sid": sid}, sid) # Make sure when its connected via client, the full log is not being sent if cd_enable_log and get_workflow_endpoint_url is None: @@ -804,27 +909,31 @@ async def websocket_handler(request): try: data = json.loads(msg.data) logger.info(data) - event_type = data.get('event') - if event_type == 'input': + event_type = data.get("event") + if event_type == "input": logger.info(f"Got input: ${data.get('inputs')}") - input = data.get('inputs') + input = data.get("inputs") streaming_prompt_metadata[sid].inputs.update(input) - elif event_type == 'queue_prompt': + elif event_type == "queue_prompt": clear_current_prompt(sid) send_prompt(sid, streaming_prompt_metadata[sid]) else: # Handle other event types pass except json.JSONDecodeError: - logger.info('Failed to decode JSON from message') + logger.info("Failed to decode JSON from message") if msg.type == aiohttp.WSMsgType.BINARY: data = msg.data - event_type, = struct.unpack(" 0: - logger.info(f"Have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}") + if ( + prompt_id in prompt_metadata + and len(prompt_metadata[prompt_id].uploading_nodes) > 0 + ): + logger.info( + f"Have pending upload {len(prompt_metadata[prompt_id].uploading_nodes)}" + ) return True logger.info("No pending upload") return False + def mark_prompt_done(prompt_id): """ Mark the prompt as done in the prompt metadata. @@ -1266,6 +1449,7 @@ def mark_prompt_done(prompt_id): prompt_metadata[prompt_id].done = True logger.info("Prompt done") + def is_prompt_done(prompt_id: str): """ Check if the prompt with the given ID is marked as done. @@ -1281,6 +1465,7 @@ def is_prompt_done(prompt_id: str): return False + # Use to handle upload error and send back to ComfyDeploy async def handle_error(prompt_id, data, e: Exception): error_type = type(e).__name__ @@ -1288,19 +1473,18 @@ async def handle_error(prompt_id, data, e: Exception): body = { "run_id": prompt_id, "output_data": { - "error": { - "type": error_type, - "message": str(e), - "stack_trace": stack_trace - } - } + "error": {"type": error_type, "message": str(e), "stack_trace": stack_trace} + }, } await update_file_status(prompt_id, data, False, have_error=True) logger.info(body) logger.info(f"Error occurred while uploading file: {e}") + # Mark the current prompt requires upload, and block it from being marked as success -async def update_file_status(prompt_id: str, data, uploading, have_error=False, node_id=None): +async def update_file_status( + prompt_id: str, data, uploading, have_error=False, node_id=None +): # if 'uploading_nodes' not in prompt_metadata[prompt_id]: # prompt_metadata[prompt_id]['uploading_nodes'] = set() @@ -1315,28 +1499,44 @@ async def update_file_status(prompt_id: str, data, uploading, have_error=False, if have_error: await update_run(prompt_id, Status.FAILED) - await send("failed", { - "prompt_id": prompt_id, - }) + await send( + "failed", + { + "prompt_id": prompt_id, + }, + ) return # if there are still nodes that are uploading, then we set the status to uploading if uploading: if prompt_metadata[prompt_id].status != Status.UPLOADING: await update_run(prompt_id, Status.UPLOADING) - await send("uploading", { - "prompt_id": prompt_id, - }) + await send( + "uploading", + { + "prompt_id": prompt_id, + }, + ) # if there are no nodes that are uploading, then we set the status to success - elif not uploading and not have_pending_upload(prompt_id) and is_prompt_done(prompt_id=prompt_id): + elif ( + not uploading + and not have_pending_upload(prompt_id) + and is_prompt_done(prompt_id=prompt_id) + ): await update_run(prompt_id, Status.SUCCESS) # logger.info("Status: SUCCUSS") - await send("success", { - "prompt_id": prompt_id, - }) + await send( + "success", + { + "prompt_id": prompt_id, + }, + ) -async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, default_content_type: str): + +async def handle_upload( + prompt_id: str, data, key: str, content_type_key: str, default_content_type: str +): items = data.get(key, []) upload_tasks = [] @@ -1347,53 +1547,66 @@ async def handle_upload(prompt_id: str, data, key: str, content_type_key: str, d file_type = item.get(content_type_key, default_content_type) file_extension = os.path.splitext(item.get("filename"))[1] - if file_extension in ['.jpg', '.jpeg']: - file_type = 'image/jpeg' - elif file_extension == '.png': - file_type = 'image/png' - elif file_extension == '.webp': - file_type = 'image/webp' + if file_extension in [".jpg", ".jpeg"]: + file_type = "image/jpeg" + elif file_extension == ".png": + file_type = "image/png" + elif file_extension == ".webp": + file_type = "image/webp" - upload_tasks.append(upload_file( - prompt_id, - item.get("filename"), - subfolder=item.get("subfolder"), - type=item.get("type"), - content_type=file_type, - item=item - )) + upload_tasks.append( + upload_file( + prompt_id, + item.get("filename"), + subfolder=item.get("subfolder"), + type=item.get("type"), + content_type=file_type, + item=item, + ) + ) # Execute all upload tasks concurrently await asyncio.gather(*upload_tasks) -# Upload files in the background -async def upload_in_background(prompt_id: str, data, node_id=None, have_upload=True, node_meta=None): + +async def upload_in_background( + prompt_id: str, data, node_id=None, have_upload=True, node_meta=None +): try: upload_tasks = [ - handle_upload(prompt_id, data, 'images', "content_type", "image/png"), - handle_upload(prompt_id, data, 'files', "content_type", "image/png"), - handle_upload(prompt_id, data, 'gifs', "format", "image/gif"), - handle_upload(prompt_id, data, 'mesh', "format", "application/octet-stream") + handle_upload(prompt_id, data, "images", "content_type", "image/png"), + handle_upload(prompt_id, data, "files", "content_type", "image/png"), + handle_upload(prompt_id, data, "gifs", "format", "image/gif"), + handle_upload( + prompt_id, data, "mesh", "format", "application/octet-stream" + ), ] await asyncio.gather(*upload_tasks) status_endpoint = prompt_metadata[prompt_id].status_endpoint token = prompt_metadata[prompt_id].token + gpu_event_id = prompt_metadata[prompt_id].gpu_event_id or None if have_upload: if status_endpoint is not None: body = { "run_id": prompt_id, "output_data": data, "node_meta": node_meta, + "gpu_event_id": gpu_event_id, } # pprint(body) - await async_request_with_retry('POST', status_endpoint, token=token, json=body) + await async_request_with_retry( + "POST", status_endpoint, token=token, json=body + ) await update_file_status(prompt_id, data, False, node_id=node_id) except Exception as e: await handle_error(prompt_id, data, e) -async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None): + +async def update_run_with_output( + prompt_id, data, node_id=None, node_meta=None, gpu_event_id=None +): if prompt_id not in prompt_metadata: return @@ -1406,10 +1619,15 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None): "run_id": prompt_id, "output_data": data, "node_meta": node_meta, + "gpu_event_id": gpu_event_id, } - have_upload_media = 'images' in data or 'files' in data or 'gifs' in data or 'mesh' in data + have_upload_media = ( + "images" in data or "files" in data or "gifs" in data or "mesh" in data + ) if bypass_upload and have_upload_media: - print("CD_BYPASS_UPLOAD is enabled, skipping the upload of the output:", node_id) + print( + "CD_BYPASS_UPLOAD is enabled, skipping the upload of the output:", node_id + ) return if have_upload_media: @@ -1419,7 +1637,15 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None): if have_upload_media: await update_file_status(prompt_id, data, True, node_id=node_id) - asyncio.create_task(upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload_media, node_meta=node_meta)) + asyncio.create_task( + upload_in_background( + prompt_id, + data, + node_id=node_id, + have_upload=have_upload_media, + node_meta=node_meta, + ) + ) # await upload_in_background(prompt_id, data, node_id=node_id, have_upload=have_upload) except Exception as e: @@ -1427,22 +1653,22 @@ async def update_run_with_output(prompt_id, data, node_id=None, node_meta=None): # requests.post(status_endpoint, json=body) elif status_endpoint is not None: token = prompt_metadata[prompt_id].token - await async_request_with_retry('POST', status_endpoint, token=token, json=body) + await async_request_with_retry("POST", status_endpoint, token=token, json=body) + + await send("outputs_uploaded", {"prompt_id": prompt_id}) - await send('outputs_uploaded', { - "prompt_id": prompt_id - }) prompt_server.send_json_original = prompt_server.send_json prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer) root_path = os.path.dirname(os.path.abspath(__file__)) two_dirs_up = os.path.dirname(os.path.dirname(root_path)) -log_file_path = os.path.join(two_dirs_up, 'comfy-deploy.log') -comfyui_file_path = os.path.join(two_dirs_up, 'comfyui.log') +log_file_path = os.path.join(two_dirs_up, "comfy-deploy.log") +comfyui_file_path = os.path.join(two_dirs_up, "comfyui.log") last_read_line = 0 + async def watch_file_changes(file_path, callback): global last_read_line last_modified_time = os.stat(file_path).st_mtime @@ -1451,39 +1677,45 @@ async def watch_file_changes(file_path, callback): modified_time = os.stat(file_path).st_mtime if modified_time != last_modified_time: last_modified_time = modified_time - with open(file_path, 'r') as file: + with open(file_path, "r") as file: lines = file.readlines() if last_read_line > len(lines): last_read_line = 0 # Reset if log file has been rotated new_lines = lines[last_read_line:] last_read_line = len(lines) if new_lines: - await callback(''.join(new_lines)) + await callback("".join(new_lines)) async def send_first_time_log(sid): - with open(log_file_path, 'r') as file: + with open(log_file_path, "r") as file: lines = file.readlines() - await send("LOGS", ''.join(lines), sid) + await send("LOGS", "".join(lines), sid) + async def send_logs_to_websocket(logs): await send("LOGS", logs) + def start_loop(loop): asyncio.set_event_loop(loop) loop.run_forever() + def run_in_new_thread(coroutine): new_loop = asyncio.new_event_loop() t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True) t.start() asyncio.run_coroutine_threadsafe(coroutine, new_loop) + if cd_enable_log: run_in_new_thread(watch_file_changes(log_file_path, send_logs_to_websocket)) + # use after calling GET /object_info (it populates the `filename_list_cache` variable) @server.PromptServer.instance.routes.get("/comfyui-deploy/filename_list_cache") async def get_filename_list_cache(_): from folder_paths import filename_list_cache - return web.json_response({'filename_list': filename_list_cache}) \ No newline at end of file + + return web.json_response({"filename_list": filename_list_cache})