Squashed commit of the following:

commit c36b0ec0b374dd8ccbee3a6044ee7e3f1fefe368
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 17:54:54 2024 -0800

    nits on wording and removing link to broken storage/:id page

commit 0777fdcf7b0002244bc713199d3d64eea6b6061e
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 17:23:55 2024 -0800

    builder update config and such

commit 958b795bb2b6ac27ce33c5729ef265b068420e1a
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 17:23:43 2024 -0800

    rename all from checkponit to model

commit 7a9c5636e73bd005499b141a4dd382db5672c962
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 16:51:59 2024 -0800

    rename for consistency

commit 48bebbafab9a95388817df97c15f8ea97e0fea75
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 16:18:36 2024 -0800

    bulider

commit 81dacd9af457886f2f027994d225a7748c738abb
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 16:17:56 2024 -0800

    different types of models
This commit is contained in:
bennykok 2024-01-26 10:08:37 +08:00
parent 62a69dba06
commit 85477aba9d
22 changed files with 2919 additions and 235 deletions

View File

@ -3,4 +3,5 @@ MODAL_TOKEN_SECRET=
CIVITAI_API_KEY= CIVITAI_API_KEY=
# On production set to False # On production set to False
DEPLOY_TEST_FLAG=True DEPLOY_TEST_FLAG=True
CIVITAI_API_KEY=

View File

@ -177,7 +177,7 @@ class Item(BaseModel):
snapshot: Snapshot snapshot: Snapshot
models: List[Model] models: List[Model]
callback_url: str callback_url: str
checkpoint_volume_name: str model_volume_name: str
gpu: GPUType = Field(default=GPUType.T4) gpu: GPUType = Field(default=GPUType.T4)
@field_validator('gpu') @field_validator('gpu')
@ -227,24 +227,31 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
# return {"Hello": "World"} # return {"Hello": "World"}
# definition based on web schema
class UploadType(str, Enum): class UploadType(str, Enum):
checkpoint = "checkpoint" checkpoint = "checkpoint"
lora = "lora"
embedding = "embedding"
class UploadBody(BaseModel): class UploadBody(BaseModel):
download_url: str download_url: str
volume_name: str volume_name: str
volume_id: str volume_id: str
checkpoint_id: str model_id: str
upload_type: UploadType upload_type: UploadType
callback_url: str callback_url: str
# based on ComfyUI's model dir, and our mappings in ./src/template/data/extra_model_paths.yaml
UPLOAD_TYPE_DIR_MAP = { UPLOAD_TYPE_DIR_MAP = {
UploadType.checkpoint: "checkpoints" UploadType.checkpoint: "checkpoints",
UploadType.lora: "loras",
UploadType.embedding: "embeddings",
} }
@app.post("/upload-volume") @app.post("/upload-volume")
async def upload_checkpoint(body: UploadBody): async def upload_model(body: UploadBody):
global last_activity_time global last_activity_time
last_activity_time = time.time() last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}") logger.info(f"Extended inactivity time to {global_timeout}")
@ -254,6 +261,7 @@ async def upload_checkpoint(body: UploadBody):
# check that this # check that this
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id}) return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
async def upload_logic(body: UploadBody): async def upload_logic(body: UploadBody):
folder_path = f"/app/builds/{body.volume_id}" folder_path = f"/app/builds/{body.volume_id}"
@ -270,7 +278,7 @@ async def upload_logic(body: UploadBody):
}, },
"callback_url": body.callback_url, "callback_url": body.callback_url,
"callback_body": { "callback_body": {
"checkpoint_id": body.checkpoint_id, "model_id": body.model_id,
"volume_id": body.volume_id, "volume_id": body.volume_id,
"folder_path": upload_path, "folder_path": upload_path,
}, },
@ -279,51 +287,11 @@ async def upload_logic(body: UploadBody):
with open(f"{folder_path}/config.py", "w") as f: with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config)) f.write("config = " + json.dumps(config))
process = await asyncio.subprocess.create_subprocess_shell( await asyncio.subprocess.create_subprocess_shell(
f"modal run app.py", f"modal run app.py",
# stdout=asyncio.subprocess.PIPE,
# stderr=asyncio.subprocess.PIPE,
cwd=folder_path, cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"} env={**os.environ, "COLUMNS": "10000"}
) )
# error_logs = []
# async def read_stream(stream):
# while True:
# line = await stream.readline()
# if line:
# l = line.decode('utf-8').strip()
# error_logs.append(l)
# logger.error(l)
# error_logs.append({
# "logs": l,
# "timestamp": time.time()
# })
# else:
# break
# stderr_read_task = asyncio.create_task(read_stream(process.stderr))
#
# await asyncio.wait([stderr_read_task])
# await process.wait()
# if process.returncode != 0:
# error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
# # Error handling: send POST request to callback URL with error details
# requests.post(body.callback_url, json={
# "volume_id": body.volume_id,
# "checkpoint_id": body.checkpoint_id,
# "folder_path": upload_path,
# "error_logs": json.dumps(error_logs),
# "status": "failed"
# })
#
# requests.post(body.callback_url, json={
# "checkpoint_id": body.checkpoint_id,
# "volume_id": body.volume_id,
# "folder_path": upload_path,
# "status": "success"
# })
@app.post("/create") @app.post("/create")
async def create_machine(item: Item): async def create_machine(item: Item):
@ -414,8 +382,8 @@ async def build_logic(item: Item):
"name": item.name, "name": item.name,
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"), "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
"gpu": item.gpu, "gpu": item.gpu,
"public_checkpoint_volume": "model-store", "public_model_volume": "model-store",
"private_checkpoint_volume": item.checkpoint_volume_name "private_model_volume": item.model_volume_name
} }
with open(f"{folder_path}/config.py", "w") as f: with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config)) f.write("config = " + json.dumps(config))

View File

@ -2,6 +2,6 @@ config = {
"name": "my-app", "name": "my-app",
"deploy_test": "True", "deploy_test": "True",
"gpu": "T4", "gpu": "T4",
"public_checkpoint_volume": "model-store", "public_model_volume": "model-store",
"private_checkpoint_volume": "private-model-store" "private_model_volume": "private-model-store"
} }

View File

@ -13,3 +13,11 @@ public:
private: private:
base_path: /private_models/ base_path: /private_models/
checkpoints: checkpoints checkpoints: checkpoints
clip: clip
clip_vision: clip_vision
configs: configs
controlnet: controlnet
embeddings: embeddings
loras: loras
upscale_models: upscale_models
vae: vae

View File

@ -1,8 +1,8 @@
import modal import modal
from config import config from config import config
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"]) public_model_volume = modal.Volume.persisted(config["public_model_volume"])
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"]) private_volume = modal.Volume.persisted(config["private_model_volume"])
PUBLIC_BASEMODEL_DIR = "/public_models" PUBLIC_BASEMODEL_DIR = "/public_models"
PRIVATE_BASEMODEL_DIR = "/private_models" PRIVATE_BASEMODEL_DIR = "/private_models"

View File

@ -1,18 +1,18 @@
config = { config = {
"volume_names": { "volume_names": {
"test": { "user4": {
"download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png", "download_url": "https://civitai.com/api/download/models/11745",
"folder_path": "images" "folder_path": "checkpoints"
} }
}, },
"volume_paths": { "volume_paths": {
"test": "/volumes/something" "user4": "/volumes/something",
}, },
"callback_url": "", "callback_url": "",
"callback_body": { "callback_body": {
"checkpoint_id": "", "model_id": "",
"volume_id": "", "volume_id": "",
"folder_path": "images", "folder_path": "checkpoints",
}, },
"civitai_api_key": "", "civitai_api_key": "",
} }

View File

@ -0,0 +1,7 @@
DO $$ BEGIN
CREATE TYPE "model_type" AS ENUM('checkpoint', 'lora', 'embedding', 'vae');
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."checkpoints" ADD COLUMN "model_type" "model_type" NOT NULL;

View File

@ -0,0 +1,26 @@
ALTER TABLE "comfyui_deploy"."checkpoints" RENAME TO "models";--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."checkpoint_volume" RENAME TO "user_volume";--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."models" RENAME COLUMN "checkpoint_volume_id" TO "user_volume_id";--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."models" DROP CONSTRAINT "checkpoints_user_id_users_id_fk";
--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."models" DROP CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk";
--> statement-breakpoint
ALTER TABLE "comfyui_deploy"."user_volume" DROP CONSTRAINT "checkpoint_volume_user_id_users_id_fk";
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."models" ADD CONSTRAINT "models_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."models" ADD CONSTRAINT "models_user_volume_id_user_volume_id_fk" FOREIGN KEY ("user_volume_id") REFERENCES "comfyui_deploy"."user_volume"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."user_volume" ADD CONSTRAINT "user_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -302,6 +302,20 @@
"when": 1706164614659, "when": 1706164614659,
"tag": "0042_windy_madelyne_pryor", "tag": "0042_windy_madelyne_pryor",
"breakpoints": true "breakpoints": true
},
{
"idx": 43,
"version": "5",
"when": 1706225960550,
"tag": "0043_dapper_santa_claus",
"breakpoints": true
},
{
"idx": 44,
"version": "5",
"when": 1706230304140,
"tag": "0044_married_malcolm_colcord",
"breakpoints": true
} }
] ]
} }

View File

@ -1,15 +1,15 @@
import { parseDataSafe } from "../../../../lib/parseDataSafe"; import { parseDataSafe } from "../../../../lib/parseDataSafe";
import { db } from "@/db/db"; import { db } from "@/db/db";
import { checkpointTable, machinesTable } from "@/db/schema"; import { modelTable } from "@/db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { z } from "zod"; import { z } from "zod";
const Request = z.object({ const Request = z.object({
volume_id: z.string(), volume_id: z.string(),
checkpoint_id: z.string(), model_id: z.string(),
folder_path: z.string().optional(), folder_path: z.string().optional(),
status: z.enum(['success', 'failed']), status: z.enum(["success", "failed"]),
error_log: z.string().optional(), error_log: z.string().optional(),
timeout: z.number().optional(), timeout: z.number().optional(),
}); });
@ -18,30 +18,30 @@ export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request); const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error; if (!data || error) return error;
const { checkpoint_id, error_log, status, folder_path } = data; const { model_id, error_log, status, folder_path } = data;
console.log( checkpoint_id, error_log, status, folder_path ) console.log(model_id, error_log, status, folder_path);
if (status === "success") { if (status === "success") {
await db await db
.update(checkpointTable) .update(modelTable)
.set({ .set({
status: "success", status: "success",
folder_path, folder_path,
updated_at: new Date(), updated_at: new Date(),
// build_log: build_log, // build_log: build_log,
}) })
.where(eq(checkpointTable.id, checkpoint_id)); .where(eq(modelTable.id, model_id));
} else { } else {
await db await db
.update(checkpointTable) .update(modelTable)
.set({ .set({
status: "failed", status: "failed",
error_log, error_log,
updated_at: new Date(), updated_at: new Date(),
// status: "error", // status: "error",
// build_log: build_log, // build_log: build_log,
}) })
.where(eq(checkpointTable.id, checkpoint_id)); .where(eq(modelTable.id, model_id));
} }
return NextResponse.json( return NextResponse.json(
@ -50,6 +50,6 @@ export async function POST(request: Request) {
}, },
{ {
status: 200, status: 200,
} },
); );
} }

View File

@ -1,14 +1,14 @@
import { setInitialUserData } from "../../../lib/setInitialUserData"; import { setInitialUserData } from "../../../lib/setInitialUserData";
import { auth } from "@clerk/nextjs"; import { auth } from "@clerk/nextjs";
import { clerkClient } from "@clerk/nextjs/server"; import { clerkClient } from "@clerk/nextjs/server";
import { CheckpointList } from "@/components/CheckpointList"; import { ModelList } from "@/components/ModelList";
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints"; import { getAllUserModels } from "@/server/getAllUserModel";
export default function Page() { export default function Page() {
return <CheckpointListServer />; return <ModelListServer />;
} }
async function CheckpointListServer() { async function ModelListServer() {
const { userId } = auth(); const { userId } = auth();
if (!userId) { if (!userId) {
@ -21,15 +21,15 @@ async function CheckpointListServer() {
await setInitialUserData(userId); await setInitialUserData(userId);
} }
const checkpoints = await getAllUserCheckpoints(); const models = await getAllUserModels();
if (!checkpoints) { if (!models) {
return <div>No checkpoints found</div>; return <div>No models found</div>;
} }
return ( return (
<div className="w-full"> <div className="w-full">
<CheckpointList data={checkpoints} /> <ModelList data={models} />
</div> </div>
); );
} }

View File

@ -4,7 +4,7 @@ import { getRelativeTime } from "../lib/getRelativeTime";
import { Badge } from "@/components/ui/badge"; import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox"; import { Checkbox } from "@/components/ui/checkbox";
import { InsertModal, UpdateModal } from "./InsertModal"; import { InsertModal } from "./InsertModal";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { ScrollArea } from "@/components/ui/scroll-area"; import { ScrollArea } from "@/components/ui/scroll-area";
import { import {
@ -15,7 +15,7 @@ import {
TableHeader, TableHeader,
TableRow, TableRow,
} from "@/components/ui/table"; } from "@/components/ui/table";
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints"; import type { getAllUserModels as getAllUserModels } from "@/server/getAllUserModel";
import type { import type {
ColumnDef, ColumnDef,
ColumnFiltersState, ColumnFiltersState,
@ -32,23 +32,22 @@ import {
} from "@tanstack/react-table"; } from "@tanstack/react-table";
import { ArrowUpDown } from "lucide-react"; import { ArrowUpDown } from "lucide-react";
import * as React from "react"; import * as React from "react";
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint"; import { addCivitaiModel } from "@/server/curdModel";
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema"; import { addCivitaiModelSchema } from "@/server/addCivitaiModelSchema";
import { modelEnumType } from "@/db/schema";
export type CheckpointItemList = NonNullable< export type ModelItemList = NonNullable<
Awaited<ReturnType<typeof getAllUserCheckpoints>> Awaited<ReturnType<typeof getAllUserModels>>
>[0]; >[0];
export const columns: ColumnDef<CheckpointItemList>[] = [ export const columns: ColumnDef<ModelItemList>[] = [
{ {
accessorKey: "id", accessorKey: "id",
id: "select", id: "select",
header: ({ table }) => ( header: ({ table }) => (
<Checkbox <Checkbox
checked={ checked={table.getIsAllPageRowsSelected() ||
table.getIsAllPageRowsSelected() || (table.getIsSomePageRowsSelected() && "indeterminate")}
(table.getIsSomePageRowsSelected() && "indeterminate")
}
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)} onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
aria-label="Select all" aria-label="Select all"
/> />
@ -77,22 +76,23 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
); );
}, },
cell: ({ row }) => { cell: ({ row }) => {
const checkpoint = row.original; const model = row.original;
return ( return (
<a <>
{
/*<a
className="hover:underline flex gap-2" className="hover:underline flex gap-2"
href={`/storage/${checkpoint.id}`} // TODO href={`/storage/${model.id}`} // TODO
> >*/
}
<span className="truncate max-w-[200px]"> <span className="truncate max-w-[200px]">
{row.original.model_name} {row.original.model_name}
</span> </span>
{checkpoint.is_public ? ( {model.is_public
<Badge variant="green">Public</Badge> ? <Badge variant="green">Public</Badge>
) : ( : <Badge variant="orange">Private</Badge>}
<Badge variant="orange">Private</Badge> </>
)}
</a>
); );
}, },
}, },
@ -111,7 +111,11 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
}, },
cell: ({ row }) => { cell: ({ row }) => {
return ( return (
<Badge variant={row.original.status === "failed" ? "red" : (row.original.status === "started" ? "yellow" : "green")}> <Badge
variant={row.original.status === "failed"
? "red"
: (row.original.status === "started" ? "yellow" : "green")}
>
{row.original.status} {row.original.status}
</Badge> </Badge>
); );
@ -167,6 +171,35 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
return <Badge variant="cyan">{row.original.upload_type}</Badge>; return <Badge variant="cyan">{row.original.upload_type}</Badge>;
}, },
}, },
{
accessorKey: "model_type",
header: ({ column }) => {
return (
<button
className="flex items-center hover:underline"
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
>
Model Type
<ArrowUpDown className="ml-2 h-4 w-4" />
</button>
);
},
cell: ({ row }) => {
const model_type_map: Record<modelEnumType, any> = {
"checkpoint": "amber",
"lora": "green",
"embedding": "violet",
"vae": "teal",
};
function getBadgeColor(modelType: modelEnumType) {
return model_type_map[modelType] || "default";
}
const color = getBadgeColor(row.original.model_type);
return <Badge variant={color}>{row.original.model_type}</Badge>;
},
},
{ {
accessorKey: "date", accessorKey: "date",
sortingFn: "datetime", sortingFn: "datetime",
@ -221,13 +254,14 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
// }, // },
]; ];
export function CheckpointList({ data }: { data: CheckpointItemList[] }) { export function ModelList({ data }: { data: ModelItemList[] }) {
const [sorting, setSorting] = React.useState<SortingState>([]); const [sorting, setSorting] = React.useState<SortingState>([]);
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>( const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
[] [],
); );
const [columnVisibility, setColumnVisibility] = const [columnVisibility, setColumnVisibility] = React.useState<
React.useState<VisibilityState>({}); VisibilityState
>({});
const [rowSelection, setRowSelection] = React.useState({}); const [rowSelection, setRowSelection] = React.useState({});
const table = useReactTable({ const table = useReactTable({
@ -254,10 +288,10 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
<div className="flex flex-row w-full items-center py-4"> <div className="flex flex-row w-full items-center py-4">
<Input <Input
placeholder="Filter workflows..." placeholder="Filter workflows..."
value={(table.getColumn("name")?.getFilterValue() as string) ?? ""} value={(table.getColumn("model_name")?.getFilterValue() as string) ??
""}
onChange={(event) => onChange={(event) =>
table.getColumn("name")?.setFilterValue(event.target.value) table.getColumn("model_name")?.setFilterValue(event.target.value)}
}
className="max-w-sm" className="max-w-sm"
/> />
<div className="ml-auto flex gap-2"> <div className="ml-auto flex gap-2">
@ -268,17 +302,17 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
// TODO: limitations based on plan // TODO: limitations based on plan
} }
tooltip={"Add models using their civitai url!"} tooltip={"Add models using their civitai url!"}
title="Civitai Checkpoint" title="Add a Civitai Model"
description="Pick a model from civitai" description="Pick a model from civitai"
serverAction={addCivitaiCheckpoint} serverAction={addCivitaiModel}
formSchema={addCivitaiCheckpointSchema} formSchema={addCivitaiModelSchema}
fieldConfig={{ fieldConfig={{
civitai_url: { civitai_url: {
fieldType: "fallback", fieldType: "fallback",
inputProps: { required: true }, inputProps: { required: true },
description: ( description: (
<> <>
Pick a checkpoint from{" "} Pick a model from{" "}
<a <a
href="https://www.civitai.com/models" href="https://www.civitai.com/models"
target="_blank" target="_blank"
@ -302,12 +336,10 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
{headerGroup.headers.map((header) => { {headerGroup.headers.map((header) => {
return ( return (
<TableHead key={header.id}> <TableHead key={header.id}>
{header.isPlaceholder {header.isPlaceholder ? null : flexRender(
? null header.column.columnDef.header,
: flexRender( header.getContext(),
header.column.columnDef.header, )}
header.getContext()
)}
</TableHead> </TableHead>
); );
})} })}
@ -315,32 +347,34 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
))} ))}
</TableHeader> </TableHeader>
<TableBody> <TableBody>
{table.getRowModel().rows?.length ? ( {table.getRowModel().rows?.length
table.getRowModel().rows.map((row) => ( ? (
<TableRow table.getRowModel().rows.map((row) => (
key={row.id} <TableRow
data-state={row.getIsSelected() && "selected"} key={row.id}
> data-state={row.getIsSelected() && "selected"}
{row.getVisibleCells().map((cell) => ( >
<TableCell key={cell.id}> {row.getVisibleCells().map((cell) => (
{flexRender( <TableCell key={cell.id}>
cell.column.columnDef.cell, {flexRender(
cell.getContext() cell.column.columnDef.cell,
)} cell.getContext(),
</TableCell> )}
))} </TableCell>
))}
</TableRow>
))
)
: (
<TableRow>
<TableCell
colSpan={columns.length}
className="h-24 text-center"
>
No results.
</TableCell>
</TableRow> </TableRow>
)) )}
) : (
<TableRow>
<TableCell
colSpan={columns.length}
className="h-24 text-center"
>
No results.
</TableCell>
</TableRow>
)}
</TableBody> </TableBody>
</Table> </Table>
</ScrollArea> </ScrollArea>

View File

@ -7,7 +7,7 @@ import AutoFormInput from "../ui/auto-form/fields/input";
import { useDebouncedCallback } from "use-debounce"; import { useDebouncedCallback } from "use-debounce";
import { CivitaiModelResponse } from "@/types/civitai"; import { CivitaiModelResponse } from "@/types/civitai";
import { z } from "zod"; import { z } from "zod";
import { insertCivitaiCheckpointSchema } from "@/db/schema"; import { insertCivitaiModelSchema } from "@/db/schema";
function getUrl(civitai_url: string) { function getUrl(civitai_url: string) {
// expect to be a URL to be https://civitai.com/models/36520 // expect to be a URL to be https://civitai.com/models/36520
@ -33,7 +33,7 @@ export default function AutoFormCheckpointInput(
const handleSearch = useDebouncedCallback((search) => { const handleSearch = useDebouncedCallback((search) => {
const validationResult = const validationResult =
insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search); insertCivitaiModelSchema.shape.civitai_url.safeParse(search);
if (!validationResult.success) { if (!validationResult.success) {
console.error(validationResult.error); console.error(validationResult.error);
// Optionally set an error state here // Optionally set an error state here

View File

@ -12,7 +12,7 @@ import {
real, real,
} from "drizzle-orm/pg-core"; } from "drizzle-orm/pg-core";
import { createInsertSchema, createSelectSchema } from "drizzle-zod"; import { createInsertSchema, createSelectSchema } from "drizzle-zod";
import { z } from "zod"; import { TypeOf, z } from "zod";
export const dbSchema = pgSchema("comfyui_deploy"); export const dbSchema = pgSchema("comfyui_deploy");
@ -376,15 +376,25 @@ export const modelUploadType = pgEnum("model_upload_type", [
"other", "other",
]); ]);
export const checkpointTable = dbSchema.table("checkpoints", { // https://www.answeroverflow.com/m/1125106227387584552
const modelTypes = [
"checkpoint",
"lora",
"embedding",
"vae",
] as const
export const modelType = pgEnum("model_type", modelTypes);
export type modelEnumType = typeof modelTypes[number]
export const modelTable = dbSchema.table("models", {
id: uuid("id").primaryKey().defaultRandom().notNull(), id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global models
org_id: text("org_id"), org_id: text("org_id"),
description: text("description"), description: text("description"),
checkpoint_volume_id: uuid("checkpoint_volume_id") user_volume_id: uuid("user_volume_id")
.notNull() .notNull()
.references(() => checkpointVolumeTable.id, { .references(() => userVolume.id, {
onDelete: "cascade", onDelete: "cascade",
}) })
.notNull(), .notNull(),
@ -408,12 +418,13 @@ export const checkpointTable = dbSchema.table("checkpoints", {
status: resourceUpload("status").notNull().default("started"), status: resourceUpload("status").notNull().default("started"),
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
upload_type: modelUploadType("upload_type").notNull(), upload_type: modelUploadType("upload_type").notNull(),
model_type: modelType("model_type").notNull(),
error_log: text("error_log"), error_log: text("error_log"),
created_at: timestamp("created_at").defaultNow().notNull(), created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", { export const userVolume = dbSchema.table("user_volume", {
id: uuid("id").primaryKey().defaultRandom().notNull(), id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, { user_id: text("user_id").references(() => usersTable.id, {
// onDelete: "cascade", // onDelete: "cascade",
@ -425,23 +436,23 @@ export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
disabled: boolean("disabled").default(false).notNull(), disabled: boolean("disabled").default(false).notNull(),
}); });
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({ export const modelRelations = relations(modelTable, ({ one }) => ({
user: one(usersTable, { user: one(usersTable, {
fields: [checkpointTable.user_id], fields: [modelTable.user_id],
references: [usersTable.id], references: [usersTable.id],
}), }),
volume: one(checkpointVolumeTable, { volume: one(userVolume, {
fields: [checkpointTable.checkpoint_volume_id], fields: [modelTable.user_volume_id],
references: [checkpointVolumeTable.id], references: [userVolume.id],
}), }),
})); }));
export const checkpointVolumeRelations = relations( export const modalVolumeRelations = relations(
checkpointVolumeTable, userVolume,
({ many, one }) => ({ ({ many, one }) => ({
checkpoint: many(checkpointTable), model: many(modelTable),
user: one(usersTable, { user: one(usersTable, {
fields: [checkpointVolumeTable.user_id], fields: [userVolume.user_id],
references: [usersTable.id], references: [usersTable.id],
}), }),
}) })
@ -473,8 +484,8 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
export const insertCivitaiCheckpointSchema = createInsertSchema( export const insertCivitaiModelSchema = createInsertSchema(
checkpointTable, modelTable,
{ {
civitai_url: (schema) => civitai_url: (schema) =>
schema.civitai_url schema.civitai_url
@ -491,8 +502,8 @@ export type WorkflowType = InferSelectModel<typeof workflowTable>;
export type MachineType = InferSelectModel<typeof machinesTable>; export type MachineType = InferSelectModel<typeof machinesTable>;
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>; export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
export type DeploymentType = InferSelectModel<typeof deploymentsTable>; export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
export type CheckpointType = InferSelectModel<typeof checkpointTable>; export type ModelType = InferSelectModel<typeof modelTable>;
export type CheckpointVolumeType = InferSelectModel< export type UserVolumeType = InferSelectModel<
typeof checkpointVolumeTable typeof userVolume
>; >;
export type UserUsageType = InferSelectModel<typeof userUsageTable>; export type UserUsageType = InferSelectModel<typeof userUsageTable>;

View File

@ -1,5 +0,0 @@
import { insertCivitaiCheckpointSchema } from "@/db/schema";
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
civitai_url: true,
});

View File

@ -0,0 +1,5 @@
import { insertCivitaiModelSchema } from "@/db/schema";
export const addCivitaiModelSchema = insertCivitaiModelSchema.pick({
civitai_url: true,
});

View File

@ -15,7 +15,7 @@ import { headers } from "next/headers";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
import "server-only"; import "server-only";
import type { z } from "zod"; import type { z } from "zod";
import { retrieveCheckpointVolumes } from "./curdCheckpoint"; import { retrieveModelVolumes } from "./curdModel";
export async function getMachines() { export async function getMachines() {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
@ -190,7 +190,7 @@ async function _buildMachine(
throw new Error("No domain"); throw new Error("No domain");
} }
const volumes = await retrieveCheckpointVolumes(); const volumes = await retrieveModelVolumes();
// Call remote builder // Call remote builder
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, { const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
method: "POST", method: "POST",
@ -204,7 +204,7 @@ async function _buildMachine(
callback_url: `${protocol}://${domain}/api/machine-built`, callback_url: `${protocol}://${domain}/api/machine-built`,
models: data.models, //JSON.parse(data.models as string), models: data.models, //JSON.parse(data.models as string),
gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4", gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
checkpoint_volume_name: volumes[0].volume_name, model_volume_name: volumes[0].volume_name,
}), }),
}); });

View File

@ -2,92 +2,91 @@
import { auth } from "@clerk/nextjs"; import { auth } from "@clerk/nextjs";
import { import {
checkpointTable, modelTable,
CheckpointType, ModelType,
checkpointVolumeTable, userVolume,
CheckpointVolumeType, UserVolumeType,
} from "@/db/schema"; } from "@/db/schema";
import { withServerPromise } from "./withServerPromise"; import { withServerPromise } from "./withServerPromise";
import { db } from "@/db/db"; import { db } from "@/db/db";
import type { z } from "zod"; import type { z } from "zod";
import { headers } from "next/headers"; import { headers } from "next/headers";
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema"; import { addCivitaiModelSchema } from "./addCivitaiModelSchema";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import { CivitaiModelResponse } from "@/types/civitai"; import { CivitaiModelResponse, getModelTypeDetails } from "@/types/civitai";
export async function getCheckpoints() { export async function getModel() {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
if (!userId) throw new Error("No user id"); if (!userId) throw new Error("No user id");
const checkpoints = await db const models = await db
.select() .select()
.from(checkpointTable) .from(modelTable)
.where( .where(
orgId orgId
? eq(checkpointTable.org_id, orgId) ? eq(modelTable.org_id, orgId)
// make sure org_id is null // make sure org_id is null
: and( : and(
eq(checkpointTable.user_id, userId), eq(modelTable.user_id, userId),
isNull(checkpointTable.org_id), isNull(modelTable.org_id),
), ),
); );
return checkpoints; return models;
} }
export async function getCheckpointById(id: string) { export async function getModelById(id: string) {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
if (!userId) throw new Error("No user id"); if (!userId) throw new Error("No user id");
const checkpoint = await db const model = await db
.select() .select()
.from(checkpointTable) .from(modelTable)
.where( .where(
and( and(
orgId ? eq(checkpointTable.org_id, orgId) : and( orgId ? eq(modelTable.org_id, orgId) : and(
eq(checkpointTable.user_id, userId), eq(modelTable.user_id, userId),
isNull(checkpointTable.org_id), isNull(modelTable.org_id),
), ),
eq(checkpointTable.id, id), eq(modelTable.id, id),
), ),
); );
return checkpoint[0]; return model[0];
} }
export async function getCheckpointVolumes() { export async function getModelVolumes() {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
if (!userId) throw new Error("No user id"); if (!userId) throw new Error("No user id");
const volume = await db const volume = await db
.select() .select()
.from(checkpointVolumeTable) .from(userVolume)
.where( .where(
and( and(
orgId orgId
? eq(checkpointVolumeTable.org_id, orgId) ? eq(userVolume.org_id, orgId)
// make sure org_id is null // make sure org_id is null
: and( : and(
eq(checkpointVolumeTable.user_id, userId), eq(userVolume.user_id, userId),
isNull(checkpointVolumeTable.org_id), isNull(userVolume.org_id),
), ),
eq(checkpointVolumeTable.disabled, false), eq(userVolume.disabled, false),
), ),
); );
return volume; return volume;
} }
export async function retrieveCheckpointVolumes() { export async function retrieveModelVolumes() {
let volumes = await getCheckpointVolumes(); let volumes = await getModelVolumes();
if (volumes.length === 0) { if (volumes.length === 0) {
// create volume if not already created // create volume if not already created
volumes = await addCheckpointVolume(); volumes = await addModelVolume();
} }
return volumes; return volumes;
} }
export async function addCheckpointVolume() { export async function addModelVolume() {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
if (!userId) throw new Error("No user id"); if (!userId) throw new Error("No user id");
// Insert the new checkpointVolume into the checkpointVolumeTable
const insertedVolume = await db const insertedVolume = await db
.insert(checkpointVolumeTable) .insert(userVolume)
.values({ .values({
user_id: userId, user_id: userId,
org_id: orgId, org_id: orgId,
@ -111,8 +110,8 @@ function getUrl(civitai_url: string) {
return { url: baseUrl + modelId, modelVersionId }; return { url: baseUrl + modelId, modelVersionId };
} }
export const addCivitaiCheckpoint = withServerPromise( export const addCivitaiModel = withServerPromise(
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => { async (data: z.infer<typeof addCivitaiModelSchema>) => {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
if (!data.civitai_url) return { error: "no civitai_url" }; if (!data.civitai_url) return { error: "no civitai_url" };
@ -145,17 +144,22 @@ export const addCivitaiCheckpoint = withServerPromise(
selectedModelVersionId = selectedModelVersion?.id.toString(); selectedModelVersionId = selectedModelVersion?.id.toString();
} }
const checkpointVolumes = await getCheckpointVolumes(); const userVolume = await getModelVolumes();
let cVolume; let cVolume;
if (checkpointVolumes.length === 0) { if (userVolume.length === 0) {
const volume = await addCheckpointVolume(); const volume = await addModelVolume();
cVolume = volume[0]; cVolume = volume[0];
} else { } else {
cVolume = checkpointVolumes[0]; cVolume = userVolume[0];
}
const model_type = getModelTypeDetails(civitaiModelRes.type);
if (!model_type) {
return
} }
const a = await db const a = await db
.insert(checkpointTable) .insert(modelTable)
.values({ .values({
user_id: userId, user_id: userId,
org_id: orgId, org_id: orgId,
@ -166,15 +170,15 @@ export const addCivitaiCheckpoint = withServerPromise(
civitai_url: data.civitai_url, civitai_url: data.civitai_url,
civitai_download_url: selectedModelVersion.files[0].downloadUrl, civitai_download_url: selectedModelVersion.files[0].downloadUrl,
civitai_model_response: civitaiModelRes, civitai_model_response: civitaiModelRes,
checkpoint_volume_id: cVolume.id, user_volume_id: cVolume.id,
model_type,
updated_at: new Date(), updated_at: new Date(),
}) })
.returning(); .returning();
const b = a[0]; const b = a[0];
await uploadCheckpoint(data, b, cVolume); await uploadModel(data, b, cVolume);
// redirect(`/checkpoints/${b.id}`);
}, },
); );
@ -213,10 +217,10 @@ export const addCivitaiCheckpoint = withServerPromise(
// }, // },
// ); // );
async function uploadCheckpoint( async function uploadModel(
data: z.infer<typeof addCivitaiCheckpointSchema>, data: z.infer<typeof addCivitaiModelSchema>,
c: CheckpointType, c: ModelType,
v: CheckpointVolumeType, v: UserVolumeType,
) { ) {
const headersList = headers(); const headersList = headers();
@ -239,9 +243,9 @@ async function uploadCheckpoint(
download_url: c.civitai_download_url, download_url: c.civitai_download_url,
volume_name: v.volume_name, volume_name: v.volume_name,
volume_id: v.id, volume_id: v.id,
checkpoint_id: c.id, model_id: c.id,
callback_url: `${protocol}://${domain}/api/volume-upload`, callback_url: `${protocol}://${domain}/api/volume-upload`,
upload_type: "checkpoint" upload_type: c.model_type,
}), }),
}, },
); );
@ -249,23 +253,23 @@ async function uploadCheckpoint(
if (!result.ok) { if (!result.ok) {
const error_log = await result.text(); const error_log = await result.text();
await db await db
.update(checkpointTable) .update(modelTable)
.set({ .set({
...data, ...data,
status: "failed", status: "failed",
error_log: error_log, error_log: error_log,
}) })
.where(eq(checkpointTable.id, c.id)); .where(eq(modelTable.id, c.id));
throw new Error(`Error: ${result.statusText} ${error_log}`); throw new Error(`Error: ${result.statusText} ${error_log}`);
} else { } else {
// setting the build machine id // setting the build machine id
const json = await result.json(); const json = await result.json();
await db await db
.update(checkpointTable) .update(modelTable)
.set({ .set({
...data, ...data,
upload_machine_id: json.build_machine_instance_id, upload_machine_id: json.build_machine_instance_id,
}) })
.where(eq(checkpointTable.id, c.id)); .where(eq(modelTable.id, c.id));
} }
} }

View File

@ -1,18 +1,18 @@
import { db } from "@/db/db"; import { db } from "@/db/db";
import { import {
checkpointTable, modelTable,
} from "@/db/schema"; } from "@/db/schema";
import { auth } from "@clerk/nextjs"; import { auth } from "@clerk/nextjs";
import { and, desc, eq, isNull } from "drizzle-orm"; import { and, desc, eq, isNull } from "drizzle-orm";
export async function getAllUserCheckpoints() { export async function getAllUserModels() {
const { userId, orgId } = await auth(); const { userId, orgId } = await auth();
if (!userId) { if (!userId) {
return null; return null;
} }
const checkpoints = await db.query.checkpointTable.findMany({ const models = await db.query.modelTable.findMany({
with: { with: {
user: { user: {
columns: { columns: {
@ -28,14 +28,15 @@ export async function getAllUserCheckpoints() {
civitai_model_response: true, civitai_model_response: true,
is_public: true, is_public: true,
upload_type: true, upload_type: true,
model_type: true,
status: true, status: true,
}, },
orderBy: desc(checkpointTable.updated_at), orderBy: desc(modelTable.updated_at),
where: where:
orgId != undefined orgId != undefined
? eq(checkpointTable.org_id, orgId) ? eq(modelTable.org_id, orgId)
: and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)), : and(eq(modelTable.user_id, userId), isNull(modelTable.org_id)),
}); });
return checkpoints; return models;
} }

View File

@ -1,4 +1,5 @@
import { z } from "zod"; import { TypeOf, z } from "zod";
import { modelEnumType } from "@/db/schema";
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e // from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
@ -110,12 +111,21 @@ export const statsSchema = z.object({
tippedAmountCount: z.number(), tippedAmountCount: z.number(),
}); });
const civitaiModelType = z.enum([
"Checkpoint",
"TextualInversion",
"Hypernetwork",
"AestheticGradient",
"LORA",
"Controlnet",
"Poses",
]);
export const CivitaiModelResponse = z.object({ export const CivitaiModelResponse = z.object({
id: z.number(), id: z.number(),
name: z.string().nullish(), name: z.string().nullish(),
description: z.string().nullish(), description: z.string().nullish(),
// type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know type: civitaiModelType,
type: z.string(),
poi: z.boolean().nullish(), poi: z.boolean().nullish(),
nsfw: z.boolean().nullish(), nsfw: z.boolean().nullish(),
allowNoCredit: z.boolean().nullish(), allowNoCredit: z.boolean().nullish(),
@ -127,3 +137,22 @@ export const CivitaiModelResponse = z.object({
tags: z.array(z.string()).nullish(), tags: z.array(z.string()).nullish(),
modelVersions: z.array(modelVersionSchema), modelVersions: z.array(modelVersionSchema),
}); });
export function getModelTypeDetails(
modelType: typeof civitaiModelType["_type"],
): modelEnumType | undefined {
switch (modelType) {
case "Checkpoint":
return "checkpoint"
case "TextualInversion":
return "embedding"
case "LORA":
return "lora"
case "AestheticGradient":
case "Hypernetwork":
case "Controlnet":
case "Poses":
default:
return undefined;
}
}