feat: add timeout and run_log

This commit is contained in:
bennykok 2024-01-30 21:42:30 +08:00
parent 2193dd287d
commit b0d1bcc303
11 changed files with 1506 additions and 78 deletions

View File

@ -180,6 +180,8 @@ class Item(BaseModel):
models: List[Model] models: List[Model]
callback_url: str callback_url: str
model_volume_name: str model_volume_name: str
run_timeout: Optional[int] = Field(default=60 * 5)
idle_timeout: Optional[int] = Field(default=60)
gpu: GPUType = Field(default=GPUType.T4) gpu: GPUType = Field(default=GPUType.T4)
@field_validator('gpu') @field_validator('gpu')
@ -391,7 +393,9 @@ async def build_logic(item: Item):
"gpu": item.gpu, "gpu": item.gpu,
"public_model_volume": public_model_volume_name, "public_model_volume": public_model_volume_name,
"private_model_volume": item.model_volume_name, "private_model_volume": item.model_volume_name,
"pip": list(pip_modules) "pip": list(pip_modules),
"run_timeout": item.run_timeout,
"idle_timeout": item.idle_timeout,
} }
with open(f"{folder_path}/config.py", "w") as f: with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config)) f.write("config = " + json.dumps(config))

View File

@ -112,7 +112,7 @@ def check_server(url, retries=50, delay=500):
# If the response status code is 200, the server is up and running # If the response status code is 200, the server is up and running
if response.status_code == 200: if response.status_code == 200:
print(f"runpod-worker-comfy - API is reachable") print(f"comfy-modal - API is reachable")
return True return True
except requests.RequestException as e: except requests.RequestException as e:
# If an exception occurs, the server may not be ready # If an exception occurs, the server may not be ready
@ -124,7 +124,7 @@ def check_server(url, retries=50, delay=500):
time.sleep(delay / 1000) time.sleep(delay / 1000)
print( print(
f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts." f"comfy-modal - Failed to connect to server at {url} after {retries} attempts."
) )
return False return False
@ -158,20 +158,68 @@ image = Image.debian_slim()
target_image = image if deploy_test else dockerfile_image target_image = image if deploy_test else dockerfile_image
@stub.cls(image=target_image, gpu=config["gpu"] ,volumes=volumes, timeout=60 * 10, container_idle_timeout=60) run_timeout = config["run_timeout"]
idle_timeout = config["idle_timeout"]
import asyncio
@stub.cls(image=target_image, gpu=config["gpu"] ,volumes=volumes, timeout=60 * 10, container_idle_timeout=idle_timeout)
class ComfyDeployRunner: class ComfyDeployRunner:
machine_logs = []
async def read_stream(self, stream, isStderr):
import time
while True:
line = await stream.readline()
if line:
l = line.decode('utf-8').strip()
if l == "":
continue
if not isStderr:
print(l, flush=True)
self.machine_logs.append({
"logs": l,
"timestamp": time.time()
})
else:
# is error
# logger.error(l)
print(l, flush=True)
self.machine_logs.append({
"logs": l,
"timestamp": time.time()
})
else:
break
@enter() @enter()
def setup(self): async def setup(self):
import subprocess import subprocess
import time import time
# Make sure that the ComfyUI API is available # Make sure that the ComfyUI API is available
print(f"comfy-modal - check server") print(f"comfy-modal - check server")
command = ["python", "main.py", # command = ["python", "main.py",
"--disable-auto-launch", "--disable-metadata"] # "--disable-auto-launch", "--disable-metadata"]
self.server_process = subprocess.Popen(command, cwd="/comfyui") # self.server_process = subprocess.Popen(command, cwd="/comfyui")
self.server_process = await asyncio.subprocess.create_subprocess_shell(
f"python main.py --disable-auto-launch --disable-metadata",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd="/comfyui",
# env={**os.environ, "COLUMNS": "10000"}
)
self.stdout_task = asyncio.create_task(
self.read_stream(self.server_process.stdout, False))
self.stderr_task = asyncio.create_task(
self.read_stream(self.server_process.stderr, True))
check_server( check_server(
f"http://{COMFY_HOST}", f"http://{COMFY_HOST}",
@ -180,11 +228,45 @@ class ComfyDeployRunner:
) )
@exit() @exit()
def cleanup(self, exc_type, exc_value, traceback): async def cleanup(self, exc_type, exc_value, traceback):
self.server_process.terminate() print(f"comfy-modal - cleanup", exc_type, exc_value, traceback)
self.stderr_task.cancel()
self.stdout_task.cancel()
# self.server_process.kill()
@method() @method()
def run(self, input: Input): async def run(self, input: Input):
import signal
import time
# import asyncio
self.stderr_task.cancel()
self.stdout_task.cancel()
self.stdout_task = asyncio.create_task(
self.read_stream(self.server_process.stdout, False))
self.stderr_task = asyncio.create_task(
self.read_stream(self.server_process.stderr, True))
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
data = json.dumps({
"run_id": input.prompt_id,
"status": "timeout",
"time": datetime.now().isoformat()
}).encode('utf-8')
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
raise TimeoutError("Operation timed out")
signal.signal(signal.SIGALRM, timeout_handler)
try:
# Set an alarm for some seconds in the future
signal.alarm(run_timeout) # 5 seconds timeout
data = json.dumps({ data = json.dumps({
"run_id": input.prompt_id, "run_id": input.prompt_id,
"status": "started", "status": "started",
@ -212,13 +294,6 @@ class ComfyDeployRunner:
print("getting request") print("getting request")
while retries < COMFY_POLLING_MAX_RETRIES: while retries < COMFY_POLLING_MAX_RETRIES:
status_result = check_status(prompt_id=prompt_id) status_result = check_status(prompt_id=prompt_id)
# history = get_history(prompt_id)
# Exit the loop if we have found the history
# if prompt_id in history and history[prompt_id].get("outputs"):
# break
# Exit the loop if we have found the status both success or failed
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'): if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
status = status_result['status'] status = status_result['status']
print(status) print(status)
@ -236,9 +311,27 @@ class ComfyDeployRunner:
result = {"status": status} result = {"status": status}
except TimeoutError:
print("Operation timed out")
return {"status": "failed"}
print("uploading log_data")
data = json.dumps({
"run_id": input.prompt_id,
"time": datetime.now().isoformat(),
"log_data": json.dumps(self.machine_logs)
}).encode('utf-8')
print("my logs", len(self.machine_logs))
# Clear logs
self.machine_logs = []
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
return result return result
@web_app.post("/run") @web_app.post("/run")
async def post_run(request_input: RequestInput): async def post_run(request_input: RequestInput):
if not deploy_test: if not deploy_test:
@ -252,10 +345,12 @@ async def post_run(request_input: RequestInput):
urllib.request.urlopen(req) urllib.request.urlopen(req)
model = ComfyDeployRunner() model = ComfyDeployRunner()
call = model.run.spawn(request_input.input) call = await model.run.spawn.aio(request_input.input)
print("call", call)
# call = run.spawn() # call = run.spawn()
return {"call_id": call.object_id} return {"call_id": None}
return {"call_id": None} return {"call_id": None}

View File

@ -4,5 +4,7 @@ config = {
"gpu": "T4", "gpu": "T4",
"public_model_volume": "model-store", "public_model_volume": "model-store",
"private_model_volume": "private-model-store", "private_model_volume": "private-model-store",
"pip": [] "pip": [],
"run_timeout": 60 * 5,
"idle_timeout": 60
} }

View File

@ -0,0 +1,2 @@
ALTER TYPE "workflow_run_status" ADD VALUE 'timeout';--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."workflow_runs" ADD COLUMN "run_log" text;

File diff suppressed because it is too large Load Diff

View File

@ -337,6 +337,13 @@
"when": 1706384528895, "when": 1706384528895,
"tag": "0047_gifted_starbolt", "tag": "0047_gifted_starbolt",
"breakpoints": true "breakpoints": true
},
{
"idx": 48,
"version": "5",
"when": 1706600255919,
"tag": "0048_dear_korath",
"breakpoints": true
} }
] ]
} }

View File

@ -2,10 +2,8 @@ import { parseDataSafe } from "../../../../lib/parseDataSafe";
import { db } from "@/db/db"; import { db } from "@/db/db";
import { import {
WorkflowRunStatusSchema, WorkflowRunStatusSchema,
userUsageTable,
workflowRunOutputs, workflowRunOutputs,
workflowRunsTable, workflowRunsTable,
workflowTable,
} from "@/db/schema"; } from "@/db/schema";
import { getCurrentPlan } from "@/server/getCurrentPlan"; import { getCurrentPlan } from "@/server/getCurrentPlan";
import { stripe } from "@/server/stripe"; import { stripe } from "@/server/stripe";
@ -18,6 +16,7 @@ const Request = z.object({
status: WorkflowRunStatusSchema.optional(), status: WorkflowRunStatusSchema.optional(),
time: z.coerce.date().optional(), time: z.coerce.date().optional(),
output_data: z.any().optional(), output_data: z.any().optional(),
log_data: z.string().optional(),
}); });
export async function POST(request: Request) { export async function POST(request: Request) {
@ -26,7 +25,17 @@ export async function POST(request: Request) {
if (!data || error) return error; if (!data || error) return error;
const { run_id, status, time, output_data } = data; const { run_id, status, time, output_data, log_data } = data;
if (log_data) {
// It successfully started, update the started_at time
await db
.update(workflowRunsTable)
.set({
run_log: log_data,
})
.where(eq(workflowRunsTable.id, run_id));
}
if (status == "started" && time != undefined) { if (status == "started" && time != undefined) {
// It successfully started, update the started_at time // It successfully started, update the started_at time
@ -48,6 +57,9 @@ export async function POST(request: Request) {
.where(eq(workflowRunsTable.id, run_id)); .where(eq(workflowRunsTable.id, run_id));
} }
const ended =
status === "success" || status === "failed" || status === "timeout";
if (output_data) { if (output_data) {
const workflow_run_output = await db.insert(workflowRunOutputs).values({ const workflow_run_output = await db.insert(workflowRunOutputs).values({
run_id: run_id, run_id: run_id,
@ -58,8 +70,7 @@ export async function POST(request: Request) {
.update(workflowRunsTable) .update(workflowRunsTable)
.set({ .set({
status: status, status: status,
ended_at: ended_at: ended ? new Date() : null,
status === "success" || status === "failed" ? new Date() : null,
}) })
.where(eq(workflowRunsTable.id, run_id)) .where(eq(workflowRunsTable.id, run_id))
.returning(); .returning();
@ -67,10 +78,7 @@ export async function POST(request: Request) {
// Need to filter out only comfy deploy serverless // Need to filter out only comfy deploy serverless
// Also multiply with the gpu selection // Also multiply with the gpu selection
if (workflow_run.machine_type == "comfy-deploy-serverless") { if (workflow_run.machine_type == "comfy-deploy-serverless") {
if ( if (ended && workflow_run.user_id) {
(status === "success" || status === "failed") &&
workflow_run.user_id
) {
const sub = await getCurrentPlan({ const sub = await getCurrentPlan({
user_id: workflow_run.user_id, user_id: workflow_run.user_id,
org_id: workflow_run.org_id, org_id: workflow_run.org_id,

View File

@ -16,7 +16,7 @@ export function LiveStatus({
(state) => (state) =>
state.data state.data
.filter((x) => x.id === run.id) .filter((x) => x.id === run.id)
.sort((a, b) => b.timestamp - a.timestamp)?.[0] .sort((a, b) => b.timestamp - a.timestamp)?.[0],
); );
let status = run.status; let status = run.status;
@ -51,7 +51,9 @@ export function LiveStatus({
<> <>
<TableCell> <TableCell>
{data && status != "success" {data && status != "success"
? `${data.json.event} - ${data.json.data.node}` ? `${data.json.event}${
data.json.data.node ? " - " + data.json.data.node : ""
}`
: "-"} : "-"}
</TableCell> </TableCell>
<TableCell className="truncate text-right"> <TableCell className="truncate text-right">

View File

@ -19,6 +19,7 @@ import { getDuration, getRelativeTime } from "@/lib/getRelativeTime";
import { type findAllRuns } from "@/server/findAllRuns"; import { type findAllRuns } from "@/server/findAllRuns";
import { Suspense } from "react"; import { Suspense } from "react";
import { LiveStatus } from "./LiveStatus"; import { LiveStatus } from "./LiveStatus";
import { LogsType, LogsViewer } from "@/components/LogsViewer";
export async function RunDisplay({ export async function RunDisplay({
run, run,
@ -75,6 +76,9 @@ export async function RunDisplay({
<RunInputs run={run} /> <RunInputs run={run} />
<Suspense> <Suspense>
<RunOutputs run_id={run.id} /> <RunOutputs run_id={run.id} />
{run.run_log && (
<LogsViewer logs={JSON.parse(run.run_log) as LogsType} />
)}
</Suspense> </Suspense>
</div> </div>
{/* <div className="max-h-96 overflow-y-scroll">{view}</div> */} {/* <div className="max-h-96 overflow-y-scroll">{view}</div> */}

View File

@ -17,6 +17,8 @@ export function StatusBadge({
); );
case "success": case "success":
return <Badge variant="success">{status}</Badge>; return <Badge variant="success">{status}</Badge>;
case "timeout":
return <Badge variant="amber">{status}</Badge>;
case "failed": case "failed":
return <Badge variant="destructive">{status}</Badge>; return <Badge variant="destructive">{status}</Badge>;
} }

View File

@ -104,6 +104,7 @@ export const workflowRunStatus = pgEnum("workflow_run_status", [
"failed", "failed",
"started", "started",
"queued", "queued",
"timeout",
]); ]);
export const deploymentEnvironment = pgEnum("deployment_environment", [ export const deploymentEnvironment = pgEnum("deployment_environment", [
@ -172,6 +173,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
machine_type: machinesType("machine_type"), machine_type: machinesType("machine_type"),
user_id: text("user_id"), user_id: text("user_id"),
org_id: text("org_id"), org_id: text("org_id"),
run_log: text("run_log"),
}); });
export const workflowRunRelations = relations( export const workflowRunRelations = relations(
@ -386,12 +388,7 @@ export const modelUploadType = pgEnum("model_upload_type", [
]); ]);
// https://www.answeroverflow.com/m/1125106227387584552 // https://www.answeroverflow.com/m/1125106227387584552
export const modelTypes = [ export const modelTypes = ["checkpoint", "lora", "embedding", "vae"] as const;
"checkpoint",
"lora",
"embedding",
"vae",
] as const
export const modelType = pgEnum("model_type", modelTypes); export const modelType = pgEnum("model_type", modelTypes);
export type modelEnumType = (typeof modelTypes)[number]; export type modelEnumType = (typeof modelTypes)[number];