Merge branch 'nickkao/checkpoint-volume'

# Conflicts:
#	web/bun.lockb
This commit is contained in:
bennykok 2024-01-25 21:07:27 +08:00
commit 6f0499c657
23 changed files with 2708 additions and 30 deletions

View File

@ -8,6 +8,7 @@ from enum import Enum
import json import json
import subprocess import subprocess
import time import time
from uuid import uuid4
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
import threading import threading
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send from starlette.types import ASGIApp, Scope, Receive, Send
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
# executor = ThreadPoolExecutor(max_workers=5) # executor = ThreadPoolExecutor(max_workers=5)
@ -45,6 +47,7 @@ machine_id_websocket_dict = {}
machine_id_status = {} machine_id_status = {}
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
class FlyReplayMiddleware(BaseHTTPMiddleware): class FlyReplayMiddleware(BaseHTTPMiddleware):
@ -174,6 +177,7 @@ class Item(BaseModel):
snapshot: Snapshot snapshot: Snapshot
models: List[Model] models: List[Model]
callback_url: str callback_url: str
checkpoint_volume_name: str
gpu: GPUType = Field(default=GPUType.T4) gpu: GPUType = Field(default=GPUType.T4)
@field_validator('gpu') @field_validator('gpu')
@ -223,6 +227,103 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
# return {"Hello": "World"} # return {"Hello": "World"}
class UploadType(str, Enum):
checkpoint = "checkpoint"
class UploadBody(BaseModel):
download_url: str
volume_name: str
volume_id: str
checkpoint_id: str
upload_type: UploadType
callback_url: str
UPLOAD_TYPE_DIR_MAP = {
UploadType.checkpoint: "checkpoints"
}
@app.post("/upload-volume")
async def upload_checkpoint(body: UploadBody):
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
asyncio.create_task(upload_logic(body))
# check that this
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
async def upload_logic(body: UploadBody):
folder_path = f"/app/builds/{body.volume_id}"
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
await cp_process.wait()
upload_path = UPLOAD_TYPE_DIR_MAP[body.upload_type]
config = {
"volume_names": {
body.volume_name: {"download_url": body.download_url, "folder_path": upload_path}
},
"volume_paths": {
body.volume_name: f'/volumes/{uuid4()}'
},
"callback_url": body.callback_url,
"callback_body": {
"checkpoint_id": body.checkpoint_id,
"volume_id": body.volume_id,
"folder_path": upload_path,
},
"civitai_api_key": os.environ.get('CIVITAI_API_KEY')
}
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))
process = await asyncio.subprocess.create_subprocess_shell(
f"modal run app.py",
# stdout=asyncio.subprocess.PIPE,
# stderr=asyncio.subprocess.PIPE,
cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"}
)
# error_logs = []
# async def read_stream(stream):
# while True:
# line = await stream.readline()
# if line:
# l = line.decode('utf-8').strip()
# error_logs.append(l)
# logger.error(l)
# error_logs.append({
# "logs": l,
# "timestamp": time.time()
# })
# else:
# break
# stderr_read_task = asyncio.create_task(read_stream(process.stderr))
#
# await asyncio.wait([stderr_read_task])
# await process.wait()
# if process.returncode != 0:
# error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
# # Error handling: send POST request to callback URL with error details
# requests.post(body.callback_url, json={
# "volume_id": body.volume_id,
# "checkpoint_id": body.checkpoint_id,
# "folder_path": upload_path,
# "error_logs": json.dumps(error_logs),
# "status": "failed"
# })
#
# requests.post(body.callback_url, json={
# "checkpoint_id": body.checkpoint_id,
# "volume_id": body.volume_id,
# "folder_path": upload_path,
# "status": "success"
# })
@app.post("/create") @app.post("/create")
async def create_machine(item: Item): async def create_machine(item: Item):
@ -312,7 +413,9 @@ async def build_logic(item: Item):
config = { config = {
"name": item.name, "name": item.name,
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"), "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
"gpu": item.gpu "gpu": item.gpu,
"public_checkpoint_volume": "model-store",
"private_checkpoint_volume": item.checkpoint_volume_name
} }
with open(f"{folder_path}/config.py", "w") as f: with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config)) f.write("config = " + json.dumps(config))

View File

@ -1,12 +1,13 @@
from config import config from config import config
import modal import modal
from modal import Image, Mount, web_endpoint, Stub, asgi_app from modal import Image, Mount, web_endpoint, Stub, asgi_app
import json import json
import urllib.request import urllib.request
import urllib.parse import urllib.parse
from pydantic import BaseModel from pydantic import BaseModel
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from volume_setup import volumes
# deploy_test = False # deploy_test = False
@ -27,8 +28,8 @@ deploy_test = config["deploy_test"] == "True"
web_app = FastAPI() web_app = FastAPI()
print(config) print(config)
print("deploy_test ", deploy_test) print("deploy_test ", deploy_test)
print('volumes', volumes)
stub = Stub(name=config["name"]) stub = Stub(name=config["name"])
# print(stub.app_id)
if not deploy_test: if not deploy_test:
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data")) # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
@ -52,11 +53,13 @@ if not deploy_test:
"cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt", "cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
"cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts", "cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
) )
.run_commands(f"cat /comfyui/server.py")
.run_commands(f"ls /comfyui/app")
# .run_commands( # .run_commands(
# # Install comfy deploy # # Install comfy deploy
# "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git", # "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
# ) # )
# .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui") .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
.copy_local_file(f"{current_directory}/data/start.sh", "/start.sh") .copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
.run_commands("chmod +x /start.sh") .run_commands("chmod +x /start.sh")
@ -153,8 +156,9 @@ image = Image.debian_slim()
target_image = image if deploy_test else dockerfile_image target_image = image if deploy_test else dockerfile_image
@stub.function(image=target_image, gpu=config["gpu"]
@stub.function(image=target_image, gpu=config["gpu"]) ,volumes=volumes
)
def run(input: Input): def run(input: Input):
import subprocess import subprocess
import time import time
@ -163,6 +167,7 @@ def run(input: Input):
command = ["python", "main.py", command = ["python", "main.py",
"--disable-auto-launch", "--disable-metadata"] "--disable-auto-launch", "--disable-metadata"]
server_process = subprocess.Popen(command, cwd="/comfyui") server_process = subprocess.Popen(command, cwd="/comfyui")
check_server( check_server(
@ -235,7 +240,9 @@ async def bar(request_input: RequestInput):
# pass # pass
@stub.function(image=image) @stub.function(image=image
,volumes=volumes
)
@asgi_app() @asgi_app()
def comfyui_api(): def comfyui_api():
return web_app return web_app
@ -285,6 +292,7 @@ def spawn_comfyui_in_background():
# to be on a single container. # to be on a single container.
concurrency_limit=1, concurrency_limit=1,
timeout=10 * 60, timeout=10 * 60,
volumes=volumes,
) )
@asgi_app() @asgi_app()
def comfyui_app(): def comfyui_app():
@ -303,4 +311,4 @@ def comfyui_app():
}, },
)() )()
return make_simple_proxy_app(ProxyContext(config)) return make_simple_proxy_app(ProxyContext(config))

View File

@ -1 +1,7 @@
config = {"name": "my-app", "deploy_test": "True", "gpu": "T4"} config = {
"name": "my-app",
"deploy_test": "True",
"gpu": "T4",
"public_checkpoint_volume": "model-store",
"private_checkpoint_volume": "private-model-store"
}

View File

@ -1,11 +1,15 @@
comfyui: public:
base_path: /runpod-volume/ComfyUI/ base_path: /public_models/
checkpoints: models/checkpoints/ checkpoints: checkpoints
clip: models/clip/ clip: clip
clip_vision: models/clip_vision/ clip_vision: clip_vision
configs: models/configs/ configs: configs
controlnet: models/controlnet/ controlnet: controlnet
embeddings: models/embeddings/ embeddings: embeddings
loras: models/loras/ loras: loras
upscale_models: models/upscale_models/ upscale_models: upscale_models
vae: models/vae/ vae: vae
private:
base_path: /private_models/
checkpoints: checkpoints

View File

@ -54,4 +54,4 @@ for model in models:
# Close the server # Close the server
server_process.terminate() server_process.terminate()
print("Finished installing dependencies.") print("Finished installing dependencies.")

View File

@ -0,0 +1,9 @@
import modal
from config import config
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
PUBLIC_BASEMODEL_DIR = "/public_models"
PRIVATE_BASEMODEL_DIR = "/private_models"
volumes = {PUBLIC_BASEMODEL_DIR: public_model_volume, PRIVATE_BASEMODEL_DIR: private_volume}

View File

@ -0,0 +1,74 @@
import modal
from config import config
import os
import subprocess
from pprint import pprint
stub = modal.Stub()
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
def is_valid_name(name: str) -> bool:
allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
def create_volumes(volume_names, paths):
path_to_vol = {}
for volume_name in volume_names.keys():
if not is_valid_name(volume_name):
pass
modal_volume = modal.Volume.persisted(volume_name)
path_to_vol[paths[volume_name]] = modal_volume
return path_to_vol
vol_name_to_links = config["volume_names"]
vol_name_to_path = config["volume_paths"]
callback_url = config["callback_url"]
callback_body = config["callback_body"]
civitai_key = config["civitai_api_key"]
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
image = (
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
)
# download config { "download_url": "", "folder_path": ""}
timeout=5000
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
def download_model(volume_name, download_config):
import requests
download_url = download_config["download_url"]
folder_path = download_config["folder_path"]
volume_base_path = vol_name_to_path[volume_name]
model_store_path = os.path.join(volume_base_path, folder_path)
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
print('downloading', modified_download_url)
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
subprocess.run(["ls", "-la", volume_base_path])
subprocess.run(["ls", "-la", model_store_path])
volumes[volume_base_path].commit()
status = {"status": "success"}
requests.post(callback_url, json={**status, **callback_body})
print(f"finished! sending to {callback_url}")
pprint({**status, **callback_body})
@stub.local_entrypoint()
def simple_download():
import requests
try:
list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
except modal.exception.FunctionTimeoutError as e:
status = {"status": "failed", "error_logs": f"{str(e)}", "timeout": timeout}
requests.post(callback_url, json={**status, **callback_body})
print(f"finished! sending to {callback_url}")
pprint({**status, **callback_body})
except Exception as e:
status = {"status": "failed", "error_logs": str(e)}
requests.post(callback_url, json={**status, **callback_body})
print(f"finished! sending to {callback_url}")
pprint({**status, **callback_body})

View File

@ -0,0 +1,18 @@
config = {
"volume_names": {
"test": {
"download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png",
"folder_path": "images"
}
},
"volume_paths": {
"test": "/volumes/something"
},
"callback_url": "",
"callback_body": {
"checkpoint_id": "",
"volume_id": "",
"folder_path": "images",
},
"civitai_api_key": "",
}

View File

@ -0,0 +1,64 @@
DO $$ BEGIN
CREATE TYPE "model_upload_type" AS ENUM('civitai', 'huggingface', 'other');
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
DO $$ BEGIN
CREATE TYPE "resource_upload" AS ENUM('started', 'success', 'failed');
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
"user_id" text,
"org_id" text,
"description" text,
"checkpoint_volume_id" uuid NOT NULL,
"model_name" text,
"folder_path" text,
"civitai_id" text,
"civitai_version_id" text,
"civitai_url" text,
"civitai_download_url" text,
"civitai_model_response" jsonb,
"hf_url" text,
"s3_url" text,
"client_url" text,
"is_public" boolean DEFAULT false NOT NULL,
"status" "resource_upload" DEFAULT 'started' NOT NULL,
"upload_machine_id" text,
"upload_type" "model_upload_type" NOT NULL,
"error_log" text,
"created_at" timestamp DEFAULT now() NOT NULL,
"updated_at" timestamp DEFAULT now() NOT NULL
);
--> statement-breakpoint
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
"user_id" text,
"org_id" text,
"volume_name" text NOT NULL,
"created_at" timestamp DEFAULT now() NOT NULL,
"updated_at" timestamp DEFAULT now() NOT NULL,
"disabled" boolean DEFAULT false NOT NULL
);
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk" FOREIGN KEY ("checkpoint_volume_id") REFERENCES "comfyui_deploy"."checkpoint_volume"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;

File diff suppressed because it is too large Load Diff

View File

@ -295,6 +295,13 @@
"when": 1706111421524, "when": 1706111421524,
"tag": "0041_thick_norrin_radd", "tag": "0041_thick_norrin_radd",
"breakpoints": true "breakpoints": true
},
{
"idx": 42,
"version": "5",
"when": 1706164614659,
"tag": "0042_windy_madelyne_pryor",
"breakpoints": true
} }
] ]
} }

View File

@ -0,0 +1,55 @@
import { parseDataSafe } from "../../../../lib/parseDataSafe";
import { db } from "@/db/db";
import { checkpointTable, machinesTable } from "@/db/schema";
import { eq } from "drizzle-orm";
import { NextResponse } from "next/server";
import { z } from "zod";
const Request = z.object({
volume_id: z.string(),
checkpoint_id: z.string(),
folder_path: z.string().optional(),
status: z.enum(['success', 'failed']),
error_log: z.string().optional(),
timeout: z.number().optional(),
});
export async function POST(request: Request) {
const [data, error] = await parseDataSafe(Request, request);
if (!data || error) return error;
const { checkpoint_id, error_log, status, folder_path } = data;
console.log( checkpoint_id, error_log, status, folder_path )
if (status === "success") {
await db
.update(checkpointTable)
.set({
status: "success",
folder_path,
updated_at: new Date(),
// build_log: build_log,
})
.where(eq(checkpointTable.id, checkpoint_id));
} else {
await db
.update(checkpointTable)
.set({
status: "failed",
error_log,
updated_at: new Date(),
// status: "error",
// build_log: build_log,
})
.where(eq(checkpointTable.id, checkpoint_id));
}
return NextResponse.json(
{
message: "success",
},
{
status: 200,
}
);
}

View File

@ -0,0 +1,9 @@
"use client";
import { LoadingPageWrapper } from "@/components/LoadingWrapper";
import { usePathname } from "next/navigation";
export default function Loading() {
const pathName = usePathname();
return <LoadingPageWrapper className="h-full" tag={pathName.toLowerCase()} />;
}

View File

@ -0,0 +1,35 @@
import { setInitialUserData } from "../../../lib/setInitialUserData";
import { auth } from "@clerk/nextjs";
import { clerkClient } from "@clerk/nextjs/server";
import { CheckpointList } from "@/components/CheckpointList";
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
export default function Page() {
return <CheckpointListServer />;
}
async function CheckpointListServer() {
const { userId } = auth();
if (!userId) {
return <div>No auth</div>;
}
const user = await clerkClient.users.getUser(userId);
if (!user) {
await setInitialUserData(userId);
}
const checkpoints = await getAllUserCheckpoints();
if (!checkpoints) {
return <div>No checkpoints found</div>;
}
return (
<div className="w-full">
<CheckpointList data={checkpoints} />
</div>
);
}

View File

@ -0,0 +1,373 @@
"use client";
import { getRelativeTime } from "../lib/getRelativeTime";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { InsertModal, UpdateModal } from "./InsertModal";
import { Input } from "@/components/ui/input";
import { ScrollArea } from "@/components/ui/scroll-area";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
import type {
ColumnDef,
ColumnFiltersState,
SortingState,
VisibilityState,
} from "@tanstack/react-table";
import {
flexRender,
getCoreRowModel,
getFilteredRowModel,
getPaginationRowModel,
getSortedRowModel,
useReactTable,
} from "@tanstack/react-table";
import { ArrowUpDown } from "lucide-react";
import * as React from "react";
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
export type CheckpointItemList = NonNullable<
Awaited<ReturnType<typeof getAllUserCheckpoints>>
>[0];
export const columns: ColumnDef<CheckpointItemList>[] = [
{
accessorKey: "id",
id: "select",
header: ({ table }) => (
<Checkbox
checked={
table.getIsAllPageRowsSelected() ||
(table.getIsSomePageRowsSelected() && "indeterminate")
}
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
aria-label="Select all"
/>
),
cell: ({ row }) => (
<Checkbox
checked={row.getIsSelected()}
onCheckedChange={(value) => row.toggleSelected(!!value)}
aria-label="Select row"
/>
),
enableSorting: false,
enableHiding: false,
},
{
accessorKey: "model_name",
header: ({ column }) => {
return (
<button
className="flex items-center hover:underline"
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
>
Model Name
<ArrowUpDown className="ml-2 h-4 w-4" />
</button>
);
},
cell: ({ row }) => {
const checkpoint = row.original;
return (
<a
className="hover:underline flex gap-2"
href={`/storage/${checkpoint.id}`} // TODO
>
<span className="truncate max-w-[200px]">
{row.original.model_name}
</span>
{checkpoint.is_public ? (
<Badge variant="green">Public</Badge>
) : (
<Badge variant="orange">Private</Badge>
)}
</a>
);
},
},
{
accessorKey: "status",
header: ({ column }) => {
return (
<button
className="flex items-center hover:underline"
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
>
Status
<ArrowUpDown className="ml-2 h-4 w-4" />
</button>
);
},
cell: ({ row }) => {
return (
<Badge variant={row.original.status === "failed" ? "red" : (row.original.status === "started" ? "yellow" : "green")}>
{row.original.status}
</Badge>
);
// NOTE: retry downloads on failures
// const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
// const lastUpdated = new Date(row.original.updated_at);
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
// cell: ({ row }) => {
// // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
// // const lastUpdated = new Date(row.original.updated_at);
// // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
// const canReDownload = true;
//
// return (
// <div className="flex items-center space-x-2">
// <Badge
// variant={row.original.status === "failed"
// ? "red"
// : row.original.status === "started"
// ? "yellow"
// : "green"}
// >
// {row.original.status}
// </Badge>
// {canReDownload && (
// <RefreshCcw
// onClick={() => {
// redownloadCheckpoint(row.original);
// }}
// className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
// />
// )}
// </div>
// );
// },
},
},
{
accessorKey: "upload_type",
header: ({ column }) => {
return (
<button
className="flex items-center hover:underline"
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
>
Source
<ArrowUpDown className="ml-2 h-4 w-4" />
</button>
);
},
cell: ({ row }) => {
return <Badge variant="cyan">{row.original.upload_type}</Badge>;
},
},
{
accessorKey: "date",
sortingFn: "datetime",
enableSorting: true,
header: ({ column }) => {
return (
<button
className="w-full flex items-center justify-end hover:underline truncate"
// variant="ghost"
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
>
Update Date
<ArrowUpDown className="ml-2 h-4 w-4" />
</button>
);
},
cell: ({ row }) => (
<div className="w-full capitalize text-right truncate">
{getRelativeTime(row.original.updated_at)}
</div>
),
},
// TODO: deletion and editing for future sprint
// {
// id: "actions",
// enableHiding: false,
// cell: ({ row }) => {
// const checkpoint = row.original;
//
// return (
// <DropdownMenu>
// <DropdownMenuTrigger asChild>
// <Button variant="ghost" className="h-8 w-8 p-0">
// <span className="sr-only">Open menu</span>
// <MoreHorizontal className="h-4 w-4" />
// </Button>
// </DropdownMenuTrigger>
// <DropdownMenuContent align="end">
// <DropdownMenuLabel>Actions</DropdownMenuLabel>
// <DropdownMenuItem
// className="text-destructive"
// onClick={() => {
// deleteWorkflow(checkpoint.id);
// }}
// >
// Delete Workflow
// </DropdownMenuItem>
// </DropdownMenuContent>
// </DropdownMenu>
// );
// },
// },
];
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
const [sorting, setSorting] = React.useState<SortingState>([]);
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
[]
);
const [columnVisibility, setColumnVisibility] =
React.useState<VisibilityState>({});
const [rowSelection, setRowSelection] = React.useState({});
const table = useReactTable({
data,
columns,
onSortingChange: setSorting,
onColumnFiltersChange: setColumnFilters,
getCoreRowModel: getCoreRowModel(),
getPaginationRowModel: getPaginationRowModel(),
getSortedRowModel: getSortedRowModel(),
getFilteredRowModel: getFilteredRowModel(),
onColumnVisibilityChange: setColumnVisibility,
onRowSelectionChange: setRowSelection,
state: {
sorting,
columnFilters,
columnVisibility,
rowSelection,
},
});
return (
<div className="grid grid-rows-[auto,1fr,auto] h-full">
<div className="flex flex-row w-full items-center py-4">
<Input
placeholder="Filter workflows..."
value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
onChange={(event) =>
table.getColumn("name")?.setFilterValue(event.target.value)
}
className="max-w-sm"
/>
<div className="ml-auto flex gap-2">
<InsertModal
dialogClassName="sm:max-w-[600px]"
disabled={
false
// TODO: limitations based on plan
}
tooltip={"Add models using their civitai url!"}
title="Civitai Checkpoint"
description="Pick a model from civitai"
serverAction={addCivitaiCheckpoint}
formSchema={addCivitaiCheckpointSchema}
fieldConfig={{
civitai_url: {
fieldType: "fallback",
inputProps: { required: true },
description: (
<>
Pick a checkpoint from{" "}
<a
href="https://www.civitai.com/models"
target="_blank"
className="underline text-blue-600 hover:text-blue-800 visited:text-purple-600"
>
civitai.com
</a>{" "}
and place it's url here
</>
),
},
}}
/>
</div>
</div>
<ScrollArea className="h-full w-full rounded-md border">
<Table>
<TableHeader className="bg-background top-0 sticky">
{table.getHeaderGroups().map((headerGroup) => (
<TableRow key={headerGroup.id}>
{headerGroup.headers.map((header) => {
return (
<TableHead key={header.id}>
{header.isPlaceholder
? null
: flexRender(
header.column.columnDef.header,
header.getContext()
)}
</TableHead>
);
})}
</TableRow>
))}
</TableHeader>
<TableBody>
{table.getRowModel().rows?.length ? (
table.getRowModel().rows.map((row) => (
<TableRow
key={row.id}
data-state={row.getIsSelected() && "selected"}
>
{row.getVisibleCells().map((cell) => (
<TableCell key={cell.id}>
{flexRender(
cell.column.columnDef.cell,
cell.getContext()
)}
</TableCell>
))}
</TableRow>
))
) : (
<TableRow>
<TableCell
colSpan={columns.length}
className="h-24 text-center"
>
No results.
</TableCell>
</TableRow>
)}
</TableBody>
</Table>
</ScrollArea>
<div className="flex flex-row items-center justify-end space-x-2 py-4">
<div className="flex-1 text-sm text-muted-foreground">
{table.getFilteredSelectedRowModel().rows.length} of{" "}
{table.getFilteredRowModel().rows.length} row(s) selected.
</div>
<div className="space-x-2">
<Button
variant="outline"
size="sm"
onClick={() => table.previousPage()}
disabled={!table.getCanPreviousPage()}
>
Previous
</Button>
<Button
variant="outline"
size="sm"
onClick={() => table.nextPage()}
disabled={!table.getCanNextPage()}
>
Next
</Button>
</div>
</div>
</div>
);
}

View File

@ -34,6 +34,10 @@ export function NavbarMenu({ className }: { className?: string }) {
name: "API Keys", name: "API Keys",
path: "/api-keys", path: "/api-keys",
}, },
{
name: "Storage",
path: "/storage",
},
]; ];
return ( return (
@ -42,9 +46,9 @@ export function NavbarMenu({ className }: { className?: string }) {
{isDesktop && ( {isDesktop && (
<Tabs <Tabs
defaultValue={pathname} defaultValue={pathname}
className="w-[300px] flex pointer-events-auto" className="w-[400px] flex pointer-events-auto"
> >
<TabsList className="grid w-full grid-cols-3"> <TabsList className="grid w-full grid-cols-4">
{pages.map((page) => ( {pages.map((page) => (
<TabsTrigger <TabsTrigger
key={page.name} key={page.name}

View File

@ -0,0 +1,86 @@
// NOTE: this is WIP for doing client side validation for civitai model downloading
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
import { FormControl, FormItem, FormLabel } from "../ui/form";
import { LoadingIcon } from "@/components/LoadingIcon";
import * as React from "react";
import AutoFormInput from "../ui/auto-form/fields/input";
import { useDebouncedCallback } from "use-debounce";
import { CivitaiModelResponse } from "@/types/civitai";
import { z } from "zod";
import { insertCivitaiCheckpointSchema } from "@/db/schema";
function getUrl(civitai_url: string) {
// expect to be a URL to be https://civitai.com/models/36520
// possiblity with slugged name and query-param modelVersionId
const baseUrl = "https://civitai.com/api/v1/models/";
const url = new URL(civitai_url);
const pathSegments = url.pathname.split("/");
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
const modelVersionId = url.searchParams.get("modelVersionId");
return { url: baseUrl + modelId, modelVersionId };
}
export default function AutoFormCheckpointInput(
props: AutoFormInputComponentProps
) {
const [loading, setLoading] = React.useState(false);
const [modelRes, setModelRes] =
React.useState<z.infer<typeof CivitaiModelResponse>>();
const [modelVersionid, setModelVersionId] = React.useState<string | null>();
const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
const handleSearch = useDebouncedCallback((search) => {
const validationResult =
insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search);
if (!validationResult.success) {
console.error(validationResult.error);
// Optionally set an error state here
return;
}
setLoading(true);
const controller = new AbortController();
const { url, modelVersionId: versionId } = getUrl(search);
setModelVersionId(versionId);
fetch(url, {
signal: controller.signal,
})
.then((x) => x.json())
.then((a) => {
const res = CivitaiModelResponse.parse(a);
console.log(a);
console.log(res);
setModelRes(res);
setLoading(false);
});
return () => {
controller.abort();
setLoading(false);
};
}, 300);
const modifiedField = {
...fieldProps,
// onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
// handleSearch(event.target.value);
// },
};
return (
<FormItem>
{fieldConfigItem.inputProps?.showLabel && (
<FormLabel>
{label}
{isRequired && <span className="text-destructive">*</span>}
</FormLabel>
)}
<FormControl>
<AutoFormInput {...props} fieldProps={modifiedField} />
</FormControl>
</FormItem>
);
}

View File

@ -1,3 +1,4 @@
import { CivitaiModelResponse } from "@/types/civitai";
import { type InferSelectModel, relations } from "drizzle-orm"; import { type InferSelectModel, relations } from "drizzle-orm";
import { import {
boolean, boolean,
@ -92,7 +93,7 @@ export const workflowVersionRelations = relations(
fields: [workflowVersionTable.workflow_id], fields: [workflowVersionTable.workflow_id],
references: [workflowTable.id], references: [workflowTable.id],
}), }),
}), })
); );
export const workflowRunStatus = pgEnum("workflow_run_status", [ export const workflowRunStatus = pgEnum("workflow_run_status", [
@ -141,7 +142,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
() => workflowVersionTable.id, () => workflowVersionTable.id,
{ {
onDelete: "set null", onDelete: "set null",
}, }
), ),
workflow_inputs: workflow_inputs:
jsonb("workflow_inputs").$type<Record<string, string | number>>(), jsonb("workflow_inputs").$type<Record<string, string | number>>(),
@ -181,7 +182,7 @@ export const workflowRunRelations = relations(
fields: [workflowRunsTable.workflow_id], fields: [workflowRunsTable.workflow_id],
references: [workflowTable.id], references: [workflowTable.id],
}), }),
}), })
); );
// We still want to keep the workflow run record. // We still want to keep the workflow run record.
@ -205,7 +206,7 @@ export const workflowOutputRelations = relations(
fields: [workflowRunOutputs.run_id], fields: [workflowRunOutputs.run_id],
references: [workflowRunsTable.id], references: [workflowRunsTable.id],
}), }),
}), })
); );
// when user delete, also delete all the workflow versions // when user delete, also delete all the workflow versions
@ -238,7 +239,7 @@ export const snapshotType = z.object({
z.object({ z.object({
hash: z.string(), hash: z.string(),
disabled: z.boolean(), disabled: z.boolean(),
}), })
), ),
file_custom_nodes: z.array(z.any()), file_custom_nodes: z.array(z.any()),
}); });
@ -253,7 +254,7 @@ export const showcaseMedia = z.array(
z.object({ z.object({
url: z.string(), url: z.string(),
isCover: z.boolean().default(false), isCover: z.boolean().default(false),
}), })
); );
export const showcaseMediaNullable = z export const showcaseMediaNullable = z
@ -261,7 +262,7 @@ export const showcaseMediaNullable = z
z.object({ z.object({
url: z.string(), url: z.string(),
isCover: z.boolean().default(false), isCover: z.boolean().default(false),
}), })
) )
.nullable(); .nullable();
@ -363,6 +364,89 @@ export const authRequestsTable = dbSchema.table("auth_requests", {
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
export const resourceUpload = pgEnum("resource_upload", [
"started",
"success",
"failed",
]);
export const modelUploadType = pgEnum("model_upload_type", [
"civitai",
"huggingface",
"other",
]);
export const checkpointTable = dbSchema.table("checkpoints", {
id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
org_id: text("org_id"),
description: text("description"),
checkpoint_volume_id: uuid("checkpoint_volume_id")
.notNull()
.references(() => checkpointVolumeTable.id, {
onDelete: "cascade",
})
.notNull(),
model_name: text("model_name"),
folder_path: text("folder_path"), // in volume
civitai_id: text("civitai_id"),
civitai_version_id: text("civitai_version_id"),
civitai_url: text("civitai_url"),
civitai_download_url: text("civitai_download_url"),
civitai_model_response: jsonb("civitai_model_response").$type<
z.infer<typeof CivitaiModelResponse>
>(),
hf_url: text("hf_url"),
s3_url: text("s3_url"),
user_url: text("client_url"),
is_public: boolean("is_public").notNull().default(false),
status: resourceUpload("status").notNull().default("started"),
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
upload_type: modelUploadType("upload_type").notNull(),
error_log: text("error_log"),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
});
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, {
// onDelete: "cascade",
}),
org_id: text("org_id"),
volume_name: text("volume_name").notNull(),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
disabled: boolean("disabled").default(false).notNull(),
});
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
user: one(usersTable, {
fields: [checkpointTable.user_id],
references: [usersTable.id],
}),
volume: one(checkpointVolumeTable, {
fields: [checkpointTable.checkpoint_volume_id],
references: [checkpointVolumeTable.id],
}),
}));
export const checkpointVolumeRelations = relations(
checkpointVolumeTable,
({ many, one }) => ({
checkpoint: many(checkpointTable),
user: one(usersTable, {
fields: [checkpointVolumeTable.user_id],
references: [usersTable.id],
}),
})
);
export const subscriptionPlan = pgEnum("subscription_plan", [ export const subscriptionPlan = pgEnum("subscription_plan", [
"basic", "basic",
"pro", "pro",
@ -389,9 +473,26 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
export const insertCivitaiCheckpointSchema = createInsertSchema(
checkpointTable,
{
civitai_url: (schema) =>
schema.civitai_url
.trim()
.url({ message: "URL required" })
.includes("civitai.com/models", {
message: "civitai.com/models link required",
}),
}
);
export type UserType = InferSelectModel<typeof usersTable>; export type UserType = InferSelectModel<typeof usersTable>;
export type WorkflowType = InferSelectModel<typeof workflowTable>; export type WorkflowType = InferSelectModel<typeof workflowTable>;
export type MachineType = InferSelectModel<typeof machinesTable>; export type MachineType = InferSelectModel<typeof machinesTable>;
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>; export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
export type DeploymentType = InferSelectModel<typeof deploymentsTable>; export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
export type CheckpointVolumeType = InferSelectModel<
typeof checkpointVolumeTable
>;
export type UserUsageType = InferSelectModel<typeof userUsageTable>; export type UserUsageType = InferSelectModel<typeof userUsageTable>;

View File

@ -0,0 +1,5 @@
import { insertCivitaiCheckpointSchema } from "@/db/schema";
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
civitai_url: true,
});

View File

@ -0,0 +1,271 @@
"use server";
import { auth } from "@clerk/nextjs";
import {
checkpointTable,
CheckpointType,
checkpointVolumeTable,
CheckpointVolumeType,
} from "@/db/schema";
import { withServerPromise } from "./withServerPromise";
import { db } from "@/db/db";
import type { z } from "zod";
import { headers } from "next/headers";
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
import { and, eq, isNull } from "drizzle-orm";
import { CivitaiModelResponse } from "@/types/civitai";
export async function getCheckpoints() {
const { userId, orgId } = auth();
if (!userId) throw new Error("No user id");
const checkpoints = await db
.select()
.from(checkpointTable)
.where(
orgId
? eq(checkpointTable.org_id, orgId)
// make sure org_id is null
: and(
eq(checkpointTable.user_id, userId),
isNull(checkpointTable.org_id),
),
);
return checkpoints;
}
export async function getCheckpointById(id: string) {
const { userId, orgId } = auth();
if (!userId) throw new Error("No user id");
const checkpoint = await db
.select()
.from(checkpointTable)
.where(
and(
orgId ? eq(checkpointTable.org_id, orgId) : and(
eq(checkpointTable.user_id, userId),
isNull(checkpointTable.org_id),
),
eq(checkpointTable.id, id),
),
);
return checkpoint[0];
}
export async function getCheckpointVolumes() {
const { userId, orgId } = auth();
if (!userId) throw new Error("No user id");
const volume = await db
.select()
.from(checkpointVolumeTable)
.where(
and(
orgId
? eq(checkpointVolumeTable.org_id, orgId)
// make sure org_id is null
: and(
eq(checkpointVolumeTable.user_id, userId),
isNull(checkpointVolumeTable.org_id),
),
eq(checkpointVolumeTable.disabled, false),
),
);
return volume;
}
export async function retrieveCheckpointVolumes() {
let volumes = await getCheckpointVolumes();
if (volumes.length === 0) {
// create volume if not already created
volumes = await addCheckpointVolume();
}
return volumes;
}
export async function addCheckpointVolume() {
const { userId, orgId } = auth();
if (!userId) throw new Error("No user id");
// Insert the new checkpointVolume into the checkpointVolumeTable
const insertedVolume = await db
.insert(checkpointVolumeTable)
.values({
user_id: userId,
org_id: orgId,
volume_name: `checkpoints_${userId}`,
// created_at and updated_at will be set to current timestamp by default
disabled: false, // Default value
})
.returning(); // Returns the inserted row
return insertedVolume;
}
function getUrl(civitai_url: string) {
// expect to be a URL to be https://civitai.com/models/36520
// possiblity with slugged name and query-param modelVersionId
const baseUrl = "https://civitai.com/api/v1/models/";
const url = new URL(civitai_url);
const pathSegments = url.pathname.split("/");
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
const modelVersionId = url.searchParams.get("modelVersionId");
return { url: baseUrl + modelId, modelVersionId };
}
export const addCivitaiCheckpoint = withServerPromise(
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
const { userId, orgId } = auth();
if (!data.civitai_url) return { error: "no civitai_url" };
if (!userId) return { error: "No user id" };
const { url, modelVersionId } = getUrl(data?.civitai_url);
const civitaiModelRes = await fetch(url)
.then((x) => x.json())
.then((a) => {
console.log(a);
return CivitaiModelResponse.parse(a);
});
if (civitaiModelRes?.modelVersions?.length === 0) {
return; // no versions to download
}
let selectedModelVersion;
let selectedModelVersionId: string | null = modelVersionId;
if (!selectedModelVersionId) {
selectedModelVersion = civitaiModelRes.modelVersions[0];
selectedModelVersionId = civitaiModelRes.modelVersions[0].id.toString();
} else {
selectedModelVersion = civitaiModelRes.modelVersions.find((version) =>
version.id.toString() === selectedModelVersionId
);
if (!selectedModelVersion) {
return; // version id is wrong
}
selectedModelVersionId = selectedModelVersion?.id.toString();
}
const checkpointVolumes = await getCheckpointVolumes();
let cVolume;
if (checkpointVolumes.length === 0) {
const volume = await addCheckpointVolume();
cVolume = volume[0];
} else {
cVolume = checkpointVolumes[0];
}
const a = await db
.insert(checkpointTable)
.values({
user_id: userId,
org_id: orgId,
upload_type: "civitai",
model_name: selectedModelVersion.files[0].name,
civitai_id: civitaiModelRes.id.toString(),
civitai_version_id: selectedModelVersionId,
civitai_url: data.civitai_url,
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
civitai_model_response: civitaiModelRes,
checkpoint_volume_id: cVolume.id,
updated_at: new Date(),
})
.returning();
const b = a[0];
await uploadCheckpoint(data, b, cVolume);
// redirect(`/checkpoints/${b.id}`);
},
);
// export const redownloadCheckpoint = withServerPromise(
// async (data: CheckpointItemList) => {
// const { userId } = auth();
// if (!userId) return { error: "No user id" };
//
// const checkpointVolumes = await getCheckpointVolumes();
// let cVolume;
// if (checkpointVolumes.length === 0) {
// const volume = await addCheckpointVolume();
// cVolume = volume[0];
// } else {
// cVolume = checkpointVolumes[0];
// }
//
// console.log("data");
// console.log(data);
//
// const a = await db
// .update(checkpointTable)
// .set({
// // status: "started",
// // updated_at: new Date(),
// })
// .returning();
//
// const b = a[0];
//
// console.log("b");
// console.log(b);
//
// await uploadCheckpoint(data, b, cVolume);
// // redirect(`/checkpoints/${b.id}`);
// },
// );
async function uploadCheckpoint(
data: z.infer<typeof addCivitaiCheckpointSchema>,
c: CheckpointType,
v: CheckpointVolumeType,
) {
const headersList = headers();
const domain = headersList.get("x-forwarded-host") || "";
const protocol = headersList.get("x-forwarded-proto") || "";
if (domain === "") {
throw new Error("No domain");
}
// Call remote builder
const result = await fetch(
`${process.env.MODAL_BUILDER_URL!}/upload-volume`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
download_url: c.civitai_download_url,
volume_name: v.volume_name,
volume_id: v.id,
checkpoint_id: c.id,
callback_url: `${protocol}://${domain}/api/volume-upload`,
upload_type: "checkpoint"
}),
},
);
if (!result.ok) {
const error_log = await result.text();
await db
.update(checkpointTable)
.set({
...data,
status: "failed",
error_log: error_log,
})
.where(eq(checkpointTable.id, c.id));
throw new Error(`Error: ${result.statusText} ${error_log}`);
} else {
// setting the build machine id
const json = await result.json();
await db
.update(checkpointTable)
.set({
...data,
upload_machine_id: json.build_machine_instance_id,
})
.where(eq(checkpointTable.id, c.id));
}
}

View File

@ -15,6 +15,7 @@ import { headers } from "next/headers";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
import "server-only"; import "server-only";
import type { z } from "zod"; import type { z } from "zod";
import { retrieveCheckpointVolumes } from "./curdCheckpoint";
export async function getMachines() { export async function getMachines() {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
@ -189,6 +190,7 @@ async function _buildMachine(
throw new Error("No domain"); throw new Error("No domain");
} }
const volumes = await retrieveCheckpointVolumes();
// Call remote builder // Call remote builder
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, { const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
method: "POST", method: "POST",
@ -202,6 +204,7 @@ async function _buildMachine(
callback_url: `${protocol}://${domain}/api/machine-built`, callback_url: `${protocol}://${domain}/api/machine-built`,
models: data.models, //JSON.parse(data.models as string), models: data.models, //JSON.parse(data.models as string),
gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4", gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
checkpoint_volume_name: volumes[0].volume_name,
}), }),
}); });

View File

@ -0,0 +1,41 @@
import { db } from "@/db/db";
import {
checkpointTable,
} from "@/db/schema";
import { auth } from "@clerk/nextjs";
import { and, desc, eq, isNull } from "drizzle-orm";
export async function getAllUserCheckpoints() {
const { userId, orgId } = await auth();
if (!userId) {
return null;
}
const checkpoints = await db.query.checkpointTable.findMany({
with: {
user: {
columns: {
name: true,
},
},
},
columns: {
id: true,
updated_at: true,
model_name: true,
civitai_url: true,
civitai_model_response: true,
is_public: true,
upload_type: true,
status: true,
},
orderBy: desc(checkpointTable.updated_at),
where:
orgId != undefined
? eq(checkpointTable.org_id, orgId)
: and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
});
return checkpoints;
}

129
web/src/types/civitai.ts Normal file
View File

@ -0,0 +1,129 @@
import { z } from "zod";
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
export const creatorSchema = z.object({
username: z.string().nullish(),
image: z.string().url().nullish(),
});
export const fileMetadataSchema = z.object({
fp: z.string().nullish(),
size: z.string().nullish(),
format: z.string().nullish(),
});
export const fileSchema = z.object({
id: z.number(),
sizeKB: z.number().nullish(),
name: z.string(),
type: z.string().nullish(),
metadata: fileMetadataSchema.nullish(),
pickleScanResult: z.string().nullish(),
pickleScanMessage: z.string().nullable(),
virusScanResult: z.string().nullish(),
virusScanMessage: z.string().nullable(),
scannedAt: z.string().nullish(),
hashes: z.record(z.string()).nullish(),
downloadUrl: z.string().url(),
primary: z.boolean().nullish(),
});
export const imageMetadataSchema = z.object({
hash: z.string(),
width: z.number(),
height: z.number(),
});
export const imageMetaSchema = z.object({
ENSD: z.string().nullish(),
Size: z.string().nullish(),
seed: z.number().nullish(),
Model: z.string().nullish(),
steps: z.number().nullish(),
hashes: z.record(z.string()).nullish(),
prompt: z.string().nullish(),
sampler: z.string().nullish(),
cfgScale: z.number().nullish(),
ClipSkip: z.number().nullish(),
resources: z.array(
z.object({
hash: z.string().nullish(),
name: z.string(),
type: z.string(),
weight: z.number().nullish(),
}),
).nullish(),
ModelHash: z.string().nullish(),
HiresSteps: z.string().nullish(),
HiresUpscale: z.string().nullish(),
HiresUpscaler: z.string().nullish(),
negativePrompt: z.string(),
DenoisingStrength: z.number().nullish(),
});
// NOTE: this definition is all over the place
// export const imageSchema = z.object({
// url: z.string().url().nullish(),
// nsfw: z.enum(["None", "Soft", "Mature"]).nullish(),
// width: z.number().nullish(),
// height: z.number().nullish(),
// hash: z.string().nullish(),
// type: z.string().nullish(),
// metadata: imageMetadataSchema.nullish(),
// meta: imageMetaSchema.nullish(),
// });
export const modelVersionSchema = z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string().nullish(),
updatedAt: z.string().nullish(),
// status: z.enum(["Published", "Unpublished"]).nullish(),
status: z.string().nullish(),
publishedAt: z.string().nullish(),
trainedWords: z.array(z.string()).nullable(),
trainingStatus: z.string().nullable(),
trainingDetails: z.string().nullable(),
baseModel: z.string().nullish(),
baseModelType: z.string().nullish(),
earlyAccessTimeFrame: z.number().nullish(),
description: z.string().nullable(),
vaeId: z.number().nullable(),
stats: z.object({
downloadCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
}).nullish(),
files: z.array(fileSchema),
images: z.array(z.any()).nullish(),
downloadUrl: z.string().url(),
});
export const statsSchema = z.object({
downloadCount: z.number(),
favoriteCount: z.number(),
commentCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
tippedAmountCount: z.number(),
});
export const CivitaiModelResponse = z.object({
id: z.number(),
name: z.string().nullish(),
description: z.string().nullish(),
// type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know
type: z.string(),
poi: z.boolean().nullish(),
nsfw: z.boolean().nullish(),
allowNoCredit: z.boolean().nullish(),
allowCommercialUse: z.string().nullish(),
allowDerivatives: z.boolean().nullish(),
allowDifferentLicense: z.boolean().nullish(),
stats: statsSchema.nullish(),
creator: creatorSchema.nullish(),
tags: z.array(z.string()).nullish(),
modelVersions: z.array(modelVersionSchema),
});