This commit is contained in:
Nicholas Koben Kao 2024-01-22 19:05:18 -08:00
parent fed7b380b6
commit f6a1b88dda
9 changed files with 201 additions and 283 deletions

View File

@ -8,6 +8,7 @@ from enum import Enum
import json import json
import subprocess import subprocess
import time import time
from uuid import uuid4
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
import threading import threading
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send from starlette.types import ASGIApp, Scope, Receive, Send
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
# executor = ThreadPoolExecutor(max_workers=5) # executor = ThreadPoolExecutor(max_workers=5)
@ -227,15 +229,47 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
class UploadBody(BaseModel): class UploadBody(BaseModel):
download_url: str download_url: str
volume_name: str volume_name: str
volume_id: str
# callback_url: str # callback_url: str
@app.post("/upload_volume") @app.post("/upload_volume")
async def upload_checkpoint(body: UploadBody): async def upload_checkpoint(body: UploadBody):
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
download_url = body.download_url download_url = body.download_url
volume_name = body.download_url volume_name = body.volume_name
# callback_url = body.callback_url # callback_url = body.callback_url
folder_path = f"/app/builds/{body.volume_id}"
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
await cp_process.wait()
# Write the config file
config = {
"volume_names": {
volume_name: download_url
},
"paths": {
volume_name: f'/volumes/{uuid4()}'
},
}
await asyncio.subprocess.create_subprocess_shell(
f"modal run app.py",
# stdout=asyncio.subprocess.PIPE,
# stderr=asyncio.subprocess.PIPE,
cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"}
)
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))
# check that thi # check that thi
return return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
@app.post("/create") @app.post("/create")

View File

@ -1,51 +1,42 @@
import modal import modal
from config import config from config import config
import os import os
import uuid
import subprocess import subprocess
stub = modal.Stub() stub = modal.Stub()
base_path = "/volumes"
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length. # Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
def is_valid_name(name: str) -> bool: def is_valid_name(name: str) -> bool:
allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._") allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
return 0 < len(name) <= 64 and all(char in allowed_characters for char in name) return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
def create_volumes(volume_names): def create_volumes(volume_names, paths):
path_to_vol = {} path_to_vol = {}
vol_to_path = {}
for volume_name in volume_names.keys(): for volume_name in volume_names.keys():
if not is_valid_name(volume_name): if not is_valid_name(volume_name):
pass pass
modal_volume = modal.Volume.persisted(volume_name) modal_volume = modal.Volume.persisted(volume_name)
volume_path = create_volume_path(base_path) path_to_vol[paths[volume_name]] = modal_volume
path_to_vol[volume_path] = modal_volume
vol_to_path[volume_name] = volume_path
return (path_to_vol, vol_to_path) return path_to_vol
def create_volume_path(base_path: str):
random_path = str(uuid.uuid4())
return os.path.join(base_path, random_path)
vol_name_to_links = config["volume_names"] vol_name_to_links = config["volume_names"]
(path_to_vol, vol_name_to_path) = create_volumes(vol_name_to_links) vol_name_to_path = config["paths"]
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
image = ( image = (
modal.Image.debian_slim().apt_install("wget").pip_install("requests") modal.Image.debian_slim().apt_install("wget").pip_install("requests")
) )
print(vol_name_to_links) print(vol_name_to_links)
print(path_to_vol)
print(vol_name_to_path) print(vol_name_to_path)
print(volumes)
@stub.function(volumes=path_to_vol, image=image, timeout=5000, gpu=None) @stub.function(volumes=volumes, image=image, timeout=5000, gpu=None)
def download_model(volume_name, download_url): def download_model(volume_name, download_url):
model_store_path = vol_name_to_path[volume_name] model_store_path = vol_name_to_path[volume_name]
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path]) subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
subprocess.run(["ls", "-la", model_store_path]) subprocess.run(["ls", "-la", model_store_path])
path_to_vol[model_store_path].commit() volumes[model_store_path].commit()
@stub.local_entrypoint() @stub.local_entrypoint()
def simple_download(): def simple_download():

View File

@ -1,5 +1,8 @@
config = { config = {
"volume_names": { "volume_names": {
"eg1": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png" "test": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png"
}, },
"paths": {
"test": "/volumes/something"
}
} }

View File

@ -33,7 +33,7 @@ CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
"updated_at" timestamp DEFAULT now() NOT NULL "updated_at" timestamp DEFAULT now() NOT NULL
); );
--> statement-breakpoint --> statement-breakpoint
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpointVolumeTable" ( CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL, "id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
"user_id" text, "user_id" text,
"org_id" text, "org_id" text,
@ -56,7 +56,7 @@ EXCEPTION
END $$; END $$;
--> statement-breakpoint --> statement-breakpoint
DO $$ BEGIN DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."checkpointVolumeTable" ADD CONSTRAINT "checkpointVolumeTable_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action; ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
EXCEPTION EXCEPTION
WHEN duplicate_object THEN null; WHEN duplicate_object THEN null;
END $$; END $$;

View File

@ -1,5 +1,5 @@
{ {
"id": "66dbc84a-6cd8-4692-9d24-fdcac227b23c", "id": "4d5b29d0-848f-4c2e-a2cd-2932f1fa38c6",
"prevId": "db06ea66-92c2-4ebe-93c1-6cb8a90ccd8b", "prevId": "db06ea66-92c2-4ebe-93c1-6cb8a90ccd8b",
"version": "5", "version": "5",
"dialect": "pg", "dialect": "pg",
@ -250,8 +250,8 @@
"compositePrimaryKeys": {}, "compositePrimaryKeys": {},
"uniqueConstraints": {} "uniqueConstraints": {}
}, },
"checkpointVolumeTable": { "checkpoint_volume": {
"name": "checkpointVolumeTable", "name": "checkpoint_volume",
"schema": "comfyui_deploy", "schema": "comfyui_deploy",
"columns": { "columns": {
"id": { "id": {
@ -303,9 +303,9 @@
}, },
"indexes": {}, "indexes": {},
"foreignKeys": { "foreignKeys": {
"checkpointVolumeTable_user_id_users_id_fk": { "checkpoint_volume_user_id_users_id_fk": {
"name": "checkpointVolumeTable_user_id_users_id_fk", "name": "checkpoint_volume_user_id_users_id_fk",
"tableFrom": "checkpointVolumeTable", "tableFrom": "checkpoint_volume",
"tableTo": "users", "tableTo": "users",
"columnsFrom": [ "columnsFrom": [
"user_id" "user_id"

View File

@ -222,8 +222,8 @@
{ {
"idx": 31, "idx": 31,
"version": "5", "version": "5",
"when": 1705963548821, "when": 1705975916818,
"tag": "0031_common_deathbird", "tag": "0031_safe_multiple_man",
"breakpoints": true "breakpoints": true
} }
] ]

View File

@ -1,3 +1,4 @@
import { CivitaiModelResponse } from "@/types/civitai";
import { type InferSelectModel, relations } from "drizzle-orm"; import { type InferSelectModel, relations } from "drizzle-orm";
import { import {
boolean, boolean,
@ -331,198 +332,9 @@ export const apiKeyTable = dbSchema.table("api_keys", {
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
// const civitaiModelVersion = z.object({
// id: z.number(),
// modelId: z.number(),
// name: z.string(),
// createdAt: z.string(),
// updatedAt: z.string(),
// status: z.string(),
// publishedAt: z.string(),
// trainedWords: z.array(z.string()).optional(),
// trainingStatus: z.string().optional(),
// trainingDetails: z.string().optional(),
// baseModel: z.string(),
// baseModelType: z.string(),
// earlyAccessTimeFrame: z.number(),
// description: z.string().optional(),
// vaeId: z.string().optional(),
// stats: z.object({
// downloadCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// }),
// files: z.array(z.object({
// id: z.number(),
// sizeKB: z.number(),
// name: z.string(),
// type: z.string(),
// metadata: z.object({
// fp: z.string(),
// size: z.string(),
// format: z.string(),
// }),
// pickleScanResult: z.string(),
// pickleScanMessage: z.string().optional(),
// virusScanResult: z.string(),
// virusScanMessage: z.string().optional(),
// scannedAt: z.string(),
// hashes: z.object({
// AutoV1: z.string(),
// AutoV2: z.string(),
// SHA256: z.string(),
// CRC32: z.string(),
// BLAKE3: z.string(),
// AutoV3: z.string(),
// }),
// downloadUrl: z.string(),
// primary: z.boolean(),
// })),
// images: z.array(z.object({
// url: z.string(),
// nsfw: z.string(),
// width: z.number(),
// height: z.number(),
// hash: z.string(),
// type: z.string(),
// metadata: z.object({
// hash: z.string(),
// size: z.number(),
// width: z.number(),
// height: z.number(),
// }),
// meta: z.any(),
// })),
// downloadUrl: z.string(),
// });
//
// const civitaiModelResponseType = z.object({
// id: z.number(),
// name: z.string(),
// description: z.string().optional(),
// type: z.string(),
// poi: z.boolean(),
// nsfw: z.boolean(),
// allowNoCredit: z.boolean(),
// allowCommercialUse: z.string(),
// allowDerivatives: z.boolean(),
// allowDifferentLicense: z.boolean(),
// stats: z.object({
// downloadCount: z.number(),
// favoriteCount: z.number(),
// commentCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// tippedAmountCount: z.number(),
// }),
// creator: z.object({
// username: z.string(),
// image: z.string(),
// }),
// tags: z.array(z.string()),
// modelVersions: z.array(civitaiModelVersion),
// });
export const CivitaiModel = z.object({
id: z.number(),
name: z.string(),
description: z.string(),
type: z.string(),
// poi: z.boolean(),
// nsfw: z.boolean(),
// allowNoCredit: z.boolean(),
// allowCommercialUse: z.string(),
// allowDerivatives: z.boolean(),
// allowDifferentLicense: z.boolean(),
// stats: z.object({
// downloadCount: z.number(),
// favoriteCount: z.number(),
// commentCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// tippedAmountCount: z.number(),
// }),
creator: z
.object({
username: z.string().nullable(),
image: z.string().nullable().default(null),
})
.nullable(),
tags: z.array(z.string()),
modelVersions: z.array(
z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string(),
updatedAt: z.string(),
status: z.string(),
publishedAt: z.string(),
trainedWords: z.array(z.unknown()),
trainingStatus: z.string().nullable(),
trainingDetails: z.string().nullable(),
baseModel: z.string(),
baseModelType: z.string().nullable(),
earlyAccessTimeFrame: z.number(),
description: z.string().nullable(),
vaeId: z.number().nullable(),
stats: z.object({
downloadCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
}),
files: z.array(
z.object({
id: z.number(),
sizeKB: z.number(),
name: z.string(),
type: z.string(),
// metadata: z.object({
// fp: z.string().nullable().optional(),
// size: z.string().nullable().optional(),
// format: z.string().nullable().optional(),
// }),
// pickleScanResult: z.string(),
// pickleScanMessage: z.string(),
// virusScanResult: z.string(),
// virusScanMessage: z.string().nullable(),
// scannedAt: z.string(),
// hashes: z.object({
// AutoV1: z.string().nullable().optional(),
// AutoV2: z.string().nullable().optional(),
// SHA256: z.string().nullable().optional(),
// CRC32: z.string().nullable().optional(),
// BLAKE3: z.string().nullable().optional(),
// }),
downloadUrl: z.string(),
// primary: z.boolean().default(false),
}),
),
images: z.array(
z.object({
id: z.number(),
url: z.string(),
nsfw: z.string(),
width: z.number(),
height: z.number(),
hash: z.string(),
type: z.string(),
metadata: z.object({
hash: z.string(),
width: z.number(),
height: z.number(),
}),
meta: z.any(),
}),
),
downloadUrl: z.string(),
}),
),
});
export const resourceUpload = pgEnum("resource_upload", [ export const resourceUpload = pgEnum("resource_upload", [
"started", "started",
"failed", "error",
"succeded", "succeded",
]); ]);
@ -552,7 +364,7 @@ export const checkpointTable = dbSchema.table("checkpoints", {
civitai_url: text("civitai_url"), civitai_url: text("civitai_url"),
civitai_download_url: text("civitai_download_url"), civitai_download_url: text("civitai_download_url"),
civitai_model_response: jsonb("civitai_model_response").$type< civitai_model_response: jsonb("civitai_model_response").$type<
z.infer<typeof CivitaiModel> z.infer<typeof CivitaiModelResponse>
>(), >(),
hf_url: text("hf_url"), hf_url: text("hf_url"),
@ -563,6 +375,7 @@ export const checkpointTable = dbSchema.table("checkpoints", {
status: resourceUpload("status").notNull().default("started"), status: resourceUpload("status").notNull().default("started"),
upload_machine_id: text("upload_machine_id"), upload_machine_id: text("upload_machine_id"),
upload_type: modelUploadType("upload_type").notNull(), upload_type: modelUploadType("upload_type").notNull(),
build_log: text("build_log"),
created_at: timestamp("created_at").defaultNow().notNull(), created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
@ -579,7 +392,7 @@ export const insertCivitaiCheckpointSchema = createInsertSchema(
}, },
); );
export const checkpointVolumeTable = dbSchema.table("checkpointVolumeTable", { export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
id: uuid("id").primaryKey().defaultRandom().notNull(), id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id") user_id: text("user_id")
.references(() => usersTable.id, { .references(() => usersTable.id, {
@ -605,8 +418,12 @@ export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
export const checkpointVolumeRelations = relations( export const checkpointVolumeRelations = relations(
checkpointVolumeTable, checkpointVolumeTable,
({ many }) => ({ ({ many, one }) => ({
checkpoint: many(checkpointTable), checkpoint: many(checkpointTable),
user: one(usersTable, {
fields: [checkpointVolumeTable.user_id],
references: [usersTable.id],
}),
}), }),
); );

View File

@ -4,9 +4,8 @@ import { auth } from "@clerk/nextjs";
import { import {
checkpointTable, checkpointTable,
CheckpointType, CheckpointType,
checkpointVolumeTable, volumeTable,
CheckpointVolumeType, CheckpointVolumeType,
CivitaiModel,
} from "@/db/schema"; } from "@/db/schema";
import { withServerPromise } from "./withServerPromise"; import { withServerPromise } from "./withServerPromise";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
@ -15,6 +14,7 @@ import type { z } from "zod";
import { headers } from "next/headers"; import { headers } from "next/headers";
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema"; import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import { CivitaiModelResponse } from "@/types/civitai";
export async function getCheckpoints() { export async function getCheckpoints() {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
@ -57,17 +57,17 @@ export async function getCheckpointVolumes() {
if (!userId) throw new Error("No user id"); if (!userId) throw new Error("No user id");
const checkpointVolume = await db const checkpointVolume = await db
.select() .select()
.from(checkpointVolumeTable) .from(volumeTable)
.where( .where(
and( and(
orgId orgId
? eq(checkpointVolumeTable.org_id, orgId) ? eq(volumeTable.org_id, orgId)
// make sure org_id is null // make sure org_id is null
: and( : and(
eq(checkpointVolumeTable.user_id, userId), eq(volumeTable.user_id, userId),
isNull(checkpointVolumeTable.org_id), isNull(volumeTable.org_id),
), ),
eq(checkpointVolumeTable.disabled, false), eq(volumeTable.disabled, false),
), ),
); );
return checkpointVolume; return checkpointVolume;
@ -79,7 +79,7 @@ export async function addCheckpointVolume() {
// Insert the new volume into the checkpointVolumeTable // Insert the new volume into the checkpointVolumeTable
const insertedVolume = await db const insertedVolume = await db
.insert(checkpointVolumeTable) .insert(volumeTable)
.values({ .values({
user_id: userId, user_id: userId,
org_id: orgId, org_id: orgId,
@ -115,11 +115,10 @@ export const addCivitaiCheckpoint = withServerPromise(
const civitaiModelRes = await fetch(url) const civitaiModelRes = await fetch(url)
.then((x) => x.json()) .then((x) => x.json())
.then((a) => { .then((a) => {
console.log(a) return CivitaiModelResponse.parse(a);
return CivitaiModel.parse(a);
}); });
if (civitaiModelRes.modelVersions?.length === 0) { if (civitaiModelRes?.modelVersions?.length === 0) {
return; // no versions to download return; // no versions to download
} }
@ -155,6 +154,8 @@ export const addCivitaiCheckpoint = withServerPromise(
upload_type: "civitai", upload_type: "civitai",
civitai_id: civitaiModelRes.id.toString(), civitai_id: civitaiModelRes.id.toString(),
civitai_version_id: selectedModelVersionId, civitai_version_id: selectedModelVersionId,
civitai_url: data.civitai_url,
civitai_download_url: selectedModelVersion.downloadUrl,
civitai_model_response: civitaiModelRes, civitai_model_response: civitaiModelRes,
checkpoint_volume_id: cVolume.id, checkpoint_volume_id: cVolume.id,
}) })
@ -192,22 +193,23 @@ async function uploadCheckpoint(
body: JSON.stringify({ body: JSON.stringify({
download_url: data.civitai_url, download_url: data.civitai_url,
volume_name: v.volume_name, volume_name: v.volume_name,
volume_id: v.id,
callback_url: `${protocol}://${domain}/api/volume-updated`, callback_url: `${protocol}://${domain}/api/volume-updated`,
}), }),
}, },
); );
if (!result.ok) { if (!result.ok) {
// const error_log = await result.text(); const error_log = await result.text();
// await db await db
// .update(checkpointTable) .update(checkpointTable)
// .set({ .set({
// ...data, ...data,
// status: "error", status: "error",
// build_log: error_log, build_log: error_log,
// }) })
// .where(eq(machinesTable.id, b.id)); .where(eq(checkpointTable.id, b.id));
// throw new Error(`Error: ${result.statusText} ${error_log}`); throw new Error(`Error: ${result.statusText} ${error_log}`);
} else { } else {
// setting the build machine id // setting the build machine id
const json = await result.json(); const json = await result.json();
@ -215,8 +217,8 @@ async function uploadCheckpoint(
.update(checkpointTable) .update(checkpointTable)
.set({ .set({
...data, ...data,
// build_machine_instance_id: json.build_machine_instance_id, upload_machine_id: json.build_machine_instance_id,
}) })
.where(eq(machinesTable.id, b.id)); .where(eq(checkpointTable.id, b.id));
} }
} }

View File

@ -1,8 +1,102 @@
import { z } from 'zod'; import { z } from 'zod';
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
export const creatorSchema = z.object({ export const creatorSchema = z.object({
username: z.string(), username: z.string().optional(),
image: z.string(), image: z.string().url().optional(),
});
export const fileMetadataSchema = z.object({
fp: z.string().optional(),
size: z.string().optional(),
format: z.string().optional(),
});
export const fileSchema = z.object({
id: z.number(),
sizeKB: z.number().optional(),
name: z.string(),
type: z.string().optional(),
metadata: fileMetadataSchema.optional(),
pickleScanResult: z.string().optional(),
pickleScanMessage: z.string().nullable(),
virusScanResult: z.string().optional(),
virusScanMessage: z.string().nullable(),
scannedAt: z.string().optional(),
hashes: z.record(z.string()).optional(),
downloadUrl: z.string().url(),
primary: z.boolean().optional().optional(),
});
export const imageMetadataSchema = z.object({
hash: z.string(),
width: z.number(),
height: z.number(),
});
export const imageMetaSchema = z.object({
ENSD: z.string().optional(),
Size: z.string().optional(),
seed: z.number().optional(),
Model: z.string().optional(),
steps: z.number().optional(),
hashes: z.record(z.string()).optional(),
prompt: z.string().optional(),
sampler: z.string().optional(),
cfgScale: z.number().optional(),
ClipSkip: z.number().optional(),
resources: z.array(
z.object({
hash: z.string().optional(),
name: z.string(),
type: z.string(),
weight: z.number().optional(),
})
).optional(),
ModelHash: z.string().optional(),
HiresSteps: z.string().optional(),
HiresUpscale: z.string().optional(),
HiresUpscaler: z.string().optional(),
negativePrompt: z.string(),
DenoisingStrength: z.number().optional(),
});
export const imageSchema = z.object({
url: z.string().url().optional(),
nsfw: z.enum(["None", "Soft", "Mature"]).optional(),
width: z.number().optional(),
height: z.number().optional(),
hash: z.string().optional(),
type: z.string().optional(),
metadata: imageMetadataSchema.optional(),
meta: imageMetaSchema.optional(),
});
export const modelVersionSchema = z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string().optional(),
updatedAt: z.string().optional(),
status: z.enum(["Published", "Unpublished"]).optional(),
publishedAt: z.string().optional(),
trainedWords: z.array(z.string()).nullable(),
trainingStatus: z.string().nullable(),
trainingDetails: z.string().nullable(),
baseModel: z.string().optional(),
baseModelType: z.string().optional(),
earlyAccessTimeFrame: z.number().optional(),
description: z.string().nullable(),
vaeId: z.string().nullable(),
stats: z.object({
downloadCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
}).optional(),
files: z.array(fileSchema),
images: z.array(imageSchema),
downloadUrl: z.string().url(),
}); });
export const statsSchema = z.object({ export const statsSchema = z.object({
@ -14,42 +108,19 @@ export const statsSchema = z.object({
tippedAmountCount: z.number(), tippedAmountCount: z.number(),
}); });
export const modelVersionSchema = z.object({ export const CivitaiModelResponse = z.object({
id: z.number(), id: z.number(),
modelId: z.number(), name: z.string().optional(),
name: z.string(),
createdAt: z.string(),
updatedAt: z.string(),
status: z.string(),
publishedAt: z.string(),
trainedWords: z.array(z.any()), // Replace with more specific type if known
trainingStatus: z.any().optional(),
trainingDetails: z.any().optional(),
baseModel: z.string(),
baseModelType: z.string(),
earlyAccessTimeFrame: z.number(),
description: z.string().optional(), description: z.string().optional(),
vaeId: z.any().optional(), // Replace with more specific type if known type: z.enum(["Checkpoint", "Lora"]),
stats: statsSchema.optional(), // If stats structure is known, replace with specific type poi: z.boolean().optional(),
files: z.array(z.any()), // Replace with more specific type if known nsfw: z.boolean().optional(),
images: z.array(z.any()), // Replace with more specific type if known allowNoCredit: z.boolean().optional(),
downloadUrl: z.string(), allowCommercialUse: z.enum(["Rent"]).optional(),
}); allowDerivatives: z.boolean().optional(),
allowDifferentLicense: z.boolean().optional(),
export const CivitaiModel = z.object({ stats: statsSchema.optional(),
id: z.number(), creator: creatorSchema.optional(),
name: z.string(), tags: z.array(z.string()).optional(),
description: z.string(),
type: z.string(),
poi: z.boolean(),
nsfw: z.boolean(),
allowNoCredit: z.boolean(),
allowCommercialUse: z.string(),
allowDerivatives: z.boolean(),
allowDifferentLicense: z.boolean(),
stats: statsSchema,
creator: creatorSchema,
tags: z.array(z.string()),
modelVersions: z.array(modelVersionSchema), modelVersions: z.array(modelVersionSchema),
}); });