diff --git a/routes.py b/routes.py index 5e6fe9d..6db1603 100644 --- a/routes.py +++ b/routes.py @@ -20,6 +20,9 @@ import atexit import logging from enum import Enum +import aiohttp +from aiohttp import web + api = None api_task = None prompt_metadata = {} @@ -101,6 +104,41 @@ async def comfy_deploy_run(request): return web.json_response(res, status=status) +sockets = dict() + +@server.PromptServer.instance.routes.get('/comfy-deploy/ws') +async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + sid = request.rel_url.query.get('clientId', '') + if sid: + # Reusing existing session, remove old + sockets.pop(sid, None) + else: + sid = uuid.uuid4().hex + + sockets[sid] = ws + + try: + # Send initial state to the new client + await send("status", { 'sid': sid }, sid) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.ERROR: + print('ws connection closed with exception %s' % ws.exception()) + finally: + sockets.pop(sid, None) + return ws + +async def send(event, data, sid=None): + if sid: + ws = sockets.get(sid) + if ws: + await ws.send_json({ 'event': event, 'data': data }) + else: + for ws in sockets.values(): + await ws.send_json({ 'event': event, 'data': data }) + logging.basicConfig(level=logging.INFO) prompt_server = server.PromptServer.instance @@ -111,6 +149,9 @@ async def send_json_override(self, event, data, sid=None): prompt_id = data.get('prompt_id') + # now we send everything + await send(event, data) + if event == 'execution_start': update_run(prompt_id, Status.RUNNING) diff --git a/web/bun.lockb b/web/bun.lockb index 1fa660a..5897671 100755 Binary files a/web/bun.lockb and b/web/bun.lockb differ diff --git a/web/package.json b/web/package.json index c0e7265..9b8fb55 100644 --- a/web/package.json +++ b/web/package.json @@ -35,9 +35,11 @@ "react": "^18", "react-dom": "^18", "react-hook-form": "^7.48.2", + "react-use-websocket": "^4.5.0", "tailwind-merge": "^2.1.0", "tailwindcss-animate": "^1.0.7", - "zod": "^3.22.4" + "zod": "^3.22.4", + "zustand": "^4.4.7" }, "devDependencies": { "eslint-config-next": "^14.0.4", diff --git a/web/src/app/[workflow_id]/page.tsx b/web/src/app/[workflow_id]/page.tsx index a10ea02..b4aea3e 100644 --- a/web/src/app/[workflow_id]/page.tsx +++ b/web/src/app/[workflow_id]/page.tsx @@ -1,4 +1,6 @@ +import { RunDisplay } from "../../components/RunDisplay"; import { LoadingIcon } from "@/components/LoadingIcon"; +import { MachinesWSMain } from "@/components/MachinesWS"; import { MachineSelect, RunWorkflowButton, @@ -16,7 +18,6 @@ import { Table, TableBody, TableCaption, - TableCell, TableHead, TableHeader, TableRow, @@ -83,6 +84,8 @@ export default async function Page({ + + @@ -109,26 +112,20 @@ async function RunsTable(props: { workflow_id: string }) { Version Machine Time + Live Status Status {allRuns.map((run) => ( - - {run.version.version} - {run.machine.name} - {getRelativeTime(run.created_at)} - - - - + ))} ); } -function StatusBadge({ +export function StatusBadge({ run, }: { run: Awaited>[0]; diff --git a/web/src/app/api/create-run/route.ts b/web/src/app/api/create-run/route.ts index b563491..aab477e 100644 --- a/web/src/app/api/create-run/route.ts +++ b/web/src/app/api/create-run/route.ts @@ -1,9 +1,5 @@ import { parseDataSafe } from "../../../lib/parseDataSafe"; -import { db } from "@/db/db"; -import { workflowRunsTable } from "@/db/schema"; -import { eq } from "drizzle-orm"; -import { revalidatePath } from "next/cache"; -import { NextResponse } from "next/server"; +import { createRun } from "../../../server/createRun"; import { z } from "zod"; const Request = z.object({ @@ -12,7 +8,7 @@ const Request = z.object({ machine_id: z.string(), }); -const ComfyAPI_Run = z.object({ +export const ComfyAPI_Run = z.object({ prompt_id: z.string(), number: z.number(), node_errors: z.any(), @@ -26,84 +22,5 @@ export async function POST(request: Request) { const { workflow_version_id, machine_id } = data; - const machine = await db.query.machinesTable.findFirst({ - where: eq(workflowRunsTable.id, machine_id), - }); - - if (!machine) { - return new Response("Machine not found", { - status: 404, - }); - } - - const workflow_version_data = - // workflow_version_id - // ? - await db.query.workflowVersionTable.findFirst({ - where: eq(workflowRunsTable.id, workflow_version_id), - }); - // : workflow_version != undefined - // ? await db.query.workflowVersionTable.findFirst({ - // where: and( - // eq(workflowVersionTable.version, workflow_version), - // eq(workflowVersionTable.workflow_id) - // ), - // }) - // : null; - - if (!workflow_version_data) { - return new Response("Workflow version not found", { - status: 404, - }); - } - - const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`; - - // Sending to comfyui - const result = await fetch(comfyui_endpoint, { - method: "POST", - // headers: { - // "Content-Type": "application/json", - // }, - body: JSON.stringify({ - workflow_api: workflow_version_data.workflow_api, - status_endpoint: `${origin}/api/update-run`, - }), - }) - .then(async (res) => ComfyAPI_Run.parseAsync(await res.json())) - .catch((error) => { - console.error(error); - return new Response(error.details, { - status: 500, - }); - }); - - console.log(result); - - // return the error - if (result instanceof Response) { - return result; - } - - // Add to our db - const workflow_run = await db - .insert(workflowRunsTable) - .values({ - id: result.prompt_id, - workflow_id: workflow_version_data.workflow_id, - workflow_version_id: workflow_version_data.id, - machine_id, - }) - .returning(); - - revalidatePath(`./${workflow_version_data.workflow_id}`); - - return NextResponse.json( - { - workflow_run_id: workflow_run[0].id, - }, - { - status: 200, - } - ); + return await createRun(origin, workflow_version_id, machine_id); } diff --git a/web/src/components/MachinesWS.tsx b/web/src/components/MachinesWS.tsx new file mode 100644 index 0000000..9d2ea04 --- /dev/null +++ b/web/src/components/MachinesWS.tsx @@ -0,0 +1,86 @@ +"use client"; + +import type { getMachines } from "@/server/curdMachine"; +import React, { useEffect } from "react"; +import useWebSocket, { ReadyState } from "react-use-websocket"; +import { create } from "zustand"; + +type State = { + data: { + id: string; + json: { + event: string; + data: any; + }; + }[]; + addData: ( + id: string, + json: { + event: string; + data: any; + } + ) => void; +}; + +export const useStore = create((set) => ({ + data: [], + addData: (id, json) => + set((state) => ({ + ...state, + data: [...state.data, { id, json }], + })), +})); + +export function MachinesWSMain(props: { + machines: Awaited>; +}) { + return ( +
+ Machine Status + {props.machines.map((x) => ( + + ))} +
+ ); +} + +function MachineWS({ + machine, +}: { + machine: Awaited>[0]; +}) { + const { addData } = useStore(); + const wsEndpoint = machine.endpoint.replace(/^http/, "ws"); + const { lastMessage, readyState } = useWebSocket( + `${wsEndpoint}/comfy-deploy/ws`, + { + reconnectAttempts: 10, + reconnectInterval: 1000, + } + ); + + const connectionStatus = { + [ReadyState.CONNECTING]: "Connecting", + [ReadyState.OPEN]: "Open", + [ReadyState.CLOSING]: "Closing", + [ReadyState.CLOSED]: "Closed", + [ReadyState.UNINSTANTIATED]: "Uninstantiated", + }[readyState]; + + useEffect(() => { + if (!lastMessage?.data) return; + + const message = JSON.parse(lastMessage.data); + console.log(message.event, message); + + if (message.data?.prompt_id) { + addData(message.data.prompt_id, message); + } + }, [lastMessage]); + + return ( +
+ {machine.name} - {connectionStatus} +
+ ); +} diff --git a/web/src/components/NavbarRight.tsx b/web/src/components/NavbarRight.tsx index 91127e9..14364ca 100644 --- a/web/src/components/NavbarRight.tsx +++ b/web/src/components/NavbarRight.tsx @@ -1,6 +1,6 @@ "use client"; -import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { usePathname } from "next/navigation"; import { useRouter } from "next/navigation"; diff --git a/web/src/components/RunDisplay.tsx b/web/src/components/RunDisplay.tsx new file mode 100644 index 0000000..4f47465 --- /dev/null +++ b/web/src/components/RunDisplay.tsx @@ -0,0 +1,27 @@ +"use client"; + +import type { findAllRuns } from "../app/[workflow_id]/page"; +import { StatusBadge } from "../app/[workflow_id]/page"; +import { useStore } from "@/components/MachinesWS"; +import { TableCell, TableRow } from "@/components/ui/table"; +import { getRelativeTime } from "@/lib/getRelativeTime"; + +export function RunDisplay({ + run, +}: { + run: Awaited>[0]; +}) { + const data = useStore((state) => state.data.find((x) => x.id === run.id)); + + return ( + + {run.version.version} + {run.machine.name} + {getRelativeTime(run.created_at)} + {data ? data.json.event : "-"} + + + + + ); +} diff --git a/web/src/components/VersionSelect.tsx b/web/src/components/VersionSelect.tsx index 2f901ed..f2cb7d2 100644 --- a/web/src/components/VersionSelect.tsx +++ b/web/src/components/VersionSelect.tsx @@ -12,6 +12,7 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { createRun } from "@/server/createRun"; import type { getMachines } from "@/server/curdMachine"; import { Play } from "lucide-react"; import { parseAsInteger, useQueryState } from "next-usequerystate"; @@ -101,17 +102,15 @@ export function RunWorkflowButton({ className="gap-2" disabled={isLoading} onClick={async () => { + const workflow_version_id = workflow?.versions.find( + (x) => x.version === version + )?.id; + if (!workflow_version_id) return; + setIsLoading(true); try { - await fetch(`/api/create-run`, { - method: "POST", - body: JSON.stringify({ - workflow_version_id: workflow?.versions.find( - (x) => x.version === version - )?.id, - machine_id: machine, - }), - }); + const origin = window.location.origin; + await createRun(origin, workflow_version_id, machine); setIsLoading(false); } catch (error) { setIsLoading(false); diff --git a/web/src/server/createRun.ts b/web/src/server/createRun.ts new file mode 100644 index 0000000..bfd922f --- /dev/null +++ b/web/src/server/createRun.ts @@ -0,0 +1,95 @@ +"use server"; + +import { ComfyAPI_Run } from "../app/api/create-run/route"; +import { db } from "@/db/db"; +import { workflowRunsTable } from "@/db/schema"; +import { eq } from "drizzle-orm"; +import { revalidatePath } from "next/cache"; +import { NextResponse } from "next/server"; +import "server-only"; + +export async function createRun( + origin: string, + workflow_version_id: string, + machine_id: string +) { + const machine = await db.query.machinesTable.findFirst({ + where: eq(workflowRunsTable.id, machine_id), + }); + + if (!machine) { + return new Response("Machine not found", { + status: 404, + }); + } + + const workflow_version_data = + // workflow_version_id + // ? + await db.query.workflowVersionTable.findFirst({ + where: eq(workflowRunsTable.id, workflow_version_id), + }); + // : workflow_version != undefined + // ? await db.query.workflowVersionTable.findFirst({ + // where: and( + // eq(workflowVersionTable.version, workflow_version), + // eq(workflowVersionTable.workflow_id) + // ), + // }) + // : null; + if (!workflow_version_data) { + return new Response("Workflow version not found", { + status: 404, + }); + } + + const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`; + + // Sending to comfyui + const result = await fetch(comfyui_endpoint, { + method: "POST", + // headers: { + // "Content-Type": "application/json", + // }, + body: JSON.stringify({ + workflow_api: workflow_version_data.workflow_api, + status_endpoint: `${origin}/api/update-run`, + }), + }) + .then(async (res) => ComfyAPI_Run.parseAsync(await res.json())) + .catch((error) => { + console.error(error); + return new Response(error.details, { + status: 500, + }); + }); + + console.log(result); + + // return the error + if (result instanceof Response) { + return result; + } + + // Add to our db + const workflow_run = await db + .insert(workflowRunsTable) + .values({ + id: result.prompt_id, + workflow_id: workflow_version_data.workflow_id, + workflow_version_id: workflow_version_data.id, + machine_id, + }) + .returning(); + + revalidatePath(`/${workflow_version_data.workflow_id}`); + + return NextResponse.json( + { + workflow_run_id: workflow_run[0].id, + }, + { + status: 200, + } + ); +} diff --git a/web/src/server/curdMachine.ts b/web/src/server/curdMachine.ts index 10a1aec..53917ab 100644 --- a/web/src/server/curdMachine.ts +++ b/web/src/server/curdMachine.ts @@ -1,7 +1,7 @@ "use server"; import { db } from "@/db/db"; -import { machinesTable, workflowTable } from "@/db/schema"; +import { machinesTable } from "@/db/schema"; import { auth } from "@clerk/nextjs"; import { eq } from "drizzle-orm"; import { revalidatePath } from "next/cache";