from typing import Union, Optional, Dict, List from pydantic import BaseModel, Field, field_validator from fastapi import FastAPI, HTTPException, WebSocket, BackgroundTasks, WebSocketDisconnect from fastapi.responses import JSONResponse from fastapi.logger import logger as fastapi_logger import os from enum import Enum import json import subprocess import time from uuid import uuid4 from contextlib import asynccontextmanager import asyncio import threading import signal import logging from fastapi.logger import logger as fastapi_logger import requests from urllib.parse import parse_qs from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp, Scope, Receive, Send import modal from concurrent.futures import ThreadPoolExecutor # executor = ThreadPoolExecutor(max_workers=5) gunicorn_error_logger = logging.getLogger("gunicorn.error") gunicorn_logger = logging.getLogger("gunicorn") uvicorn_access_logger = logging.getLogger("uvicorn.access") uvicorn_access_logger.handlers = gunicorn_error_logger.handlers fastapi_logger.handlers = gunicorn_error_logger.handlers if __name__ != "__main__": fastapi_logger.setLevel(gunicorn_logger.level) else: fastapi_logger.setLevel(logging.DEBUG) logger = logging.getLogger("uvicorn") logger.setLevel(logging.INFO) last_activity_time = time.time() global_timeout = 60 * 4 machine_id_websocket_dict = {} machine_id_status = {} fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local') public_model_volume_name = os.environ.get('PUBLIC_MODEL_VOLUME_NAME', 'local') class FlyReplayMiddleware(BaseHTTPMiddleware): """ If the wrong instance was picked by the fly.io load balancer we use the fly-replay header to repeat the request again on the right instance. This only works if the right instance is provided as a query_string parameter. """ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: query_string = scope.get('query_string', b'').decode() query_params = parse_qs(query_string) target_instance = query_params.get( 'fly_instance_id', [fly_instance_id])[0] async def send_wrapper(message): if target_instance != fly_instance_id: if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']: # fly.io only seems to look at the fly-replay header if websocket is accepted message = {'type': 'websocket.accept'} if 'headers' not in message: message['headers'] = [] message['headers'].append( [b'fly-replay', f'instance={target_instance}'.encode()]) await send(message) await self.app(scope, receive, send_wrapper) async def check_inactivity(): global last_activity_time while True: # logger.info("Checking inactivity...") if time.time() - last_activity_time > global_timeout: if len(machine_id_status) == 0: # The application has been inactive for more than 60 seconds. # Scale it down to zero here. logger.info( f"No activity for {global_timeout} seconds, exiting...") # os._exit(0) os.kill(os.getpid(), signal.SIGINT) break else: pass # logger.info(f"Timeout but still in progress") await asyncio.sleep(1) # Check every second @asynccontextmanager async def lifespan(app: FastAPI): thread = run_in_new_thread(check_inactivity()) yield logger.info("Cancelling") # app = FastAPI(lifespan=lifespan) app.add_middleware(FlyReplayMiddleware) # MODAL_ORG = os.environ.get("MODAL_ORG") @app.get("/") def read_root(): global last_activity_time last_activity_time = time.time() logger.info(f"Extended inactivity time to {global_timeout}") return {"Hello": "World"} # create a post route called /create takes in a json of example # { # name: "my first image", # deps: { # "comfyui": "d0165d819afe76bd4e6bdd710eb5f3e571b6a804", # "git_custom_nodes": { # "https://github.com/cubiq/ComfyUI_IPAdapter_plus": { # "hash": "2ca0c6dd0b2ad64b1c480828638914a564331dcd", # "disabled": true # }, # "https://github.com/ltdrdata/ComfyUI-Manager.git": { # "hash": "9c86f62b912f4625fe2b929c7fc61deb9d16f6d3", # "disabled": false # }, # }, # "file_custom_nodes": [] # } # } class GitCustomNodes(BaseModel): hash: str disabled: bool pip: Optional[List[str]] = None class FileCustomNodes(BaseModel): filename: str disabled: bool class Snapshot(BaseModel): comfyui: str git_custom_nodes: Dict[str, GitCustomNodes] file_custom_nodes: List[FileCustomNodes] class Model(BaseModel): name: str type: str base: str save_path: str description: str reference: str filename: str url: str class GPUType(str, Enum): T4 = "T4" A10G = "A10G" A100 = "A100" L4 = "L4" class Item(BaseModel): machine_id: str name: str snapshot: Snapshot models: List[Model] callback_url: str model_volume_name: str run_timeout: Optional[int] = Field(default=60 * 5) idle_timeout: Optional[int] = Field(default=60) gpu: GPUType = Field(default=GPUType.T4) @field_validator('gpu') @classmethod def check_gpu(cls, value): if value not in GPUType.__members__: raise ValueError( f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}") return GPUType(value) @app.websocket("/ws/{machine_id}") async def websocket_endpoint(websocket: WebSocket, machine_id: str): await websocket.accept() machine_id_websocket_dict[machine_id] = websocket # Send existing logs if machine_id in machine_logs_cache: combined_logs = "\n".join( log_entry['logs'] for log_entry in machine_logs_cache[machine_id]) await websocket.send_text(json.dumps({"event": "LOGS", "data": { "machine_id": machine_id, "logs": combined_logs, "timestamp": time.time() }})) try: while True: data = await websocket.receive_text() global last_activity_time last_activity_time = time.time() logger.info(f"Extended inactivity time to {global_timeout}") # You can handle received messages here if needed except WebSocketDisconnect: if machine_id in machine_id_websocket_dict: machine_id_websocket_dict.pop(machine_id) # @app.get("/test") # async def test(): # machine_id_status["123"] = True # global last_activity_time # last_activity_time = time.time() # logger.info(f"Extended inactivity time to {global_timeout}") # await asyncio.sleep(10) # machine_id_status["123"] = False # machine_id_status.pop("123") # return {"Hello": "World"} # definition based on web schema class UploadType(str, Enum): checkpoint = "checkpoint" lora = "lora" embedding = "embedding" clip = "clip" clip_vision = "clip_vision" configs = "configs" controlnet = "controlnet" upscale_models = "upscale_models" vae = "vae" ipadapter = "ipadapter" other = "other" class UploadBody(BaseModel): download_url: str volume_name: str volume_id: str model_id: str upload_type: UploadType callback_url: str # based on ComfyUI's model dir, and our mappings in ./src/template/data/extra_model_paths.yaml UPLOAD_TYPE_DIR_MAP = { UploadType.checkpoint: "checkpoints", UploadType.lora: "loras", UploadType.embedding: "embeddings", UploadType.clip: "clip", UploadType.clip_vision: "clip_vision", UploadType.configs: "configs", UploadType.controlnet: "controlnet", UploadType.upscale_models: "upscale_models", UploadType.vae: "vae", UploadType.ipadapter: "ipadapter", UploadType.other: "", } class DeleteBody(BaseModel): volume_name: str path: str file_name: str @app.post("/delete-volume-model") async def delete_model(body: DeleteBody): global last_activity_time last_activity_time = time.time() logger.info(f"Extended inactivity time to {global_timeout}") full_path = f"{body.path.rstrip('/')}/{body.file_name}" rm_process = await asyncio.subprocess.create_subprocess_exec("modal", "volume", "rm", body.volume_name, full_path, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,) await rm_process.wait() logger.info(f"Successfully deleted: {full_path} from volume: {body.volume_name}") stdout, stderr = await rm_process.communicate() if stdout: logger.info(f"cp_process stdout: {stdout.decode()}") if stderr: logger.info(f"cp_process stderr: {stderr.decode()}") if rm_process.returncode == 0: return JSONResponse(status_code=200, content={"status":f"Successfully deleted {full_path} from volume {body.volume_name}"}) else: return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()}) @app.post("/upload-volume") async def upload_model(body: UploadBody): global last_activity_time last_activity_time = time.time() logger.info(f"Extended inactivity time to {global_timeout}") asyncio.create_task(upload_logic(body)) # check that this return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id}) async def upload_logic(body: UploadBody): folder_path = f"/app/builds/{body.volume_id}-{uuid4()}" cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume_builder", folder_path) await cp_process.wait() upload_path = UPLOAD_TYPE_DIR_MAP[body.upload_type] if upload_path == "": # TODO: deal with custom paths pass config = { "volume_names": { body.volume_name: {"download_url": body.download_url, "folder_path": upload_path} }, "volume_paths": { body.volume_name: f'/volumes/{uuid4()}' }, "callback_url": body.callback_url, "callback_body": { "model_id": body.model_id, "volume_id": body.volume_id, "folder_path": upload_path, }, "civitai_api_key": os.environ.get('CIVITAI_API_KEY'), "app_name": f"vol_name_{uuid4()}" } with open(f"{folder_path}/config.py", "w") as f: f.write("config = " + json.dumps(config)) process = await asyncio.subprocess.create_subprocess_shell( f"python runner.py", cwd=folder_path, env={**os.environ, "COLUMNS": "10000"} ) await process.wait() # import modal # modal.deploy_stub(stub) # stub["download_model"].web_url @app.post("/create") async def create_machine(item: Item): global last_activity_time last_activity_time = time.time() logger.info(f"Extended inactivity time to {global_timeout}") if item.machine_id in machine_id_status and machine_id_status[item.machine_id]: return JSONResponse(status_code=400, content={"error": "Build already in progress."}) # Run the building logic in a separate thread # future = executor.submit(build_logic, item) task = asyncio.create_task(build_logic(item)) return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id}) class StopAppItem(BaseModel): machine_id: str def find_app_id(app_list, app_name): for app in app_list: if app['Name'] == app_name: return app['App ID'] return None @app.post("/stop-app") async def stop_app(item: StopAppItem): # cmd = f"modal app list | grep {item.machine_id} | awk -F '│' '{{print $2}}'" cmd = f"modal app list --json" env = os.environ.copy() env["COLUMNS"] = "10000" # Set the width to a large value find_id_process = await asyncio.subprocess.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env) await find_id_process.wait() stdout, stderr = await find_id_process.communicate() if stdout: app_id = stdout.decode().strip() app_list = json.loads(app_id) app_id = find_app_id(app_list, item.machine_id) logger.info(f"cp_process stdout: {app_id}") if stderr: logger.info(f"cp_process stderr: {stderr.decode()}") cp_process = await asyncio.subprocess.create_subprocess_exec("modal", "app", "stop", app_id, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,) await cp_process.wait() logger.info(f"Stopping app {item.machine_id}") stdout, stderr = await cp_process.communicate() if stdout: logger.info(f"cp_process stdout: {stdout.decode()}") if stderr: logger.info(f"cp_process stderr: {stderr.decode()}") if cp_process.returncode == 0: return JSONResponse(status_code=200, content={"status": "success"}) else: return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()}) # Initialize the logs cache machine_logs_cache = {} async def build_logic(item: Item): # Deploy to modal folder_path = f"/app/builds/{item.machine_id}" machine_id_status[item.machine_id] = True # Ensure the os path is same as the current directory # os.chdir(os.path.dirname(os.path.realpath(__file__))) # print( # f"builder - Current working directory: {os.getcwd()}" # ) # Copy the app template # os.system(f"cp -r template {folder_path}") cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/template", folder_path) await cp_process.wait() pip_modules = set() for git_custom_node in item.snapshot.git_custom_nodes.values(): if git_custom_node.pip: pip_modules.update(git_custom_node.pip) # Write the config file config = { "name": item.name, "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"), "gpu": item.gpu, "public_model_volume": public_model_volume_name, "private_model_volume": item.model_volume_name, "pip": list(pip_modules), "run_timeout": item.run_timeout, "idle_timeout": item.idle_timeout, } with open(f"{folder_path}/config.py", "w") as f: f.write("config = " + json.dumps(config)) with open(f"{folder_path}/data/snapshot.json", "w") as f: f.write(item.snapshot.json()) with open(f"{folder_path}/data/models.json", "w") as f: models_json_list = [model.dict() for model in item.models] models_json_string = json.dumps(models_json_list) f.write(models_json_string) # os.chdir(folder_path) # process = subprocess.Popen(f"modal deploy {folder_path}/app.py", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) process = await asyncio.subprocess.create_subprocess_shell( f"modal deploy app.py", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=folder_path, env={**os.environ, "COLUMNS": "10000"} ) url = None if item.machine_id not in machine_logs_cache: machine_logs_cache[item.machine_id] = [] machine_logs = machine_logs_cache[item.machine_id] url_queue = asyncio.Queue() async def read_stream(stream, isStderr, url_queue: asyncio.Queue): while True: line = await stream.readline() if line: l = line.decode('utf-8').strip() if l == "": continue if not isStderr: logger.info(l) machine_logs.append({ "logs": l, "timestamp": time.time() }) if item.machine_id in machine_id_websocket_dict: await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": { "machine_id": item.machine_id, "logs": l, "timestamp": time.time() }})) if "Created comfyui_api =>" in l or ((l.startswith("https://") or l.startswith("│")) and l.endswith(".modal.run")): if "Created comfyui_api =>" in l: url = l.split("=>")[1].strip() # making sure it is a url elif "comfyui-api" in l: # Some case it only prints the url on a blank line if l.startswith("│"): url = l.split("│")[1].strip() else: url = l if url: machine_logs.append({ "logs": f"App image built, url: {url}", "timestamp": time.time() }) await url_queue.put(url) if item.machine_id in machine_id_websocket_dict: await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": { "machine_id": item.machine_id, "logs": f"App image built, url: {url}", "timestamp": time.time() }})) await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": { "status": "succuss", }})) else: # is error logger.error(l) machine_logs.append({ "logs": l, "timestamp": time.time() }) if item.machine_id in machine_id_websocket_dict: await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": { "machine_id": item.machine_id, "logs": l, "timestamp": time.time() }})) await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": { "status": "failed", }})) else: break stdout_task = asyncio.create_task( read_stream(process.stdout, False, url_queue)) stderr_task = asyncio.create_task( read_stream(process.stderr, True, url_queue)) await asyncio.wait([stdout_task, stderr_task]) # Wait for the subprocess to finish await process.wait() if not url_queue.empty(): # The queue is not empty, you can get an item url = await url_queue.get() # Close the ws connection and also pop the item if item.machine_id in machine_id_websocket_dict and machine_id_websocket_dict[item.machine_id] is not None: await machine_id_websocket_dict[item.machine_id].close() if item.machine_id in machine_id_websocket_dict: machine_id_websocket_dict.pop(item.machine_id) if item.machine_id in machine_id_status: machine_id_status[item.machine_id] = False # Check for errors if process.returncode != 0: logger.info("An error occurred.") # Send a post request with the json body machine_id to the callback url machine_logs.append({ "logs": "Unable to build the app image.", "timestamp": time.time() }) requests.post(item.callback_url, json={ "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)}) if item.machine_id in machine_logs_cache: del machine_logs_cache[item.machine_id] return # return JSONResponse(status_code=400, content={"error": "Unable to build the app image."}) # app_suffix = "comfyui-app" if url is None: machine_logs.append({ "logs": "App image built, but url is None, unable to parse the url.", "timestamp": time.time() }) requests.post(item.callback_url, json={ "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)}) if item.machine_id in machine_logs_cache: del machine_logs_cache[item.machine_id] return # return JSONResponse(status_code=400, content={"error": "App image built, but url is None, unable to parse the url."}) # example https://bennykok--my-app-comfyui-app.modal.run/ # my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run" requests.post(item.callback_url, json={ "machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)}) if item.machine_id in machine_logs_cache: del machine_logs_cache[item.machine_id] logger.info("done") logger.info(url) def start_loop(loop): asyncio.set_event_loop(loop) loop.run_forever() def run_in_new_thread(coroutine): new_loop = asyncio.new_event_loop() t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True) t.start() asyncio.run_coroutine_threadsafe(coroutine, new_loop) return t if __name__ == "__main__": import uvicorn # , log_level="debug" uvicorn.run("main:app", host="0.0.0.0", port=8080, lifespan="on")