Nickkao/volume improvemnts v3 (#4)

* fix: attempt fixing timeout

* be validation work

* arbitrary model input, BE validation, error record creation with error logs during potential failure points

* remove unused type

---------

Co-authored-by: bennykok <itechbenny@gmail.com>
This commit is contained in:
Nick Kao 2024-01-27 19:09:00 -08:00 committed by GitHub
parent 852d889397
commit 42aaf1acb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2810 additions and 143 deletions

View File

@ -42,7 +42,7 @@ def download_model(volume_name, download_config):
volume_base_path = vol_name_to_path[volume_name] volume_base_path = vol_name_to_path[volume_name]
model_store_path = os.path.join(volume_base_path, folder_path) model_store_path = os.path.join(volume_base_path, folder_path)
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key # civitai requires auth
print('downloading', modified_download_url) print('downloading', modified_download_url)
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path]) subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])

View File

@ -0,0 +1 @@
ALTER TYPE "model_upload_type" ADD VALUE 'download_url';

View File

@ -0,0 +1 @@
ALTER TYPE "model_upload_type" ADD VALUE 'download-url';

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -323,6 +323,20 @@
"when": 1706336448134, "when": 1706336448134,
"tag": "0045_careful_cerise", "tag": "0045_careful_cerise",
"breakpoints": true "breakpoints": true
},
{
"idx": 46,
"version": "5",
"when": 1706383154642,
"tag": "0046_complex_mentallo",
"breakpoints": true
},
{
"idx": 47,
"version": "5",
"when": 1706384528895,
"tag": "0047_gifted_starbolt",
"breakpoints": true
} }
] ]
} }

View File

@ -32,8 +32,8 @@ import {
} from "@tanstack/react-table"; } from "@tanstack/react-table";
import { ArrowUpDown } from "lucide-react"; import { ArrowUpDown } from "lucide-react";
import * as React from "react"; import * as React from "react";
import { addCivitaiModel } from "@/server/curdModel"; import { addModel } from "@/server/curdModel";
import { addCivitaiModelSchema } from "@/server/addCivitaiModelSchema"; import { downloadUrlModelSchema } from "@/server/addCivitaiModelSchema";
import { modelEnumType } from "@/db/schema"; import { modelEnumType } from "@/db/schema";
export type ModelItemList = NonNullable< export type ModelItemList = NonNullable<
@ -89,9 +89,7 @@ export const columns: ColumnDef<ModelItemList>[] = [
{row.original.model_name} {row.original.model_name}
</span> </span>
{model.is_public {model.is_public ? <></> : <Badge variant="orange">Private</Badge>}
? <></>
: <Badge variant="orange">Private</Badge>}
</> </>
); );
}, },
@ -298,16 +296,14 @@ export function ModelList({ data }: { data: ModelItemList[] }) {
<InsertModal <InsertModal
dialogClassName="sm:max-w-[600px]" dialogClassName="sm:max-w-[600px]"
disabled={ disabled={
false false // TODO: limitations based on plan
// TODO: limitations based on plan
} }
tooltip={"Add models using their civitai url!"} title="Add a Model"
title="Add a Civitai Model" description="using a link to a model"
description="Pick a model from civitai" serverAction={addModel}
serverAction={addCivitaiModel} formSchema={downloadUrlModelSchema}
formSchema={addCivitaiModelSchema}
fieldConfig={{ fieldConfig={{
civitai_url: { url: {
fieldType: "fallback", fieldType: "fallback",
inputProps: { required: true }, inputProps: { required: true },
description: ( description: (
@ -320,7 +316,16 @@ export function ModelList({ data }: { data: ModelItemList[] }) {
> >
civitai.com civitai.com
</a>{" "} </a>{" "}
and place it's url here or a url we can download a model from
</>
),
},
model_type: {
fieldType: "select",
inputProps: { required: true },
description: (
<>
We'll figure this out if you pick a civitai model
</> </>
), ),
}, },

View File

@ -1,86 +0,0 @@
// NOTE: this is WIP for doing client side validation for civitai model downloading
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
import { FormControl, FormItem, FormLabel } from "../ui/form";
import { LoadingIcon } from "@/components/LoadingIcon";
import * as React from "react";
import AutoFormInput from "../ui/auto-form/fields/input";
import { useDebouncedCallback } from "use-debounce";
import { CivitaiModelResponse } from "@/types/civitai";
import { z } from "zod";
import { insertCivitaiModelSchema } from "@/db/schema";
function getUrl(civitai_url: string) {
// expect to be a URL to be https://civitai.com/models/36520
// possiblity with slugged name and query-param modelVersionId
const baseUrl = "https://civitai.com/api/v1/models/";
const url = new URL(civitai_url);
const pathSegments = url.pathname.split("/");
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
const modelVersionId = url.searchParams.get("modelVersionId");
return { url: baseUrl + modelId, modelVersionId };
}
export default function AutoFormCheckpointInput(
props: AutoFormInputComponentProps
) {
const [loading, setLoading] = React.useState(false);
const [modelRes, setModelRes] =
React.useState<z.infer<typeof CivitaiModelResponse>>();
const [modelVersionid, setModelVersionId] = React.useState<string | null>();
const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
const handleSearch = useDebouncedCallback((search) => {
const validationResult =
insertCivitaiModelSchema.shape.civitai_url.safeParse(search);
if (!validationResult.success) {
console.error(validationResult.error);
// Optionally set an error state here
return;
}
setLoading(true);
const controller = new AbortController();
const { url, modelVersionId: versionId } = getUrl(search);
setModelVersionId(versionId);
fetch(url, {
signal: controller.signal,
})
.then((x) => x.json())
.then((a) => {
const res = CivitaiModelResponse.parse(a);
console.log(a);
console.log(res);
setModelRes(res);
setLoading(false);
});
return () => {
controller.abort();
setLoading(false);
};
}, 300);
const modifiedField = {
...fieldProps,
// onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
// handleSearch(event.target.value);
// },
};
return (
<FormItem>
{fieldConfigItem.inputProps?.showLabel && (
<FormLabel>
{label}
{isRequired && <span className="text-destructive">*</span>}
</FormLabel>
)}
<FormControl>
<AutoFormInput {...props} fieldProps={modifiedField} />
</FormControl>
</FormItem>
);
}

View File

@ -380,12 +380,18 @@ export const resourceUpload = pgEnum("resource_upload", [
export const modelUploadType = pgEnum("model_upload_type", [ export const modelUploadType = pgEnum("model_upload_type", [
"civitai", "civitai",
"huggingface", "download-url",
"huggingface", // remove?
"other", "other",
]); ]);
// https://www.answeroverflow.com/m/1125106227387584552 // https://www.answeroverflow.com/m/1125106227387584552
const modelTypes = ["checkpoint", "lora", "embedding", "vae"] as const; export const modelTypes = [
"checkpoint",
"lora",
"embedding",
"vae",
] as const
export const modelType = pgEnum("model_type", modelTypes); export const modelType = pgEnum("model_type", modelTypes);
export type modelEnumType = (typeof modelTypes)[number]; export type modelEnumType = (typeof modelTypes)[number];
@ -399,8 +405,7 @@ export const modelTable = dbSchema.table("models", {
.notNull() .notNull()
.references(() => userVolume.id, { .references(() => userVolume.id, {
onDelete: "cascade", onDelete: "cascade",
}) }),
.notNull(),
model_name: text("model_name"), model_name: text("model_name"),
folder_path: text("folder_path"), // in volume folder_path: text("folder_path"), // in volume
@ -413,8 +418,10 @@ export const modelTable = dbSchema.table("models", {
z.infer<typeof CivitaiModelResponse> z.infer<typeof CivitaiModelResponse>
>(), >(),
// for our own storage
hf_url: text("hf_url"), hf_url: text("hf_url"),
s3_url: text("s3_url"), s3_url: text("s3_url"),
user_url: text("client_url"), user_url: text("client_url"),
is_public: boolean("is_public").notNull().default(true), is_public: boolean("is_public").notNull().default(true),
@ -484,16 +491,6 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
updated_at: timestamp("updated_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(),
}); });
export const insertCivitaiModelSchema = createInsertSchema(modelTable, {
civitai_url: (schema) =>
schema.civitai_url
.trim()
.url({ message: "URL required" })
.includes("civitai.com/models", {
message: "civitai.com/models link required",
}),
});
export type UserType = InferSelectModel<typeof usersTable>; export type UserType = InferSelectModel<typeof usersTable>;
export type WorkflowType = InferSelectModel<typeof workflowTable>; export type WorkflowType = InferSelectModel<typeof workflowTable>;
export type MachineType = InferSelectModel<typeof machinesTable>; export type MachineType = InferSelectModel<typeof machinesTable>;

View File

@ -1,5 +1,8 @@
import { insertCivitaiModelSchema } from "@/db/schema"; import { z } from "zod";
import { modelTypes } from "@/db/schema";
export const addCivitaiModelSchema = insertCivitaiModelSchema.pick({ export const downloadUrlModelSchema = z.object({
civitai_url: true, url: z.string().url(),
model_type: z.enum(modelTypes).default("checkpoint")
}); });

View File

@ -2,6 +2,7 @@
import { auth } from "@clerk/nextjs"; import { auth } from "@clerk/nextjs";
import { import {
modelEnumType,
modelTable, modelTable,
ModelType, ModelType,
userVolume, userVolume,
@ -11,7 +12,7 @@ import { withServerPromise } from "./withServerPromise";
import { db } from "@/db/db"; import { db } from "@/db/db";
import type { z } from "zod"; import type { z } from "zod";
import { headers } from "next/headers"; import { headers } from "next/headers";
import { addCivitaiModelSchema } from "./addCivitaiModelSchema"; import { downloadUrlModelSchema } from "./addCivitaiModelSchema";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import { CivitaiModelResponse, getModelTypeDetails } from "@/types/civitai"; import { CivitaiModelResponse, getModelTypeDetails } from "@/types/civitai";
@ -90,7 +91,7 @@ export async function addModelVolume() {
.values({ .values({
user_id: userId, user_id: userId,
org_id: orgId, org_id: orgId,
volume_name: `models_${orgId ? orgId: userId}`, // if orgid is avalible use as part of the volume name volume_name: `models_${orgId ? orgId : userId}`, // if orgid is avalible use as part of the volume name
disabled: false, disabled: false,
}) })
.returning(); .returning();
@ -109,14 +110,150 @@ function getUrl(civitai_url: string) {
return { url: baseUrl + modelId, modelVersionId }; return { url: baseUrl + modelId, modelVersionId };
} }
// Helper function to make a HEAD request and follow redirects
async function fetchFinalUrl(
url: string,
): Promise<{ finalUrl: string; dispositionFilename?: string }> {
console.log("fetching");
const response = await fetch(url, { method: "HEAD", redirect: "follow" });
if (!response.ok) {
console.log("response not ok");
throw new Error(`Request failed with status ${response.status}`);
}
const contentDisposition = response.headers.get("content-disposition");
let filename;
if (contentDisposition) {
const matches = contentDisposition.match(
/filename\*?=['"]?(?:UTF-8'')?([^;'"\n]*)['"]?;?/i,
);
filename = matches && matches[1]
? decodeURIComponent(matches[1])
: undefined;
}
return { finalUrl: response.url, dispositionFilename: filename };
}
// The main function for validation
export const addModel = withServerPromise(
async (data: z.infer<typeof downloadUrlModelSchema>) => {
const { url } = data;
if (url.includes("civitai.com/models/")) {
// Make a HEAD request to check for 200 OK
const response = await fetch(url, { method: "HEAD" });
if (!response.ok) {
createModelErrorRecord(
url,
`civitai gave non-ok response`,
"civitai",
data.model_type,
);
}
addCivitaiModel(data);
} else {
const { finalUrl, dispositionFilename } = await fetchFinalUrl(url);
console.log("finished fetching");
console.log(finalUrl, dispositionFilename);
if (!dispositionFilename) {
console.log("no file name");
createModelErrorRecord(
url,
`Could not find a filename from resolved Url: ${finalUrl}`,
"download-url",
data.model_type,
);
return;
}
const validExtensions = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"];
const extension = dispositionFilename.slice(
dispositionFilename.lastIndexOf("."),
);
if (!validExtensions.includes(extension)) {
console.log("invalid extension");
createModelErrorRecord(
url,
`file ext ${extension} is invalid. Valid extensions: ${validExtensions}`,
"download-url",
data.model_type,
);
}
addModelDownloadUrl(data, dispositionFilename);
}
},
);
export const addModelDownloadUrl = withServerPromise(
async (data: z.infer<typeof downloadUrlModelSchema>, filename: string) => {
console.log("adding model download");
const { userId, orgId } = auth();
if (!userId) return { error: "No user id" };
const volumes = await retrieveModelVolumes();
const a = await db
.insert(modelTable)
.values({
user_id: userId,
org_id: orgId,
upload_type: "download-url",
model_name: filename,
user_url: data.url,
user_volume_id: volumes[0].id,
model_type: data.model_type,
})
.returning();
const b = a[0];
console.log("download url about to upload");
await uploadModel(data, b, volumes[0]);
},
);
export const getCivitaiModelRes = async (civitaiUrl: string) => {
const { url, modelVersionId } = getUrl(civitaiUrl);
const civitaiModelRes = await fetch(url)
.then((x) => x.json())
.then((a) => {
return CivitaiModelResponse.parse(a);
});
return { civitaiModelRes, url, modelVersionId };
};
const createModelErrorRecord = async (
url: string,
errorMessage: string,
upload_type: "civitai" | "download-url",
model_type: modelEnumType,
) => {
const { userId, orgId } = auth();
if (!userId) return { error: "No user id" };
const volumes = await retrieveModelVolumes();
const a = await db
.insert(modelTable)
.values({
user_id: userId,
org_id: orgId,
user_volume_id: volumes[0].id,
upload_type: "civitai",
model_type,
civitai_url: upload_type === "civitai" ? url : undefined,
user_url: upload_type === "download-url" ? url : undefined,
error_log: errorMessage,
status: "failed",
})
.returning();
return a;
};
export const addCivitaiModel = withServerPromise( export const addCivitaiModel = withServerPromise(
async (data: z.infer<typeof addCivitaiModelSchema>) => { async (data: z.infer<typeof downloadUrlModelSchema>) => {
const { userId, orgId } = auth(); const { userId, orgId } = auth();
if (!data.civitai_url) return { error: "no civitai_url" };
if (!userId) return { error: "No user id" }; if (!userId) return { error: "No user id" };
const { url, modelVersionId } = getUrl(data?.civitai_url); const { url, modelVersionId } = getUrl(data.url);
const civitaiModelRes = await fetch(url) const civitaiModelRes = await fetch(url)
.then((x) => x.json()) .then((x) => x.json())
.then((a) => { .then((a) => {
@ -142,18 +279,17 @@ export const addCivitaiModel = withServerPromise(
selectedModelVersionId = selectedModelVersion?.id.toString(); selectedModelVersionId = selectedModelVersion?.id.toString();
} }
const userVolume = await getModelVolumes(); const volumes = await retrieveModelVolumes();
let cVolume;
if (userVolume.length === 0) {
const volume = await addModelVolume();
cVolume = volume[0];
} else {
cVolume = userVolume[0];
}
const model_type = getModelTypeDetails(civitaiModelRes.type); const model_type = getModelTypeDetails(civitaiModelRes.type);
if (!model_type) { if (!model_type) {
return createModelErrorRecord(
url,
`Civitai model type ${civitaiModelRes.type} is not currently supported`,
"civitai",
data.model_type,
);
return;
} }
const a = await db const a = await db
@ -165,18 +301,17 @@ export const addCivitaiModel = withServerPromise(
model_name: selectedModelVersion.files[0].name, model_name: selectedModelVersion.files[0].name,
civitai_id: civitaiModelRes.id.toString(), civitai_id: civitaiModelRes.id.toString(),
civitai_version_id: selectedModelVersionId, civitai_version_id: selectedModelVersionId,
civitai_url: data.civitai_url, civitai_url: data.url, // TODO: need to confirm
civitai_download_url: selectedModelVersion.files[0].downloadUrl, civitai_download_url: selectedModelVersion.files[0].downloadUrl,
civitai_model_response: civitaiModelRes, civitai_model_response: civitaiModelRes,
user_volume_id: cVolume.id, user_volume_id: volumes[0].id,
model_type, model_type,
updated_at: new Date(),
}) })
.returning(); .returning();
const b = a[0]; const b = a[0];
await uploadModel(data, b, cVolume); await uploadModel(data, b, volumes[0]);
}, },
); );
@ -216,7 +351,7 @@ export const addCivitaiModel = withServerPromise(
// ); // );
async function uploadModel( async function uploadModel(
data: z.infer<typeof addCivitaiModelSchema>, data: z.infer<typeof downloadUrlModelSchema>,
c: ModelType, c: ModelType,
v: UserVolumeType, v: UserVolumeType,
) { ) {
@ -238,7 +373,9 @@ async function uploadModel(
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ body: JSON.stringify({
download_url: c.civitai_download_url, download_url: c.upload_type === "civitai"
? c.civitai_download_url
: c.user_url,
volume_name: v.volume_name, volume_name: v.volume_name,
volume_id: v.id, volume_id: v.id,
model_id: c.id, model_id: c.id,
@ -253,7 +390,6 @@ async function uploadModel(
await db await db
.update(modelTable) .update(modelTable)
.set({ .set({
...data,
status: "failed", status: "failed",
error_log: error_log, error_log: error_log,
}) })