feat: add civital model install

This commit is contained in:
BennyKok 2024-01-13 23:15:19 +08:00
parent 1f91d4d357
commit 1d25aadd74
6 changed files with 266 additions and 34 deletions

Binary file not shown.

View File

@ -95,6 +95,7 @@
"tailwindcss-animate": "^1.0.7", "tailwindcss-animate": "^1.0.7",
"unist-util-filter": "^5.0.1", "unist-util-filter": "^5.0.1",
"unist-util-visit": "^5.0.0", "unist-util-visit": "^5.0.0",
"use-debounce": "^10.0.0",
"uuid": "^9.0.1", "uuid": "^9.0.1",
"zod": "^3.22.4", "zod": "^3.22.4",
"zustand": "^4.4.7" "zustand": "^4.4.7"

View File

@ -1,6 +1,7 @@
"use client"; "use client";
import type { AutoFormInputComponentProps } from "../ui/auto-form/types"; import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
import { LoadingIcon } from "@/components/LoadingIcon";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { import {
Command, Command,
@ -21,6 +22,7 @@ import { cn } from "@/lib/utils";
import { Check, ChevronsUpDown } from "lucide-react"; import { Check, ChevronsUpDown } from "lucide-react";
import * as React from "react"; import * as React from "react";
import { useRef } from "react"; import { useRef } from "react";
import { useDebouncedCallback } from "use-debounce";
import { z } from "zod"; import { z } from "zod";
const Model = z.object({ const Model = z.object({
@ -34,6 +36,114 @@ const Model = z.object({
url: z.string(), url: z.string(),
}); });
export const CivitalModelSchema = z.object({
items: z.array(
z.object({
id: z.number(),
name: z.string(),
description: z.string(),
type: z.string(),
// poi: z.boolean(),
// nsfw: z.boolean(),
// allowNoCredit: z.boolean(),
// allowCommercialUse: z.string(),
// allowDerivatives: z.boolean(),
// allowDifferentLicense: z.boolean(),
// stats: z.object({
// downloadCount: z.number(),
// favoriteCount: z.number(),
// commentCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// tippedAmountCount: z.number(),
// }),
creator: z
.object({
username: z.string().nullable(),
image: z.string().nullable().default(null),
})
.nullable(),
tags: z.array(z.string()),
modelVersions: z.array(
z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string(),
updatedAt: z.string(),
status: z.string(),
publishedAt: z.string(),
trainedWords: z.array(z.unknown()),
trainingStatus: z.string().nullable(),
trainingDetails: z.string().nullable(),
baseModel: z.string(),
baseModelType: z.string().nullable(),
earlyAccessTimeFrame: z.number(),
description: z.string().nullable(),
vaeId: z.number().nullable(),
stats: z.object({
downloadCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
}),
files: z.array(
z.object({
id: z.number(),
sizeKB: z.number(),
name: z.string(),
type: z.string(),
// metadata: z.object({
// fp: z.string().nullable().optional(),
// size: z.string().nullable().optional(),
// format: z.string().nullable().optional(),
// }),
// pickleScanResult: z.string(),
// pickleScanMessage: z.string(),
// virusScanResult: z.string(),
// virusScanMessage: z.string().nullable(),
// scannedAt: z.string(),
// hashes: z.object({
// AutoV1: z.string().nullable().optional(),
// AutoV2: z.string().nullable().optional(),
// SHA256: z.string().nullable().optional(),
// CRC32: z.string().nullable().optional(),
// BLAKE3: z.string().nullable().optional(),
// }),
downloadUrl: z.string(),
// primary: z.boolean().default(false),
})
),
images: z.array(
z.object({
id: z.number(),
url: z.string(),
nsfw: z.string(),
width: z.number(),
height: z.number(),
hash: z.string(),
type: z.string(),
metadata: z.object({
hash: z.string(),
width: z.number(),
height: z.number(),
}),
meta: z.any(),
})
),
downloadUrl: z.string(),
})
),
})
),
metadata: z.object({
totalItems: z.number(),
currentPage: z.number(),
pageSize: z.number(),
totalPages: z.number(),
nextPage: z.string().optional(),
}),
});
const ModelList = z.array(Model); const ModelList = z.array(Model);
export const ModelListWrapper = z.object({ export const ModelListWrapper = z.object({
@ -43,16 +153,132 @@ export const ModelListWrapper = z.object({
export function ModelPickerView({ export function ModelPickerView({
field, field,
}: Pick<AutoFormInputComponentProps, "field">) { }: Pick<AutoFormInputComponentProps, "field">) {
const value = (field.value as z.infer<typeof ModelList>) ?? []; return (
<div className="flex gap-2 flex-col">
<ComfyUIManagerModelRegistry field={field} />
<CivitaiModelRegistry field={field} />
{/* <span>{field.value.length} selected</span> */}
{field.value && (
<ScrollArea className="w-full bg-gray-100 mx-auto rounded-lg mt-2">
<Textarea
className="min-h-[150px] max-h-[300px] p-2 rounded-lg text-xs w-full"
value={JSON.stringify(field.value, null, 2)}
onChange={(e) => {
field.onChange(JSON.parse(e.target.value));
}}
/>
</ScrollArea>
)}
</div>
);
}
const [open, setOpen] = React.useState(false); function mapModelsList(
models: z.infer<typeof CivitalModelSchema>
): z.infer<typeof ModelListWrapper> {
return {
models: models.items.map((item) => {
const v = item.modelVersions[0];
return {
name: `${item.name} ${v.name} (${v.files[0].name})`,
type: v.files[0].type.toLowerCase(),
base: v.baseModel,
save_path: "default",
description: item.description,
reference: "",
filename: v.files[0].name,
url: v.files[0].downloadUrl,
} as z.infer<typeof Model>;
}),
};
}
function getUrl(search?: string) {
const baseUrl = "https://civitai.com/api/v1/models";
const searchParams = {
limit: 5,
} as any;
searchParams["sort"] = "Most Downloaded";
if (search) {
searchParams["query"] = search;
} else {
// sort: "Highest Rated",
}
const url = new URL(baseUrl);
Object.keys(searchParams).forEach((key) =>
url.searchParams.append(key, searchParams[key])
);
return url;
}
export function CivitaiModelRegistry({
field,
}: Pick<AutoFormInputComponentProps, "field">) {
const [modelList, setModelList] = const [modelList, setModelList] =
React.useState<z.infer<typeof ModelListWrapper>>(); React.useState<z.infer<typeof ModelListWrapper>>();
// const [selectedModels, setSelectedModels] = React.useState< const [loading, setLoading] = React.useState(false);
// z.infer<typeof ModelList>
// >(field.value ?? []); const handleSearch = useDebouncedCallback((search) => {
console.log(`Searching... ${search}`);
setLoading(true);
const controller = new AbortController();
fetch(getUrl(search), {
signal: controller.signal,
})
.then((x) => x.json())
.then((a) => {
const list = CivitalModelSchema.parse(a);
console.log(a);
setModelList(mapModelsList(list));
setLoading(false);
});
return () => {
controller.abort();
setLoading(false);
};
}, 300);
React.useEffect(() => {
const controller = new AbortController();
fetch(getUrl(), {
signal: controller.signal,
})
.then((x) => x.json())
.then((a) => {
const list = CivitalModelSchema.parse(a);
setModelList(mapModelsList(list));
});
return () => {
controller.abort();
};
}, []);
return (
<ModelSelector
field={field}
modelList={modelList}
label="Civitai"
onSearch={handleSearch}
shouldFilter={false}
isLoading={loading}
/>
);
}
export function ComfyUIManagerModelRegistry({
field,
}: Pick<AutoFormInputComponentProps, "field">) {
const [modelList, setModelList] =
React.useState<z.infer<typeof ModelListWrapper>>();
React.useEffect(() => { React.useEffect(() => {
const controller = new AbortController(); const controller = new AbortController();
@ -72,6 +298,26 @@ export function ModelPickerView({
}; };
}, []); }, []);
return <ModelSelector field={field} modelList={modelList} label="common" />;
}
export function ModelSelector({
field,
modelList,
label,
onSearch,
shouldFilter = true,
isLoading,
}: Pick<AutoFormInputComponentProps, "field"> & {
modelList?: z.infer<typeof ModelListWrapper>;
label: string;
onSearch?: (search: string) => void;
shouldFilter?: boolean;
isLoading?: boolean;
}) {
const value = (field.value as z.infer<typeof ModelList>) ?? [];
const [open, setOpen] = React.useState(false);
function toggleModel(model: z.infer<typeof Model>) { function toggleModel(model: z.infer<typeof Model>) {
const prevSelectedModels = value; const prevSelectedModels = value;
if ( if (
@ -91,10 +337,6 @@ export function ModelPickerView({
} }
} }
// React.useEffect(() => {
// field.onChange(selectedModels);
// }, [selectedModels]);
const containerRef = useRef<HTMLDivElement>(null); const containerRef = useRef<HTMLDivElement>(null);
return ( return (
@ -107,13 +349,19 @@ export function ModelPickerView({
aria-expanded={open} aria-expanded={open}
className="w-full justify-between flex" className="w-full justify-between flex"
> >
Select models... ({value.length} selected) Select {label}
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" /> <ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
</Button> </Button>
</PopoverTrigger> </PopoverTrigger>
<PopoverContent className="w-[375px] p-0" side="top"> <PopoverContent className="w-[375px] p-0" side="bottom">
<Command> <Command shouldFilter={shouldFilter}>
<CommandInput placeholder="Search models..." className="h-9" /> <CommandInput
placeholder="Search models..."
className="h-9"
onValueChange={onSearch}
>
{isLoading && <LoadingIcon />}
</CommandInput>
<CommandEmpty>No framework found.</CommandEmpty> <CommandEmpty>No framework found.</CommandEmpty>
<CommandList className="pointer-events-auto"> <CommandList className="pointer-events-auto">
<CommandGroup> <CommandGroup>
@ -123,7 +371,6 @@ export function ModelPickerView({
value={model.url} value={model.url}
onSelect={() => { onSelect={() => {
toggleModel(model); toggleModel(model);
// Update field.onChange to pass the array of selected models
}} }}
> >
{model.name} {model.name}
@ -144,23 +391,6 @@ export function ModelPickerView({
</Command> </Command>
</PopoverContent> </PopoverContent>
</Popover> </Popover>
{field.value && (
<ScrollArea className="w-full bg-gray-100 mx-auto rounded-lg mt-2">
{/* <div className="max-h-[200px]">
<pre className="p-2 rounded-md text-xs ">
{JSON.stringify(field.value, null, 2)}
</pre>
</div> */}
<Textarea
className="min-h-[150px] max-h-[300px] p-2 rounded-md text-xs w-full"
value={JSON.stringify(field.value, null, 2)}
onChange={(e) => {
// Update field.onChange to pass the array of selected models
field.onChange(JSON.parse(e.target.value));
}}
/>
</ScrollArea>
)}
</div> </div>
); );
} }

View File

@ -37,7 +37,7 @@ export default function AutoFormObject<
const { shape } = getBaseSchema<SchemaType>(schema); const { shape } = getBaseSchema<SchemaType>(schema);
return ( return (
<Accordion type="multiple" className="space-y-5"> <Accordion type="multiple" className="space-y-5 py-1">
{Object.keys(shape).map((name) => { {Object.keys(shape).map((name) => {
const item = shape[name] as z.ZodAny; const item = shape[name] as z.ZodAny;
const zodBaseType = getBaseType(item); const zodBaseType = getBaseType(item);

View File

@ -70,7 +70,7 @@ function AutoForm<SchemaType extends ZodObjectOrWrapped>({
className={cn("space-y-5", className)} className={cn("space-y-5", className)}
> >
<ScrollArea> <ScrollArea>
<div className="max-h-[400px] px-1 py-1 w-full"> <div className="max-h-[400px] px-1 w-full">
<AutoFormObject <AutoFormObject
schema={objectFormSchema} schema={objectFormSchema}
form={form} form={form}

View File

@ -39,7 +39,7 @@ const CommandDialog = ({ children, ...props }: CommandDialogProps) => {
const CommandInput = React.forwardRef< const CommandInput = React.forwardRef<
React.ElementRef<typeof CommandPrimitive.Input>, React.ElementRef<typeof CommandPrimitive.Input>,
React.ComponentPropsWithoutRef<typeof CommandPrimitive.Input> React.ComponentPropsWithoutRef<typeof CommandPrimitive.Input>
>(({ className, ...props }, ref) => ( >(({ className, children, ...props }, ref) => (
<div className="flex items-center border-b px-3" cmdk-input-wrapper=""> <div className="flex items-center border-b px-3" cmdk-input-wrapper="">
<Search className="mr-2 h-4 w-4 shrink-0 opacity-50" /> <Search className="mr-2 h-4 w-4 shrink-0 opacity-50" />
<CommandPrimitive.Input <CommandPrimitive.Input
@ -50,6 +50,7 @@ const CommandInput = React.forwardRef<
)} )}
{...props} {...props}
/> />
{children}
</div> </div>
)); ));