feat: add new ws realtime event

This commit is contained in:
BennyKok 2023-12-11 10:49:37 +08:00
parent ed853fc5f1
commit f9ed8145d2
11 changed files with 272 additions and 108 deletions

View File

@ -20,6 +20,9 @@ import atexit
import logging
from enum import Enum
import aiohttp
from aiohttp import web
api = None
api_task = None
prompt_metadata = {}
@ -101,6 +104,41 @@ async def comfy_deploy_run(request):
return web.json_response(res, status=status)
sockets = dict()
@server.PromptServer.instance.routes.get('/comfy-deploy/ws')
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
sid = request.rel_url.query.get('clientId', '')
if sid:
# Reusing existing session, remove old
sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
sockets[sid] = ws
try:
# Send initial state to the new client
await send("status", { 'sid': sid }, sid)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
sockets.pop(sid, None)
return ws
async def send(event, data, sid=None):
if sid:
ws = sockets.get(sid)
if ws:
await ws.send_json({ 'event': event, 'data': data })
else:
for ws in sockets.values():
await ws.send_json({ 'event': event, 'data': data })
logging.basicConfig(level=logging.INFO)
prompt_server = server.PromptServer.instance
@ -111,6 +149,9 @@ async def send_json_override(self, event, data, sid=None):
prompt_id = data.get('prompt_id')
# now we send everything
await send(event, data)
if event == 'execution_start':
update_run(prompt_id, Status.RUNNING)

Binary file not shown.

View File

@ -35,9 +35,11 @@
"react": "^18",
"react-dom": "^18",
"react-hook-form": "^7.48.2",
"react-use-websocket": "^4.5.0",
"tailwind-merge": "^2.1.0",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4"
"zod": "^3.22.4",
"zustand": "^4.4.7"
},
"devDependencies": {
"eslint-config-next": "^14.0.4",

View File

@ -1,4 +1,6 @@
import { RunDisplay } from "../../components/RunDisplay";
import { LoadingIcon } from "@/components/LoadingIcon";
import { MachinesWSMain } from "@/components/MachinesWS";
import {
MachineSelect,
RunWorkflowButton,
@ -16,7 +18,6 @@ import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
@ -83,6 +84,8 @@ export default async function Page({
<MachineSelect machines={machines} />
<RunWorkflowButton workflow={workflow} machines={machines} />
</div>
<MachinesWSMain machines={machines} />
</CardContent>
</Card>
@ -109,26 +112,20 @@ async function RunsTable(props: { workflow_id: string }) {
<TableHead className="w-[100px]">Version</TableHead>
<TableHead>Machine</TableHead>
<TableHead>Time</TableHead>
<TableHead>Live Status</TableHead>
<TableHead className="text-right">Status</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{allRuns.map((run) => (
<TableRow key={run.id}>
<TableCell>{run.version.version}</TableCell>
<TableCell className="font-medium">{run.machine.name}</TableCell>
<TableCell>{getRelativeTime(run.created_at)}</TableCell>
<TableCell className="text-right">
<StatusBadge run={run} />
</TableCell>
</TableRow>
<RunDisplay run={run} key={run.id} />
))}
</TableBody>
</Table>
);
}
function StatusBadge({
export function StatusBadge({
run,
}: {
run: Awaited<ReturnType<typeof findAllRuns>>[0];

View File

@ -1,9 +1,5 @@
import { parseDataSafe } from "../../../lib/parseDataSafe";
import { db } from "@/db/db";
import { workflowRunsTable } from "@/db/schema";
import { eq } from "drizzle-orm";
import { revalidatePath } from "next/cache";
import { NextResponse } from "next/server";
import { createRun } from "../../../server/createRun";
import { z } from "zod";
const Request = z.object({
@ -12,7 +8,7 @@ const Request = z.object({
machine_id: z.string(),
});
const ComfyAPI_Run = z.object({
export const ComfyAPI_Run = z.object({
prompt_id: z.string(),
number: z.number(),
node_errors: z.any(),
@ -26,84 +22,5 @@ export async function POST(request: Request) {
const { workflow_version_id, machine_id } = data;
const machine = await db.query.machinesTable.findFirst({
where: eq(workflowRunsTable.id, machine_id),
});
if (!machine) {
return new Response("Machine not found", {
status: 404,
});
}
const workflow_version_data =
// workflow_version_id
// ?
await db.query.workflowVersionTable.findFirst({
where: eq(workflowRunsTable.id, workflow_version_id),
});
// : workflow_version != undefined
// ? await db.query.workflowVersionTable.findFirst({
// where: and(
// eq(workflowVersionTable.version, workflow_version),
// eq(workflowVersionTable.workflow_id)
// ),
// })
// : null;
if (!workflow_version_data) {
return new Response("Workflow version not found", {
status: 404,
});
}
const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
// Sending to comfyui
const result = await fetch(comfyui_endpoint, {
method: "POST",
// headers: {
// "Content-Type": "application/json",
// },
body: JSON.stringify({
workflow_api: workflow_version_data.workflow_api,
status_endpoint: `${origin}/api/update-run`,
}),
})
.then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
.catch((error) => {
console.error(error);
return new Response(error.details, {
status: 500,
});
});
console.log(result);
// return the error
if (result instanceof Response) {
return result;
}
// Add to our db
const workflow_run = await db
.insert(workflowRunsTable)
.values({
id: result.prompt_id,
workflow_id: workflow_version_data.workflow_id,
workflow_version_id: workflow_version_data.id,
machine_id,
})
.returning();
revalidatePath(`./${workflow_version_data.workflow_id}`);
return NextResponse.json(
{
workflow_run_id: workflow_run[0].id,
},
{
status: 200,
}
);
return await createRun(origin, workflow_version_id, machine_id);
}

View File

@ -0,0 +1,86 @@
"use client";
import type { getMachines } from "@/server/curdMachine";
import React, { useEffect } from "react";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { create } from "zustand";
type State = {
data: {
id: string;
json: {
event: string;
data: any;
};
}[];
addData: (
id: string,
json: {
event: string;
data: any;
}
) => void;
};
export const useStore = create<State>((set) => ({
data: [],
addData: (id, json) =>
set((state) => ({
...state,
data: [...state.data, { id, json }],
})),
}));
export function MachinesWSMain(props: {
machines: Awaited<ReturnType<typeof getMachines>>;
}) {
return (
<div className="flex flex-col gap-2 mt-6">
Machine Status
{props.machines.map((x) => (
<MachineWS key={x.id} machine={x} />
))}
</div>
);
}
function MachineWS({
machine,
}: {
machine: Awaited<ReturnType<typeof getMachines>>[0];
}) {
const { addData } = useStore();
const wsEndpoint = machine.endpoint.replace(/^http/, "ws");
const { lastMessage, readyState } = useWebSocket(
`${wsEndpoint}/comfy-deploy/ws`,
{
reconnectAttempts: 10,
reconnectInterval: 1000,
}
);
const connectionStatus = {
[ReadyState.CONNECTING]: "Connecting",
[ReadyState.OPEN]: "Open",
[ReadyState.CLOSING]: "Closing",
[ReadyState.CLOSED]: "Closed",
[ReadyState.UNINSTANTIATED]: "Uninstantiated",
}[readyState];
useEffect(() => {
if (!lastMessage?.data) return;
const message = JSON.parse(lastMessage.data);
console.log(message.event, message);
if (message.data?.prompt_id) {
addData(message.data.prompt_id, message);
}
}, [lastMessage]);
return (
<div className="text-sm">
{machine.name} - {connectionStatus}
</div>
);
}

View File

@ -1,6 +1,6 @@
"use client";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { usePathname } from "next/navigation";
import { useRouter } from "next/navigation";

View File

@ -0,0 +1,27 @@
"use client";
import type { findAllRuns } from "../app/[workflow_id]/page";
import { StatusBadge } from "../app/[workflow_id]/page";
import { useStore } from "@/components/MachinesWS";
import { TableCell, TableRow } from "@/components/ui/table";
import { getRelativeTime } from "@/lib/getRelativeTime";
export function RunDisplay({
run,
}: {
run: Awaited<ReturnType<typeof findAllRuns>>[0];
}) {
const data = useStore((state) => state.data.find((x) => x.id === run.id));
return (
<TableRow>
<TableCell>{run.version.version}</TableCell>
<TableCell className="font-medium">{run.machine.name}</TableCell>
<TableCell>{getRelativeTime(run.created_at)}</TableCell>
<TableCell>{data ? data.json.event : "-"}</TableCell>
<TableCell className="text-right">
<StatusBadge run={run} />
</TableCell>
</TableRow>
);
}

View File

@ -12,6 +12,7 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { createRun } from "@/server/createRun";
import type { getMachines } from "@/server/curdMachine";
import { Play } from "lucide-react";
import { parseAsInteger, useQueryState } from "next-usequerystate";
@ -101,17 +102,15 @@ export function RunWorkflowButton({
className="gap-2"
disabled={isLoading}
onClick={async () => {
const workflow_version_id = workflow?.versions.find(
(x) => x.version === version
)?.id;
if (!workflow_version_id) return;
setIsLoading(true);
try {
await fetch(`/api/create-run`, {
method: "POST",
body: JSON.stringify({
workflow_version_id: workflow?.versions.find(
(x) => x.version === version
)?.id,
machine_id: machine,
}),
});
const origin = window.location.origin;
await createRun(origin, workflow_version_id, machine);
setIsLoading(false);
} catch (error) {
setIsLoading(false);

View File

@ -0,0 +1,95 @@
"use server";
import { ComfyAPI_Run } from "../app/api/create-run/route";
import { db } from "@/db/db";
import { workflowRunsTable } from "@/db/schema";
import { eq } from "drizzle-orm";
import { revalidatePath } from "next/cache";
import { NextResponse } from "next/server";
import "server-only";
export async function createRun(
origin: string,
workflow_version_id: string,
machine_id: string
) {
const machine = await db.query.machinesTable.findFirst({
where: eq(workflowRunsTable.id, machine_id),
});
if (!machine) {
return new Response("Machine not found", {
status: 404,
});
}
const workflow_version_data =
// workflow_version_id
// ?
await db.query.workflowVersionTable.findFirst({
where: eq(workflowRunsTable.id, workflow_version_id),
});
// : workflow_version != undefined
// ? await db.query.workflowVersionTable.findFirst({
// where: and(
// eq(workflowVersionTable.version, workflow_version),
// eq(workflowVersionTable.workflow_id)
// ),
// })
// : null;
if (!workflow_version_data) {
return new Response("Workflow version not found", {
status: 404,
});
}
const comfyui_endpoint = `${machine.endpoint}/comfy-deploy/run`;
// Sending to comfyui
const result = await fetch(comfyui_endpoint, {
method: "POST",
// headers: {
// "Content-Type": "application/json",
// },
body: JSON.stringify({
workflow_api: workflow_version_data.workflow_api,
status_endpoint: `${origin}/api/update-run`,
}),
})
.then(async (res) => ComfyAPI_Run.parseAsync(await res.json()))
.catch((error) => {
console.error(error);
return new Response(error.details, {
status: 500,
});
});
console.log(result);
// return the error
if (result instanceof Response) {
return result;
}
// Add to our db
const workflow_run = await db
.insert(workflowRunsTable)
.values({
id: result.prompt_id,
workflow_id: workflow_version_data.workflow_id,
workflow_version_id: workflow_version_data.id,
machine_id,
})
.returning();
revalidatePath(`/${workflow_version_data.workflow_id}`);
return NextResponse.json(
{
workflow_run_id: workflow_run[0].id,
},
{
status: 200,
}
);
}

View File

@ -1,7 +1,7 @@
"use server";
import { db } from "@/db/db";
import { machinesTable, workflowTable } from "@/db/schema";
import { machinesTable } from "@/db/schema";
import { auth } from "@clerk/nextjs";
import { eq } from "drizzle-orm";
import { revalidatePath } from "next/cache";