Nickkao/volume improvemnts v3 (#4)
* fix: attempt fixing timeout * be validation work * arbitrary model input, BE validation, error record creation with error logs during potential failure points * remove unused type --------- Co-authored-by: bennykok <itechbenny@gmail.com>
This commit is contained in:
parent
852d889397
commit
42aaf1acb9
@ -42,7 +42,7 @@ def download_model(volume_name, download_config):
|
||||
|
||||
volume_base_path = vol_name_to_path[volume_name]
|
||||
model_store_path = os.path.join(volume_base_path, folder_path)
|
||||
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
|
||||
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key # civitai requires auth
|
||||
print('downloading', modified_download_url)
|
||||
|
||||
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
|
||||
|
1
web/drizzle/0046_complex_mentallo.sql
Normal file
1
web/drizzle/0046_complex_mentallo.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TYPE "model_upload_type" ADD VALUE 'download_url';
|
1
web/drizzle/0047_gifted_starbolt.sql
Normal file
1
web/drizzle/0047_gifted_starbolt.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TYPE "model_upload_type" ADD VALUE 'download-url';
|
1298
web/drizzle/meta/0046_snapshot.json
Normal file
1298
web/drizzle/meta/0046_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
1298
web/drizzle/meta/0047_snapshot.json
Normal file
1298
web/drizzle/meta/0047_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -323,6 +323,20 @@
|
||||
"when": 1706336448134,
|
||||
"tag": "0045_careful_cerise",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 46,
|
||||
"version": "5",
|
||||
"when": 1706383154642,
|
||||
"tag": "0046_complex_mentallo",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 47,
|
||||
"version": "5",
|
||||
"when": 1706384528895,
|
||||
"tag": "0047_gifted_starbolt",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
@ -32,8 +32,8 @@ import {
|
||||
} from "@tanstack/react-table";
|
||||
import { ArrowUpDown } from "lucide-react";
|
||||
import * as React from "react";
|
||||
import { addCivitaiModel } from "@/server/curdModel";
|
||||
import { addCivitaiModelSchema } from "@/server/addCivitaiModelSchema";
|
||||
import { addModel } from "@/server/curdModel";
|
||||
import { downloadUrlModelSchema } from "@/server/addCivitaiModelSchema";
|
||||
import { modelEnumType } from "@/db/schema";
|
||||
|
||||
export type ModelItemList = NonNullable<
|
||||
@ -89,9 +89,7 @@ export const columns: ColumnDef<ModelItemList>[] = [
|
||||
{row.original.model_name}
|
||||
</span>
|
||||
|
||||
{model.is_public
|
||||
? <></>
|
||||
: <Badge variant="orange">Private</Badge>}
|
||||
{model.is_public ? <></> : <Badge variant="orange">Private</Badge>}
|
||||
</>
|
||||
);
|
||||
},
|
||||
@ -298,16 +296,14 @@ export function ModelList({ data }: { data: ModelItemList[] }) {
|
||||
<InsertModal
|
||||
dialogClassName="sm:max-w-[600px]"
|
||||
disabled={
|
||||
false
|
||||
// TODO: limitations based on plan
|
||||
false // TODO: limitations based on plan
|
||||
}
|
||||
tooltip={"Add models using their civitai url!"}
|
||||
title="Add a Civitai Model"
|
||||
description="Pick a model from civitai"
|
||||
serverAction={addCivitaiModel}
|
||||
formSchema={addCivitaiModelSchema}
|
||||
title="Add a Model"
|
||||
description="using a link to a model"
|
||||
serverAction={addModel}
|
||||
formSchema={downloadUrlModelSchema}
|
||||
fieldConfig={{
|
||||
civitai_url: {
|
||||
url: {
|
||||
fieldType: "fallback",
|
||||
inputProps: { required: true },
|
||||
description: (
|
||||
@ -320,10 +316,19 @@ export function ModelList({ data }: { data: ModelItemList[] }) {
|
||||
>
|
||||
civitai.com
|
||||
</a>{" "}
|
||||
and place it's url here
|
||||
or a url we can download a model from
|
||||
</>
|
||||
),
|
||||
},
|
||||
model_type: {
|
||||
fieldType: "select",
|
||||
inputProps: { required: true },
|
||||
description: (
|
||||
<>
|
||||
We'll figure this out if you pick a civitai model
|
||||
</>
|
||||
),
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
@ -1,86 +0,0 @@
|
||||
// NOTE: this is WIP for doing client side validation for civitai model downloading
|
||||
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
|
||||
import { FormControl, FormItem, FormLabel } from "../ui/form";
|
||||
import { LoadingIcon } from "@/components/LoadingIcon";
|
||||
import * as React from "react";
|
||||
import AutoFormInput from "../ui/auto-form/fields/input";
|
||||
import { useDebouncedCallback } from "use-debounce";
|
||||
import { CivitaiModelResponse } from "@/types/civitai";
|
||||
import { z } from "zod";
|
||||
import { insertCivitaiModelSchema } from "@/db/schema";
|
||||
|
||||
function getUrl(civitai_url: string) {
|
||||
// expect to be a URL to be https://civitai.com/models/36520
|
||||
// possiblity with slugged name and query-param modelVersionId
|
||||
|
||||
const baseUrl = "https://civitai.com/api/v1/models/";
|
||||
const url = new URL(civitai_url);
|
||||
const pathSegments = url.pathname.split("/");
|
||||
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
|
||||
const modelVersionId = url.searchParams.get("modelVersionId");
|
||||
|
||||
return { url: baseUrl + modelId, modelVersionId };
|
||||
}
|
||||
|
||||
export default function AutoFormCheckpointInput(
|
||||
props: AutoFormInputComponentProps
|
||||
) {
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [modelRes, setModelRes] =
|
||||
React.useState<z.infer<typeof CivitaiModelResponse>>();
|
||||
const [modelVersionid, setModelVersionId] = React.useState<string | null>();
|
||||
const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
|
||||
|
||||
const handleSearch = useDebouncedCallback((search) => {
|
||||
const validationResult =
|
||||
insertCivitaiModelSchema.shape.civitai_url.safeParse(search);
|
||||
if (!validationResult.success) {
|
||||
console.error(validationResult.error);
|
||||
// Optionally set an error state here
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
|
||||
const controller = new AbortController();
|
||||
const { url, modelVersionId: versionId } = getUrl(search);
|
||||
setModelVersionId(versionId);
|
||||
fetch(url, {
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
const res = CivitaiModelResponse.parse(a);
|
||||
console.log(a);
|
||||
console.log(res);
|
||||
setModelRes(res);
|
||||
setLoading(false);
|
||||
});
|
||||
|
||||
return () => {
|
||||
controller.abort();
|
||||
setLoading(false);
|
||||
};
|
||||
}, 300);
|
||||
|
||||
const modifiedField = {
|
||||
...fieldProps,
|
||||
// onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
// handleSearch(event.target.value);
|
||||
// },
|
||||
};
|
||||
|
||||
return (
|
||||
<FormItem>
|
||||
{fieldConfigItem.inputProps?.showLabel && (
|
||||
<FormLabel>
|
||||
{label}
|
||||
{isRequired && <span className="text-destructive">*</span>}
|
||||
</FormLabel>
|
||||
)}
|
||||
<FormControl>
|
||||
<AutoFormInput {...props} fieldProps={modifiedField} />
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
);
|
||||
}
|
@ -380,12 +380,18 @@ export const resourceUpload = pgEnum("resource_upload", [
|
||||
|
||||
export const modelUploadType = pgEnum("model_upload_type", [
|
||||
"civitai",
|
||||
"huggingface",
|
||||
"download-url",
|
||||
"huggingface", // remove?
|
||||
"other",
|
||||
]);
|
||||
|
||||
// https://www.answeroverflow.com/m/1125106227387584552
|
||||
const modelTypes = ["checkpoint", "lora", "embedding", "vae"] as const;
|
||||
// https://www.answeroverflow.com/m/1125106227387584552
|
||||
export const modelTypes = [
|
||||
"checkpoint",
|
||||
"lora",
|
||||
"embedding",
|
||||
"vae",
|
||||
] as const
|
||||
export const modelType = pgEnum("model_type", modelTypes);
|
||||
export type modelEnumType = (typeof modelTypes)[number];
|
||||
|
||||
@ -399,8 +405,7 @@ export const modelTable = dbSchema.table("models", {
|
||||
.notNull()
|
||||
.references(() => userVolume.id, {
|
||||
onDelete: "cascade",
|
||||
})
|
||||
.notNull(),
|
||||
}),
|
||||
|
||||
model_name: text("model_name"),
|
||||
folder_path: text("folder_path"), // in volume
|
||||
@ -413,8 +418,10 @@ export const modelTable = dbSchema.table("models", {
|
||||
z.infer<typeof CivitaiModelResponse>
|
||||
>(),
|
||||
|
||||
// for our own storage
|
||||
hf_url: text("hf_url"),
|
||||
s3_url: text("s3_url"),
|
||||
|
||||
user_url: text("client_url"),
|
||||
|
||||
is_public: boolean("is_public").notNull().default(true),
|
||||
@ -484,16 +491,6 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
});
|
||||
|
||||
export const insertCivitaiModelSchema = createInsertSchema(modelTable, {
|
||||
civitai_url: (schema) =>
|
||||
schema.civitai_url
|
||||
.trim()
|
||||
.url({ message: "URL required" })
|
||||
.includes("civitai.com/models", {
|
||||
message: "civitai.com/models link required",
|
||||
}),
|
||||
});
|
||||
|
||||
export type UserType = InferSelectModel<typeof usersTable>;
|
||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
||||
export type MachineType = InferSelectModel<typeof machinesTable>;
|
||||
|
@ -1,5 +1,8 @@
|
||||
import { insertCivitaiModelSchema } from "@/db/schema";
|
||||
import { z } from "zod";
|
||||
import { modelTypes } from "@/db/schema";
|
||||
|
||||
export const addCivitaiModelSchema = insertCivitaiModelSchema.pick({
|
||||
civitai_url: true,
|
||||
export const downloadUrlModelSchema = z.object({
|
||||
url: z.string().url(),
|
||||
model_type: z.enum(modelTypes).default("checkpoint")
|
||||
});
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import { auth } from "@clerk/nextjs";
|
||||
import {
|
||||
modelEnumType,
|
||||
modelTable,
|
||||
ModelType,
|
||||
userVolume,
|
||||
@ -11,7 +12,7 @@ import { withServerPromise } from "./withServerPromise";
|
||||
import { db } from "@/db/db";
|
||||
import type { z } from "zod";
|
||||
import { headers } from "next/headers";
|
||||
import { addCivitaiModelSchema } from "./addCivitaiModelSchema";
|
||||
import { downloadUrlModelSchema } from "./addCivitaiModelSchema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
import { CivitaiModelResponse, getModelTypeDetails } from "@/types/civitai";
|
||||
|
||||
@ -90,10 +91,10 @@ export async function addModelVolume() {
|
||||
.values({
|
||||
user_id: userId,
|
||||
org_id: orgId,
|
||||
volume_name: `models_${orgId ? orgId: userId}`, // if orgid is avalible use as part of the volume name
|
||||
disabled: false,
|
||||
volume_name: `models_${orgId ? orgId : userId}`, // if orgid is avalible use as part of the volume name
|
||||
disabled: false,
|
||||
})
|
||||
.returning();
|
||||
.returning();
|
||||
return insertedVolume;
|
||||
}
|
||||
|
||||
@ -109,14 +110,150 @@ function getUrl(civitai_url: string) {
|
||||
return { url: baseUrl + modelId, modelVersionId };
|
||||
}
|
||||
|
||||
// Helper function to make a HEAD request and follow redirects
|
||||
async function fetchFinalUrl(
|
||||
url: string,
|
||||
): Promise<{ finalUrl: string; dispositionFilename?: string }> {
|
||||
console.log("fetching");
|
||||
const response = await fetch(url, { method: "HEAD", redirect: "follow" });
|
||||
if (!response.ok) {
|
||||
console.log("response not ok");
|
||||
throw new Error(`Request failed with status ${response.status}`);
|
||||
}
|
||||
const contentDisposition = response.headers.get("content-disposition");
|
||||
let filename;
|
||||
if (contentDisposition) {
|
||||
const matches = contentDisposition.match(
|
||||
/filename\*?=['"]?(?:UTF-8'')?([^;'"\n]*)['"]?;?/i,
|
||||
);
|
||||
filename = matches && matches[1]
|
||||
? decodeURIComponent(matches[1])
|
||||
: undefined;
|
||||
}
|
||||
return { finalUrl: response.url, dispositionFilename: filename };
|
||||
}
|
||||
|
||||
// The main function for validation
|
||||
export const addModel = withServerPromise(
|
||||
async (data: z.infer<typeof downloadUrlModelSchema>) => {
|
||||
const { url } = data;
|
||||
|
||||
if (url.includes("civitai.com/models/")) {
|
||||
// Make a HEAD request to check for 200 OK
|
||||
const response = await fetch(url, { method: "HEAD" });
|
||||
if (!response.ok) {
|
||||
createModelErrorRecord(
|
||||
url,
|
||||
`civitai gave non-ok response`,
|
||||
"civitai",
|
||||
data.model_type,
|
||||
);
|
||||
}
|
||||
addCivitaiModel(data);
|
||||
} else {
|
||||
const { finalUrl, dispositionFilename } = await fetchFinalUrl(url);
|
||||
console.log("finished fetching");
|
||||
console.log(finalUrl, dispositionFilename);
|
||||
|
||||
if (!dispositionFilename) {
|
||||
console.log("no file name");
|
||||
createModelErrorRecord(
|
||||
url,
|
||||
`Could not find a filename from resolved Url: ${finalUrl}`,
|
||||
"download-url",
|
||||
data.model_type,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const validExtensions = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"];
|
||||
const extension = dispositionFilename.slice(
|
||||
dispositionFilename.lastIndexOf("."),
|
||||
);
|
||||
if (!validExtensions.includes(extension)) {
|
||||
console.log("invalid extension");
|
||||
createModelErrorRecord(
|
||||
url,
|
||||
`file ext ${extension} is invalid. Valid extensions: ${validExtensions}`,
|
||||
"download-url",
|
||||
data.model_type,
|
||||
);
|
||||
}
|
||||
addModelDownloadUrl(data, dispositionFilename);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
export const addModelDownloadUrl = withServerPromise(
|
||||
async (data: z.infer<typeof downloadUrlModelSchema>, filename: string) => {
|
||||
console.log("adding model download");
|
||||
const { userId, orgId } = auth();
|
||||
if (!userId) return { error: "No user id" };
|
||||
const volumes = await retrieveModelVolumes();
|
||||
|
||||
const a = await db
|
||||
.insert(modelTable)
|
||||
.values({
|
||||
user_id: userId,
|
||||
org_id: orgId,
|
||||
upload_type: "download-url",
|
||||
model_name: filename,
|
||||
user_url: data.url,
|
||||
user_volume_id: volumes[0].id,
|
||||
model_type: data.model_type,
|
||||
})
|
||||
.returning();
|
||||
|
||||
const b = a[0];
|
||||
console.log("download url about to upload");
|
||||
await uploadModel(data, b, volumes[0]);
|
||||
},
|
||||
);
|
||||
|
||||
export const getCivitaiModelRes = async (civitaiUrl: string) => {
|
||||
const { url, modelVersionId } = getUrl(civitaiUrl);
|
||||
const civitaiModelRes = await fetch(url)
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
return CivitaiModelResponse.parse(a);
|
||||
});
|
||||
return { civitaiModelRes, url, modelVersionId };
|
||||
};
|
||||
|
||||
const createModelErrorRecord = async (
|
||||
url: string,
|
||||
errorMessage: string,
|
||||
upload_type: "civitai" | "download-url",
|
||||
model_type: modelEnumType,
|
||||
) => {
|
||||
const { userId, orgId } = auth();
|
||||
if (!userId) return { error: "No user id" };
|
||||
const volumes = await retrieveModelVolumes();
|
||||
|
||||
const a = await db
|
||||
.insert(modelTable)
|
||||
.values({
|
||||
user_id: userId,
|
||||
org_id: orgId,
|
||||
user_volume_id: volumes[0].id,
|
||||
upload_type: "civitai",
|
||||
model_type,
|
||||
civitai_url: upload_type === "civitai" ? url : undefined,
|
||||
user_url: upload_type === "download-url" ? url : undefined,
|
||||
error_log: errorMessage,
|
||||
status: "failed",
|
||||
})
|
||||
.returning();
|
||||
return a;
|
||||
};
|
||||
|
||||
export const addCivitaiModel = withServerPromise(
|
||||
async (data: z.infer<typeof addCivitaiModelSchema>) => {
|
||||
async (data: z.infer<typeof downloadUrlModelSchema>) => {
|
||||
const { userId, orgId } = auth();
|
||||
|
||||
if (!data.civitai_url) return { error: "no civitai_url" };
|
||||
if (!userId) return { error: "No user id" };
|
||||
|
||||
const { url, modelVersionId } = getUrl(data?.civitai_url);
|
||||
const { url, modelVersionId } = getUrl(data.url);
|
||||
const civitaiModelRes = await fetch(url)
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
@ -142,18 +279,17 @@ export const addCivitaiModel = withServerPromise(
|
||||
selectedModelVersionId = selectedModelVersion?.id.toString();
|
||||
}
|
||||
|
||||
const userVolume = await getModelVolumes();
|
||||
let cVolume;
|
||||
if (userVolume.length === 0) {
|
||||
const volume = await addModelVolume();
|
||||
cVolume = volume[0];
|
||||
} else {
|
||||
cVolume = userVolume[0];
|
||||
}
|
||||
const volumes = await retrieveModelVolumes();
|
||||
|
||||
const model_type = getModelTypeDetails(civitaiModelRes.type);
|
||||
if (!model_type) {
|
||||
return
|
||||
createModelErrorRecord(
|
||||
url,
|
||||
`Civitai model type ${civitaiModelRes.type} is not currently supported`,
|
||||
"civitai",
|
||||
data.model_type,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const a = await db
|
||||
@ -165,18 +301,17 @@ export const addCivitaiModel = withServerPromise(
|
||||
model_name: selectedModelVersion.files[0].name,
|
||||
civitai_id: civitaiModelRes.id.toString(),
|
||||
civitai_version_id: selectedModelVersionId,
|
||||
civitai_url: data.civitai_url,
|
||||
civitai_url: data.url, // TODO: need to confirm
|
||||
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
||||
civitai_model_response: civitaiModelRes,
|
||||
user_volume_id: cVolume.id,
|
||||
model_type,
|
||||
updated_at: new Date(),
|
||||
user_volume_id: volumes[0].id,
|
||||
model_type,
|
||||
})
|
||||
.returning();
|
||||
|
||||
const b = a[0];
|
||||
|
||||
await uploadModel(data, b, cVolume);
|
||||
await uploadModel(data, b, volumes[0]);
|
||||
},
|
||||
);
|
||||
|
||||
@ -216,7 +351,7 @@ export const addCivitaiModel = withServerPromise(
|
||||
// );
|
||||
|
||||
async function uploadModel(
|
||||
data: z.infer<typeof addCivitaiModelSchema>,
|
||||
data: z.infer<typeof downloadUrlModelSchema>,
|
||||
c: ModelType,
|
||||
v: UserVolumeType,
|
||||
) {
|
||||
@ -238,7 +373,9 @@ async function uploadModel(
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
download_url: c.civitai_download_url,
|
||||
download_url: c.upload_type === "civitai"
|
||||
? c.civitai_download_url
|
||||
: c.user_url,
|
||||
volume_name: v.volume_name,
|
||||
volume_id: v.id,
|
||||
model_id: c.id,
|
||||
@ -253,7 +390,6 @@ async function uploadModel(
|
||||
await db
|
||||
.update(modelTable)
|
||||
.set({
|
||||
...data,
|
||||
status: "failed",
|
||||
error_log: error_log,
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user