feat: add new ws realtime event
This commit is contained in:
parent
ed853fc5f1
commit
f9ed8145d2
41
routes.py
41
routes.py
@ -20,6 +20,9 @@ import atexit
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
api = None
|
||||
api_task = None
|
||||
prompt_metadata = {}
|
||||
@ -101,6 +104,41 @@ async def comfy_deploy_run(request):
|
||||
|
||||
return web.json_response(res, status=status)
|
||||
|
||||
sockets = dict()
|
||||
|
||||
@server.PromptServer.instance.routes.get('/comfy-deploy/ws')
|
||||
async def websocket_handler(request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
sid = request.rel_url.query.get('clientId', '')
|
||||
if sid:
|
||||
# Reusing existing session, remove old
|
||||
sockets.pop(sid, None)
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
sockets[sid] = ws
|
||||
|
||||
try:
|
||||
# Send initial state to the new client
|
||||
await send("status", { 'sid': sid }, sid)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print('ws connection closed with exception %s' % ws.exception())
|
||||
finally:
|
||||
sockets.pop(sid, None)
|
||||
return ws
|
||||
|
||||
async def send(event, data, sid=None):
|
||||
if sid:
|
||||
ws = sockets.get(sid)
|
||||
if ws:
|
||||
await ws.send_json({ 'event': event, 'data': data })
|
||||
else:
|
||||
for ws in sockets.values():
|
||||
await ws.send_json({ 'event': event, 'data': data })
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
prompt_server = server.PromptServer.instance
|
||||
@ -111,6 +149,9 @@ async def send_json_override(self, event, data, sid=None):
|
||||
|
||||
prompt_id = data.get('prompt_id')
|
||||
|
||||
# now we send everything
|
||||
await send(event, data)
|
||||
|
||||
if event == 'execution_start':
|
||||
update_run(prompt_id, Status.RUNNING)
|
||||
|
||||
|
BIN
web/bun.lockb
BIN
web/bun.lockb
Binary file not shown.
@ -35,9 +35,11 @@
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
"react-hook-form": "^7.48.2",
|
||||
"react-use-websocket": "^4.5.0",
|
||||
"tailwind-merge": "^2.1.0",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"zod": "^3.22.4"
|
||||
"zod": "^3.22.4",
|
||||
"zustand": "^4.4.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"eslint-config-next": "^14.0.4",
|
||||
|
@ -1,4 +1,6 @@
|
||||
import { RunDisplay } from "../../components/RunDisplay";
|
||||
import { LoadingIcon } from "@/components/LoadingIcon";
|
||||
import { MachinesWSMain } from "@/components/MachinesWS";
|
||||
import {
|
||||
MachineSelect,
|
||||
RunWorkflowButton,
|
||||
@ -16,7 +18,6 @@ import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCaption,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
@ -83,6 +84,8 @@ export default async function Page({
|
||||
<MachineSelect machines={machines} />
|
||||
<RunWorkflowButton workflow={workflow} machines={machines} />
|
||||
</div>
|
||||
|
||||
<MachinesWSMain machines={machines} />
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
@ -109,26 +112,20 @@ async function RunsTable(props: { workflow_id: string }) {
|
||||
<TableHead className="w-[100px]">Version</TableHead>
|
||||
<TableHead>Machine</TableHead>
|
||||
<TableHead>Time</TableHead>
|
||||
<TableHead>Live Status</TableHead>
|
||||
<TableHead className="text-right">Status</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{allRuns.map((run) => (
|
||||
<TableRow key={run.id}>
|
||||
<TableCell>{run.version.version}</TableCell>
|
||||
<TableCell className="font-medium">{run.machine.name}</TableCell>
|
||||
<TableCell>{getRelativeTime(run.created_at)}</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<StatusBadge run={run} />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
<RunDisplay run={run} key={run.id} />
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
);
|
||||
}
|
||||
|
||||
function StatusBadge({
|
||||
export function StatusBadge({
|
||||
run,
|
||||
}: {
|
||||
run: Awaited<ReturnType<typeof findAllRuns>>[0];
|
||||
|
@ -1,9 +1,5 @@
|
||||
import { parseDataSafe } from "../../../lib/parseDataSafe";
|
||||
import { db } from "@/db/db";
|
||||
import { workflowRunsTable } from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { NextResponse } from "next/server";
|
||||
import { createRun } from "../../../server/createRun";
|
||||
import { z } from "zod";
|
||||
|
||||
const Request = z.object({
|
||||
@ -12,7 +8,7 @@ const Request = z.object({
|
||||
machine_id: z.string(),
|
||||
});
|
||||
|
||||
const ComfyAPI_Run = z.object({
|
||||
export const ComfyAPI_Run = z.object({
|
||||
prompt_id: z.string(),
|
||||
number: z.number(),
|
||||
node_errors: z.any(),
|
||||
@ -26,84 +22,5 @@ export async function POST(request: Request) {
|
||||
|
||||
const { workflow_version_id, machine_id } = data;
|
||||
|
||||
const machine = await db.query.machinesTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, machine_id),
|
||||
});
|
||||
|
||||
if (!machine) {
|
||||
return new Response("Machine not found", {
|
||||
status: 404,
|
||||
});
|
||||
}
|
||||
|
||||
const workflow_version_data =
|
||||
// workflow_version_id
|
||||
// ?
|
||||
await db.query.workflowVersionTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, workflow_version_id),
|
||||
});
|
||||
// : workflow_version != undefined
|
||||
// ? await db.query.workflowVersionTable.findFirst({
|
||||
// where: and(
|
||||
// eq(workflowVersionTable.version, workflow_version),
|
||||
// eq(workflowVersionTable.workflow_id)
|
||||
// ),
|
||||
// })
|
||||
// : null;
|
||||
|
||||
if (!workflow_version_data) {
|
||||
return new Response("Workflow version not found", {
|
||||
status: 404,
|
||||
});
|
||||
}
|
||||
|
||||
const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
|
||||
|
||||
// Sending to comfyui
|
||||
const result = await fetch(comfyui_endpoint, {
|
||||
method: "POST",
|
||||
// headers: {
|
||||
// "Content-Type": "application/json",
|
||||
// },
|
||||
body: JSON.stringify({
|
||||
workflow_api: workflow_version_data.workflow_api,
|
||||
status_endpoint: `${origin}/api/update-run`,
|
||||
}),
|
||||
})
|
||||
.then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
|
||||
.catch((error) => {
|
||||
console.error(error);
|
||||
return new Response(error.details, {
|
||||
status: 500,
|
||||
});
|
||||
});
|
||||
|
||||
console.log(result);
|
||||
|
||||
// return the error
|
||||
if (result instanceof Response) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Add to our db
|
||||
const workflow_run = await db
|
||||
.insert(workflowRunsTable)
|
||||
.values({
|
||||
id: result.prompt_id,
|
||||
workflow_id: workflow_version_data.workflow_id,
|
||||
workflow_version_id: workflow_version_data.id,
|
||||
machine_id,
|
||||
})
|
||||
.returning();
|
||||
|
||||
revalidatePath(`./${workflow_version_data.workflow_id}`);
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
workflow_run_id: workflow_run[0].id,
|
||||
},
|
||||
{
|
||||
status: 200,
|
||||
}
|
||||
);
|
||||
return await createRun(origin, workflow_version_id, machine_id);
|
||||
}
|
||||
|
86
web/src/components/MachinesWS.tsx
Normal file
86
web/src/components/MachinesWS.tsx
Normal file
@ -0,0 +1,86 @@
|
||||
"use client";
|
||||
|
||||
import type { getMachines } from "@/server/curdMachine";
|
||||
import React, { useEffect } from "react";
|
||||
import useWebSocket, { ReadyState } from "react-use-websocket";
|
||||
import { create } from "zustand";
|
||||
|
||||
type State = {
|
||||
data: {
|
||||
id: string;
|
||||
json: {
|
||||
event: string;
|
||||
data: any;
|
||||
};
|
||||
}[];
|
||||
addData: (
|
||||
id: string,
|
||||
json: {
|
||||
event: string;
|
||||
data: any;
|
||||
}
|
||||
) => void;
|
||||
};
|
||||
|
||||
export const useStore = create<State>((set) => ({
|
||||
data: [],
|
||||
addData: (id, json) =>
|
||||
set((state) => ({
|
||||
...state,
|
||||
data: [...state.data, { id, json }],
|
||||
})),
|
||||
}));
|
||||
|
||||
export function MachinesWSMain(props: {
|
||||
machines: Awaited<ReturnType<typeof getMachines>>;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-col gap-2 mt-6">
|
||||
Machine Status
|
||||
{props.machines.map((x) => (
|
||||
<MachineWS key={x.id} machine={x} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MachineWS({
|
||||
machine,
|
||||
}: {
|
||||
machine: Awaited<ReturnType<typeof getMachines>>[0];
|
||||
}) {
|
||||
const { addData } = useStore();
|
||||
const wsEndpoint = machine.endpoint.replace(/^http/, "ws");
|
||||
const { lastMessage, readyState } = useWebSocket(
|
||||
`${wsEndpoint}/comfy-deploy/ws`,
|
||||
{
|
||||
reconnectAttempts: 10,
|
||||
reconnectInterval: 1000,
|
||||
}
|
||||
);
|
||||
|
||||
const connectionStatus = {
|
||||
[ReadyState.CONNECTING]: "Connecting",
|
||||
[ReadyState.OPEN]: "Open",
|
||||
[ReadyState.CLOSING]: "Closing",
|
||||
[ReadyState.CLOSED]: "Closed",
|
||||
[ReadyState.UNINSTANTIATED]: "Uninstantiated",
|
||||
}[readyState];
|
||||
|
||||
useEffect(() => {
|
||||
if (!lastMessage?.data) return;
|
||||
|
||||
const message = JSON.parse(lastMessage.data);
|
||||
console.log(message.event, message);
|
||||
|
||||
if (message.data?.prompt_id) {
|
||||
addData(message.data.prompt_id, message);
|
||||
}
|
||||
}, [lastMessage]);
|
||||
|
||||
return (
|
||||
<div className="text-sm">
|
||||
{machine.name} - {connectionStatus}
|
||||
</div>
|
||||
);
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
|
27
web/src/components/RunDisplay.tsx
Normal file
27
web/src/components/RunDisplay.tsx
Normal file
@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import type { findAllRuns } from "../app/[workflow_id]/page";
|
||||
import { StatusBadge } from "../app/[workflow_id]/page";
|
||||
import { useStore } from "@/components/MachinesWS";
|
||||
import { TableCell, TableRow } from "@/components/ui/table";
|
||||
import { getRelativeTime } from "@/lib/getRelativeTime";
|
||||
|
||||
export function RunDisplay({
|
||||
run,
|
||||
}: {
|
||||
run: Awaited<ReturnType<typeof findAllRuns>>[0];
|
||||
}) {
|
||||
const data = useStore((state) => state.data.find((x) => x.id === run.id));
|
||||
|
||||
return (
|
||||
<TableRow>
|
||||
<TableCell>{run.version.version}</TableCell>
|
||||
<TableCell className="font-medium">{run.machine.name}</TableCell>
|
||||
<TableCell>{getRelativeTime(run.created_at)}</TableCell>
|
||||
<TableCell>{data ? data.json.event : "-"}</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<StatusBadge run={run} />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
}
|
@ -12,6 +12,7 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { createRun } from "@/server/createRun";
|
||||
import type { getMachines } from "@/server/curdMachine";
|
||||
import { Play } from "lucide-react";
|
||||
import { parseAsInteger, useQueryState } from "next-usequerystate";
|
||||
@ -101,17 +102,15 @@ export function RunWorkflowButton({
|
||||
className="gap-2"
|
||||
disabled={isLoading}
|
||||
onClick={async () => {
|
||||
const workflow_version_id = workflow?.versions.find(
|
||||
(x) => x.version === version
|
||||
)?.id;
|
||||
if (!workflow_version_id) return;
|
||||
|
||||
setIsLoading(true);
|
||||
try {
|
||||
await fetch(`/api/create-run`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
workflow_version_id: workflow?.versions.find(
|
||||
(x) => x.version === version
|
||||
)?.id,
|
||||
machine_id: machine,
|
||||
}),
|
||||
});
|
||||
const origin = window.location.origin;
|
||||
await createRun(origin, workflow_version_id, machine);
|
||||
setIsLoading(false);
|
||||
} catch (error) {
|
||||
setIsLoading(false);
|
||||
|
95
web/src/server/createRun.ts
Normal file
95
web/src/server/createRun.ts
Normal file
@ -0,0 +1,95 @@
|
||||
"use server";
|
||||
|
||||
import { ComfyAPI_Run } from "../app/api/create-run/route";
|
||||
import { db } from "@/db/db";
|
||||
import { workflowRunsTable } from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { NextResponse } from "next/server";
|
||||
import "server-only";
|
||||
|
||||
export async function createRun(
|
||||
origin: string,
|
||||
workflow_version_id: string,
|
||||
machine_id: string
|
||||
) {
|
||||
const machine = await db.query.machinesTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, machine_id),
|
||||
});
|
||||
|
||||
if (!machine) {
|
||||
return new Response("Machine not found", {
|
||||
status: 404,
|
||||
});
|
||||
}
|
||||
|
||||
const workflow_version_data =
|
||||
// workflow_version_id
|
||||
// ?
|
||||
await db.query.workflowVersionTable.findFirst({
|
||||
where: eq(workflowRunsTable.id, workflow_version_id),
|
||||
});
|
||||
// : workflow_version != undefined
|
||||
// ? await db.query.workflowVersionTable.findFirst({
|
||||
// where: and(
|
||||
// eq(workflowVersionTable.version, workflow_version),
|
||||
// eq(workflowVersionTable.workflow_id)
|
||||
// ),
|
||||
// })
|
||||
// : null;
|
||||
if (!workflow_version_data) {
|
||||
return new Response("Workflow version not found", {
|
||||
status: 404,
|
||||
});
|
||||
}
|
||||
|
||||
const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
|
||||
|
||||
// Sending to comfyui
|
||||
const result = await fetch(comfyui_endpoint, {
|
||||
method: "POST",
|
||||
// headers: {
|
||||
// "Content-Type": "application/json",
|
||||
// },
|
||||
body: JSON.stringify({
|
||||
workflow_api: workflow_version_data.workflow_api,
|
||||
status_endpoint: `${origin}/api/update-run`,
|
||||
}),
|
||||
})
|
||||
.then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
|
||||
.catch((error) => {
|
||||
console.error(error);
|
||||
return new Response(error.details, {
|
||||
status: 500,
|
||||
});
|
||||
});
|
||||
|
||||
console.log(result);
|
||||
|
||||
// return the error
|
||||
if (result instanceof Response) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Add to our db
|
||||
const workflow_run = await db
|
||||
.insert(workflowRunsTable)
|
||||
.values({
|
||||
id: result.prompt_id,
|
||||
workflow_id: workflow_version_data.workflow_id,
|
||||
workflow_version_id: workflow_version_data.id,
|
||||
machine_id,
|
||||
})
|
||||
.returning();
|
||||
|
||||
revalidatePath(`/${workflow_version_data.workflow_id}`);
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
workflow_run_id: workflow_run[0].id,
|
||||
},
|
||||
{
|
||||
status: 200,
|
||||
}
|
||||
);
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
"use server";
|
||||
|
||||
import { db } from "@/db/db";
|
||||
import { machinesTable, workflowTable } from "@/db/schema";
|
||||
import { machinesTable } from "@/db/schema";
|
||||
import { auth } from "@clerk/nextjs";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
Loading…
x
Reference in New Issue
Block a user