feat: display runs and create run and update run endpoint
This commit is contained in:
parent
6a1c1d0ff5
commit
9c8a518c46
112
routes.py
112
routes.py
@ -15,13 +15,14 @@ import execution
|
||||
import random
|
||||
|
||||
import uuid
|
||||
import websockets
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
api = None
|
||||
api_task = None
|
||||
prompt_metadata = {}
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -73,20 +74,15 @@ def randomSeed(num_digits=15):
|
||||
|
||||
@server.PromptServer.instance.routes.post("/comfy-deploy/run")
|
||||
async def comfy_deploy_run(request):
|
||||
print("hi")
|
||||
prompt_server = server.PromptServer.instance
|
||||
data = await request.json()
|
||||
|
||||
for key in data:
|
||||
if 'inputs' in data[key] and 'seed' in data[key]['inputs']:
|
||||
data[key]['inputs']['seed'] = randomSeed()
|
||||
|
||||
if api is None:
|
||||
connect_to_websocket()
|
||||
while api.client_id is None:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
workflow_api = data.get("workflow_api")
|
||||
# print(workflow_api)
|
||||
|
||||
for key in workflow_api:
|
||||
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
|
||||
workflow_api[key]['inputs']['seed'] = randomSeed()
|
||||
|
||||
prompt = {
|
||||
"prompt": workflow_api,
|
||||
@ -95,7 +91,9 @@ async def comfy_deploy_run(request):
|
||||
|
||||
res = post_prompt(prompt)
|
||||
|
||||
# print(prompt)
|
||||
prompt_metadata[res['prompt_id']] = {
|
||||
'status_endpoint': data.get('status_endpoint'),
|
||||
}
|
||||
|
||||
status = 200
|
||||
if "error" in res:
|
||||
@ -105,69 +103,41 @@ async def comfy_deploy_run(request):
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class ComfyApi:
|
||||
def __init__(self):
|
||||
self.websocket = None
|
||||
self.client_id = None
|
||||
|
||||
async def connect(self, uri):
|
||||
self.websocket = await websockets.connect(uri)
|
||||
|
||||
# Event listeners
|
||||
await self.on_open()
|
||||
await self.on_message()
|
||||
await self.on_close()
|
||||
|
||||
async def close(self):
|
||||
await self.websocket.close()
|
||||
|
||||
async def on_open(self):
|
||||
print("Connection opened")
|
||||
|
||||
async def on_message(self):
|
||||
async for message in self.websocket:
|
||||
if isinstance(message, bytes):
|
||||
print("Received binary message, skipping...")
|
||||
continue # skip to the next message
|
||||
logging.info(f"Received message: {message}")
|
||||
|
||||
try:
|
||||
message_data = json.loads(message)
|
||||
|
||||
msg_type = message_data["type"]
|
||||
|
||||
if msg_type == "status" and message_data["data"]["sid"] is not None:
|
||||
self.client_id = message_data["data"]["sid"]
|
||||
logging.info(f"Received client_id: {self.client_id}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logging.info(f"Failed to parse message as JSON: {message}")
|
||||
|
||||
async def on_close(self):
|
||||
print("Connection closed")
|
||||
|
||||
async def run(self, uri):
|
||||
await self.connect(uri)
|
||||
|
||||
def connect_to_websocket():
|
||||
global api, api_task
|
||||
api = ComfyApi()
|
||||
api_task = asyncio.create_task(api.run('ws://localhost:8188/ws'))
|
||||
|
||||
prompt_server = server.PromptServer.instance
|
||||
|
||||
send_json = prompt_server.send_json
|
||||
async def send_json_override(self, event, data, sid=None):
|
||||
print("INTERNAL:", event, data, sid)
|
||||
|
||||
prompt_id = data.get('prompt_id')
|
||||
|
||||
if event == 'execution_start':
|
||||
update_run(prompt_id, Status.RUNNING)
|
||||
|
||||
# if event == 'executing':
|
||||
# update_run(prompt_id, Status.RUNNING)
|
||||
|
||||
if event == 'executed':
|
||||
update_run(prompt_id, Status.SUCCESS)
|
||||
|
||||
await self.send_json_original(event, data, sid)
|
||||
print("Sending event:", sid, event, data)
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
NOT_STARTED = "not-started"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
def update_run(prompt_id, status: Status):
|
||||
if prompt_id in prompt_metadata and ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
|
||||
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
|
||||
body = {
|
||||
"run_id": prompt_id,
|
||||
"status": status.value,
|
||||
}
|
||||
prompt_metadata[prompt_id]['status'] = status
|
||||
requests.post(status_endpoint, json=body)
|
||||
|
||||
prompt_server.send_json_original = prompt_server.send_json
|
||||
prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)
|
||||
|
||||
@atexit.register
|
||||
def close_websocket():
|
||||
print("Got close_websocket")
|
||||
|
||||
global api, api_task
|
||||
if api_task:
|
||||
api_task.cancel()
|
||||
prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)
|
@ -1,4 +1,8 @@
|
||||
/** @type {import('next').NextConfig} */
|
||||
const nextConfig = {}
|
||||
const nextConfig = {
|
||||
eslint: {
|
||||
ignoreDuringBuilds: true,
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = nextConfig
|
||||
module.exports = nextConfig;
|
||||
|
@ -1,5 +1,8 @@
|
||||
import { MachineSelect, VersionSelect } from "@/components/VersionSelect";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
MachineSelect,
|
||||
RunWorkflowButton,
|
||||
VersionSelect,
|
||||
} from "@/components/VersionSelect";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
@ -7,12 +10,24 @@ import {
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCaption,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { db } from "@/db/db";
|
||||
import { workflowTable, workflowVersionTable } from "@/db/schema";
|
||||
import {
|
||||
workflowRunsTable,
|
||||
workflowTable,
|
||||
workflowVersionTable,
|
||||
} from "@/db/schema";
|
||||
import { getRelativeTime } from "@/lib/getRelativeTime";
|
||||
import { getMachines } from "@/server/curdMachine";
|
||||
import { desc, eq } from "drizzle-orm";
|
||||
import { Play } from "lucide-react";
|
||||
|
||||
export async function findFirstTableWithVersion(workflow_id: string) {
|
||||
return await db.query.workflowTable.findFirst({
|
||||
@ -21,6 +36,32 @@ export async function findFirstTableWithVersion(workflow_id: string) {
|
||||
});
|
||||
}
|
||||
|
||||
export async function findAllRuns(workflow_id: string) {
|
||||
const workflowVersion = await db.query.workflowVersionTable.findFirst({
|
||||
where: eq(workflowVersionTable.workflow_id, workflow_id),
|
||||
});
|
||||
|
||||
if (!workflowVersion) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return await db.query.workflowRunsTable.findMany({
|
||||
where: eq(workflowRunsTable.workflow_version_id, workflowVersion?.id),
|
||||
with: {
|
||||
machine: {
|
||||
columns: {
|
||||
name: true,
|
||||
},
|
||||
},
|
||||
version: {
|
||||
columns: {
|
||||
version: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export default async function Page({
|
||||
params,
|
||||
}: {
|
||||
@ -45,9 +86,7 @@ export default async function Page({
|
||||
<div className="flex gap-2 ">
|
||||
<VersionSelect workflow={workflow} />
|
||||
<MachineSelect machines={machines} />
|
||||
<Button className="gap-2">
|
||||
Run <Play size={14} />
|
||||
</Button>
|
||||
<RunWorkflowButton workflow={workflow} machines={machines} />
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
@ -57,8 +96,37 @@ export default async function Page({
|
||||
<CardTitle>Run</CardTitle>
|
||||
</CardHeader>
|
||||
|
||||
<CardContent />
|
||||
<CardContent>
|
||||
<RunsTable workflow_id={workflow_id} />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
async function RunsTable(props: { workflow_id: string }) {
|
||||
const allRuns = await findAllRuns(props.workflow_id);
|
||||
return (
|
||||
<Table>
|
||||
<TableCaption>A list of your recent runs.</TableCaption>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead className="w-[100px]">Version</TableHead>
|
||||
<TableHead>Machine</TableHead>
|
||||
<TableHead>Time</TableHead>
|
||||
<TableHead className="text-right">Status</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{allRuns.map((run) => (
|
||||
<TableRow key={run.id}>
|
||||
<TableCell>{run.version.version}</TableCell>
|
||||
<TableCell className="font-medium">{run.machine.name}</TableCell>
|
||||
<TableCell>{getRelativeTime(run.created_at)}</TableCell>
|
||||
<TableCell className="text-right">{run.status}</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
);
|
||||
}
|
||||
|
@ -1,56 +1,106 @@
|
||||
import { parseDataSafe } from "../../../lib/parseDataSafe";
|
||||
import { db } from "@/db/db";
|
||||
import {
|
||||
workflowRunStatus,
|
||||
workflowRunsTable,
|
||||
workflowTable,
|
||||
workflowVersionTable,
|
||||
} from "@/db/schema";
|
||||
import { eq, sql } from "drizzle-orm";
|
||||
import { workflowRunsTable } from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { NextResponse } from "next/server";
|
||||
import { ZodFormattedError, z } from "zod";
|
||||
import { z } from "zod";
|
||||
|
||||
const Request = z.object({
|
||||
workflow_version_id: z.string(),
|
||||
// workflow_version: z.number().optional(),
|
||||
machine_id: z.string(),
|
||||
});
|
||||
|
||||
export async function OPTIONS(request: Request) {
|
||||
return new Response(null, {
|
||||
status: 204,
|
||||
headers: {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||
},
|
||||
});
|
||||
}
|
||||
const ComfyAPI_Run = z.object({
|
||||
prompt_id: z.string(),
|
||||
number: z.number(),
|
||||
node_errors: z.any(),
|
||||
});
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const [data, error] = await parseDataSafe(Request, request);
|
||||
if (!data || error) return error;
|
||||
|
||||
let { workflow_version_id, machine_id } = data;
|
||||
const origin = new URL(request.url).origin;
|
||||
|
||||
const { workflow_version_id, machine_id } = data;
|
||||
|
||||
const machine = await db.query.machinesTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, machine_id),
|
||||
});
|
||||
|
||||
if (!machine) {
|
||||
return new Response("Machine not found", {
|
||||
status: 404,
|
||||
});
|
||||
}
|
||||
|
||||
const workflow_version_data =
|
||||
// workflow_version_id
|
||||
// ?
|
||||
await db.query.workflowVersionTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, workflow_version_id),
|
||||
});
|
||||
// : workflow_version != undefined
|
||||
// ? await db.query.workflowVersionTable.findFirst({
|
||||
// where: and(
|
||||
// eq(workflowVersionTable.version, workflow_version),
|
||||
// eq(workflowVersionTable.workflow_id)
|
||||
// ),
|
||||
// })
|
||||
// : null;
|
||||
|
||||
if (!workflow_version_data) {
|
||||
return new Response("Workflow version not found", {
|
||||
status: 404,
|
||||
});
|
||||
}
|
||||
|
||||
const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
|
||||
|
||||
// Sending to comfyui
|
||||
const result = await fetch(comfyui_endpoint, {
|
||||
method: "POST",
|
||||
// headers: {
|
||||
// "Content-Type": "application/json",
|
||||
// },
|
||||
body: JSON.stringify({
|
||||
workflow_api: workflow_version_data.workflow_api,
|
||||
status_endpoint: `${origin}/api/update-run`,
|
||||
}),
|
||||
})
|
||||
.then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
|
||||
.catch((error) => {
|
||||
console.error(error);
|
||||
return new Response(error.details, {
|
||||
status: 500,
|
||||
});
|
||||
});
|
||||
|
||||
// return the error
|
||||
if (result instanceof Response) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Add to our db
|
||||
const workflow_run = await db
|
||||
.insert(workflowRunsTable)
|
||||
.values({
|
||||
workflow_version_id,
|
||||
id: result.prompt_id,
|
||||
workflow_version_id: workflow_version_data.id,
|
||||
machine_id,
|
||||
})
|
||||
.returning();
|
||||
|
||||
revalidatePath(`./${workflow_version_data.workflow_id}`);
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
workflow_run_id: workflow_run[0].id,
|
||||
},
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||
},
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ import { parseDataSafe } from "../../../lib/parseDataSafe";
|
||||
import { db } from "@/db/db";
|
||||
import { workflowRunsTable } from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { NextResponse } from "next/server";
|
||||
import { z } from "zod";
|
||||
|
||||
@ -10,17 +11,6 @@ const Request = z.object({
|
||||
status: z.enum(["not-started", "running", "success", "failed"]),
|
||||
});
|
||||
|
||||
export async function OPTIONS(request: Request) {
|
||||
return new Response(null, {
|
||||
status: 204,
|
||||
headers: {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const [data, error] = await parseDataSafe(Request, request);
|
||||
if (!data || error) return error;
|
||||
@ -32,7 +22,14 @@ export async function POST(request: Request) {
|
||||
.set({
|
||||
status: status,
|
||||
})
|
||||
.where(eq(workflowRunsTable.id, run_id));
|
||||
.where(eq(workflowRunsTable.id, run_id))
|
||||
.returning();
|
||||
|
||||
const workflow_version = await db.query.workflowVersionTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, workflow_run[0].workflow_version_id),
|
||||
});
|
||||
|
||||
revalidatePath(`./${workflow_version?.workflow_id}`);
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
@ -40,11 +37,6 @@ export async function POST(request: Request) {
|
||||
},
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ import { WorkflowList } from "@/components/WorkflowList";
|
||||
import { db } from "@/db/db";
|
||||
import { usersTable, workflowTable, workflowVersionTable } from "@/db/schema";
|
||||
import { auth, clerkClient } from "@clerk/nextjs";
|
||||
import { desc, eq, sql } from "drizzle-orm";
|
||||
import { desc, eq } from "drizzle-orm";
|
||||
|
||||
export default function Home() {
|
||||
return <WorkflowServer />;
|
||||
|
7
web/src/components/LoadingIcon.tsx
Normal file
7
web/src/components/LoadingIcon.tsx
Normal file
@ -0,0 +1,7 @@
|
||||
"use client";
|
||||
import { LoaderIcon } from "lucide-react";
|
||||
import * as React from "react";
|
||||
|
||||
export function LoadingIcon() {
|
||||
return <LoaderIcon size={14} className="ml-2 animate-spin" />;
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { getRelativeTime } from "../lib/getRelativeTime";
|
||||
import { LoadingIcon } from "./LoadingIcon";
|
||||
import {
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
@ -23,15 +23,12 @@ import {
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuCheckboxItem,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
@ -41,7 +38,6 @@ import {
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { addMachine, deleteMachine } from "@/server/curdMachine";
|
||||
import { deleteWorkflow } from "@/server/deleteWorkflow";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import type {
|
||||
ColumnDef,
|
||||
@ -57,14 +53,8 @@ import {
|
||||
getSortedRowModel,
|
||||
useReactTable,
|
||||
} from "@tanstack/react-table";
|
||||
import {
|
||||
ArrowUpDown,
|
||||
ChevronDown,
|
||||
LoaderIcon,
|
||||
MoreHorizontal,
|
||||
} from "lucide-react";
|
||||
import { ArrowUpDown, MoreHorizontal } from "lucide-react";
|
||||
import * as React from "react";
|
||||
import { useFormStatus } from "react-dom";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { z } from "zod";
|
||||
|
||||
@ -145,7 +135,9 @@ export const columns: ColumnDef<Machine>[] = [
|
||||
);
|
||||
},
|
||||
cell: ({ row }) => (
|
||||
<div className="capitalize text-right">{getRelativeTime(row.original.date)}</div>
|
||||
<div className="capitalize text-right">
|
||||
{getRelativeTime(row.original.date)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
|
||||
@ -187,7 +179,7 @@ export const columns: ColumnDef<Machine>[] = [
|
||||
export function MachineList({ data }: { data: Machine[] }) {
|
||||
const [sorting, setSorting] = React.useState<SortingState>([]);
|
||||
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
||||
[],
|
||||
[]
|
||||
);
|
||||
const [columnVisibility, setColumnVisibility] =
|
||||
React.useState<VisibilityState>({});
|
||||
@ -265,7 +257,7 @@ export function MachineList({ data }: { data: Machine[] }) {
|
||||
? null
|
||||
: flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext(),
|
||||
header.getContext()
|
||||
)}
|
||||
</TableHead>
|
||||
);
|
||||
@ -284,7 +276,7 @@ export function MachineList({ data }: { data: Machine[] }) {
|
||||
<TableCell key={cell.id}>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext(),
|
||||
cell.getContext()
|
||||
)}
|
||||
</TableCell>
|
||||
))}
|
||||
@ -418,8 +410,7 @@ function AddWorkflowButton({ pending }: { pending: boolean }) {
|
||||
// const { pending } = useFormStatus();
|
||||
return (
|
||||
<Button type="submit" disabled={pending}>
|
||||
Save changes{" "}
|
||||
{pending && <LoaderIcon size={14} className="ml-2 animate-spin" />}
|
||||
Save changes {pending && <LoadingIcon />}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { findFirstTableWithVersion } from "@/app/[workflow_id]/page";
|
||||
import type { findFirstTableWithVersion } from "@/app/[workflow_id]/page";
|
||||
import { LoadingIcon } from "@/components/LoadingIcon";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@ -10,15 +12,26 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { getMachines } from "@/server/curdMachine";
|
||||
import type { getMachines } from "@/server/curdMachine";
|
||||
import { Play } from "lucide-react";
|
||||
import { parseAsInteger, useQueryState } from "next-usequerystate";
|
||||
import { useState } from "react";
|
||||
|
||||
export function VersionSelect({
|
||||
workflow,
|
||||
}: {
|
||||
workflow: Awaited<ReturnType<typeof findFirstTableWithVersion>>;
|
||||
}) {
|
||||
const [version, setVersion] = useQueryState("version", {
|
||||
defaultValue: workflow?.versions[0].version?.toString() ?? "",
|
||||
});
|
||||
return (
|
||||
<Select defaultValue={workflow?.versions[0].version?.toString()}>
|
||||
<Select
|
||||
value={version}
|
||||
onValueChange={(v) => {
|
||||
setVersion(v);
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="w-[180px]">
|
||||
<SelectValue placeholder="Select a version" />
|
||||
</SelectTrigger>
|
||||
@ -26,7 +39,7 @@ export function VersionSelect({
|
||||
<SelectGroup>
|
||||
<SelectLabel>Versions</SelectLabel>
|
||||
{workflow?.versions.map((x) => (
|
||||
<SelectItem value={x.version?.toString() ?? ""}>
|
||||
<SelectItem key={x.id} value={x.version?.toString() ?? ""}>
|
||||
{x.version}
|
||||
</SelectItem>
|
||||
))}
|
||||
@ -36,28 +49,76 @@ export function VersionSelect({
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
export function MachineSelect({
|
||||
machines,
|
||||
}: {
|
||||
machines: Awaited<ReturnType<typeof getMachines>>;
|
||||
}) {
|
||||
return (
|
||||
<Select defaultValue={machines[0].id}>
|
||||
<SelectTrigger className="w-[180px]">
|
||||
<SelectValue placeholder="Select a version" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectLabel>Versions</SelectLabel>
|
||||
{machines?.map((x) => (
|
||||
<SelectItem value={x.id ?? ""}>
|
||||
{x.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
|
||||
machines,
|
||||
}: {
|
||||
machines: Awaited<ReturnType<typeof getMachines>>;
|
||||
}) {
|
||||
const [machine, setMachine] = useQueryState("machine", {
|
||||
defaultValue: machines[0].id ?? "",
|
||||
});
|
||||
return (
|
||||
<Select
|
||||
value={machine}
|
||||
onValueChange={(v) => {
|
||||
setMachine(v);
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="w-[180px]">
|
||||
<SelectValue placeholder="Select a version" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectLabel>Versions</SelectLabel>
|
||||
{machines?.map((x) => (
|
||||
<SelectItem key={x.id} value={x.id ?? ""}>
|
||||
{x.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
|
||||
export function RunWorkflowButton({
|
||||
workflow,
|
||||
machines,
|
||||
}: {
|
||||
workflow: Awaited<ReturnType<typeof findFirstTableWithVersion>>;
|
||||
machines: Awaited<ReturnType<typeof getMachines>>;
|
||||
}) {
|
||||
const [version] = useQueryState("version", {
|
||||
defaultValue: workflow?.versions[0].version ?? 1,
|
||||
...parseAsInteger,
|
||||
});
|
||||
const [machine] = useQueryState("machine", {
|
||||
defaultValue: machines[0].id ?? "",
|
||||
});
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
return (
|
||||
<Button
|
||||
className="gap-2"
|
||||
disabled={isLoading}
|
||||
onClick={async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
await fetch(`/api/create-run`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
workflow_version_id: workflow?.versions.find(
|
||||
(x) => x.version === version
|
||||
)?.id,
|
||||
machine_id: machine,
|
||||
}),
|
||||
});
|
||||
setIsLoading(false);
|
||||
} catch (error) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Run <Play size={14} /> {isLoading && <LoadingIcon />}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
@ -17,21 +17,8 @@ export const usersTable = dbSchema.table("users", {
|
||||
name: text("name").notNull(),
|
||||
created_at: timestamp("created_at").defaultNow(),
|
||||
updated_at: timestamp("updated_at").defaultNow(),
|
||||
// primary_avatar_id: uuid("primary_avatar_id").references(
|
||||
// () => chatAvatarTable.id,
|
||||
// ),
|
||||
// twitter_initial_json: jsonb("twitter_initial_json").$type<
|
||||
// Omit<UserV2Result, "errors">
|
||||
// >(),
|
||||
// initial_prompt: text("initial_prompt"),
|
||||
// payment_status: text("payment_status"),
|
||||
// early_access: boolean("early_access").default(false),
|
||||
});
|
||||
|
||||
// export const usersRelations = relations(userTable, ({ many }) => ({
|
||||
// chat_avatars: many(chatAvatarTable),
|
||||
// }));
|
||||
|
||||
export const workflowTable = dbSchema.table("workflows", {
|
||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||
user_id: text("user_id")
|
||||
@ -70,7 +57,7 @@ export const workflowVersionRelations = relations(
|
||||
fields: [workflowVersionTable.workflow_id],
|
||||
references: [workflowTable.id],
|
||||
}),
|
||||
}),
|
||||
})
|
||||
);
|
||||
|
||||
export const workflowRunStatus = pgEnum("workflow_run_status", [
|
||||
@ -98,45 +85,30 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
});
|
||||
|
||||
export const workflowRunRelations = relations(workflowRunsTable, ({ one }) => ({
|
||||
machine: one(machinesTable, {
|
||||
fields: [workflowRunsTable.machine_id],
|
||||
references: [machinesTable.id],
|
||||
}),
|
||||
version: one(workflowVersionTable, {
|
||||
fields: [workflowRunsTable.workflow_version_id],
|
||||
references: [workflowVersionTable.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
// when user delete, also delete all the workflow versions
|
||||
export const machinesTable = dbSchema.table("machines", {
|
||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||
user_id: text("user_id").references(() => usersTable.id, {
|
||||
onDelete: "no action",
|
||||
}).notNull(),
|
||||
user_id: text("user_id")
|
||||
.references(() => usersTable.id, {
|
||||
onDelete: "no action",
|
||||
})
|
||||
.notNull(),
|
||||
name: text("name").notNull(),
|
||||
endpoint: text("endpoint").notNull(),
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
});
|
||||
|
||||
// export const chatAvatarRelations = relations(chatAvatarTable, ({ one }) => ({
|
||||
// author: one(userTable, {
|
||||
// fields: [chatAvatarTable.user_id],
|
||||
// references: [userTable.id],
|
||||
// }),
|
||||
// }));
|
||||
|
||||
// export const subscriptionTable = dbSchema.table("subscription", {
|
||||
// id: text("id").primaryKey().notNull(),
|
||||
// email: text("email"),
|
||||
// user_id: text("user_id"),
|
||||
// status: text("status"),
|
||||
// created_at: timestamp("created_at").defaultNow(),
|
||||
// updated_at: timestamp("updated_at").defaultNow(),
|
||||
// });
|
||||
|
||||
// export const subscriptionRelations = relations(
|
||||
// subscriptionTable,
|
||||
// ({ one }) => ({
|
||||
// user: one(userTable, {
|
||||
// fields: [subscriptionTable.user_id],
|
||||
// references: [userTable.id],
|
||||
// }),
|
||||
// }),
|
||||
// );
|
||||
|
||||
export type UserType = InferSelectModel<typeof usersTable>;
|
||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
||||
// export type ChatAvatarType = InferSelectModel<typeof chatAvatarTable>;
|
||||
// export type SubscriptionType = InferSelectModel<typeof subscriptionTable>;
|
||||
|
@ -1,10 +1,11 @@
|
||||
import { NextResponse } from "next/server";
|
||||
import { ZodError, ZodType, z } from "zod";
|
||||
import type { ZodType, z } from "zod";
|
||||
import { ZodError } from "zod";
|
||||
|
||||
export async function parseDataSafe<T extends ZodType<any, any, any>>(
|
||||
schema: T,
|
||||
request: Request,
|
||||
headers?: HeadersInit,
|
||||
headers?: HeadersInit
|
||||
): Promise<[z.infer<T> | undefined, NextResponse | undefined]> {
|
||||
let data: z.infer<T> | undefined = undefined;
|
||||
try {
|
||||
@ -30,7 +31,7 @@ export async function parseDataSafe<T extends ZodType<any, any, any>>(
|
||||
{
|
||||
message: "Invalid request",
|
||||
},
|
||||
{ status: 500, statusText: "Invalid request", headers: headers },
|
||||
{ status: 500, statusText: "Invalid request", headers: headers }
|
||||
),
|
||||
];
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user