Squashed commit of the following:
commit c36b0ec0b374dd8ccbee3a6044ee7e3f1fefe368 Author: Nicholas Koben Kao <kobenkao@gmail.com> Date: Thu Jan 25 17:54:54 2024 -0800 nits on wording and removing link to broken storage/:id page commit 0777fdcf7b0002244bc713199d3d64eea6b6061e Author: Nicholas Koben Kao <kobenkao@gmail.com> Date: Thu Jan 25 17:23:55 2024 -0800 builder update config and such commit 958b795bb2b6ac27ce33c5729ef265b068420e1a Author: Nicholas Koben Kao <kobenkao@gmail.com> Date: Thu Jan 25 17:23:43 2024 -0800 rename all from checkponit to model commit 7a9c5636e73bd005499b141a4dd382db5672c962 Author: Nicholas Koben Kao <kobenkao@gmail.com> Date: Thu Jan 25 16:51:59 2024 -0800 rename for consistency commit 48bebbafab9a95388817df97c15f8ea97e0fea75 Author: Nicholas Koben Kao <kobenkao@gmail.com> Date: Thu Jan 25 16:18:36 2024 -0800 bulider commit 81dacd9af457886f2f027994d225a7748c738abb Author: Nicholas Koben Kao <kobenkao@gmail.com> Date: Thu Jan 25 16:17:56 2024 -0800 different types of models
This commit is contained in:
parent
62a69dba06
commit
85477aba9d
@ -3,4 +3,5 @@ MODAL_TOKEN_SECRET=
|
|||||||
CIVITAI_API_KEY=
|
CIVITAI_API_KEY=
|
||||||
|
|
||||||
# On production set to False
|
# On production set to False
|
||||||
DEPLOY_TEST_FLAG=True
|
DEPLOY_TEST_FLAG=True
|
||||||
|
CIVITAI_API_KEY=
|
||||||
|
@ -177,7 +177,7 @@ class Item(BaseModel):
|
|||||||
snapshot: Snapshot
|
snapshot: Snapshot
|
||||||
models: List[Model]
|
models: List[Model]
|
||||||
callback_url: str
|
callback_url: str
|
||||||
checkpoint_volume_name: str
|
model_volume_name: str
|
||||||
gpu: GPUType = Field(default=GPUType.T4)
|
gpu: GPUType = Field(default=GPUType.T4)
|
||||||
|
|
||||||
@field_validator('gpu')
|
@field_validator('gpu')
|
||||||
@ -227,24 +227,31 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
|
|||||||
|
|
||||||
# return {"Hello": "World"}
|
# return {"Hello": "World"}
|
||||||
|
|
||||||
|
# definition based on web schema
|
||||||
class UploadType(str, Enum):
|
class UploadType(str, Enum):
|
||||||
checkpoint = "checkpoint"
|
checkpoint = "checkpoint"
|
||||||
|
lora = "lora"
|
||||||
|
embedding = "embedding"
|
||||||
|
|
||||||
class UploadBody(BaseModel):
|
class UploadBody(BaseModel):
|
||||||
download_url: str
|
download_url: str
|
||||||
volume_name: str
|
volume_name: str
|
||||||
volume_id: str
|
volume_id: str
|
||||||
checkpoint_id: str
|
model_id: str
|
||||||
upload_type: UploadType
|
upload_type: UploadType
|
||||||
callback_url: str
|
callback_url: str
|
||||||
|
|
||||||
|
|
||||||
|
# based on ComfyUI's model dir, and our mappings in ./src/template/data/extra_model_paths.yaml
|
||||||
UPLOAD_TYPE_DIR_MAP = {
|
UPLOAD_TYPE_DIR_MAP = {
|
||||||
UploadType.checkpoint: "checkpoints"
|
UploadType.checkpoint: "checkpoints",
|
||||||
|
UploadType.lora: "loras",
|
||||||
|
UploadType.embedding: "embeddings",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/upload-volume")
|
@app.post("/upload-volume")
|
||||||
async def upload_checkpoint(body: UploadBody):
|
async def upload_model(body: UploadBody):
|
||||||
global last_activity_time
|
global last_activity_time
|
||||||
last_activity_time = time.time()
|
last_activity_time = time.time()
|
||||||
logger.info(f"Extended inactivity time to {global_timeout}")
|
logger.info(f"Extended inactivity time to {global_timeout}")
|
||||||
@ -254,6 +261,7 @@ async def upload_checkpoint(body: UploadBody):
|
|||||||
# check that this
|
# check that this
|
||||||
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
|
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
|
||||||
|
|
||||||
|
|
||||||
async def upload_logic(body: UploadBody):
|
async def upload_logic(body: UploadBody):
|
||||||
folder_path = f"/app/builds/{body.volume_id}"
|
folder_path = f"/app/builds/{body.volume_id}"
|
||||||
|
|
||||||
@ -270,7 +278,7 @@ async def upload_logic(body: UploadBody):
|
|||||||
},
|
},
|
||||||
"callback_url": body.callback_url,
|
"callback_url": body.callback_url,
|
||||||
"callback_body": {
|
"callback_body": {
|
||||||
"checkpoint_id": body.checkpoint_id,
|
"model_id": body.model_id,
|
||||||
"volume_id": body.volume_id,
|
"volume_id": body.volume_id,
|
||||||
"folder_path": upload_path,
|
"folder_path": upload_path,
|
||||||
},
|
},
|
||||||
@ -279,51 +287,11 @@ async def upload_logic(body: UploadBody):
|
|||||||
with open(f"{folder_path}/config.py", "w") as f:
|
with open(f"{folder_path}/config.py", "w") as f:
|
||||||
f.write("config = " + json.dumps(config))
|
f.write("config = " + json.dumps(config))
|
||||||
|
|
||||||
process = await asyncio.subprocess.create_subprocess_shell(
|
await asyncio.subprocess.create_subprocess_shell(
|
||||||
f"modal run app.py",
|
f"modal run app.py",
|
||||||
# stdout=asyncio.subprocess.PIPE,
|
|
||||||
# stderr=asyncio.subprocess.PIPE,
|
|
||||||
cwd=folder_path,
|
cwd=folder_path,
|
||||||
env={**os.environ, "COLUMNS": "10000"}
|
env={**os.environ, "COLUMNS": "10000"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# error_logs = []
|
|
||||||
# async def read_stream(stream):
|
|
||||||
# while True:
|
|
||||||
# line = await stream.readline()
|
|
||||||
# if line:
|
|
||||||
# l = line.decode('utf-8').strip()
|
|
||||||
# error_logs.append(l)
|
|
||||||
# logger.error(l)
|
|
||||||
# error_logs.append({
|
|
||||||
# "logs": l,
|
|
||||||
# "timestamp": time.time()
|
|
||||||
# })
|
|
||||||
# else:
|
|
||||||
# break
|
|
||||||
|
|
||||||
# stderr_read_task = asyncio.create_task(read_stream(process.stderr))
|
|
||||||
#
|
|
||||||
# await asyncio.wait([stderr_read_task])
|
|
||||||
# await process.wait()
|
|
||||||
|
|
||||||
# if process.returncode != 0:
|
|
||||||
# error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
|
|
||||||
# # Error handling: send POST request to callback URL with error details
|
|
||||||
# requests.post(body.callback_url, json={
|
|
||||||
# "volume_id": body.volume_id,
|
|
||||||
# "checkpoint_id": body.checkpoint_id,
|
|
||||||
# "folder_path": upload_path,
|
|
||||||
# "error_logs": json.dumps(error_logs),
|
|
||||||
# "status": "failed"
|
|
||||||
# })
|
|
||||||
#
|
|
||||||
# requests.post(body.callback_url, json={
|
|
||||||
# "checkpoint_id": body.checkpoint_id,
|
|
||||||
# "volume_id": body.volume_id,
|
|
||||||
# "folder_path": upload_path,
|
|
||||||
# "status": "success"
|
|
||||||
# })
|
|
||||||
|
|
||||||
@app.post("/create")
|
@app.post("/create")
|
||||||
async def create_machine(item: Item):
|
async def create_machine(item: Item):
|
||||||
@ -414,8 +382,8 @@ async def build_logic(item: Item):
|
|||||||
"name": item.name,
|
"name": item.name,
|
||||||
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
|
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
|
||||||
"gpu": item.gpu,
|
"gpu": item.gpu,
|
||||||
"public_checkpoint_volume": "model-store",
|
"public_model_volume": "model-store",
|
||||||
"private_checkpoint_volume": item.checkpoint_volume_name
|
"private_model_volume": item.model_volume_name
|
||||||
}
|
}
|
||||||
with open(f"{folder_path}/config.py", "w") as f:
|
with open(f"{folder_path}/config.py", "w") as f:
|
||||||
f.write("config = " + json.dumps(config))
|
f.write("config = " + json.dumps(config))
|
||||||
|
@ -2,6 +2,6 @@ config = {
|
|||||||
"name": "my-app",
|
"name": "my-app",
|
||||||
"deploy_test": "True",
|
"deploy_test": "True",
|
||||||
"gpu": "T4",
|
"gpu": "T4",
|
||||||
"public_checkpoint_volume": "model-store",
|
"public_model_volume": "model-store",
|
||||||
"private_checkpoint_volume": "private-model-store"
|
"private_model_volume": "private-model-store"
|
||||||
}
|
}
|
||||||
|
@ -13,3 +13,11 @@ public:
|
|||||||
private:
|
private:
|
||||||
base_path: /private_models/
|
base_path: /private_models/
|
||||||
checkpoints: checkpoints
|
checkpoints: checkpoints
|
||||||
|
clip: clip
|
||||||
|
clip_vision: clip_vision
|
||||||
|
configs: configs
|
||||||
|
controlnet: controlnet
|
||||||
|
embeddings: embeddings
|
||||||
|
loras: loras
|
||||||
|
upscale_models: upscale_models
|
||||||
|
vae: vae
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import modal
|
import modal
|
||||||
from config import config
|
from config import config
|
||||||
|
|
||||||
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
|
public_model_volume = modal.Volume.persisted(config["public_model_volume"])
|
||||||
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
|
private_volume = modal.Volume.persisted(config["private_model_volume"])
|
||||||
|
|
||||||
PUBLIC_BASEMODEL_DIR = "/public_models"
|
PUBLIC_BASEMODEL_DIR = "/public_models"
|
||||||
PRIVATE_BASEMODEL_DIR = "/private_models"
|
PRIVATE_BASEMODEL_DIR = "/private_models"
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
config = {
|
config = {
|
||||||
"volume_names": {
|
"volume_names": {
|
||||||
"test": {
|
"user4": {
|
||||||
"download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png",
|
"download_url": "https://civitai.com/api/download/models/11745",
|
||||||
"folder_path": "images"
|
"folder_path": "checkpoints"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"volume_paths": {
|
"volume_paths": {
|
||||||
"test": "/volumes/something"
|
"user4": "/volumes/something",
|
||||||
},
|
},
|
||||||
"callback_url": "",
|
"callback_url": "",
|
||||||
"callback_body": {
|
"callback_body": {
|
||||||
"checkpoint_id": "",
|
"model_id": "",
|
||||||
"volume_id": "",
|
"volume_id": "",
|
||||||
"folder_path": "images",
|
"folder_path": "checkpoints",
|
||||||
},
|
},
|
||||||
"civitai_api_key": "",
|
"civitai_api_key": "",
|
||||||
}
|
}
|
||||||
|
7
web/drizzle/0043_dapper_santa_claus.sql
Normal file
7
web/drizzle/0043_dapper_santa_claus.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE "model_type" AS ENUM('checkpoint', 'lora', 'embedding', 'vae');
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
ALTER TABLE "comfyui_deploy"."checkpoints" ADD COLUMN "model_type" "model_type" NOT NULL;
|
26
web/drizzle/0044_married_malcolm_colcord.sql
Normal file
26
web/drizzle/0044_married_malcolm_colcord.sql
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
ALTER TABLE "comfyui_deploy"."checkpoints" RENAME TO "models";--> statement-breakpoint
|
||||||
|
ALTER TABLE "comfyui_deploy"."checkpoint_volume" RENAME TO "user_volume";--> statement-breakpoint
|
||||||
|
ALTER TABLE "comfyui_deploy"."models" RENAME COLUMN "checkpoint_volume_id" TO "user_volume_id";--> statement-breakpoint
|
||||||
|
ALTER TABLE "comfyui_deploy"."models" DROP CONSTRAINT "checkpoints_user_id_users_id_fk";
|
||||||
|
--> statement-breakpoint
|
||||||
|
ALTER TABLE "comfyui_deploy"."models" DROP CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk";
|
||||||
|
--> statement-breakpoint
|
||||||
|
ALTER TABLE "comfyui_deploy"."user_volume" DROP CONSTRAINT "checkpoint_volume_user_id_users_id_fk";
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
ALTER TABLE "comfyui_deploy"."models" ADD CONSTRAINT "models_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
ALTER TABLE "comfyui_deploy"."models" ADD CONSTRAINT "models_user_volume_id_user_volume_id_fk" FOREIGN KEY ("user_volume_id") REFERENCES "comfyui_deploy"."user_volume"("id") ON DELETE cascade ON UPDATE no action;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
ALTER TABLE "comfyui_deploy"."user_volume" ADD CONSTRAINT "user_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
1288
web/drizzle/meta/0043_snapshot.json
Normal file
1288
web/drizzle/meta/0043_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
1293
web/drizzle/meta/0044_snapshot.json
Normal file
1293
web/drizzle/meta/0044_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -302,6 +302,20 @@
|
|||||||
"when": 1706164614659,
|
"when": 1706164614659,
|
||||||
"tag": "0042_windy_madelyne_pryor",
|
"tag": "0042_windy_madelyne_pryor",
|
||||||
"breakpoints": true
|
"breakpoints": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idx": 43,
|
||||||
|
"version": "5",
|
||||||
|
"when": 1706225960550,
|
||||||
|
"tag": "0043_dapper_santa_claus",
|
||||||
|
"breakpoints": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idx": 44,
|
||||||
|
"version": "5",
|
||||||
|
"when": 1706230304140,
|
||||||
|
"tag": "0044_married_malcolm_colcord",
|
||||||
|
"breakpoints": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -1,15 +1,15 @@
|
|||||||
import { parseDataSafe } from "../../../../lib/parseDataSafe";
|
import { parseDataSafe } from "../../../../lib/parseDataSafe";
|
||||||
import { db } from "@/db/db";
|
import { db } from "@/db/db";
|
||||||
import { checkpointTable, machinesTable } from "@/db/schema";
|
import { modelTable } from "@/db/schema";
|
||||||
import { eq } from "drizzle-orm";
|
import { eq } from "drizzle-orm";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
const Request = z.object({
|
const Request = z.object({
|
||||||
volume_id: z.string(),
|
volume_id: z.string(),
|
||||||
checkpoint_id: z.string(),
|
model_id: z.string(),
|
||||||
folder_path: z.string().optional(),
|
folder_path: z.string().optional(),
|
||||||
status: z.enum(['success', 'failed']),
|
status: z.enum(["success", "failed"]),
|
||||||
error_log: z.string().optional(),
|
error_log: z.string().optional(),
|
||||||
timeout: z.number().optional(),
|
timeout: z.number().optional(),
|
||||||
});
|
});
|
||||||
@ -18,30 +18,30 @@ export async function POST(request: Request) {
|
|||||||
const [data, error] = await parseDataSafe(Request, request);
|
const [data, error] = await parseDataSafe(Request, request);
|
||||||
if (!data || error) return error;
|
if (!data || error) return error;
|
||||||
|
|
||||||
const { checkpoint_id, error_log, status, folder_path } = data;
|
const { model_id, error_log, status, folder_path } = data;
|
||||||
console.log( checkpoint_id, error_log, status, folder_path )
|
console.log(model_id, error_log, status, folder_path);
|
||||||
|
|
||||||
if (status === "success") {
|
if (status === "success") {
|
||||||
await db
|
await db
|
||||||
.update(checkpointTable)
|
.update(modelTable)
|
||||||
.set({
|
.set({
|
||||||
status: "success",
|
status: "success",
|
||||||
folder_path,
|
folder_path,
|
||||||
updated_at: new Date(),
|
updated_at: new Date(),
|
||||||
// build_log: build_log,
|
// build_log: build_log,
|
||||||
})
|
})
|
||||||
.where(eq(checkpointTable.id, checkpoint_id));
|
.where(eq(modelTable.id, model_id));
|
||||||
} else {
|
} else {
|
||||||
await db
|
await db
|
||||||
.update(checkpointTable)
|
.update(modelTable)
|
||||||
.set({
|
.set({
|
||||||
status: "failed",
|
status: "failed",
|
||||||
error_log,
|
error_log,
|
||||||
updated_at: new Date(),
|
updated_at: new Date(),
|
||||||
// status: "error",
|
// status: "error",
|
||||||
// build_log: build_log,
|
// build_log: build_log,
|
||||||
})
|
})
|
||||||
.where(eq(checkpointTable.id, checkpoint_id));
|
.where(eq(modelTable.id, model_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
return NextResponse.json(
|
return NextResponse.json(
|
||||||
@ -50,6 +50,6 @@ export async function POST(request: Request) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
status: 200,
|
status: 200,
|
||||||
}
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { setInitialUserData } from "../../../lib/setInitialUserData";
|
import { setInitialUserData } from "../../../lib/setInitialUserData";
|
||||||
import { auth } from "@clerk/nextjs";
|
import { auth } from "@clerk/nextjs";
|
||||||
import { clerkClient } from "@clerk/nextjs/server";
|
import { clerkClient } from "@clerk/nextjs/server";
|
||||||
import { CheckpointList } from "@/components/CheckpointList";
|
import { ModelList } from "@/components/ModelList";
|
||||||
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
|
import { getAllUserModels } from "@/server/getAllUserModel";
|
||||||
|
|
||||||
export default function Page() {
|
export default function Page() {
|
||||||
return <CheckpointListServer />;
|
return <ModelListServer />;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function CheckpointListServer() {
|
async function ModelListServer() {
|
||||||
const { userId } = auth();
|
const { userId } = auth();
|
||||||
|
|
||||||
if (!userId) {
|
if (!userId) {
|
||||||
@ -21,15 +21,15 @@ async function CheckpointListServer() {
|
|||||||
await setInitialUserData(userId);
|
await setInitialUserData(userId);
|
||||||
}
|
}
|
||||||
|
|
||||||
const checkpoints = await getAllUserCheckpoints();
|
const models = await getAllUserModels();
|
||||||
|
|
||||||
if (!checkpoints) {
|
if (!models) {
|
||||||
return <div>No checkpoints found</div>;
|
return <div>No models found</div>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<CheckpointList data={checkpoints} />
|
<ModelList data={models} />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ import { getRelativeTime } from "../lib/getRelativeTime";
|
|||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Checkbox } from "@/components/ui/checkbox";
|
import { Checkbox } from "@/components/ui/checkbox";
|
||||||
import { InsertModal, UpdateModal } from "./InsertModal";
|
import { InsertModal } from "./InsertModal";
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
import {
|
import {
|
||||||
@ -15,7 +15,7 @@ import {
|
|||||||
TableHeader,
|
TableHeader,
|
||||||
TableRow,
|
TableRow,
|
||||||
} from "@/components/ui/table";
|
} from "@/components/ui/table";
|
||||||
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
|
import type { getAllUserModels as getAllUserModels } from "@/server/getAllUserModel";
|
||||||
import type {
|
import type {
|
||||||
ColumnDef,
|
ColumnDef,
|
||||||
ColumnFiltersState,
|
ColumnFiltersState,
|
||||||
@ -32,23 +32,22 @@ import {
|
|||||||
} from "@tanstack/react-table";
|
} from "@tanstack/react-table";
|
||||||
import { ArrowUpDown } from "lucide-react";
|
import { ArrowUpDown } from "lucide-react";
|
||||||
import * as React from "react";
|
import * as React from "react";
|
||||||
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
|
import { addCivitaiModel } from "@/server/curdModel";
|
||||||
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
|
import { addCivitaiModelSchema } from "@/server/addCivitaiModelSchema";
|
||||||
|
import { modelEnumType } from "@/db/schema";
|
||||||
|
|
||||||
export type CheckpointItemList = NonNullable<
|
export type ModelItemList = NonNullable<
|
||||||
Awaited<ReturnType<typeof getAllUserCheckpoints>>
|
Awaited<ReturnType<typeof getAllUserModels>>
|
||||||
>[0];
|
>[0];
|
||||||
|
|
||||||
export const columns: ColumnDef<CheckpointItemList>[] = [
|
export const columns: ColumnDef<ModelItemList>[] = [
|
||||||
{
|
{
|
||||||
accessorKey: "id",
|
accessorKey: "id",
|
||||||
id: "select",
|
id: "select",
|
||||||
header: ({ table }) => (
|
header: ({ table }) => (
|
||||||
<Checkbox
|
<Checkbox
|
||||||
checked={
|
checked={table.getIsAllPageRowsSelected() ||
|
||||||
table.getIsAllPageRowsSelected() ||
|
(table.getIsSomePageRowsSelected() && "indeterminate")}
|
||||||
(table.getIsSomePageRowsSelected() && "indeterminate")
|
|
||||||
}
|
|
||||||
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
|
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
|
||||||
aria-label="Select all"
|
aria-label="Select all"
|
||||||
/>
|
/>
|
||||||
@ -77,22 +76,23 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
cell: ({ row }) => {
|
cell: ({ row }) => {
|
||||||
const checkpoint = row.original;
|
const model = row.original;
|
||||||
return (
|
return (
|
||||||
<a
|
<>
|
||||||
|
{
|
||||||
|
/*<a
|
||||||
className="hover:underline flex gap-2"
|
className="hover:underline flex gap-2"
|
||||||
href={`/storage/${checkpoint.id}`} // TODO
|
href={`/storage/${model.id}`} // TODO
|
||||||
>
|
>*/
|
||||||
|
}
|
||||||
<span className="truncate max-w-[200px]">
|
<span className="truncate max-w-[200px]">
|
||||||
{row.original.model_name}
|
{row.original.model_name}
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
{checkpoint.is_public ? (
|
{model.is_public
|
||||||
<Badge variant="green">Public</Badge>
|
? <Badge variant="green">Public</Badge>
|
||||||
) : (
|
: <Badge variant="orange">Private</Badge>}
|
||||||
<Badge variant="orange">Private</Badge>
|
</>
|
||||||
)}
|
|
||||||
</a>
|
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -111,7 +111,11 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
|
|||||||
},
|
},
|
||||||
cell: ({ row }) => {
|
cell: ({ row }) => {
|
||||||
return (
|
return (
|
||||||
<Badge variant={row.original.status === "failed" ? "red" : (row.original.status === "started" ? "yellow" : "green")}>
|
<Badge
|
||||||
|
variant={row.original.status === "failed"
|
||||||
|
? "red"
|
||||||
|
: (row.original.status === "started" ? "yellow" : "green")}
|
||||||
|
>
|
||||||
{row.original.status}
|
{row.original.status}
|
||||||
</Badge>
|
</Badge>
|
||||||
);
|
);
|
||||||
@ -167,6 +171,35 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
|
|||||||
return <Badge variant="cyan">{row.original.upload_type}</Badge>;
|
return <Badge variant="cyan">{row.original.upload_type}</Badge>;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "model_type",
|
||||||
|
header: ({ column }) => {
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className="flex items-center hover:underline"
|
||||||
|
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||||
|
>
|
||||||
|
Model Type
|
||||||
|
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const model_type_map: Record<modelEnumType, any> = {
|
||||||
|
"checkpoint": "amber",
|
||||||
|
"lora": "green",
|
||||||
|
"embedding": "violet",
|
||||||
|
"vae": "teal",
|
||||||
|
};
|
||||||
|
|
||||||
|
function getBadgeColor(modelType: modelEnumType) {
|
||||||
|
return model_type_map[modelType] || "default";
|
||||||
|
}
|
||||||
|
|
||||||
|
const color = getBadgeColor(row.original.model_type);
|
||||||
|
return <Badge variant={color}>{row.original.model_type}</Badge>;
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
accessorKey: "date",
|
accessorKey: "date",
|
||||||
sortingFn: "datetime",
|
sortingFn: "datetime",
|
||||||
@ -221,13 +254,14 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
|
|||||||
// },
|
// },
|
||||||
];
|
];
|
||||||
|
|
||||||
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
export function ModelList({ data }: { data: ModelItemList[] }) {
|
||||||
const [sorting, setSorting] = React.useState<SortingState>([]);
|
const [sorting, setSorting] = React.useState<SortingState>([]);
|
||||||
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
||||||
[]
|
[],
|
||||||
);
|
);
|
||||||
const [columnVisibility, setColumnVisibility] =
|
const [columnVisibility, setColumnVisibility] = React.useState<
|
||||||
React.useState<VisibilityState>({});
|
VisibilityState
|
||||||
|
>({});
|
||||||
const [rowSelection, setRowSelection] = React.useState({});
|
const [rowSelection, setRowSelection] = React.useState({});
|
||||||
|
|
||||||
const table = useReactTable({
|
const table = useReactTable({
|
||||||
@ -254,10 +288,10 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
|||||||
<div className="flex flex-row w-full items-center py-4">
|
<div className="flex flex-row w-full items-center py-4">
|
||||||
<Input
|
<Input
|
||||||
placeholder="Filter workflows..."
|
placeholder="Filter workflows..."
|
||||||
value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
|
value={(table.getColumn("model_name")?.getFilterValue() as string) ??
|
||||||
|
""}
|
||||||
onChange={(event) =>
|
onChange={(event) =>
|
||||||
table.getColumn("name")?.setFilterValue(event.target.value)
|
table.getColumn("model_name")?.setFilterValue(event.target.value)}
|
||||||
}
|
|
||||||
className="max-w-sm"
|
className="max-w-sm"
|
||||||
/>
|
/>
|
||||||
<div className="ml-auto flex gap-2">
|
<div className="ml-auto flex gap-2">
|
||||||
@ -268,17 +302,17 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
|||||||
// TODO: limitations based on plan
|
// TODO: limitations based on plan
|
||||||
}
|
}
|
||||||
tooltip={"Add models using their civitai url!"}
|
tooltip={"Add models using their civitai url!"}
|
||||||
title="Civitai Checkpoint"
|
title="Add a Civitai Model"
|
||||||
description="Pick a model from civitai"
|
description="Pick a model from civitai"
|
||||||
serverAction={addCivitaiCheckpoint}
|
serverAction={addCivitaiModel}
|
||||||
formSchema={addCivitaiCheckpointSchema}
|
formSchema={addCivitaiModelSchema}
|
||||||
fieldConfig={{
|
fieldConfig={{
|
||||||
civitai_url: {
|
civitai_url: {
|
||||||
fieldType: "fallback",
|
fieldType: "fallback",
|
||||||
inputProps: { required: true },
|
inputProps: { required: true },
|
||||||
description: (
|
description: (
|
||||||
<>
|
<>
|
||||||
Pick a checkpoint from{" "}
|
Pick a model from{" "}
|
||||||
<a
|
<a
|
||||||
href="https://www.civitai.com/models"
|
href="https://www.civitai.com/models"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
@ -302,12 +336,10 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
|||||||
{headerGroup.headers.map((header) => {
|
{headerGroup.headers.map((header) => {
|
||||||
return (
|
return (
|
||||||
<TableHead key={header.id}>
|
<TableHead key={header.id}>
|
||||||
{header.isPlaceholder
|
{header.isPlaceholder ? null : flexRender(
|
||||||
? null
|
header.column.columnDef.header,
|
||||||
: flexRender(
|
header.getContext(),
|
||||||
header.column.columnDef.header,
|
)}
|
||||||
header.getContext()
|
|
||||||
)}
|
|
||||||
</TableHead>
|
</TableHead>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
@ -315,32 +347,34 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
|||||||
))}
|
))}
|
||||||
</TableHeader>
|
</TableHeader>
|
||||||
<TableBody>
|
<TableBody>
|
||||||
{table.getRowModel().rows?.length ? (
|
{table.getRowModel().rows?.length
|
||||||
table.getRowModel().rows.map((row) => (
|
? (
|
||||||
<TableRow
|
table.getRowModel().rows.map((row) => (
|
||||||
key={row.id}
|
<TableRow
|
||||||
data-state={row.getIsSelected() && "selected"}
|
key={row.id}
|
||||||
>
|
data-state={row.getIsSelected() && "selected"}
|
||||||
{row.getVisibleCells().map((cell) => (
|
>
|
||||||
<TableCell key={cell.id}>
|
{row.getVisibleCells().map((cell) => (
|
||||||
{flexRender(
|
<TableCell key={cell.id}>
|
||||||
cell.column.columnDef.cell,
|
{flexRender(
|
||||||
cell.getContext()
|
cell.column.columnDef.cell,
|
||||||
)}
|
cell.getContext(),
|
||||||
</TableCell>
|
)}
|
||||||
))}
|
</TableCell>
|
||||||
|
))}
|
||||||
|
</TableRow>
|
||||||
|
))
|
||||||
|
)
|
||||||
|
: (
|
||||||
|
<TableRow>
|
||||||
|
<TableCell
|
||||||
|
colSpan={columns.length}
|
||||||
|
className="h-24 text-center"
|
||||||
|
>
|
||||||
|
No results.
|
||||||
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
))
|
)}
|
||||||
) : (
|
|
||||||
<TableRow>
|
|
||||||
<TableCell
|
|
||||||
colSpan={columns.length}
|
|
||||||
className="h-24 text-center"
|
|
||||||
>
|
|
||||||
No results.
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
)}
|
|
||||||
</TableBody>
|
</TableBody>
|
||||||
</Table>
|
</Table>
|
||||||
</ScrollArea>
|
</ScrollArea>
|
@ -7,7 +7,7 @@ import AutoFormInput from "../ui/auto-form/fields/input";
|
|||||||
import { useDebouncedCallback } from "use-debounce";
|
import { useDebouncedCallback } from "use-debounce";
|
||||||
import { CivitaiModelResponse } from "@/types/civitai";
|
import { CivitaiModelResponse } from "@/types/civitai";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
import { insertCivitaiModelSchema } from "@/db/schema";
|
||||||
|
|
||||||
function getUrl(civitai_url: string) {
|
function getUrl(civitai_url: string) {
|
||||||
// expect to be a URL to be https://civitai.com/models/36520
|
// expect to be a URL to be https://civitai.com/models/36520
|
||||||
@ -33,7 +33,7 @@ export default function AutoFormCheckpointInput(
|
|||||||
|
|
||||||
const handleSearch = useDebouncedCallback((search) => {
|
const handleSearch = useDebouncedCallback((search) => {
|
||||||
const validationResult =
|
const validationResult =
|
||||||
insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search);
|
insertCivitaiModelSchema.shape.civitai_url.safeParse(search);
|
||||||
if (!validationResult.success) {
|
if (!validationResult.success) {
|
||||||
console.error(validationResult.error);
|
console.error(validationResult.error);
|
||||||
// Optionally set an error state here
|
// Optionally set an error state here
|
@ -12,7 +12,7 @@ import {
|
|||||||
real,
|
real,
|
||||||
} from "drizzle-orm/pg-core";
|
} from "drizzle-orm/pg-core";
|
||||||
import { createInsertSchema, createSelectSchema } from "drizzle-zod";
|
import { createInsertSchema, createSelectSchema } from "drizzle-zod";
|
||||||
import { z } from "zod";
|
import { TypeOf, z } from "zod";
|
||||||
|
|
||||||
export const dbSchema = pgSchema("comfyui_deploy");
|
export const dbSchema = pgSchema("comfyui_deploy");
|
||||||
|
|
||||||
@ -376,15 +376,25 @@ export const modelUploadType = pgEnum("model_upload_type", [
|
|||||||
"other",
|
"other",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export const checkpointTable = dbSchema.table("checkpoints", {
|
// https://www.answeroverflow.com/m/1125106227387584552
|
||||||
|
const modelTypes = [
|
||||||
|
"checkpoint",
|
||||||
|
"lora",
|
||||||
|
"embedding",
|
||||||
|
"vae",
|
||||||
|
] as const
|
||||||
|
export const modelType = pgEnum("model_type", modelTypes);
|
||||||
|
export type modelEnumType = typeof modelTypes[number]
|
||||||
|
|
||||||
|
export const modelTable = dbSchema.table("models", {
|
||||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||||
user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
|
user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global models
|
||||||
org_id: text("org_id"),
|
org_id: text("org_id"),
|
||||||
description: text("description"),
|
description: text("description"),
|
||||||
|
|
||||||
checkpoint_volume_id: uuid("checkpoint_volume_id")
|
user_volume_id: uuid("user_volume_id")
|
||||||
.notNull()
|
.notNull()
|
||||||
.references(() => checkpointVolumeTable.id, {
|
.references(() => userVolume.id, {
|
||||||
onDelete: "cascade",
|
onDelete: "cascade",
|
||||||
})
|
})
|
||||||
.notNull(),
|
.notNull(),
|
||||||
@ -408,12 +418,13 @@ export const checkpointTable = dbSchema.table("checkpoints", {
|
|||||||
status: resourceUpload("status").notNull().default("started"),
|
status: resourceUpload("status").notNull().default("started"),
|
||||||
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
|
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
|
||||||
upload_type: modelUploadType("upload_type").notNull(),
|
upload_type: modelUploadType("upload_type").notNull(),
|
||||||
|
model_type: modelType("model_type").notNull(),
|
||||||
error_log: text("error_log"),
|
error_log: text("error_log"),
|
||||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
export const userVolume = dbSchema.table("user_volume", {
|
||||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||||
user_id: text("user_id").references(() => usersTable.id, {
|
user_id: text("user_id").references(() => usersTable.id, {
|
||||||
// onDelete: "cascade",
|
// onDelete: "cascade",
|
||||||
@ -425,23 +436,23 @@ export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
|||||||
disabled: boolean("disabled").default(false).notNull(),
|
disabled: boolean("disabled").default(false).notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
export const modelRelations = relations(modelTable, ({ one }) => ({
|
||||||
user: one(usersTable, {
|
user: one(usersTable, {
|
||||||
fields: [checkpointTable.user_id],
|
fields: [modelTable.user_id],
|
||||||
references: [usersTable.id],
|
references: [usersTable.id],
|
||||||
}),
|
}),
|
||||||
volume: one(checkpointVolumeTable, {
|
volume: one(userVolume, {
|
||||||
fields: [checkpointTable.checkpoint_volume_id],
|
fields: [modelTable.user_volume_id],
|
||||||
references: [checkpointVolumeTable.id],
|
references: [userVolume.id],
|
||||||
}),
|
}),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
export const checkpointVolumeRelations = relations(
|
export const modalVolumeRelations = relations(
|
||||||
checkpointVolumeTable,
|
userVolume,
|
||||||
({ many, one }) => ({
|
({ many, one }) => ({
|
||||||
checkpoint: many(checkpointTable),
|
model: many(modelTable),
|
||||||
user: one(usersTable, {
|
user: one(usersTable, {
|
||||||
fields: [checkpointVolumeTable.user_id],
|
fields: [userVolume.user_id],
|
||||||
references: [usersTable.id],
|
references: [usersTable.id],
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
@ -473,8 +484,8 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
|
|||||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const insertCivitaiCheckpointSchema = createInsertSchema(
|
export const insertCivitaiModelSchema = createInsertSchema(
|
||||||
checkpointTable,
|
modelTable,
|
||||||
{
|
{
|
||||||
civitai_url: (schema) =>
|
civitai_url: (schema) =>
|
||||||
schema.civitai_url
|
schema.civitai_url
|
||||||
@ -491,8 +502,8 @@ export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
|||||||
export type MachineType = InferSelectModel<typeof machinesTable>;
|
export type MachineType = InferSelectModel<typeof machinesTable>;
|
||||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
|
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
|
||||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
|
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
|
||||||
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
|
export type ModelType = InferSelectModel<typeof modelTable>;
|
||||||
export type CheckpointVolumeType = InferSelectModel<
|
export type UserVolumeType = InferSelectModel<
|
||||||
typeof checkpointVolumeTable
|
typeof userVolume
|
||||||
>;
|
>;
|
||||||
export type UserUsageType = InferSelectModel<typeof userUsageTable>;
|
export type UserUsageType = InferSelectModel<typeof userUsageTable>;
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
|
||||||
|
|
||||||
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
|
|
||||||
civitai_url: true,
|
|
||||||
});
|
|
5
web/src/server/addCivitaiModelSchema.tsx
Normal file
5
web/src/server/addCivitaiModelSchema.tsx
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import { insertCivitaiModelSchema } from "@/db/schema";
|
||||||
|
|
||||||
|
export const addCivitaiModelSchema = insertCivitaiModelSchema.pick({
|
||||||
|
civitai_url: true,
|
||||||
|
});
|
@ -15,7 +15,7 @@ import { headers } from "next/headers";
|
|||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
import "server-only";
|
import "server-only";
|
||||||
import type { z } from "zod";
|
import type { z } from "zod";
|
||||||
import { retrieveCheckpointVolumes } from "./curdCheckpoint";
|
import { retrieveModelVolumes } from "./curdModel";
|
||||||
|
|
||||||
export async function getMachines() {
|
export async function getMachines() {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
@ -190,7 +190,7 @@ async function _buildMachine(
|
|||||||
throw new Error("No domain");
|
throw new Error("No domain");
|
||||||
}
|
}
|
||||||
|
|
||||||
const volumes = await retrieveCheckpointVolumes();
|
const volumes = await retrieveModelVolumes();
|
||||||
// Call remote builder
|
// Call remote builder
|
||||||
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
|
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@ -204,7 +204,7 @@ async function _buildMachine(
|
|||||||
callback_url: `${protocol}://${domain}/api/machine-built`,
|
callback_url: `${protocol}://${domain}/api/machine-built`,
|
||||||
models: data.models, //JSON.parse(data.models as string),
|
models: data.models, //JSON.parse(data.models as string),
|
||||||
gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
|
gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
|
||||||
checkpoint_volume_name: volumes[0].volume_name,
|
model_volume_name: volumes[0].volume_name,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -2,92 +2,91 @@
|
|||||||
|
|
||||||
import { auth } from "@clerk/nextjs";
|
import { auth } from "@clerk/nextjs";
|
||||||
import {
|
import {
|
||||||
checkpointTable,
|
modelTable,
|
||||||
CheckpointType,
|
ModelType,
|
||||||
checkpointVolumeTable,
|
userVolume,
|
||||||
CheckpointVolumeType,
|
UserVolumeType,
|
||||||
} from "@/db/schema";
|
} from "@/db/schema";
|
||||||
import { withServerPromise } from "./withServerPromise";
|
import { withServerPromise } from "./withServerPromise";
|
||||||
import { db } from "@/db/db";
|
import { db } from "@/db/db";
|
||||||
import type { z } from "zod";
|
import type { z } from "zod";
|
||||||
import { headers } from "next/headers";
|
import { headers } from "next/headers";
|
||||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
|
import { addCivitaiModelSchema } from "./addCivitaiModelSchema";
|
||||||
import { and, eq, isNull } from "drizzle-orm";
|
import { and, eq, isNull } from "drizzle-orm";
|
||||||
import { CivitaiModelResponse } from "@/types/civitai";
|
import { CivitaiModelResponse, getModelTypeDetails } from "@/types/civitai";
|
||||||
|
|
||||||
export async function getCheckpoints() {
|
export async function getModel() {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
if (!userId) throw new Error("No user id");
|
if (!userId) throw new Error("No user id");
|
||||||
const checkpoints = await db
|
const models = await db
|
||||||
.select()
|
.select()
|
||||||
.from(checkpointTable)
|
.from(modelTable)
|
||||||
.where(
|
.where(
|
||||||
orgId
|
orgId
|
||||||
? eq(checkpointTable.org_id, orgId)
|
? eq(modelTable.org_id, orgId)
|
||||||
// make sure org_id is null
|
// make sure org_id is null
|
||||||
: and(
|
: and(
|
||||||
eq(checkpointTable.user_id, userId),
|
eq(modelTable.user_id, userId),
|
||||||
isNull(checkpointTable.org_id),
|
isNull(modelTable.org_id),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
return checkpoints;
|
return models;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getCheckpointById(id: string) {
|
export async function getModelById(id: string) {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
if (!userId) throw new Error("No user id");
|
if (!userId) throw new Error("No user id");
|
||||||
const checkpoint = await db
|
const model = await db
|
||||||
.select()
|
.select()
|
||||||
.from(checkpointTable)
|
.from(modelTable)
|
||||||
.where(
|
.where(
|
||||||
and(
|
and(
|
||||||
orgId ? eq(checkpointTable.org_id, orgId) : and(
|
orgId ? eq(modelTable.org_id, orgId) : and(
|
||||||
eq(checkpointTable.user_id, userId),
|
eq(modelTable.user_id, userId),
|
||||||
isNull(checkpointTable.org_id),
|
isNull(modelTable.org_id),
|
||||||
),
|
),
|
||||||
eq(checkpointTable.id, id),
|
eq(modelTable.id, id),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
return checkpoint[0];
|
return model[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getCheckpointVolumes() {
|
export async function getModelVolumes() {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
if (!userId) throw new Error("No user id");
|
if (!userId) throw new Error("No user id");
|
||||||
const volume = await db
|
const volume = await db
|
||||||
.select()
|
.select()
|
||||||
.from(checkpointVolumeTable)
|
.from(userVolume)
|
||||||
.where(
|
.where(
|
||||||
and(
|
and(
|
||||||
orgId
|
orgId
|
||||||
? eq(checkpointVolumeTable.org_id, orgId)
|
? eq(userVolume.org_id, orgId)
|
||||||
// make sure org_id is null
|
// make sure org_id is null
|
||||||
: and(
|
: and(
|
||||||
eq(checkpointVolumeTable.user_id, userId),
|
eq(userVolume.user_id, userId),
|
||||||
isNull(checkpointVolumeTable.org_id),
|
isNull(userVolume.org_id),
|
||||||
),
|
),
|
||||||
eq(checkpointVolumeTable.disabled, false),
|
eq(userVolume.disabled, false),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
return volume;
|
return volume;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function retrieveCheckpointVolumes() {
|
export async function retrieveModelVolumes() {
|
||||||
let volumes = await getCheckpointVolumes();
|
let volumes = await getModelVolumes();
|
||||||
if (volumes.length === 0) {
|
if (volumes.length === 0) {
|
||||||
// create volume if not already created
|
// create volume if not already created
|
||||||
volumes = await addCheckpointVolume();
|
volumes = await addModelVolume();
|
||||||
}
|
}
|
||||||
return volumes;
|
return volumes;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function addCheckpointVolume() {
|
export async function addModelVolume() {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
if (!userId) throw new Error("No user id");
|
if (!userId) throw new Error("No user id");
|
||||||
|
|
||||||
// Insert the new checkpointVolume into the checkpointVolumeTable
|
|
||||||
const insertedVolume = await db
|
const insertedVolume = await db
|
||||||
.insert(checkpointVolumeTable)
|
.insert(userVolume)
|
||||||
.values({
|
.values({
|
||||||
user_id: userId,
|
user_id: userId,
|
||||||
org_id: orgId,
|
org_id: orgId,
|
||||||
@ -111,8 +110,8 @@ function getUrl(civitai_url: string) {
|
|||||||
return { url: baseUrl + modelId, modelVersionId };
|
return { url: baseUrl + modelId, modelVersionId };
|
||||||
}
|
}
|
||||||
|
|
||||||
export const addCivitaiCheckpoint = withServerPromise(
|
export const addCivitaiModel = withServerPromise(
|
||||||
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
|
async (data: z.infer<typeof addCivitaiModelSchema>) => {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
|
|
||||||
if (!data.civitai_url) return { error: "no civitai_url" };
|
if (!data.civitai_url) return { error: "no civitai_url" };
|
||||||
@ -145,17 +144,22 @@ export const addCivitaiCheckpoint = withServerPromise(
|
|||||||
selectedModelVersionId = selectedModelVersion?.id.toString();
|
selectedModelVersionId = selectedModelVersion?.id.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
const checkpointVolumes = await getCheckpointVolumes();
|
const userVolume = await getModelVolumes();
|
||||||
let cVolume;
|
let cVolume;
|
||||||
if (checkpointVolumes.length === 0) {
|
if (userVolume.length === 0) {
|
||||||
const volume = await addCheckpointVolume();
|
const volume = await addModelVolume();
|
||||||
cVolume = volume[0];
|
cVolume = volume[0];
|
||||||
} else {
|
} else {
|
||||||
cVolume = checkpointVolumes[0];
|
cVolume = userVolume[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
const model_type = getModelTypeDetails(civitaiModelRes.type);
|
||||||
|
if (!model_type) {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const a = await db
|
const a = await db
|
||||||
.insert(checkpointTable)
|
.insert(modelTable)
|
||||||
.values({
|
.values({
|
||||||
user_id: userId,
|
user_id: userId,
|
||||||
org_id: orgId,
|
org_id: orgId,
|
||||||
@ -166,15 +170,15 @@ export const addCivitaiCheckpoint = withServerPromise(
|
|||||||
civitai_url: data.civitai_url,
|
civitai_url: data.civitai_url,
|
||||||
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
||||||
civitai_model_response: civitaiModelRes,
|
civitai_model_response: civitaiModelRes,
|
||||||
checkpoint_volume_id: cVolume.id,
|
user_volume_id: cVolume.id,
|
||||||
|
model_type,
|
||||||
updated_at: new Date(),
|
updated_at: new Date(),
|
||||||
})
|
})
|
||||||
.returning();
|
.returning();
|
||||||
|
|
||||||
const b = a[0];
|
const b = a[0];
|
||||||
|
|
||||||
await uploadCheckpoint(data, b, cVolume);
|
await uploadModel(data, b, cVolume);
|
||||||
// redirect(`/checkpoints/${b.id}`);
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -213,10 +217,10 @@ export const addCivitaiCheckpoint = withServerPromise(
|
|||||||
// },
|
// },
|
||||||
// );
|
// );
|
||||||
|
|
||||||
async function uploadCheckpoint(
|
async function uploadModel(
|
||||||
data: z.infer<typeof addCivitaiCheckpointSchema>,
|
data: z.infer<typeof addCivitaiModelSchema>,
|
||||||
c: CheckpointType,
|
c: ModelType,
|
||||||
v: CheckpointVolumeType,
|
v: UserVolumeType,
|
||||||
) {
|
) {
|
||||||
const headersList = headers();
|
const headersList = headers();
|
||||||
|
|
||||||
@ -239,9 +243,9 @@ async function uploadCheckpoint(
|
|||||||
download_url: c.civitai_download_url,
|
download_url: c.civitai_download_url,
|
||||||
volume_name: v.volume_name,
|
volume_name: v.volume_name,
|
||||||
volume_id: v.id,
|
volume_id: v.id,
|
||||||
checkpoint_id: c.id,
|
model_id: c.id,
|
||||||
callback_url: `${protocol}://${domain}/api/volume-upload`,
|
callback_url: `${protocol}://${domain}/api/volume-upload`,
|
||||||
upload_type: "checkpoint"
|
upload_type: c.model_type,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@ -249,23 +253,23 @@ async function uploadCheckpoint(
|
|||||||
if (!result.ok) {
|
if (!result.ok) {
|
||||||
const error_log = await result.text();
|
const error_log = await result.text();
|
||||||
await db
|
await db
|
||||||
.update(checkpointTable)
|
.update(modelTable)
|
||||||
.set({
|
.set({
|
||||||
...data,
|
...data,
|
||||||
status: "failed",
|
status: "failed",
|
||||||
error_log: error_log,
|
error_log: error_log,
|
||||||
})
|
})
|
||||||
.where(eq(checkpointTable.id, c.id));
|
.where(eq(modelTable.id, c.id));
|
||||||
throw new Error(`Error: ${result.statusText} ${error_log}`);
|
throw new Error(`Error: ${result.statusText} ${error_log}`);
|
||||||
} else {
|
} else {
|
||||||
// setting the build machine id
|
// setting the build machine id
|
||||||
const json = await result.json();
|
const json = await result.json();
|
||||||
await db
|
await db
|
||||||
.update(checkpointTable)
|
.update(modelTable)
|
||||||
.set({
|
.set({
|
||||||
...data,
|
...data,
|
||||||
upload_machine_id: json.build_machine_instance_id,
|
upload_machine_id: json.build_machine_instance_id,
|
||||||
})
|
})
|
||||||
.where(eq(checkpointTable.id, c.id));
|
.where(eq(modelTable.id, c.id));
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,18 +1,18 @@
|
|||||||
import { db } from "@/db/db";
|
import { db } from "@/db/db";
|
||||||
import {
|
import {
|
||||||
checkpointTable,
|
modelTable,
|
||||||
} from "@/db/schema";
|
} from "@/db/schema";
|
||||||
import { auth } from "@clerk/nextjs";
|
import { auth } from "@clerk/nextjs";
|
||||||
import { and, desc, eq, isNull } from "drizzle-orm";
|
import { and, desc, eq, isNull } from "drizzle-orm";
|
||||||
|
|
||||||
export async function getAllUserCheckpoints() {
|
export async function getAllUserModels() {
|
||||||
const { userId, orgId } = await auth();
|
const { userId, orgId } = await auth();
|
||||||
|
|
||||||
if (!userId) {
|
if (!userId) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const checkpoints = await db.query.checkpointTable.findMany({
|
const models = await db.query.modelTable.findMany({
|
||||||
with: {
|
with: {
|
||||||
user: {
|
user: {
|
||||||
columns: {
|
columns: {
|
||||||
@ -28,14 +28,15 @@ export async function getAllUserCheckpoints() {
|
|||||||
civitai_model_response: true,
|
civitai_model_response: true,
|
||||||
is_public: true,
|
is_public: true,
|
||||||
upload_type: true,
|
upload_type: true,
|
||||||
|
model_type: true,
|
||||||
status: true,
|
status: true,
|
||||||
},
|
},
|
||||||
orderBy: desc(checkpointTable.updated_at),
|
orderBy: desc(modelTable.updated_at),
|
||||||
where:
|
where:
|
||||||
orgId != undefined
|
orgId != undefined
|
||||||
? eq(checkpointTable.org_id, orgId)
|
? eq(modelTable.org_id, orgId)
|
||||||
: and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
|
: and(eq(modelTable.user_id, userId), isNull(modelTable.org_id)),
|
||||||
});
|
});
|
||||||
|
|
||||||
return checkpoints;
|
return models;
|
||||||
}
|
}
|
@ -1,4 +1,5 @@
|
|||||||
import { z } from "zod";
|
import { TypeOf, z } from "zod";
|
||||||
|
import { modelEnumType } from "@/db/schema";
|
||||||
|
|
||||||
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
|
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
|
||||||
|
|
||||||
@ -110,12 +111,21 @@ export const statsSchema = z.object({
|
|||||||
tippedAmountCount: z.number(),
|
tippedAmountCount: z.number(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const civitaiModelType = z.enum([
|
||||||
|
"Checkpoint",
|
||||||
|
"TextualInversion",
|
||||||
|
"Hypernetwork",
|
||||||
|
"AestheticGradient",
|
||||||
|
"LORA",
|
||||||
|
"Controlnet",
|
||||||
|
"Poses",
|
||||||
|
]);
|
||||||
|
|
||||||
export const CivitaiModelResponse = z.object({
|
export const CivitaiModelResponse = z.object({
|
||||||
id: z.number(),
|
id: z.number(),
|
||||||
name: z.string().nullish(),
|
name: z.string().nullish(),
|
||||||
description: z.string().nullish(),
|
description: z.string().nullish(),
|
||||||
// type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know
|
type: civitaiModelType,
|
||||||
type: z.string(),
|
|
||||||
poi: z.boolean().nullish(),
|
poi: z.boolean().nullish(),
|
||||||
nsfw: z.boolean().nullish(),
|
nsfw: z.boolean().nullish(),
|
||||||
allowNoCredit: z.boolean().nullish(),
|
allowNoCredit: z.boolean().nullish(),
|
||||||
@ -127,3 +137,22 @@ export const CivitaiModelResponse = z.object({
|
|||||||
tags: z.array(z.string()).nullish(),
|
tags: z.array(z.string()).nullish(),
|
||||||
modelVersions: z.array(modelVersionSchema),
|
modelVersions: z.array(modelVersionSchema),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export function getModelTypeDetails(
|
||||||
|
modelType: typeof civitaiModelType["_type"],
|
||||||
|
): modelEnumType | undefined {
|
||||||
|
switch (modelType) {
|
||||||
|
case "Checkpoint":
|
||||||
|
return "checkpoint"
|
||||||
|
case "TextualInversion":
|
||||||
|
return "embedding"
|
||||||
|
case "LORA":
|
||||||
|
return "lora"
|
||||||
|
case "AestheticGradient":
|
||||||
|
case "Hypernetwork":
|
||||||
|
case "Controlnet":
|
||||||
|
case "Poses":
|
||||||
|
default:
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user