feat: add timeout and run_log

This commit is contained in:
bennykok 2024-01-30 21:42:30 +08:00
parent 2193dd287d
commit b0d1bcc303
11 changed files with 1506 additions and 78 deletions

View File

@ -180,6 +180,8 @@ class Item(BaseModel):
models: List[Model]
callback_url: str
model_volume_name: str
run_timeout: Optional[int] = Field(default=60 * 5)
idle_timeout: Optional[int] = Field(default=60)
gpu: GPUType = Field(default=GPUType.T4)
@field_validator('gpu')
@ -391,7 +393,9 @@ async def build_logic(item: Item):
"gpu": item.gpu,
"public_model_volume": public_model_volume_name,
"private_model_volume": item.model_volume_name,
"pip": list(pip_modules)
"pip": list(pip_modules),
"run_timeout": item.run_timeout,
"idle_timeout": item.idle_timeout,
}
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))

View File

@ -112,7 +112,7 @@ def check_server(url, retries=50, delay=500):
# If the response status code is 200, the server is up and running
if response.status_code == 200:
print(f"runpod-worker-comfy - API is reachable")
print(f"comfy-modal - API is reachable")
return True
except requests.RequestException as e:
# If an exception occurs, the server may not be ready
@ -124,7 +124,7 @@ def check_server(url, retries=50, delay=500):
time.sleep(delay / 1000)
print(
f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts."
f"comfy-modal - Failed to connect to server at {url} after {retries} attempts."
)
return False
@ -158,20 +158,68 @@ image = Image.debian_slim()
target_image = image if deploy_test else dockerfile_image
@stub.cls(image=target_image, gpu=config["gpu"] ,volumes=volumes, timeout=60 * 10, container_idle_timeout=60)
run_timeout = config["run_timeout"]
idle_timeout = config["idle_timeout"]
import asyncio
@stub.cls(image=target_image, gpu=config["gpu"] ,volumes=volumes, timeout=60 * 10, container_idle_timeout=idle_timeout)
class ComfyDeployRunner:
machine_logs = []
async def read_stream(self, stream, isStderr):
import time
while True:
line = await stream.readline()
if line:
l = line.decode('utf-8').strip()
if l == "":
continue
if not isStderr:
print(l, flush=True)
self.machine_logs.append({
"logs": l,
"timestamp": time.time()
})
else:
# is error
# logger.error(l)
print(l, flush=True)
self.machine_logs.append({
"logs": l,
"timestamp": time.time()
})
else:
break
@enter()
def setup(self):
async def setup(self):
import subprocess
import time
# Make sure that the ComfyUI API is available
print(f"comfy-modal - check server")
command = ["python", "main.py",
"--disable-auto-launch", "--disable-metadata"]
# command = ["python", "main.py",
# "--disable-auto-launch", "--disable-metadata"]
self.server_process = subprocess.Popen(command, cwd="/comfyui")
# self.server_process = subprocess.Popen(command, cwd="/comfyui")
self.server_process = await asyncio.subprocess.create_subprocess_shell(
f"python main.py --disable-auto-launch --disable-metadata",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd="/comfyui",
# env={**os.environ, "COLUMNS": "10000"}
)
self.stdout_task = asyncio.create_task(
self.read_stream(self.server_process.stdout, False))
self.stderr_task = asyncio.create_task(
self.read_stream(self.server_process.stderr, True))
check_server(
f"http://{COMFY_HOST}",
@ -180,65 +228,110 @@ class ComfyDeployRunner:
)
@exit()
def cleanup(self, exc_type, exc_value, traceback):
self.server_process.terminate()
async def cleanup(self, exc_type, exc_value, traceback):
print(f"comfy-modal - cleanup", exc_type, exc_value, traceback)
self.stderr_task.cancel()
self.stdout_task.cancel()
# self.server_process.kill()
@method()
def run(self, input: Input):
async def run(self, input: Input):
import signal
import time
# import asyncio
self.stderr_task.cancel()
self.stdout_task.cancel()
self.stdout_task = asyncio.create_task(
self.read_stream(self.server_process.stdout, False))
self.stderr_task = asyncio.create_task(
self.read_stream(self.server_process.stderr, True))
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
data = json.dumps({
"run_id": input.prompt_id,
"status": "timeout",
"time": datetime.now().isoformat()
}).encode('utf-8')
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
raise TimeoutError("Operation timed out")
signal.signal(signal.SIGALRM, timeout_handler)
try:
# Set an alarm for some seconds in the future
signal.alarm(run_timeout) # 5 seconds timeout
data = json.dumps({
"run_id": input.prompt_id,
"status": "started",
"time": datetime.now().isoformat()
}).encode('utf-8')
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
job_input = input
try:
queued_workflow = queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
prompt_id = queued_workflow["prompt_id"]
print(f"comfy-modal - queued workflow with ID {prompt_id}")
except Exception as e:
import traceback
print(traceback.format_exc())
return {"error": f"Error queuing workflow: {str(e)}"}
# Poll for completion
print(f"comfy-modal - wait until image generation is complete")
retries = 0
status = ""
try:
print("getting request")
while retries < COMFY_POLLING_MAX_RETRIES:
status_result = check_status(prompt_id=prompt_id)
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
status = status_result['status']
print(status)
break
else:
# Wait before trying again
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
return {"error": "Max retries reached while waiting for image generation"}
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}
print(f"comfy-modal - Finished, turning off")
result = {"status": status}
except TimeoutError:
print("Operation timed out")
return {"status": "failed"}
print("uploading log_data")
data = json.dumps({
"run_id": input.prompt_id,
"status": "started",
"time": datetime.now().isoformat()
"time": datetime.now().isoformat(),
"log_data": json.dumps(self.machine_logs)
}).encode('utf-8')
print("my logs", len(self.machine_logs))
# Clear logs
self.machine_logs = []
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
job_input = input
try:
queued_workflow = queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
prompt_id = queued_workflow["prompt_id"]
print(f"comfy-modal - queued workflow with ID {prompt_id}")
except Exception as e:
import traceback
print(traceback.format_exc())
return {"error": f"Error queuing workflow: {str(e)}"}
# Poll for completion
print(f"comfy-modal - wait until image generation is complete")
retries = 0
status = ""
try:
print("getting request")
while retries < COMFY_POLLING_MAX_RETRIES:
status_result = check_status(prompt_id=prompt_id)
# history = get_history(prompt_id)
# Exit the loop if we have found the history
# if prompt_id in history and history[prompt_id].get("outputs"):
# break
# Exit the loop if we have found the status both success or failed
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
status = status_result['status']
print(status)
break
else:
# Wait before trying again
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
return {"error": "Max retries reached while waiting for image generation"}
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}
print(f"comfy-modal - Finished, turning off")
result = {"status": status}
return result
@web_app.post("/run")
async def post_run(request_input: RequestInput):
if not deploy_test:
@ -252,10 +345,12 @@ async def post_run(request_input: RequestInput):
urllib.request.urlopen(req)
model = ComfyDeployRunner()
call = model.run.spawn(request_input.input)
call = await model.run.spawn.aio(request_input.input)
print("call", call)
# call = run.spawn()
return {"call_id": call.object_id}
return {"call_id": None}
return {"call_id": None}

View File

@ -4,5 +4,7 @@ config = {
"gpu": "T4",
"public_model_volume": "model-store",
"private_model_volume": "private-model-store",
"pip": []
"pip": [],
"run_timeout": 60 * 5,
"idle_timeout": 60
}

View File

@ -0,0 +1,2 @@
ALTER TYPE "workflow_run_status" ADD VALUE 'timeout';--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."workflow_runs" ADD COLUMN "run_log" text;

File diff suppressed because it is too large Load Diff

View File

@ -337,6 +337,13 @@
"when": 1706384528895,
"tag": "0047_gifted_starbolt",
"breakpoints": true
},
{
"idx": 48,
"version": "5",
"when": 1706600255919,
"tag": "0048_dear_korath",
"breakpoints": true
}
]
}

View File

@ -2,10 +2,8 @@ import { parseDataSafe } from "../../../../lib/parseDataSafe";
import { db } from "@/db/db";
import {
WorkflowRunStatusSchema,
userUsageTable,
workflowRunOutputs,
workflowRunsTable,
workflowTable,
} from "@/db/schema";
import { getCurrentPlan } from "@/server/getCurrentPlan";
import { stripe } from "@/server/stripe";
@ -18,6 +16,7 @@ const Request = z.object({
status: WorkflowRunStatusSchema.optional(),
time: z.coerce.date().optional(),
output_data: z.any().optional(),
log_data: z.string().optional(),
});
export async function POST(request: Request) {
@ -26,7 +25,17 @@ export async function POST(request: Request) {
if (!data || error) return error;
const { run_id, status, time, output_data } = data;
const { run_id, status, time, output_data, log_data } = data;
if (log_data) {
// It successfully started, update the started_at time
await db
.update(workflowRunsTable)
.set({
run_log: log_data,
})
.where(eq(workflowRunsTable.id, run_id));
}
if (status == "started" && time != undefined) {
// It successfully started, update the started_at time
@ -48,6 +57,9 @@ export async function POST(request: Request) {
.where(eq(workflowRunsTable.id, run_id));
}
const ended =
status === "success" || status === "failed" || status === "timeout";
if (output_data) {
const workflow_run_output = await db.insert(workflowRunOutputs).values({
run_id: run_id,
@ -58,8 +70,7 @@ export async function POST(request: Request) {
.update(workflowRunsTable)
.set({
status: status,
ended_at:
status === "success" || status === "failed" ? new Date() : null,
ended_at: ended ? new Date() : null,
})
.where(eq(workflowRunsTable.id, run_id))
.returning();
@ -67,10 +78,7 @@ export async function POST(request: Request) {
// Need to filter out only comfy deploy serverless
// Also multiply with the gpu selection
if (workflow_run.machine_type == "comfy-deploy-serverless") {
if (
(status === "success" || status === "failed") &&
workflow_run.user_id
) {
if (ended && workflow_run.user_id) {
const sub = await getCurrentPlan({
user_id: workflow_run.user_id,
org_id: workflow_run.org_id,

View File

@ -16,7 +16,7 @@ export function LiveStatus({
(state) =>
state.data
.filter((x) => x.id === run.id)
.sort((a, b) => b.timestamp - a.timestamp)?.[0]
.sort((a, b) => b.timestamp - a.timestamp)?.[0],
);
let status = run.status;
@ -51,7 +51,9 @@ export function LiveStatus({
<>
<TableCell>
{data && status != "success"
? `${data.json.event} - ${data.json.data.node}`
? `${data.json.event}${
data.json.data.node ? " - " + data.json.data.node : ""
}`
: "-"}
</TableCell>
<TableCell className="truncate text-right">

View File

@ -19,6 +19,7 @@ import { getDuration, getRelativeTime } from "@/lib/getRelativeTime";
import { type findAllRuns } from "@/server/findAllRuns";
import { Suspense } from "react";
import { LiveStatus } from "./LiveStatus";
import { LogsType, LogsViewer } from "@/components/LogsViewer";
export async function RunDisplay({
run,
@ -75,6 +76,9 @@ export async function RunDisplay({
<RunInputs run={run} />
<Suspense>
<RunOutputs run_id={run.id} />
{run.run_log && (
<LogsViewer logs={JSON.parse(run.run_log) as LogsType} />
)}
</Suspense>
</div>
{/* <div className="max-h-96 overflow-y-scroll">{view}</div> */}

View File

@ -17,6 +17,8 @@ export function StatusBadge({
);
case "success":
return <Badge variant="success">{status}</Badge>;
case "timeout":
return <Badge variant="amber">{status}</Badge>;
case "failed":
return <Badge variant="destructive">{status}</Badge>;
}

View File

@ -104,6 +104,7 @@ export const workflowRunStatus = pgEnum("workflow_run_status", [
"failed",
"started",
"queued",
"timeout",
]);
export const deploymentEnvironment = pgEnum("deployment_environment", [
@ -172,6 +173,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
machine_type: machinesType("machine_type"),
user_id: text("user_id"),
org_id: text("org_id"),
run_log: text("run_log"),
});
export const workflowRunRelations = relations(
@ -386,12 +388,7 @@ export const modelUploadType = pgEnum("model_upload_type", [
]);
// https://www.answeroverflow.com/m/1125106227387584552
export const modelTypes = [
"checkpoint",
"lora",
"embedding",
"vae",
] as const
export const modelTypes = ["checkpoint", "lora", "embedding", "vae"] as const;
export const modelType = pgEnum("model_type", modelTypes);
export type modelEnumType = (typeof modelTypes)[number];