diff --git a/routes.py b/routes.py
index c68cb0c..5e6fe9d 100644
--- a/routes.py
+++ b/routes.py
@@ -15,13 +15,14 @@ import execution
import random
import uuid
-import websockets
import asyncio
import atexit
import logging
+from enum import Enum
api = None
api_task = None
+prompt_metadata = {}
load_dotenv()
@@ -73,20 +74,15 @@ def randomSeed(num_digits=15):
@server.PromptServer.instance.routes.post("/comfy-deploy/run")
async def comfy_deploy_run(request):
+ print("hi")
prompt_server = server.PromptServer.instance
data = await request.json()
- for key in data:
- if 'inputs' in data[key] and 'seed' in data[key]['inputs']:
- data[key]['inputs']['seed'] = randomSeed()
-
- if api is None:
- connect_to_websocket()
- while api.client_id is None:
- await asyncio.sleep(0.1)
-
workflow_api = data.get("workflow_api")
- # print(workflow_api)
+
+ for key in workflow_api:
+ if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
+ workflow_api[key]['inputs']['seed'] = randomSeed()
prompt = {
"prompt": workflow_api,
@@ -95,7 +91,9 @@ async def comfy_deploy_run(request):
res = post_prompt(prompt)
- # print(prompt)
+ prompt_metadata[res['prompt_id']] = {
+ 'status_endpoint': data.get('status_endpoint'),
+ }
status = 200
if "error" in res:
@@ -105,69 +103,41 @@ async def comfy_deploy_run(request):
logging.basicConfig(level=logging.INFO)
-class ComfyApi:
- def __init__(self):
- self.websocket = None
- self.client_id = None
-
- async def connect(self, uri):
- self.websocket = await websockets.connect(uri)
-
- # Event listeners
- await self.on_open()
- await self.on_message()
- await self.on_close()
-
- async def close(self):
- await self.websocket.close()
-
- async def on_open(self):
- print("Connection opened")
-
- async def on_message(self):
- async for message in self.websocket:
- if isinstance(message, bytes):
- print("Received binary message, skipping...")
- continue # skip to the next message
- logging.info(f"Received message: {message}")
-
- try:
- message_data = json.loads(message)
-
- msg_type = message_data["type"]
-
- if msg_type == "status" and message_data["data"]["sid"] is not None:
- self.client_id = message_data["data"]["sid"]
- logging.info(f"Received client_id: {self.client_id}")
-
- except json.JSONDecodeError:
- logging.info(f"Failed to parse message as JSON: {message}")
-
- async def on_close(self):
- print("Connection closed")
-
- async def run(self, uri):
- await self.connect(uri)
-
-def connect_to_websocket():
- global api, api_task
- api = ComfyApi()
- api_task = asyncio.create_task(api.run('ws://localhost:8188/ws'))
-
prompt_server = server.PromptServer.instance
send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None):
+ print("INTERNAL:", event, data, sid)
+
+ prompt_id = data.get('prompt_id')
+
+ if event == 'execution_start':
+ update_run(prompt_id, Status.RUNNING)
+
+ # if event == 'executing':
+ # update_run(prompt_id, Status.RUNNING)
+
+ if event == 'executed':
+ update_run(prompt_id, Status.SUCCESS)
+
await self.send_json_original(event, data, sid)
- print("Sending event:", sid, event, data)
+
+
+class Status(Enum):
+ NOT_STARTED = "not-started"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+def update_run(prompt_id, status: Status):
+ if prompt_id in prompt_metadata and ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
+ status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
+ body = {
+ "run_id": prompt_id,
+ "status": status.value,
+ }
+ prompt_metadata[prompt_id]['status'] = status
+ requests.post(status_endpoint, json=body)
prompt_server.send_json_original = prompt_server.send_json
-prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)
-
-@atexit.register
-def close_websocket():
- print("Got close_websocket")
-
- global api, api_task
- if api_task:
- api_task.cancel()
\ No newline at end of file
+prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)
\ No newline at end of file
diff --git a/web/next.config.js b/web/next.config.js
index 767719f..954fac0 100644
--- a/web/next.config.js
+++ b/web/next.config.js
@@ -1,4 +1,8 @@
/** @type {import('next').NextConfig} */
-const nextConfig = {}
+const nextConfig = {
+ eslint: {
+ ignoreDuringBuilds: true,
+ },
+};
-module.exports = nextConfig
+module.exports = nextConfig;
diff --git a/web/src/app/[workflow_id]/page.tsx b/web/src/app/[workflow_id]/page.tsx
index 3da52d1..79ec293 100644
--- a/web/src/app/[workflow_id]/page.tsx
+++ b/web/src/app/[workflow_id]/page.tsx
@@ -1,5 +1,8 @@
-import { MachineSelect, VersionSelect } from "@/components/VersionSelect";
-import { Button } from "@/components/ui/button";
+import {
+ MachineSelect,
+ RunWorkflowButton,
+ VersionSelect,
+} from "@/components/VersionSelect";
import {
Card,
CardContent,
@@ -7,12 +10,24 @@ import {
CardHeader,
CardTitle,
} from "@/components/ui/card";
+import {
+ Table,
+ TableBody,
+ TableCaption,
+ TableCell,
+ TableHead,
+ TableHeader,
+ TableRow,
+} from "@/components/ui/table";
import { db } from "@/db/db";
-import { workflowTable, workflowVersionTable } from "@/db/schema";
+import {
+ workflowRunsTable,
+ workflowTable,
+ workflowVersionTable,
+} from "@/db/schema";
import { getRelativeTime } from "@/lib/getRelativeTime";
import { getMachines } from "@/server/curdMachine";
import { desc, eq } from "drizzle-orm";
-import { Play } from "lucide-react";
export async function findFirstTableWithVersion(workflow_id: string) {
return await db.query.workflowTable.findFirst({
@@ -21,6 +36,32 @@ export async function findFirstTableWithVersion(workflow_id: string) {
});
}
+export async function findAllRuns(workflow_id: string) {
+ const workflowVersion = await db.query.workflowVersionTable.findFirst({
+ where: eq(workflowVersionTable.workflow_id, workflow_id),
+ });
+
+ if (!workflowVersion) {
+ return [];
+ }
+
+ return await db.query.workflowRunsTable.findMany({
+ where: eq(workflowRunsTable.workflow_version_id, workflowVersion?.id),
+ with: {
+ machine: {
+ columns: {
+ name: true,
+ },
+ },
+ version: {
+ columns: {
+ version: true,
+ },
+ },
+ },
+ });
+}
+
export default async function Page({
params,
}: {
@@ -45,9 +86,7 @@ export default async function Page({
@@ -57,8 +96,37 @@ export default async function Page({
Run
-
+
+
+
);
}
+
+async function RunsTable(props: { workflow_id: string }) {
+ const allRuns = await findAllRuns(props.workflow_id);
+ return (
+
+ A list of your recent runs.
+
+
+ Version
+ Machine
+ Time
+ Status
+
+
+
+ {allRuns.map((run) => (
+
+ {run.version.version}
+ {run.machine.name}
+ {getRelativeTime(run.created_at)}
+ {run.status}
+
+ ))}
+
+
+ );
+}
diff --git a/web/src/app/api/create-run/route.ts b/web/src/app/api/create-run/route.ts
index 0ea8c3b..9e44df1 100644
--- a/web/src/app/api/create-run/route.ts
+++ b/web/src/app/api/create-run/route.ts
@@ -1,56 +1,106 @@
import { parseDataSafe } from "../../../lib/parseDataSafe";
import { db } from "@/db/db";
-import {
- workflowRunStatus,
- workflowRunsTable,
- workflowTable,
- workflowVersionTable,
-} from "@/db/schema";
-import { eq, sql } from "drizzle-orm";
+import { workflowRunsTable } from "@/db/schema";
+import { eq } from "drizzle-orm";
+import { revalidatePath } from "next/cache";
import { NextResponse } from "next/server";
-import { ZodFormattedError, z } from "zod";
+import { z } from "zod";
const Request = z.object({
workflow_version_id: z.string(),
+ // workflow_version: z.number().optional(),
machine_id: z.string(),
});
-export async function OPTIONS(request: Request) {
- return new Response(null, {
- status: 204,
- headers: {
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
- },
- });
-}
+const ComfyAPI_Run = z.object({
+ prompt_id: z.string(),
+ number: z.number(),
+ node_errors: z.any(),
+});
export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error;
- let { workflow_version_id, machine_id } = data;
+ const origin = new URL(request.url).origin;
+ const { workflow_version_id, machine_id } = data;
+
+ const machine = await db.query.machinesTable.findFirst({
+ where: eq(workflowRunsTable.id, machine_id),
+ });
+
+ if (!machine) {
+ return new Response("Machine not found", {
+ status: 404,
+ });
+ }
+
+ const workflow_version_data =
+ // workflow_version_id
+ // ?
+ await db.query.workflowVersionTable.findFirst({
+ where: eq(workflowRunsTable.id, workflow_version_id),
+ });
+ // : workflow_version != undefined
+ // ? await db.query.workflowVersionTable.findFirst({
+ // where: and(
+ // eq(workflowVersionTable.version, workflow_version),
+ // eq(workflowVersionTable.workflow_id)
+ // ),
+ // })
+ // : null;
+
+ if (!workflow_version_data) {
+ return new Response("Workflow version not found", {
+ status: 404,
+ });
+ }
+
+ const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
+
+ // Sending to comfyui
+ const result = await fetch(comfyui_endpoint, {
+ method: "POST",
+ // headers: {
+ // "Content-Type": "application/json",
+ // },
+ body: JSON.stringify({
+ workflow_api: workflow_version_data.workflow_api,
+ status_endpoint: `${origin}/api/update-run`,
+ }),
+ })
+ .then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
+ .catch((error) => {
+ console.error(error);
+ return new Response(error.details, {
+ status: 500,
+ });
+ });
+
+ // return the error
+ if (result instanceof Response) {
+ return result;
+ }
+
+ // Add to our db
const workflow_run = await db
.insert(workflowRunsTable)
.values({
- workflow_version_id,
+ id: result.prompt_id,
+ workflow_version_id: workflow_version_data.id,
machine_id,
})
.returning();
+ revalidatePath(`./${workflow_version_data.workflow_id}`);
+
return NextResponse.json(
{
workflow_run_id: workflow_run[0].id,
},
{
status: 200,
- headers: {
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
- },
- },
+ }
);
}
diff --git a/web/src/app/api/update-run/route.ts b/web/src/app/api/update-run/route.ts
index 9a5de79..3418680 100644
--- a/web/src/app/api/update-run/route.ts
+++ b/web/src/app/api/update-run/route.ts
@@ -2,6 +2,7 @@ import { parseDataSafe } from "../../../lib/parseDataSafe";
import { db } from "@/db/db";
import { workflowRunsTable } from "@/db/schema";
import { eq } from "drizzle-orm";
+import { revalidatePath } from "next/cache";
import { NextResponse } from "next/server";
import { z } from "zod";
@@ -10,17 +11,6 @@ const Request = z.object({
status: z.enum(["not-started", "running", "success", "failed"]),
});
-export async function OPTIONS(request: Request) {
- return new Response(null, {
- status: 204,
- headers: {
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
- },
- });
-}
-
export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error;
@@ -32,7 +22,14 @@ export async function POST(request: Request) {
.set({
status: status,
})
- .where(eq(workflowRunsTable.id, run_id));
+ .where(eq(workflowRunsTable.id, run_id))
+ .returning();
+
+ const workflow_version = await db.query.workflowVersionTable.findFirst({
+ where: eq(workflowRunsTable.id, workflow_run[0].workflow_version_id),
+ });
+
+ revalidatePath(`./${workflow_version?.workflow_id}`);
return NextResponse.json(
{
@@ -40,11 +37,6 @@ export async function POST(request: Request) {
},
{
status: 200,
- headers: {
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
- "Access-Control-Allow-Headers": "Content-Type, Authorization",
- },
}
);
}
diff --git a/web/src/app/page.tsx b/web/src/app/page.tsx
index 1ceb9a4..24e5ca7 100644
--- a/web/src/app/page.tsx
+++ b/web/src/app/page.tsx
@@ -2,7 +2,7 @@ import { WorkflowList } from "@/components/WorkflowList";
import { db } from "@/db/db";
import { usersTable, workflowTable, workflowVersionTable } from "@/db/schema";
import { auth, clerkClient } from "@clerk/nextjs";
-import { desc, eq, sql } from "drizzle-orm";
+import { desc, eq } from "drizzle-orm";
export default function Home() {
return ;
diff --git a/web/src/components/LoadingIcon.tsx b/web/src/components/LoadingIcon.tsx
new file mode 100644
index 0000000..65f0078
--- /dev/null
+++ b/web/src/components/LoadingIcon.tsx
@@ -0,0 +1,7 @@
+"use client";
+import { LoaderIcon } from "lucide-react";
+import * as React from "react";
+
+export function LoadingIcon() {
+ return ;
+}
diff --git a/web/src/components/MachineList.tsx b/web/src/components/MachineList.tsx
index 67b3556..b7caff4 100644
--- a/web/src/components/MachineList.tsx
+++ b/web/src/components/MachineList.tsx
@@ -1,9 +1,9 @@
"use client";
import { getRelativeTime } from "../lib/getRelativeTime";
+import { LoadingIcon } from "./LoadingIcon";
import {
FormControl,
- FormDescription,
FormField,
FormItem,
FormLabel,
@@ -23,15 +23,12 @@ import {
} from "@/components/ui/dialog";
import {
DropdownMenu,
- DropdownMenuCheckboxItem,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
- DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Input } from "@/components/ui/input";
-import { Label } from "@/components/ui/label";
import {
Table,
TableBody,
@@ -41,7 +38,6 @@ import {
TableRow,
} from "@/components/ui/table";
import { addMachine, deleteMachine } from "@/server/curdMachine";
-import { deleteWorkflow } from "@/server/deleteWorkflow";
import { zodResolver } from "@hookform/resolvers/zod";
import type {
ColumnDef,
@@ -57,14 +53,8 @@ import {
getSortedRowModel,
useReactTable,
} from "@tanstack/react-table";
-import {
- ArrowUpDown,
- ChevronDown,
- LoaderIcon,
- MoreHorizontal,
-} from "lucide-react";
+import { ArrowUpDown, MoreHorizontal } from "lucide-react";
import * as React from "react";
-import { useFormStatus } from "react-dom";
import { useForm } from "react-hook-form";
import { z } from "zod";
@@ -145,7 +135,9 @@ export const columns: ColumnDef[] = [
);
},
cell: ({ row }) => (
- {getRelativeTime(row.original.date)}
+
+ {getRelativeTime(row.original.date)}
+
),
},
@@ -187,7 +179,7 @@ export const columns: ColumnDef[] = [
export function MachineList({ data }: { data: Machine[] }) {
const [sorting, setSorting] = React.useState([]);
const [columnFilters, setColumnFilters] = React.useState(
- [],
+ []
);
const [columnVisibility, setColumnVisibility] =
React.useState({});
@@ -265,7 +257,7 @@ export function MachineList({ data }: { data: Machine[] }) {
? null
: flexRender(
header.column.columnDef.header,
- header.getContext(),
+ header.getContext()
)}
);
@@ -284,7 +276,7 @@ export function MachineList({ data }: { data: Machine[] }) {
{flexRender(
cell.column.columnDef.cell,
- cell.getContext(),
+ cell.getContext()
)}
))}
@@ -418,8 +410,7 @@ function AddWorkflowButton({ pending }: { pending: boolean }) {
// const { pending } = useFormStatus();
return (
);
}
diff --git a/web/src/components/VersionSelect.tsx b/web/src/components/VersionSelect.tsx
index 4684e30..2f901ed 100644
--- a/web/src/components/VersionSelect.tsx
+++ b/web/src/components/VersionSelect.tsx
@@ -1,6 +1,8 @@
"use client";
-import { findFirstTableWithVersion } from "@/app/[workflow_id]/page";
+import type { findFirstTableWithVersion } from "@/app/[workflow_id]/page";
+import { LoadingIcon } from "@/components/LoadingIcon";
+import { Button } from "@/components/ui/button";
import {
Select,
SelectContent,
@@ -10,15 +12,26 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
-import { getMachines } from "@/server/curdMachine";
+import type { getMachines } from "@/server/curdMachine";
+import { Play } from "lucide-react";
+import { parseAsInteger, useQueryState } from "next-usequerystate";
+import { useState } from "react";
export function VersionSelect({
workflow,
}: {
workflow: Awaited>;
}) {
+ const [version, setVersion] = useQueryState("version", {
+ defaultValue: workflow?.versions[0].version?.toString() ?? "",
+ });
return (
-