diff --git a/builder/modal-builder/src/main.py b/builder/modal-builder/src/main.py index f717d7c..1f58d8f 100644 --- a/builder/modal-builder/src/main.py +++ b/builder/modal-builder/src/main.py @@ -15,6 +15,9 @@ import signal import logging from fastapi.logger import logger as fastapi_logger import requests +from urllib.parse import parse_qs +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp, Scope, Receive, Send from concurrent.futures import ThreadPoolExecutor @@ -41,6 +44,34 @@ global_timeout = 60 * 4 machine_id_websocket_dict = {} machine_id_status = {} +fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] + +class FlyReplayMiddleware(BaseHTTPMiddleware): + """ + If the wrong instance was picked by the fly.io load balancer we use the fly-replay header + to repeat the request again on the right instance. + + This only works if the right instance is provided as a query_string parameter. + """ + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + query_string = scope.get('query_string', b'').decode() + query_params = parse_qs(query_string) + target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0] + async def send_wrapper(message): + if target_instance != fly_instance_id: + if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']: + # fly.io only seems to look at the fly-replay header if websocket is accepted + message = {'type': 'websocket.accept'} + if 'headers' not in message: + message['headers'] = [] + message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()]) + await send(message) + await self.app(scope, receive, send_wrapper) + + async def check_inactivity(): global last_activity_time while True: @@ -49,7 +80,8 @@ async def check_inactivity(): if len(machine_id_status) == 0: # The application has been inactive for more than 60 seconds. # Scale it down to zero here. - logger.info(f"No activity for {global_timeout} seconds, exiting...") + logger.info( + f"No activity for {global_timeout} seconds, exiting...") # os._exit(0) os.kill(os.getpid(), signal.SIGINT) break @@ -66,11 +98,12 @@ async def lifespan(app: FastAPI): yield logger.info("Cancelling") -# +# app = FastAPI(lifespan=lifespan) - +app.add_middleware(FlyReplayMiddleware) # MODAL_ORG = os.environ.get("MODAL_ORG") + @app.get("/") def read_root(): global last_activity_time @@ -97,14 +130,17 @@ def read_root(): # } # } + class GitCustomNodes(BaseModel): hash: str disabled: bool + class Snapshot(BaseModel): comfyui: str git_custom_nodes: Dict[str, GitCustomNodes] + class Model(BaseModel): name: str type: str @@ -115,12 +151,14 @@ class Model(BaseModel): filename: str url: str + class GPUType(str, Enum): T4 = "T4" A10G = "A10G" A100 = "A100" L4 = "L4" + class Item(BaseModel): machine_id: str name: str @@ -133,7 +171,8 @@ class Item(BaseModel): @classmethod def check_gpu(cls, value): if value not in GPUType.__members__: - raise ValueError(f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}") + raise ValueError( + f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}") return GPUType(value) @@ -143,9 +182,11 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str): machine_id_websocket_dict[machine_id] = websocket # Send existing logs if machine_id in machine_logs_cache: + combined_logs = "\n".join( + log_entry['logs'] for log_entry in machine_logs_cache[machine_id]) await websocket.send_text(json.dumps({"event": "LOGS", "data": { "machine_id": machine_id, - "logs": json.dumps(machine_logs_cache[machine_id]) , + "logs": combined_logs, "timestamp": time.time() }})) try: @@ -173,6 +214,7 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str): # return {"Hello": "World"} + @app.post("/create") async def create_item(item: Item): global last_activity_time @@ -185,13 +227,14 @@ async def create_item(item: Item): # Run the building logic in a separate thread # future = executor.submit(build_logic, item) task = asyncio.create_task(build_logic(item)) - - return JSONResponse(status_code=200, content={"message": "Build Queued"}) + + return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id}) # Initialize the logs cache machine_logs_cache = {} + async def build_logic(item: Item): # Deploy to modal folder_path = f"/app/builds/{item.machine_id}" @@ -239,16 +282,18 @@ async def build_logic(item: Item): if item.machine_id not in machine_logs_cache: machine_logs_cache[item.machine_id] = [] - + machine_logs = machine_logs_cache[item.machine_id] - async def read_stream(stream, isStderr): - while True: - line = await stream.readline() - if line: + url_queue = asyncio.Queue() + + async def read_stream(stream, isStderr, url_queue: asyncio.Queue): + while True: + line = await stream.readline() + if line: l = line.decode('utf-8').strip() - if l == "": + if l == "": continue if not isStderr: @@ -265,12 +310,12 @@ async def build_logic(item: Item): "timestamp": time.time() }})) - - if "Created comfyui_app =>" in l or (l.startswith("https://") and l.endswith(".modal.run")): - if "Created comfyui_app =>" in l: + if "Created comfyui_api =>" in l or (l.startswith("https://") and l.endswith(".modal.run")): + if "Created comfyui_api =>" in l: url = l.split("=>")[1].strip() - else: - # Some case it only prints the url on a blank line + # making sure it is a url + elif "comfyui_api" in l: + # Some case it only prints the url on a blank line url = l if url: @@ -279,6 +324,8 @@ async def build_logic(item: Item): "timestamp": time.time() }) + await url_queue.put(url) + if item.machine_id in machine_id_websocket_dict: await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": { "machine_id": item.machine_id, @@ -288,7 +335,7 @@ async def build_logic(item: Item): await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": { "status": "succuss", }})) - + else: # is error logger.error(l) @@ -306,11 +353,15 @@ async def build_logic(item: Item): await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": { "status": "failed", }})) - else: - break + else: + break - stdout_task = asyncio.create_task(read_stream(process.stdout, False)) - stderr_task = asyncio.create_task(read_stream(process.stderr, True)) + stdout_task = asyncio.create_task( + read_stream(process.stdout, False, url_queue)) + stderr_task = asyncio.create_task( + read_stream(process.stderr, True, url_queue)) + + url = await url_queue.get() await asyncio.wait([stdout_task, stderr_task]) @@ -334,8 +385,9 @@ async def build_logic(item: Item): "logs": "Unable to build the app image.", "timestamp": time.time() }) - requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)}) - + requests.post(item.callback_url, json={ + "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)}) + if item.machine_id in machine_logs_cache: del machine_logs_cache[item.machine_id] @@ -349,7 +401,8 @@ async def build_logic(item: Item): "logs": "App image built, but url is None, unable to parse the url.", "timestamp": time.time() }) - requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)}) + requests.post(item.callback_url, json={ + "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)}) if item.machine_id in machine_logs_cache: del machine_logs_cache[item.machine_id] @@ -359,17 +412,20 @@ async def build_logic(item: Item): # example https://bennykok--my-app-comfyui-app.modal.run/ # my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run" - requests.post(item.callback_url, json={"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)}) + requests.post(item.callback_url, json={ + "machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)}) if item.machine_id in machine_logs_cache: - del machine_logs_cache[item.machine_id] - + del machine_logs_cache[item.machine_id] + logger.info("done") logger.info(url) + def start_loop(loop): asyncio.set_event_loop(loop) loop.run_forever() + def run_in_new_thread(coroutine): new_loop = asyncio.new_event_loop() t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True) @@ -377,6 +433,7 @@ def run_in_new_thread(coroutine): asyncio.run_coroutine_threadsafe(coroutine, new_loop) return t + if __name__ == "__main__": import uvicorn # , log_level="debug" diff --git a/builder/modal-builder/src/template/app.py b/builder/modal-builder/src/template/app.py index 0e0fd04..96a7a70 100644 --- a/builder/modal-builder/src/template/app.py +++ b/builder/modal-builder/src/template/app.py @@ -1,3 +1,4 @@ +from config import config import modal from modal import Image, Mount, web_endpoint, Stub, asgi_app import json @@ -12,7 +13,6 @@ from fastapi.responses import HTMLResponse import os current_directory = os.path.dirname(os.path.realpath(__file__)) -from config import config deploy_test = config["deploy_test"] == "True" # MODAL_IMAGE_ID = os.environ.get('MODAL_IMAGE_ID', None) @@ -30,8 +30,41 @@ print("deploy_test ", deploy_test) stub = Stub(name=config["name"]) if not deploy_test: - dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data")) + # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data")) + # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data")) + dockerfile_image = ( + modal.Image.debian_slim() + .apt_install("git", "wget") + .run_commands( + # Basic comfyui setup + "git clone https://github.com/comfyanonymous/ComfyUI.git /comfyui", + "cd /comfyui && pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121", + + # Install comfyui manager + "cd /comfyui/custom_nodes && git clone --depth 1 https://github.com/ltdrdata/ComfyUI-Manager.git", + "cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt", + "cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts", + + # Install comfy deploy + "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git", + ) + .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui") + .copy_local_file(f"{current_directory}/data/snapshot.json", "/comfyui/custom_nodes/ComfyUI-Manager/startup-scripts/restore-snapshot.json") + + .copy_local_file(f"{current_directory}/data/start.sh", "/start.sh") + .run_commands("chmod +x /start.sh") + + .copy_local_file(f"{current_directory}/data/install_deps.py", "/") + .copy_local_file(f"{current_directory}/data/models.json", "/") + .copy_local_file(f"{current_directory}/data/deps.json", "/") + + .run_commands("python install_deps.py") + + .pip_install( + "git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm" + ) + ) # Time to wait between API check attempts in milliseconds COMFY_API_AVAILABLE_INTERVAL_MS = 50 @@ -44,6 +77,7 @@ COMFY_POLLING_MAX_RETRIES = 500 # Host where ComfyUI is running COMFY_HOST = "127.0.0.1:8188" + def check_server(url, retries=50, delay=500): import requests import time @@ -71,7 +105,6 @@ def check_server(url, retries=50, delay=500): # If an exception occurs, the server may not be ready pass - # print(f"runpod-worker-comfy - trying") # Wait for the specified delay before retrying @@ -82,29 +115,37 @@ def check_server(url, retries=50, delay=500): ) return False + def check_status(prompt_id): - req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}") + req = urllib.request.Request( + f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}") return json.loads(urllib.request.urlopen(req).read()) + class Input(BaseModel): prompt_id: str workflow_api: dict status_endpoint: str file_upload_endpoint: str + def queue_workflow_comfy_deploy(data: Input): data_str = data.json() - data_bytes = data_str.encode('utf-8') - req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes) + data_bytes = data_str.encode('utf-8') + req = urllib.request.Request( + f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes) return json.loads(urllib.request.urlopen(req).read()) + class RequestInput(BaseModel): input: Input + image = Image.debian_slim() target_image = image if deploy_test else dockerfile_image + @stub.function(image=target_image, gpu=config["gpu"]) def run(input: Input): import subprocess @@ -112,8 +153,9 @@ def run(input: Input): # Make sure that the ComfyUI API is available print(f"comfy-modal - check server") - command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata"] - server_process = subprocess.Popen(command) + command = ["python", "main.py", + "--disable-auto-launch", "--disable-metadata"] + server_process = subprocess.Popen(command, cwd="/comfyui") check_server( f"http://{COMFY_HOST}", @@ -128,7 +170,8 @@ def run(input: Input): # Queue the workflow try: # job_input is the json input - queued_workflow = queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow) + queued_workflow = queue_workflow_comfy_deploy( + job_input) # queue_workflow(workflow) prompt_id = queued_workflow["prompt_id"] print(f"comfy-modal - queued workflow with ID {prompt_id}") except Exception as e: @@ -170,11 +213,12 @@ def run(input: Input): # Get the generated image and return it as URL in an AWS bucket or as base64 # images_result = process_output_images(history[prompt_id].get("outputs"), job["id"]) # result = {**images_result, "refresh_worker": REFRESH_WORKER} - result = { "status": status } + result = {"status": status} return result print("Running remotely on Modal!") + @web_app.post("/run") async def bar(request_input: RequestInput): # print(request_input) @@ -182,7 +226,73 @@ async def bar(request_input: RequestInput): return run.remote(request_input.input) # pass + @stub.function(image=image) @asgi_app() +def comfyui_api(): + return web_app + + +HOST = "127.0.0.1" +PORT = "8188" + + +def spawn_comfyui_in_background(): + import socket + import subprocess + + process = subprocess.Popen( + [ + "python", + "main.py", + "--dont-print-server", + "--port", + PORT, + ], + cwd="/comfyui", + ) + + # Poll until webserver accepts connections before running inputs. + while True: + try: + socket.create_connection((HOST, int(PORT)), timeout=1).close() + print("ComfyUI webserver ready!") + break + except (socket.timeout, ConnectionRefusedError): + # Check if launcher webserving process has exited. + # If so, a connection can never be made. + retcode = process.poll() + if retcode is not None: + raise RuntimeError( + f"comfyui main.py exited unexpectedly with code {retcode}" + ) + + +@stub.function( + image=target_image, + gpu=config["gpu"], + # Allows 100 concurrent requests per container. + allow_concurrent_inputs=100, + # Restrict to 1 container because we want to our ComfyUI session state + # to be on a single container. + concurrency_limit=1, + timeout=10 * 60, +) +@asgi_app() def comfyui_app(): - return web_app \ No newline at end of file + from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig + from asgiproxy.context import ProxyContext + from asgiproxy.simple_proxy import make_simple_proxy_app + + spawn_comfyui_in_background() + + config = type( + "Config", + (BaseURLProxyConfigMixin, ProxyConfig), + { + "upstream_base_url": f"http://{HOST}:{PORT}", + "rewrite_host_header": f"{HOST}:{PORT}", + }, + )() + + return make_simple_proxy_app(ProxyContext(config)) diff --git a/builder/modal-builder/src/template/data/install_deps.py b/builder/modal-builder/src/template/data/install_deps.py index 568afe6..b90a8a5 100644 --- a/builder/modal-builder/src/template/data/install_deps.py +++ b/builder/modal-builder/src/template/data/install_deps.py @@ -3,9 +3,9 @@ import requests import time import subprocess -command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"] +command = ["python", "main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"] # Start the server -server_process = subprocess.Popen(command) +server_process = subprocess.Popen(command, cwd="/comfyui") def check_server(url, retries=50, delay=500): for i in range(retries):