From 9c8a518c465c26bb1acba8400b8eba42bacc1b16 Mon Sep 17 00:00:00 2001 From: BennyKok Date: Sun, 10 Dec 2023 16:06:20 +0800 Subject: [PATCH] feat: display runs and create run and update run endpoint --- routes.py | 112 ++++++++++--------------- web/next.config.js | 8 +- web/src/app/[workflow_id]/page.tsx | 84 +++++++++++++++++-- web/src/app/api/create-run/route.ts | 102 +++++++++++++++++------ web/src/app/api/update-run/route.ts | 26 +++--- web/src/app/page.tsx | 2 +- web/src/components/LoadingIcon.tsx | 7 ++ web/src/components/MachineList.tsx | 27 +++---- web/src/components/VersionSelect.tsx | 117 ++++++++++++++++++++------- web/src/db/schema.ts | 62 ++++---------- web/src/lib/parseDataSafe.ts | 7 +- 11 files changed, 335 insertions(+), 219 deletions(-) create mode 100644 web/src/components/LoadingIcon.tsx diff --git a/routes.py b/routes.py index c68cb0c..5e6fe9d 100644 --- a/routes.py +++ b/routes.py @@ -15,13 +15,14 @@ import execution import random import uuid -import websockets import asyncio import atexit import logging +from enum import Enum api = None api_task = None +prompt_metadata = {} load_dotenv() @@ -73,20 +74,15 @@ def randomSeed(num_digits=15): @server.PromptServer.instance.routes.post("/comfy-deploy/run") async def comfy_deploy_run(request): + print("hi") prompt_server = server.PromptServer.instance data = await request.json() - for key in data: - if 'inputs' in data[key] and 'seed' in data[key]['inputs']: - data[key]['inputs']['seed'] = randomSeed() - - if api is None: - connect_to_websocket() - while api.client_id is None: - await asyncio.sleep(0.1) - workflow_api = data.get("workflow_api") - # print(workflow_api) + + for key in workflow_api: + if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']: + workflow_api[key]['inputs']['seed'] = randomSeed() prompt = { "prompt": workflow_api, @@ -95,7 +91,9 @@ async def comfy_deploy_run(request): res = post_prompt(prompt) - # print(prompt) + prompt_metadata[res['prompt_id']] = { + 'status_endpoint': data.get('status_endpoint'), + } status = 200 if "error" in res: @@ -105,69 +103,41 @@ async def comfy_deploy_run(request): logging.basicConfig(level=logging.INFO) -class ComfyApi: - def __init__(self): - self.websocket = None - self.client_id = None - - async def connect(self, uri): - self.websocket = await websockets.connect(uri) - - # Event listeners - await self.on_open() - await self.on_message() - await self.on_close() - - async def close(self): - await self.websocket.close() - - async def on_open(self): - print("Connection opened") - - async def on_message(self): - async for message in self.websocket: - if isinstance(message, bytes): - print("Received binary message, skipping...") - continue # skip to the next message - logging.info(f"Received message: {message}") - - try: - message_data = json.loads(message) - - msg_type = message_data["type"] - - if msg_type == "status" and message_data["data"]["sid"] is not None: - self.client_id = message_data["data"]["sid"] - logging.info(f"Received client_id: {self.client_id}") - - except json.JSONDecodeError: - logging.info(f"Failed to parse message as JSON: {message}") - - async def on_close(self): - print("Connection closed") - - async def run(self, uri): - await self.connect(uri) - -def connect_to_websocket(): - global api, api_task - api = ComfyApi() - api_task = asyncio.create_task(api.run('ws://localhost:8188/ws')) - prompt_server = server.PromptServer.instance send_json = prompt_server.send_json async def send_json_override(self, event, data, sid=None): + print("INTERNAL:", event, data, sid) + + prompt_id = data.get('prompt_id') + + if event == 'execution_start': + update_run(prompt_id, Status.RUNNING) + + # if event == 'executing': + # update_run(prompt_id, Status.RUNNING) + + if event == 'executed': + update_run(prompt_id, Status.SUCCESS) + await self.send_json_original(event, data, sid) - print("Sending event:", sid, event, data) + + +class Status(Enum): + NOT_STARTED = "not-started" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + +def update_run(prompt_id, status: Status): + if prompt_id in prompt_metadata and ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status): + status_endpoint = prompt_metadata[prompt_id]['status_endpoint'] + body = { + "run_id": prompt_id, + "status": status.value, + } + prompt_metadata[prompt_id]['status'] = status + requests.post(status_endpoint, json=body) prompt_server.send_json_original = prompt_server.send_json -prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer) - -@atexit.register -def close_websocket(): - print("Got close_websocket") - - global api, api_task - if api_task: - api_task.cancel() \ No newline at end of file +prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer) \ No newline at end of file diff --git a/web/next.config.js b/web/next.config.js index 767719f..954fac0 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -1,4 +1,8 @@ /** @type {import('next').NextConfig} */ -const nextConfig = {} +const nextConfig = { + eslint: { + ignoreDuringBuilds: true, + }, +}; -module.exports = nextConfig +module.exports = nextConfig; diff --git a/web/src/app/[workflow_id]/page.tsx b/web/src/app/[workflow_id]/page.tsx index 3da52d1..79ec293 100644 --- a/web/src/app/[workflow_id]/page.tsx +++ b/web/src/app/[workflow_id]/page.tsx @@ -1,5 +1,8 @@ -import { MachineSelect, VersionSelect } from "@/components/VersionSelect"; -import { Button } from "@/components/ui/button"; +import { + MachineSelect, + RunWorkflowButton, + VersionSelect, +} from "@/components/VersionSelect"; import { Card, CardContent, @@ -7,12 +10,24 @@ import { CardHeader, CardTitle, } from "@/components/ui/card"; +import { + Table, + TableBody, + TableCaption, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; import { db } from "@/db/db"; -import { workflowTable, workflowVersionTable } from "@/db/schema"; +import { + workflowRunsTable, + workflowTable, + workflowVersionTable, +} from "@/db/schema"; import { getRelativeTime } from "@/lib/getRelativeTime"; import { getMachines } from "@/server/curdMachine"; import { desc, eq } from "drizzle-orm"; -import { Play } from "lucide-react"; export async function findFirstTableWithVersion(workflow_id: string) { return await db.query.workflowTable.findFirst({ @@ -21,6 +36,32 @@ export async function findFirstTableWithVersion(workflow_id: string) { }); } +export async function findAllRuns(workflow_id: string) { + const workflowVersion = await db.query.workflowVersionTable.findFirst({ + where: eq(workflowVersionTable.workflow_id, workflow_id), + }); + + if (!workflowVersion) { + return []; + } + + return await db.query.workflowRunsTable.findMany({ + where: eq(workflowRunsTable.workflow_version_id, workflowVersion?.id), + with: { + machine: { + columns: { + name: true, + }, + }, + version: { + columns: { + version: true, + }, + }, + }, + }); +} + export default async function Page({ params, }: { @@ -45,9 +86,7 @@ export default async function Page({
- +
@@ -57,8 +96,37 @@ export default async function Page({ Run - + + + ); } + +async function RunsTable(props: { workflow_id: string }) { + const allRuns = await findAllRuns(props.workflow_id); + return ( + + A list of your recent runs. + + + Version + Machine + Time + Status + + + + {allRuns.map((run) => ( + + {run.version.version} + {run.machine.name} + {getRelativeTime(run.created_at)} + {run.status} + + ))} + +
+ ); +} diff --git a/web/src/app/api/create-run/route.ts b/web/src/app/api/create-run/route.ts index 0ea8c3b..9e44df1 100644 --- a/web/src/app/api/create-run/route.ts +++ b/web/src/app/api/create-run/route.ts @@ -1,56 +1,106 @@ import { parseDataSafe } from "../../../lib/parseDataSafe"; import { db } from "@/db/db"; -import { - workflowRunStatus, - workflowRunsTable, - workflowTable, - workflowVersionTable, -} from "@/db/schema"; -import { eq, sql } from "drizzle-orm"; +import { workflowRunsTable } from "@/db/schema"; +import { eq } from "drizzle-orm"; +import { revalidatePath } from "next/cache"; import { NextResponse } from "next/server"; -import { ZodFormattedError, z } from "zod"; +import { z } from "zod"; const Request = z.object({ workflow_version_id: z.string(), + // workflow_version: z.number().optional(), machine_id: z.string(), }); -export async function OPTIONS(request: Request) { - return new Response(null, { - status: 204, - headers: { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization", - }, - }); -} +const ComfyAPI_Run = z.object({ + prompt_id: z.string(), + number: z.number(), + node_errors: z.any(), +}); export async function POST(request: Request) { const [data, error] = await parseDataSafe(Request, request); if (!data || error) return error; - let { workflow_version_id, machine_id } = data; + const origin = new URL(request.url).origin; + const { workflow_version_id, machine_id } = data; + + const machine = await db.query.machinesTable.findFirst({ + where: eq(workflowRunsTable.id, machine_id), + }); + + if (!machine) { + return new Response("Machine not found", { + status: 404, + }); + } + + const workflow_version_data = + // workflow_version_id + // ? + await db.query.workflowVersionTable.findFirst({ + where: eq(workflowRunsTable.id, workflow_version_id), + }); + // : workflow_version != undefined + // ? await db.query.workflowVersionTable.findFirst({ + // where: and( + // eq(workflowVersionTable.version, workflow_version), + // eq(workflowVersionTable.workflow_id) + // ), + // }) + // : null; + + if (!workflow_version_data) { + return new Response("Workflow version not found", { + status: 404, + }); + } + + const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`; + + // Sending to comfyui + const result = await fetch(comfyui_endpoint, { + method: "POST", + // headers: { + // "Content-Type": "application/json", + // }, + body: JSON.stringify({ + workflow_api: workflow_version_data.workflow_api, + status_endpoint: `${origin}/api/update-run`, + }), + }) + .then(async (res) => ComfyAPI_Run.parseAsync(await res.json())) + .catch((error) => { + console.error(error); + return new Response(error.details, { + status: 500, + }); + }); + + // return the error + if (result instanceof Response) { + return result; + } + + // Add to our db const workflow_run = await db .insert(workflowRunsTable) .values({ - workflow_version_id, + id: result.prompt_id, + workflow_version_id: workflow_version_data.id, machine_id, }) .returning(); + revalidatePath(`./${workflow_version_data.workflow_id}`); + return NextResponse.json( { workflow_run_id: workflow_run[0].id, }, { status: 200, - headers: { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization", - }, - }, + } ); } diff --git a/web/src/app/api/update-run/route.ts b/web/src/app/api/update-run/route.ts index 9a5de79..3418680 100644 --- a/web/src/app/api/update-run/route.ts +++ b/web/src/app/api/update-run/route.ts @@ -2,6 +2,7 @@ import { parseDataSafe } from "../../../lib/parseDataSafe"; import { db } from "@/db/db"; import { workflowRunsTable } from "@/db/schema"; import { eq } from "drizzle-orm"; +import { revalidatePath } from "next/cache"; import { NextResponse } from "next/server"; import { z } from "zod"; @@ -10,17 +11,6 @@ const Request = z.object({ status: z.enum(["not-started", "running", "success", "failed"]), }); -export async function OPTIONS(request: Request) { - return new Response(null, { - status: 204, - headers: { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization", - }, - }); -} - export async function POST(request: Request) { const [data, error] = await parseDataSafe(Request, request); if (!data || error) return error; @@ -32,7 +22,14 @@ export async function POST(request: Request) { .set({ status: status, }) - .where(eq(workflowRunsTable.id, run_id)); + .where(eq(workflowRunsTable.id, run_id)) + .returning(); + + const workflow_version = await db.query.workflowVersionTable.findFirst({ + where: eq(workflowRunsTable.id, workflow_run[0].workflow_version_id), + }); + + revalidatePath(`./${workflow_version?.workflow_id}`); return NextResponse.json( { @@ -40,11 +37,6 @@ export async function POST(request: Request) { }, { status: 200, - headers: { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization", - }, } ); } diff --git a/web/src/app/page.tsx b/web/src/app/page.tsx index 1ceb9a4..24e5ca7 100644 --- a/web/src/app/page.tsx +++ b/web/src/app/page.tsx @@ -2,7 +2,7 @@ import { WorkflowList } from "@/components/WorkflowList"; import { db } from "@/db/db"; import { usersTable, workflowTable, workflowVersionTable } from "@/db/schema"; import { auth, clerkClient } from "@clerk/nextjs"; -import { desc, eq, sql } from "drizzle-orm"; +import { desc, eq } from "drizzle-orm"; export default function Home() { return ; diff --git a/web/src/components/LoadingIcon.tsx b/web/src/components/LoadingIcon.tsx new file mode 100644 index 0000000..65f0078 --- /dev/null +++ b/web/src/components/LoadingIcon.tsx @@ -0,0 +1,7 @@ +"use client"; +import { LoaderIcon } from "lucide-react"; +import * as React from "react"; + +export function LoadingIcon() { + return ; +} diff --git a/web/src/components/MachineList.tsx b/web/src/components/MachineList.tsx index 67b3556..b7caff4 100644 --- a/web/src/components/MachineList.tsx +++ b/web/src/components/MachineList.tsx @@ -1,9 +1,9 @@ "use client"; import { getRelativeTime } from "../lib/getRelativeTime"; +import { LoadingIcon } from "./LoadingIcon"; import { FormControl, - FormDescription, FormField, FormItem, FormLabel, @@ -23,15 +23,12 @@ import { } from "@/components/ui/dialog"; import { DropdownMenu, - DropdownMenuCheckboxItem, DropdownMenuContent, DropdownMenuItem, DropdownMenuLabel, - DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; import { Table, TableBody, @@ -41,7 +38,6 @@ import { TableRow, } from "@/components/ui/table"; import { addMachine, deleteMachine } from "@/server/curdMachine"; -import { deleteWorkflow } from "@/server/deleteWorkflow"; import { zodResolver } from "@hookform/resolvers/zod"; import type { ColumnDef, @@ -57,14 +53,8 @@ import { getSortedRowModel, useReactTable, } from "@tanstack/react-table"; -import { - ArrowUpDown, - ChevronDown, - LoaderIcon, - MoreHorizontal, -} from "lucide-react"; +import { ArrowUpDown, MoreHorizontal } from "lucide-react"; import * as React from "react"; -import { useFormStatus } from "react-dom"; import { useForm } from "react-hook-form"; import { z } from "zod"; @@ -145,7 +135,9 @@ export const columns: ColumnDef[] = [ ); }, cell: ({ row }) => ( -
{getRelativeTime(row.original.date)}
+
+ {getRelativeTime(row.original.date)} +
), }, @@ -187,7 +179,7 @@ export const columns: ColumnDef[] = [ export function MachineList({ data }: { data: Machine[] }) { const [sorting, setSorting] = React.useState([]); const [columnFilters, setColumnFilters] = React.useState( - [], + [] ); const [columnVisibility, setColumnVisibility] = React.useState({}); @@ -265,7 +257,7 @@ export function MachineList({ data }: { data: Machine[] }) { ? null : flexRender( header.column.columnDef.header, - header.getContext(), + header.getContext() )} ); @@ -284,7 +276,7 @@ export function MachineList({ data }: { data: Machine[] }) { {flexRender( cell.column.columnDef.cell, - cell.getContext(), + cell.getContext() )} ))} @@ -418,8 +410,7 @@ function AddWorkflowButton({ pending }: { pending: boolean }) { // const { pending } = useFormStatus(); return ( ); } diff --git a/web/src/components/VersionSelect.tsx b/web/src/components/VersionSelect.tsx index 4684e30..2f901ed 100644 --- a/web/src/components/VersionSelect.tsx +++ b/web/src/components/VersionSelect.tsx @@ -1,6 +1,8 @@ "use client"; -import { findFirstTableWithVersion } from "@/app/[workflow_id]/page"; +import type { findFirstTableWithVersion } from "@/app/[workflow_id]/page"; +import { LoadingIcon } from "@/components/LoadingIcon"; +import { Button } from "@/components/ui/button"; import { Select, SelectContent, @@ -10,15 +12,26 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { getMachines } from "@/server/curdMachine"; +import type { getMachines } from "@/server/curdMachine"; +import { Play } from "lucide-react"; +import { parseAsInteger, useQueryState } from "next-usequerystate"; +import { useState } from "react"; export function VersionSelect({ workflow, }: { workflow: Awaited>; }) { + const [version, setVersion] = useQueryState("version", { + defaultValue: workflow?.versions[0].version?.toString() ?? "", + }); return ( - { + setVersion(v); + }} + > @@ -26,7 +39,7 @@ export function VersionSelect({ Versions {workflow?.versions.map((x) => ( - + {x.version} ))} @@ -36,28 +49,76 @@ export function VersionSelect({ ); } - export function MachineSelect({ - machines, - }: { - machines: Awaited>; - }) { - return ( - - ); - } - \ No newline at end of file + machines, +}: { + machines: Awaited>; +}) { + const [machine, setMachine] = useQueryState("machine", { + defaultValue: machines[0].id ?? "", + }); + return ( + + ); +} + +export function RunWorkflowButton({ + workflow, + machines, +}: { + workflow: Awaited>; + machines: Awaited>; +}) { + const [version] = useQueryState("version", { + defaultValue: workflow?.versions[0].version ?? 1, + ...parseAsInteger, + }); + const [machine] = useQueryState("machine", { + defaultValue: machines[0].id ?? "", + }); + const [isLoading, setIsLoading] = useState(false); + return ( + + ); +} diff --git a/web/src/db/schema.ts b/web/src/db/schema.ts index cf73d1f..f2e65f1 100644 --- a/web/src/db/schema.ts +++ b/web/src/db/schema.ts @@ -17,21 +17,8 @@ export const usersTable = dbSchema.table("users", { name: text("name").notNull(), created_at: timestamp("created_at").defaultNow(), updated_at: timestamp("updated_at").defaultNow(), - // primary_avatar_id: uuid("primary_avatar_id").references( - // () => chatAvatarTable.id, - // ), - // twitter_initial_json: jsonb("twitter_initial_json").$type< - // Omit - // >(), - // initial_prompt: text("initial_prompt"), - // payment_status: text("payment_status"), - // early_access: boolean("early_access").default(false), }); -// export const usersRelations = relations(userTable, ({ many }) => ({ -// chat_avatars: many(chatAvatarTable), -// })); - export const workflowTable = dbSchema.table("workflows", { id: uuid("id").primaryKey().defaultRandom().notNull(), user_id: text("user_id") @@ -70,7 +57,7 @@ export const workflowVersionRelations = relations( fields: [workflowVersionTable.workflow_id], references: [workflowTable.id], }), - }), + }) ); export const workflowRunStatus = pgEnum("workflow_run_status", [ @@ -98,45 +85,30 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", { created_at: timestamp("created_at").defaultNow().notNull(), }); +export const workflowRunRelations = relations(workflowRunsTable, ({ one }) => ({ + machine: one(machinesTable, { + fields: [workflowRunsTable.machine_id], + references: [machinesTable.id], + }), + version: one(workflowVersionTable, { + fields: [workflowRunsTable.workflow_version_id], + references: [workflowVersionTable.id], + }), +})); + // when user delete, also delete all the workflow versions export const machinesTable = dbSchema.table("machines", { id: uuid("id").primaryKey().defaultRandom().notNull(), - user_id: text("user_id").references(() => usersTable.id, { - onDelete: "no action", - }).notNull(), + user_id: text("user_id") + .references(() => usersTable.id, { + onDelete: "no action", + }) + .notNull(), name: text("name").notNull(), endpoint: text("endpoint").notNull(), created_at: timestamp("created_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(), }); -// export const chatAvatarRelations = relations(chatAvatarTable, ({ one }) => ({ -// author: one(userTable, { -// fields: [chatAvatarTable.user_id], -// references: [userTable.id], -// }), -// })); - -// export const subscriptionTable = dbSchema.table("subscription", { -// id: text("id").primaryKey().notNull(), -// email: text("email"), -// user_id: text("user_id"), -// status: text("status"), -// created_at: timestamp("created_at").defaultNow(), -// updated_at: timestamp("updated_at").defaultNow(), -// }); - -// export const subscriptionRelations = relations( -// subscriptionTable, -// ({ one }) => ({ -// user: one(userTable, { -// fields: [subscriptionTable.user_id], -// references: [userTable.id], -// }), -// }), -// ); - export type UserType = InferSelectModel; export type WorkflowType = InferSelectModel; -// export type ChatAvatarType = InferSelectModel; -// export type SubscriptionType = InferSelectModel; diff --git a/web/src/lib/parseDataSafe.ts b/web/src/lib/parseDataSafe.ts index 35d7a8e..6f1b1cf 100644 --- a/web/src/lib/parseDataSafe.ts +++ b/web/src/lib/parseDataSafe.ts @@ -1,10 +1,11 @@ import { NextResponse } from "next/server"; -import { ZodError, ZodType, z } from "zod"; +import type { ZodType, z } from "zod"; +import { ZodError } from "zod"; export async function parseDataSafe>( schema: T, request: Request, - headers?: HeadersInit, + headers?: HeadersInit ): Promise<[z.infer | undefined, NextResponse | undefined]> { let data: z.infer | undefined = undefined; try { @@ -30,7 +31,7 @@ export async function parseDataSafe>( { message: "Invalid request", }, - { status: 500, statusText: "Invalid request", headers: headers }, + { status: 500, statusText: "Invalid request", headers: headers } ), ];