feat: revamp cold start time counter

This commit is contained in:
bennykok 2024-01-27 14:21:59 +08:00
parent 72325a4217
commit c7727fc1be
8 changed files with 1456 additions and 153 deletions

View File

@ -0,0 +1,3 @@
ALTER TYPE "workflow_run_status" ADD VALUE 'started';--> statement-breakpoint
ALTER TYPE "workflow_run_status" ADD VALUE 'queued';--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."workflow_runs" ADD COLUMN "queued_at" timestamp;

File diff suppressed because it is too large Load Diff

View File

@ -316,6 +316,13 @@
"when": 1706317908300,
"tag": "0044_panoramic_mister_fear",
"breakpoints": true
},
{
"idx": 45,
"version": "5",
"when": 1706336448134,
"tag": "0045_careful_cerise",
"breakpoints": true
}
]
}

View File

@ -1,6 +1,7 @@
import { parseDataSafe } from "../../../../lib/parseDataSafe";
import { db } from "@/db/db";
import {
WorkflowRunStatusSchema,
userUsageTable,
workflowRunOutputs,
workflowRunsTable,
@ -14,9 +15,8 @@ import { z } from "zod";
const Request = z.object({
run_id: z.string(),
status: z
.enum(["not-started", "running", "uploading", "success", "failed"])
.optional(),
status: WorkflowRunStatusSchema.optional(),
time: z.date().optional(),
output_data: z.any().optional(),
});
@ -24,9 +24,27 @@ export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error;
const { run_id, status, output_data } = data;
const { run_id, status, time, output_data } = data;
// console.log(run_id, status, output_data);
if (status == "started" && time != undefined) {
// It successfully started, update the started_at time
await db
.update(workflowRunsTable)
.set({
started_at: time,
})
.where(eq(workflowRunsTable.id, run_id));
}
if (status == "queued" && time != undefined) {
// It successfully started, update the started_at time
await db
.update(workflowRunsTable)
.set({
queued_at: time,
})
.where(eq(workflowRunsTable.id, run_id));
}
if (output_data) {
const workflow_run_output = await db.insert(workflowRunOutputs).values({
@ -82,12 +100,6 @@ export async function POST(request: Request) {
}
}
// const workflow_version = await db.query.workflowVersionTable.findFirst({
// where: eq(workflowRunsTable.id, workflow_run[0].workflow_version_id),
// });
// revalidatePath(`./${workflow_version?.workflow_id}`);
return NextResponse.json(
{
message: "success",

View File

@ -2,18 +2,18 @@ import { RunInputs } from "@/components/RunInputs";
import { RunOutputs } from "@/components/RunOutputs";
import { Badge } from "@/components/ui/badge";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { TableCell, TableRow } from "@/components/ui/table";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { getDuration, getRelativeTime } from "@/lib/getRelativeTime";
import { type findAllRuns } from "@/server/findAllRuns";
@ -21,54 +21,54 @@ import { Suspense } from "react";
import { LiveStatus } from "./LiveStatus";
export async function RunDisplay({
run,
run,
}: {
run: Awaited<ReturnType<typeof findAllRuns>>[0];
run: Awaited<ReturnType<typeof findAllRuns>>[0];
}) {
return (
<Dialog>
<DialogTrigger asChild className="appearance-none hover:cursor-pointer">
<TableRow>
<TableCell>{run.number}</TableCell>
<TableCell className="font-medium truncate">
{run.machine?.name}
</TableCell>
<TableCell className="truncate">
{getRelativeTime(run.created_at)}
</TableCell>
<TableCell>{run.version?.version}</TableCell>
<TableCell>
<Badge variant="outline" className="truncate">
{run.origin}
</Badge>
</TableCell>
<TableCell className="truncate">
<Tooltip>
<TooltipTrigger>{getDuration(run.duration)}</TooltipTrigger>
<TooltipContent>
<div>Cold start: {getDuration(run.cold_start_duration)}</div>
<div>Run duration: {getDuration(run.run_duration)}</div>
</TooltipContent>
</Tooltip>
</TableCell>
<LiveStatus run={run} />
</TableRow>
</DialogTrigger>
<DialogContent className="max-w-3xl">
<DialogHeader>
<DialogTitle>Run outputs</DialogTitle>
<DialogDescription>
You can view your run&apos;s outputs here
</DialogDescription>
</DialogHeader>
<div className="max-h-96 overflow-y-scroll">
<RunInputs run={run} />
<Suspense>
<RunOutputs run_id={run.id} />
</Suspense>
</div>
{/* <div className="max-h-96 overflow-y-scroll">{view}</div> */}
</DialogContent>
</Dialog>
);
return (
<Dialog>
<DialogTrigger asChild className="appearance-none hover:cursor-pointer">
<TableRow>
<TableCell>{run.number}</TableCell>
<TableCell className="font-medium truncate">
{run.machine?.name}
</TableCell>
<TableCell className="truncate">
{getRelativeTime(run.created_at)}
</TableCell>
<TableCell>{run.version?.version}</TableCell>
<TableCell>
<Badge variant="outline" className="truncate">
{run.origin}
</Badge>
</TableCell>
<TableCell className="truncate">
<Tooltip>
<TooltipTrigger>{getDuration(run.duration)}</TooltipTrigger>
<TooltipContent>
<div>Cold start: {getDuration(run.cold_start_duration)}</div>
<div>Run duration: {getDuration(run.run_duration)}</div>
</TooltipContent>
</Tooltip>
</TableCell>
<LiveStatus run={run} />
</TableRow>
</DialogTrigger>
<DialogContent className="max-w-3xl">
<DialogHeader>
<DialogTitle>Run outputs</DialogTitle>
<DialogDescription>
You can view your run&apos;s outputs here
</DialogDescription>
</DialogHeader>
<div className="max-h-96 overflow-y-scroll">
<RunInputs run={run} />
<Suspense>
<RunOutputs run_id={run.id} />
</Suspense>
</div>
{/* <div className="max-h-96 overflow-y-scroll">{view}</div> */}
</DialogContent>
</Dialog>
);
}

View File

@ -93,7 +93,7 @@ export const workflowVersionRelations = relations(
fields: [workflowVersionTable.workflow_id],
references: [workflowTable.id],
}),
})
}),
);
export const workflowRunStatus = pgEnum("workflow_run_status", [
@ -102,6 +102,8 @@ export const workflowRunStatus = pgEnum("workflow_run_status", [
"uploading",
"success",
"failed",
"started",
"queued",
]);
export const deploymentEnvironment = pgEnum("deployment_environment", [
@ -116,6 +118,8 @@ export const workflowRunOrigin = pgEnum("workflow_run_origin", [
"public-share",
]);
export const WorkflowRunStatusSchema = z.enum(workflowRunStatus.enumValues);
export const WorkflowRunOriginSchema = z.enum(workflowRunOrigin.enumValues);
export type WorkflowRunOriginType = z.infer<typeof WorkflowRunOriginSchema>;
@ -142,7 +146,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
() => workflowVersionTable.id,
{
onDelete: "set null",
}
},
),
workflow_inputs:
jsonb("workflow_inputs").$type<Record<string, string | number>>(),
@ -158,7 +162,11 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
origin: workflowRunOrigin("origin").notNull().default("api"),
status: workflowRunStatus("status").notNull().default("not-started"),
ended_at: timestamp("ended_at"),
// comfy deploy run created time
created_at: timestamp("created_at").defaultNow().notNull(),
// modal gpu cold start begin
queued_at: timestamp("queued_at"),
// modal gpu function actual start time
started_at: timestamp("started_at"),
gpu: machineGPUOptions("gpu"),
machine_type: machinesType("machine_type"),
@ -182,7 +190,7 @@ export const workflowRunRelations = relations(
fields: [workflowRunsTable.workflow_id],
references: [workflowTable.id],
}),
})
}),
);
// We still want to keep the workflow run record.
@ -206,7 +214,7 @@ export const workflowOutputRelations = relations(
fields: [workflowRunOutputs.run_id],
references: [workflowRunsTable.id],
}),
})
}),
);
// when user delete, also delete all the workflow versions
@ -239,7 +247,7 @@ export const snapshotType = z.object({
z.object({
hash: z.string(),
disabled: z.boolean(),
})
}),
),
file_custom_nodes: z.array(z.any()),
});
@ -254,7 +262,7 @@ export const showcaseMedia = z.array(
z.object({
url: z.string(),
isCover: z.boolean().default(false),
})
}),
);
export const showcaseMediaNullable = z
@ -262,7 +270,7 @@ export const showcaseMediaNullable = z
z.object({
url: z.string(),
isCover: z.boolean().default(false),
})
}),
)
.nullable();
@ -376,15 +384,10 @@ export const modelUploadType = pgEnum("model_upload_type", [
"other",
]);
// https://www.answeroverflow.com/m/1125106227387584552
const modelTypes = [
"checkpoint",
"lora",
"embedding",
"vae",
] as const
// https://www.answeroverflow.com/m/1125106227387584552
const modelTypes = ["checkpoint", "lora", "embedding", "vae"] as const;
export const modelType = pgEnum("model_type", modelTypes);
export type modelEnumType = typeof modelTypes[number]
export type modelEnumType = (typeof modelTypes)[number];
export const modelTable = dbSchema.table("models", {
id: uuid("id").primaryKey().defaultRandom().notNull(),
@ -447,16 +450,13 @@ export const modelRelations = relations(modelTable, ({ one }) => ({
}),
}));
export const modalVolumeRelations = relations(
userVolume,
({ many, one }) => ({
model: many(modelTable),
user: one(usersTable, {
fields: [userVolume.user_id],
references: [usersTable.id],
}),
})
);
export const modalVolumeRelations = relations(userVolume, ({ many, one }) => ({
model: many(modelTable),
user: one(usersTable, {
fields: [userVolume.user_id],
references: [usersTable.id],
}),
}));
export const subscriptionPlan = pgEnum("subscription_plan", [
"basic",
@ -484,18 +484,15 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
updated_at: timestamp("updated_at").defaultNow().notNull(),
});
export const insertCivitaiModelSchema = createInsertSchema(
modelTable,
{
civitai_url: (schema) =>
schema.civitai_url
.trim()
.url({ message: "URL required" })
.includes("civitai.com/models", {
message: "civitai.com/models link required",
}),
}
);
export const insertCivitaiModelSchema = createInsertSchema(modelTable, {
civitai_url: (schema) =>
schema.civitai_url
.trim()
.url({ message: "URL required" })
.includes("civitai.com/models", {
message: "civitai.com/models link required",
}),
});
export type UserType = InferSelectModel<typeof usersTable>;
export type WorkflowType = InferSelectModel<typeof workflowTable>;
@ -503,7 +500,5 @@ export type MachineType = InferSelectModel<typeof machinesTable>;
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
export type ModelType = InferSelectModel<typeof modelTable>;
export type UserVolumeType = InferSelectModel<
typeof userVolume
>;
export type UserVolumeType = InferSelectModel<typeof userVolume>;
export type UserUsageType = InferSelectModel<typeof userUsageTable>;

View File

@ -229,15 +229,6 @@ export const createRun = withServerPromise(
throw e;
}
// It successfully started, update the started_at time
await db
.update(workflowRunsTable)
.set({
started_at: new Date(),
})
.where(eq(workflowRunsTable.id, workflow_run[0].id));
return {
workflow_run_id: workflow_run[0].id,
message: "Successful workflow run",

View File

@ -16,42 +16,40 @@ export async function findAllRuns({
offset = 0,
}: RunsSearchTypes) {
return await db.query.workflowRunsTable.findMany({
where: eq(workflowRunsTable.workflow_id, workflow_id),
orderBy: desc(workflowRunsTable.created_at),
offset: offset,
limit: limit,
extras: {
number: sql<number>`row_number() over (order by created_at)`.as(
"number",
),
total: sql<number>`count(*) over ()`.as("total"),
duration:
sql<number>`(extract(epoch from ended_at) - extract(epoch from created_at))`.as(
"duration",
),
cold_start_duration:
sql<number>`(extract(epoch from started_at) - extract(epoch from created_at))`.as(
"cold_start_duration",
),
run_duration:
sql<number>`(extract(epoch from ended_at) - extract(epoch from started_at))`.as(
"run_duration",
),
},
with: {
machine: {
columns: {
name: true,
endpoint: true,
},
},
version: {
columns: {
version: true,
},
},
},
});
where: eq(workflowRunsTable.workflow_id, workflow_id),
orderBy: desc(workflowRunsTable.created_at),
offset: offset,
limit: limit,
extras: {
number: sql<number>`row_number() over (order by created_at)`.as("number"),
total: sql<number>`count(*) over ()`.as("total"),
duration:
sql<number>`(extract(epoch from ended_at) - extract(epoch from created_at))`.as(
"duration",
),
cold_start_duration:
sql<number>`(extract(epoch from started_at) - extract(epoch from queued_at))`.as(
"cold_start_duration",
),
run_duration:
sql<number>`(extract(epoch from ended_at) - extract(epoch from started_at))`.as(
"run_duration",
),
},
with: {
machine: {
columns: {
name: true,
endpoint: true,
},
},
version: {
columns: {
version: true,
},
},
},
});
}
export async function findAllRunsWithCounts(props: RunsSearchTypes) {