diff --git a/builder/modal-builder/src/main.py b/builder/modal-builder/src/main.py index a3d3c25..caa2459 100644 --- a/builder/modal-builder/src/main.py +++ b/builder/modal-builder/src/main.py @@ -46,6 +46,7 @@ machine_id_status = {} fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] + class FlyReplayMiddleware(BaseHTTPMiddleware): """ If the wrong instance was picked by the fly.io load balancer we use the fly-replay header @@ -53,13 +54,16 @@ class FlyReplayMiddleware(BaseHTTPMiddleware): This only works if the right instance is provided as a query_string parameter. """ + def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: query_string = scope.get('query_string', b'').decode() query_params = parse_qs(query_string) - target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0] + target_instance = query_params.get( + 'fly_instance_id', [fly_instance_id])[0] + async def send_wrapper(message): if target_instance != fly_instance_id: if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']: @@ -67,7 +71,8 @@ class FlyReplayMiddleware(BaseHTTPMiddleware): message = {'type': 'websocket.accept'} if 'headers' not in message: message['headers'] = [] - message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()]) + message['headers'].append( + [b'fly-replay', f'instance={target_instance}'.encode()]) await send(message) await self.app(scope, receive, send_wrapper) @@ -216,7 +221,7 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str): @app.post("/create") -async def create_item(item: Item): +async def create_machine(item: Item): global last_activity_time last_activity_time = time.time() logger.info(f"Extended inactivity time to {global_timeout}") @@ -231,6 +236,54 @@ async def create_item(item: Item): return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id}) +class StopAppItem(BaseModel): + machine_id: str + + +def find_app_id(app_list, app_name): + for app in app_list: + if app['Name'] == app_name: + return app['App ID'] + return None + +@app.post("/stop-app") +async def stop_app(item: StopAppItem): + # cmd = f"modal app list | grep {item.machine_id} | awk -F '│' '{{print $2}}'" + cmd = f"modal app list --json" + + env = os.environ.copy() + env["COLUMNS"] = "10000" # Set the width to a large value + find_id_process = await asyncio.subprocess.create_subprocess_shell(cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env) + await find_id_process.wait() + + stdout, stderr = await find_id_process.communicate() + if stdout: + app_id = stdout.decode().strip() + app_list = json.loads(app_id) + app_id = find_app_id(app_list, item.machine_id) + logger.info(f"cp_process stdout: {app_id}") + if stderr: + logger.info(f"cp_process stderr: {stderr.decode()}") + + cp_process = await asyncio.subprocess.create_subprocess_exec("modal", "app", "stop", app_id, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE,) + await cp_process.wait() + logger.info(f"Stopping app {item.machine_id}") + stdout, stderr = await cp_process.communicate() + if stdout: + logger.info(f"cp_process stdout: {stdout.decode()}") + if stderr: + logger.info(f"cp_process stderr: {stderr.decode()}") + + if cp_process.returncode == 0: + return JSONResponse(status_code=200, content={"status": "success"}) + else: + return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()}) + # Initialize the logs cache machine_logs_cache = {} diff --git a/web/src/server/curdMachine.ts b/web/src/server/curdMachine.ts index 5f193bc..0961815 100644 --- a/web/src/server/curdMachine.ts +++ b/web/src/server/curdMachine.ts @@ -219,6 +219,28 @@ export const updateMachine = withServerPromise( export const deleteMachine = withServerPromise( async (machine_id: string): Promise<{ message: string }> => { + const machine = await db.query.machinesTable.findFirst({ + where: eq(machinesTable.id, machine_id), + }); + + if (machine?.type === "comfy-deploy-serverless") { + // Call remote builder to stop the app on modal + const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/stop-app`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + machine_id: machine_id, + }), + }); + + if (!result.ok) { + const error_log = await result.text(); + throw new Error(`Error: ${result.statusText} ${error_log}`); + } + } + await db.delete(machinesTable).where(eq(machinesTable.id, machine_id)); revalidatePath("/machines"); return { message: "Machine Deleted" };