fix: properly handle deleting device

This commit is contained in:
BennyKok 2024-01-10 13:58:51 +08:00
parent 924d85443d
commit c949ea17b2
2 changed files with 78 additions and 3 deletions

View File

@ -46,6 +46,7 @@ machine_id_status = {}
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
class FlyReplayMiddleware(BaseHTTPMiddleware): class FlyReplayMiddleware(BaseHTTPMiddleware):
""" """
If the wrong instance was picked by the fly.io load balancer we use the fly-replay header If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
@ -53,13 +54,16 @@ class FlyReplayMiddleware(BaseHTTPMiddleware):
This only works if the right instance is provided as a query_string parameter. This only works if the right instance is provided as a query_string parameter.
""" """
def __init__(self, app: ASGIApp) -> None: def __init__(self, app: ASGIApp) -> None:
self.app = app self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
query_string = scope.get('query_string', b'').decode() query_string = scope.get('query_string', b'').decode()
query_params = parse_qs(query_string) query_params = parse_qs(query_string)
target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0] target_instance = query_params.get(
'fly_instance_id', [fly_instance_id])[0]
async def send_wrapper(message): async def send_wrapper(message):
if target_instance != fly_instance_id: if target_instance != fly_instance_id:
if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']: if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
@ -67,7 +71,8 @@ class FlyReplayMiddleware(BaseHTTPMiddleware):
message = {'type': 'websocket.accept'} message = {'type': 'websocket.accept'}
if 'headers' not in message: if 'headers' not in message:
message['headers'] = [] message['headers'] = []
message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()]) message['headers'].append(
[b'fly-replay', f'instance={target_instance}'.encode()])
await send(message) await send(message)
await self.app(scope, receive, send_wrapper) await self.app(scope, receive, send_wrapper)
@ -216,7 +221,7 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
@app.post("/create") @app.post("/create")
async def create_item(item: Item): async def create_machine(item: Item):
global last_activity_time global last_activity_time
last_activity_time = time.time() last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}") logger.info(f"Extended inactivity time to {global_timeout}")
@ -231,6 +236,54 @@ async def create_item(item: Item):
return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id}) return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id})
class StopAppItem(BaseModel):
machine_id: str
def find_app_id(app_list, app_name):
for app in app_list:
if app['Name'] == app_name:
return app['App ID']
return None
@app.post("/stop-app")
async def stop_app(item: StopAppItem):
# cmd = f"modal app list | grep {item.machine_id} | awk -F '│' '{{print $2}}'"
cmd = f"modal app list --json"
env = os.environ.copy()
env["COLUMNS"] = "10000" # Set the width to a large value
find_id_process = await asyncio.subprocess.create_subprocess_shell(cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env)
await find_id_process.wait()
stdout, stderr = await find_id_process.communicate()
if stdout:
app_id = stdout.decode().strip()
app_list = json.loads(app_id)
app_id = find_app_id(app_list, item.machine_id)
logger.info(f"cp_process stdout: {app_id}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
cp_process = await asyncio.subprocess.create_subprocess_exec("modal", "app", "stop", app_id,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,)
await cp_process.wait()
logger.info(f"Stopping app {item.machine_id}")
stdout, stderr = await cp_process.communicate()
if stdout:
logger.info(f"cp_process stdout: {stdout.decode()}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
if cp_process.returncode == 0:
return JSONResponse(status_code=200, content={"status": "success"})
else:
return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()})
# Initialize the logs cache # Initialize the logs cache
machine_logs_cache = {} machine_logs_cache = {}

View File

@ -219,6 +219,28 @@ export const updateMachine = withServerPromise(
export const deleteMachine = withServerPromise( export const deleteMachine = withServerPromise(
async (machine_id: string): Promise<{ message: string }> => { async (machine_id: string): Promise<{ message: string }> => {
const machine = await db.query.machinesTable.findFirst({
where: eq(machinesTable.id, machine_id),
});
if (machine?.type === "comfy-deploy-serverless") {
// Call remote builder to stop the app on modal
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/stop-app`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
machine_id: machine_id,
}),
});
if (!result.ok) {
const error_log = await result.text();
throw new Error(`Error: ${result.statusText} ${error_log}`);
}
}
await db.delete(machinesTable).where(eq(machinesTable.id, machine_id)); await db.delete(machinesTable).where(eq(machinesTable.id, machine_id));
revalidatePath("/machines"); revalidatePath("/machines");
return { message: "Machine Deleted" }; return { message: "Machine Deleted" };