feat: display runs and create run and update run endpoint

This commit is contained in:
BennyKok 2023-12-10 16:06:20 +08:00
parent 6a1c1d0ff5
commit 9c8a518c46
11 changed files with 335 additions and 219 deletions

112
routes.py
View File

@ -15,13 +15,14 @@ import execution
import random import random
import uuid import uuid
import websockets
import asyncio import asyncio
import atexit import atexit
import logging import logging
from enum import Enum
api = None api = None
api_task = None api_task = None
prompt_metadata = {}
load_dotenv() load_dotenv()
@ -73,20 +74,15 @@ def randomSeed(num_digits=15):
@server.PromptServer.instance.routes.post("/comfy-deploy/run") @server.PromptServer.instance.routes.post("/comfy-deploy/run")
async def comfy_deploy_run(request): async def comfy_deploy_run(request):
print("hi")
prompt_server = server.PromptServer.instance prompt_server = server.PromptServer.instance
data = await request.json() data = await request.json()
for key in data:
if 'inputs' in data[key] and 'seed' in data[key]['inputs']:
data[key]['inputs']['seed'] = randomSeed()
if api is None:
connect_to_websocket()
while api.client_id is None:
await asyncio.sleep(0.1)
workflow_api = data.get("workflow_api") workflow_api = data.get("workflow_api")
# print(workflow_api)
for key in workflow_api:
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
workflow_api[key]['inputs']['seed'] = randomSeed()
prompt = { prompt = {
"prompt": workflow_api, "prompt": workflow_api,
@ -95,7 +91,9 @@ async def comfy_deploy_run(request):
res = post_prompt(prompt) res = post_prompt(prompt)
# print(prompt) prompt_metadata[res['prompt_id']] = {
'status_endpoint': data.get('status_endpoint'),
}
status = 200 status = 200
if "error" in res: if "error" in res:
@ -105,69 +103,41 @@ async def comfy_deploy_run(request):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
class ComfyApi:
def __init__(self):
self.websocket = None
self.client_id = None
async def connect(self, uri):
self.websocket = await websockets.connect(uri)
# Event listeners
await self.on_open()
await self.on_message()
await self.on_close()
async def close(self):
await self.websocket.close()
async def on_open(self):
print("Connection opened")
async def on_message(self):
async for message in self.websocket:
if isinstance(message, bytes):
print("Received binary message, skipping...")
continue # skip to the next message
logging.info(f"Received message: {message}")
try:
message_data = json.loads(message)
msg_type = message_data["type"]
if msg_type == "status" and message_data["data"]["sid"] is not None:
self.client_id = message_data["data"]["sid"]
logging.info(f"Received client_id: {self.client_id}")
except json.JSONDecodeError:
logging.info(f"Failed to parse message as JSON: {message}")
async def on_close(self):
print("Connection closed")
async def run(self, uri):
await self.connect(uri)
def connect_to_websocket():
global api, api_task
api = ComfyApi()
api_task = asyncio.create_task(api.run('ws://localhost:8188/ws'))
prompt_server = server.PromptServer.instance prompt_server = server.PromptServer.instance
send_json = prompt_server.send_json send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None): async def send_json_override(self, event, data, sid=None):
print("INTERNAL:", event, data, sid)
prompt_id = data.get('prompt_id')
if event == 'execution_start':
update_run(prompt_id, Status.RUNNING)
# if event == 'executing':
# update_run(prompt_id, Status.RUNNING)
if event == 'executed':
update_run(prompt_id, Status.SUCCESS)
await self.send_json_original(event, data, sid) await self.send_json_original(event, data, sid)
print("Sending event:", sid, event, data)
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
def update_run(prompt_id, status: Status):
if prompt_id in prompt_metadata and ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
body = {
"run_id": prompt_id,
"status": status.value,
}
prompt_metadata[prompt_id]['status'] = status
requests.post(status_endpoint, json=body)
prompt_server.send_json_original = prompt_server.send_json prompt_server.send_json_original = prompt_server.send_json
prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer) prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)
@atexit.register
def close_websocket():
print("Got close_websocket")
global api, api_task
if api_task:
api_task.cancel()

View File

@ -1,4 +1,8 @@
/** @type {import('next').NextConfig} */ /** @type {import('next').NextConfig} */
const nextConfig = {} const nextConfig = {
eslint: {
ignoreDuringBuilds: true,
},
};
module.exports = nextConfig module.exports = nextConfig;

View File

@ -1,5 +1,8 @@
import { MachineSelect, VersionSelect } from "@/components/VersionSelect"; import {
import { Button } from "@/components/ui/button"; MachineSelect,
RunWorkflowButton,
VersionSelect,
} from "@/components/VersionSelect";
import { import {
Card, Card,
CardContent, CardContent,
@ -7,12 +10,24 @@ import {
CardHeader, CardHeader,
CardTitle, CardTitle,
} from "@/components/ui/card"; } from "@/components/ui/card";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { db } from "@/db/db"; import { db } from "@/db/db";
import { workflowTable, workflowVersionTable } from "@/db/schema"; import {
workflowRunsTable,
workflowTable,
workflowVersionTable,
} from "@/db/schema";
import { getRelativeTime } from "@/lib/getRelativeTime"; import { getRelativeTime } from "@/lib/getRelativeTime";
import { getMachines } from "@/server/curdMachine"; import { getMachines } from "@/server/curdMachine";
import { desc, eq } from "drizzle-orm"; import { desc, eq } from "drizzle-orm";
import { Play } from "lucide-react";
export async function findFirstTableWithVersion(workflow_id: string) { export async function findFirstTableWithVersion(workflow_id: string) {
return await db.query.workflowTable.findFirst({ return await db.query.workflowTable.findFirst({
@ -21,6 +36,32 @@ export async function findFirstTableWithVersion(workflow_id: string) {
}); });
} }
export async function findAllRuns(workflow_id: string) {
const workflowVersion = await db.query.workflowVersionTable.findFirst({
where: eq(workflowVersionTable.workflow_id, workflow_id),
});
if (!workflowVersion) {
return [];
}
return await db.query.workflowRunsTable.findMany({
where: eq(workflowRunsTable.workflow_version_id, workflowVersion?.id),
with: {
machine: {
columns: {
name: true,
},
},
version: {
columns: {
version: true,
},
},
},
});
}
export default async function Page({ export default async function Page({
params, params,
}: { }: {
@ -45,9 +86,7 @@ export default async function Page({
<div className="flex gap-2 "> <div className="flex gap-2 ">
<VersionSelect workflow={workflow} /> <VersionSelect workflow={workflow} />
<MachineSelect machines={machines} /> <MachineSelect machines={machines} />
<Button className="gap-2"> <RunWorkflowButton workflow={workflow} machines={machines} />
Run <Play size={14} />
</Button>
</div> </div>
</CardContent> </CardContent>
</Card> </Card>
@ -57,8 +96,37 @@ export default async function Page({
<CardTitle>Run</CardTitle> <CardTitle>Run</CardTitle>
</CardHeader> </CardHeader>
<CardContent /> <CardContent>
<RunsTable workflow_id={workflow_id} />
</CardContent>
</Card> </Card>
</div> </div>
); );
} }
async function RunsTable(props: { workflow_id: string }) {
const allRuns = await findAllRuns(props.workflow_id);
return (
<Table>
<TableCaption>A list of your recent runs.</TableCaption>
<TableHeader>
<TableRow>
<TableHead className="w-[100px]">Version</TableHead>
<TableHead>Machine</TableHead>
<TableHead>Time</TableHead>
<TableHead className="text-right">Status</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{allRuns.map((run) => (
<TableRow key={run.id}>
<TableCell>{run.version.version}</TableCell>
<TableCell className="font-medium">{run.machine.name}</TableCell>
<TableCell>{getRelativeTime(run.created_at)}</TableCell>
<TableCell className="text-right">{run.status}</TableCell>
</TableRow>
))}
</TableBody>
</Table>
);
}

View File

@ -1,56 +1,106 @@
import { parseDataSafe } from "../../../lib/parseDataSafe"; import { parseDataSafe } from "../../../lib/parseDataSafe";
import { db } from "@/db/db"; import { db } from "@/db/db";
import { import { workflowRunsTable } from "@/db/schema";
workflowRunStatus, import { eq } from "drizzle-orm";
workflowRunsTable, import { revalidatePath } from "next/cache";
workflowTable,
workflowVersionTable,
} from "@/db/schema";
import { eq, sql } from "drizzle-orm";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { ZodFormattedError, z } from "zod"; import { z } from "zod";
const Request = z.object({ const Request = z.object({
workflow_version_id: z.string(), workflow_version_id: z.string(),
// workflow_version: z.number().optional(),
machine_id: z.string(), machine_id: z.string(),
}); });
export async function OPTIONS(request: Request) { const ComfyAPI_Run = z.object({
return new Response(null, { prompt_id: z.string(),
status: 204, number: z.number(),
headers: { node_errors: z.any(),
"Access-Control-Allow-Origin": "*", });
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
});
}
export async function POST(request: Request) { export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request); const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error; if (!data || error) return error;
let { workflow_version_id, machine_id } = data; const origin = new URL(request.url).origin;
const { workflow_version_id, machine_id } = data;
const machine = await db.query.machinesTable.findFirst({
where: eq(workflowRunsTable.id, machine_id),
});
if (!machine) {
return new Response("Machine not found", {
status: 404,
});
}
const workflow_version_data =
// workflow_version_id
// ?
await db.query.workflowVersionTable.findFirst({
where: eq(workflowRunsTable.id, workflow_version_id),
});
// : workflow_version != undefined
// ? await db.query.workflowVersionTable.findFirst({
// where: and(
// eq(workflowVersionTable.version, workflow_version),
// eq(workflowVersionTable.workflow_id)
// ),
// })
// : null;
if (!workflow_version_data) {
return new Response("Workflow version not found", {
status: 404,
});
}
const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
// Sending to comfyui
const result = await fetch(comfyui_endpoint, {
method: "POST",
// headers: {
// "Content-Type": "application/json",
// },
body: JSON.stringify({
workflow_api: workflow_version_data.workflow_api,
status_endpoint: `${origin}/api/update-run`,
}),
})
.then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
.catch((error) => {
console.error(error);
return new Response(error.details, {
status: 500,
});
});
// return the error
if (result instanceof Response) {
return result;
}
// Add to our db
const workflow_run = await db const workflow_run = await db
.insert(workflowRunsTable) .insert(workflowRunsTable)
.values({ .values({
workflow_version_id, id: result.prompt_id,
workflow_version_id: workflow_version_data.id,
machine_id, machine_id,
}) })
.returning(); .returning();
revalidatePath(`./${workflow_version_data.workflow_id}`);
return NextResponse.json( return NextResponse.json(
{ {
workflow_run_id: workflow_run[0].id, workflow_run_id: workflow_run[0].id,
}, },
{ {
status: 200, status: 200,
headers: { }
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
},
); );
} }

View File

@ -2,6 +2,7 @@ import { parseDataSafe } from "../../../lib/parseDataSafe";
import { db } from "@/db/db"; import { db } from "@/db/db";
import { workflowRunsTable } from "@/db/schema"; import { workflowRunsTable } from "@/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { revalidatePath } from "next/cache";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { z } from "zod"; import { z } from "zod";
@ -10,17 +11,6 @@ const Request = z.object({
status: z.enum(["not-started", "running", "success", "failed"]), status: z.enum(["not-started", "running", "success", "failed"]),
}); });
export async function OPTIONS(request: Request) {
return new Response(null, {
status: 204,
headers: {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
});
}
export async function POST(request: Request) { export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request); const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error; if (!data || error) return error;
@ -32,7 +22,14 @@ export async function POST(request: Request) {
.set({ .set({
status: status, status: status,
}) })
.where(eq(workflowRunsTable.id, run_id)); .where(eq(workflowRunsTable.id, run_id))
.returning();
const workflow_version = await db.query.workflowVersionTable.findFirst({
where: eq(workflowRunsTable.id, workflow_run[0].workflow_version_id),
});
revalidatePath(`./${workflow_version?.workflow_id}`);
return NextResponse.json( return NextResponse.json(
{ {
@ -40,11 +37,6 @@ export async function POST(request: Request) {
}, },
{ {
status: 200, status: 200,
headers: {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
} }
); );
} }

View File

@ -2,7 +2,7 @@ import { WorkflowList } from "@/components/WorkflowList";
import { db } from "@/db/db"; import { db } from "@/db/db";
import { usersTable, workflowTable, workflowVersionTable } from "@/db/schema"; import { usersTable, workflowTable, workflowVersionTable } from "@/db/schema";
import { auth, clerkClient } from "@clerk/nextjs"; import { auth, clerkClient } from "@clerk/nextjs";
import { desc, eq, sql } from "drizzle-orm"; import { desc, eq } from "drizzle-orm";
export default function Home() { export default function Home() {
return <WorkflowServer />; return <WorkflowServer />;

View File

@ -0,0 +1,7 @@
"use client";
import { LoaderIcon } from "lucide-react";
import * as React from "react";
export function LoadingIcon() {
return <LoaderIcon size={14} className="ml-2 animate-spin" />;
}

View File

@ -1,9 +1,9 @@
"use client"; "use client";
import { getRelativeTime } from "../lib/getRelativeTime"; import { getRelativeTime } from "../lib/getRelativeTime";
import { LoadingIcon } from "./LoadingIcon";
import { import {
FormControl, FormControl,
FormDescription,
FormField, FormField,
FormItem, FormItem,
FormLabel, FormLabel,
@ -23,15 +23,12 @@ import {
} from "@/components/ui/dialog"; } from "@/components/ui/dialog";
import { import {
DropdownMenu, DropdownMenu,
DropdownMenuCheckboxItem,
DropdownMenuContent, DropdownMenuContent,
DropdownMenuItem, DropdownMenuItem,
DropdownMenuLabel, DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuTrigger, DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { import {
Table, Table,
TableBody, TableBody,
@ -41,7 +38,6 @@ import {
TableRow, TableRow,
} from "@/components/ui/table"; } from "@/components/ui/table";
import { addMachine, deleteMachine } from "@/server/curdMachine"; import { addMachine, deleteMachine } from "@/server/curdMachine";
import { deleteWorkflow } from "@/server/deleteWorkflow";
import { zodResolver } from "@hookform/resolvers/zod"; import { zodResolver } from "@hookform/resolvers/zod";
import type { import type {
ColumnDef, ColumnDef,
@ -57,14 +53,8 @@ import {
getSortedRowModel, getSortedRowModel,
useReactTable, useReactTable,
} from "@tanstack/react-table"; } from "@tanstack/react-table";
import { import { ArrowUpDown, MoreHorizontal } from "lucide-react";
ArrowUpDown,
ChevronDown,
LoaderIcon,
MoreHorizontal,
} from "lucide-react";
import * as React from "react"; import * as React from "react";
import { useFormStatus } from "react-dom";
import { useForm } from "react-hook-form"; import { useForm } from "react-hook-form";
import { z } from "zod"; import { z } from "zod";
@ -145,7 +135,9 @@ export const columns: ColumnDef<Machine>[] = [
); );
}, },
cell: ({ row }) => ( cell: ({ row }) => (
<div className="capitalize text-right">{getRelativeTime(row.original.date)}</div> <div className="capitalize text-right">
{getRelativeTime(row.original.date)}
</div>
), ),
}, },
@ -187,7 +179,7 @@ export const columns: ColumnDef<Machine>[] = [
export function MachineList({ data }: { data: Machine[] }) { export function MachineList({ data }: { data: Machine[] }) {
const [sorting, setSorting] = React.useState<SortingState>([]); const [sorting, setSorting] = React.useState<SortingState>([]);
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>( const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
[], []
); );
const [columnVisibility, setColumnVisibility] = const [columnVisibility, setColumnVisibility] =
React.useState<VisibilityState>({}); React.useState<VisibilityState>({});
@ -265,7 +257,7 @@ export function MachineList({ data }: { data: Machine[] }) {
? null ? null
: flexRender( : flexRender(
header.column.columnDef.header, header.column.columnDef.header,
header.getContext(), header.getContext()
)} )}
</TableHead> </TableHead>
); );
@ -284,7 +276,7 @@ export function MachineList({ data }: { data: Machine[] }) {
<TableCell key={cell.id}> <TableCell key={cell.id}>
{flexRender( {flexRender(
cell.column.columnDef.cell, cell.column.columnDef.cell,
cell.getContext(), cell.getContext()
)} )}
</TableCell> </TableCell>
))} ))}
@ -418,8 +410,7 @@ function AddWorkflowButton({ pending }: { pending: boolean }) {
// const { pending } = useFormStatus(); // const { pending } = useFormStatus();
return ( return (
<Button type="submit" disabled={pending}> <Button type="submit" disabled={pending}>
Save changes{" "} Save changes {pending && <LoadingIcon />}
{pending && <LoaderIcon size={14} className="ml-2 animate-spin" />}
</Button> </Button>
); );
} }

View File

@ -1,6 +1,8 @@
"use client"; "use client";
import { findFirstTableWithVersion } from "@/app/[workflow_id]/page"; import type { findFirstTableWithVersion } from "@/app/[workflow_id]/page";
import { LoadingIcon } from "@/components/LoadingIcon";
import { Button } from "@/components/ui/button";
import { import {
Select, Select,
SelectContent, SelectContent,
@ -10,15 +12,26 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from "@/components/ui/select"; } from "@/components/ui/select";
import { getMachines } from "@/server/curdMachine"; import type { getMachines } from "@/server/curdMachine";
import { Play } from "lucide-react";
import { parseAsInteger, useQueryState } from "next-usequerystate";
import { useState } from "react";
export function VersionSelect({ export function VersionSelect({
workflow, workflow,
}: { }: {
workflow: Awaited<ReturnType<typeof findFirstTableWithVersion>>; workflow: Awaited<ReturnType<typeof findFirstTableWithVersion>>;
}) { }) {
const [version, setVersion] = useQueryState("version", {
defaultValue: workflow?.versions[0].version?.toString() ?? "",
});
return ( return (
<Select defaultValue={workflow?.versions[0].version?.toString()}> <Select
value={version}
onValueChange={(v) => {
setVersion(v);
}}
>
<SelectTrigger className="w-[180px]"> <SelectTrigger className="w-[180px]">
<SelectValue placeholder="Select a version" /> <SelectValue placeholder="Select a version" />
</SelectTrigger> </SelectTrigger>
@ -26,7 +39,7 @@ export function VersionSelect({
<SelectGroup> <SelectGroup>
<SelectLabel>Versions</SelectLabel> <SelectLabel>Versions</SelectLabel>
{workflow?.versions.map((x) => ( {workflow?.versions.map((x) => (
<SelectItem value={x.version?.toString() ?? ""}> <SelectItem key={x.id} value={x.version?.toString() ?? ""}>
{x.version} {x.version}
</SelectItem> </SelectItem>
))} ))}
@ -36,28 +49,76 @@ export function VersionSelect({
); );
} }
export function MachineSelect({ export function MachineSelect({
machines, machines,
}: { }: {
machines: Awaited<ReturnType<typeof getMachines>>; machines: Awaited<ReturnType<typeof getMachines>>;
}) { }) {
return ( const [machine, setMachine] = useQueryState("machine", {
<Select defaultValue={machines[0].id}> defaultValue: machines[0].id ?? "",
<SelectTrigger className="w-[180px]"> });
<SelectValue placeholder="Select a version" /> return (
</SelectTrigger> <Select
<SelectContent> value={machine}
<SelectGroup> onValueChange={(v) => {
<SelectLabel>Versions</SelectLabel> setMachine(v);
{machines?.map((x) => ( }}
<SelectItem value={x.id ?? ""}> >
{x.name} <SelectTrigger className="w-[180px]">
</SelectItem> <SelectValue placeholder="Select a version" />
))} </SelectTrigger>
</SelectGroup> <SelectContent>
</SelectContent> <SelectGroup>
</Select> <SelectLabel>Versions</SelectLabel>
); {machines?.map((x) => (
} <SelectItem key={x.id} value={x.id ?? ""}>
{x.name}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
);
}
export function RunWorkflowButton({
workflow,
machines,
}: {
workflow: Awaited<ReturnType<typeof findFirstTableWithVersion>>;
machines: Awaited<ReturnType<typeof getMachines>>;
}) {
const [version] = useQueryState("version", {
defaultValue: workflow?.versions[0].version ?? 1,
...parseAsInteger,
});
const [machine] = useQueryState("machine", {
defaultValue: machines[0].id ?? "",
});
const [isLoading, setIsLoading] = useState(false);
return (
<Button
className="gap-2"
disabled={isLoading}
onClick={async () => {
setIsLoading(true);
try {
await fetch(`/api/create-run`, {
method: "POST",
body: JSON.stringify({
workflow_version_id: workflow?.versions.find(
(x) => x.version === version
)?.id,
machine_id: machine,
}),
});
setIsLoading(false);
} catch (error) {
setIsLoading(false);
}
}}
>
Run <Play size={14} /> {isLoading && <LoadingIcon />}
</Button>
);
}

View File

@ -17,21 +17,8 @@ export const usersTable = dbSchema.table("users", {
name: text("name").notNull(), name: text("name").notNull(),
created_at: timestamp("created_at").defaultNow(), created_at: timestamp("created_at").defaultNow(),
updated_at: timestamp("updated_at").defaultNow(), updated_at: timestamp("updated_at").defaultNow(),
// primary_avatar_id: uuid("primary_avatar_id").references(
// () => chatAvatarTable.id,
// ),
// twitter_initial_json: jsonb("twitter_initial_json").$type<
// Omit<UserV2Result, "errors">
// >(),
// initial_prompt: text("initial_prompt"),
// payment_status: text("payment_status"),
// early_access: boolean("early_access").default(false),
}); });
// export const usersRelations = relations(userTable, ({ many }) => ({
// chat_avatars: many(chatAvatarTable),
// }));
export const workflowTable = dbSchema.table("workflows", { export const workflowTable = dbSchema.table("workflows", {
id: uuid("id").primaryKey().defaultRandom().notNull(), id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id") user_id: text("user_id")
@ -70,7 +57,7 @@ export const workflowVersionRelations = relations(
fields: [workflowVersionTable.workflow_id], fields: [workflowVersionTable.workflow_id],
references: [workflowTable.id], references: [workflowTable.id],
}), }),
}), })
); );
export const workflowRunStatus = pgEnum("workflow_run_status", [ export const workflowRunStatus = pgEnum("workflow_run_status", [
@ -98,45 +85,30 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
created_at: timestamp("created_at").defaultNow().notNull(), created_at: timestamp("created_at").defaultNow().notNull(),
}); });
export const workflowRunRelations = relations(workflowRunsTable, ({ one }) => ({
machine: one(machinesTable, {
fields: [workflowRunsTable.machine_id],
references: [machinesTable.id],
}),
version: one(workflowVersionTable, {
fields: [workflowRunsTable.workflow_version_id],
references: [workflowVersionTable.id],
}),
}));
// when user delete, also delete all the workflow versions // when user delete, also delete all the workflow versions
export const machinesTable = dbSchema.table("machines", { export const machinesTable = dbSchema.table("machines", {
id: uuid("id").primaryKey().defaultRandom().notNull(), id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, { user_id: text("user_id")
onDelete: "no action", .references(() => usersTable.id, {
}).notNull(), onDelete: "no action",
})
.notNull(),
name: text("name").notNull(), name: text("name").notNull(),
endpoint: text("endpoint").notNull(), endpoint: text("endpoint").notNull(),
created_at: timestamp("created_at").defaultNow().notNull(), created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
// export const chatAvatarRelations = relations(chatAvatarTable, ({ one }) => ({
// author: one(userTable, {
// fields: [chatAvatarTable.user_id],
// references: [userTable.id],
// }),
// }));
// export const subscriptionTable = dbSchema.table("subscription", {
// id: text("id").primaryKey().notNull(),
// email: text("email"),
// user_id: text("user_id"),
// status: text("status"),
// created_at: timestamp("created_at").defaultNow(),
// updated_at: timestamp("updated_at").defaultNow(),
// });
// export const subscriptionRelations = relations(
// subscriptionTable,
// ({ one }) => ({
// user: one(userTable, {
// fields: [subscriptionTable.user_id],
// references: [userTable.id],
// }),
// }),
// );
export type UserType = InferSelectModel<typeof usersTable>; export type UserType = InferSelectModel<typeof usersTable>;
export type WorkflowType = InferSelectModel<typeof workflowTable>; export type WorkflowType = InferSelectModel<typeof workflowTable>;
// export type ChatAvatarType = InferSelectModel<typeof chatAvatarTable>;
// export type SubscriptionType = InferSelectModel<typeof subscriptionTable>;

View File

@ -1,10 +1,11 @@
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { ZodError, ZodType, z } from "zod"; import type { ZodType, z } from "zod";
import { ZodError } from "zod";
export async function parseDataSafe<T extends ZodType<any, any, any>>( export async function parseDataSafe<T extends ZodType<any, any, any>>(
schema: T, schema: T,
request: Request, request: Request,
headers?: HeadersInit, headers?: HeadersInit
): Promise<[z.infer<T> | undefined, NextResponse | undefined]> { ): Promise<[z.infer<T> | undefined, NextResponse | undefined]> {
let data: z.infer<T> | undefined = undefined; let data: z.infer<T> | undefined = undefined;
try { try {
@ -30,7 +31,7 @@ export async function parseDataSafe<T extends ZodType<any, any, any>>(
{ {
message: "Invalid request", message: "Invalid request",
}, },
{ status: 500, statusText: "Invalid request", headers: headers }, { status: 500, statusText: "Invalid request", headers: headers }
), ),
]; ];