Squashed commit of the following:

commit 33c0ad7d14a85f22c57f943dab58610c13d2ac07
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Tue Jan 30 21:56:00 2024 -0800

    revert custom form change

commit d2905ad045ad7856156e3647a81d642999352de7
Merge: 654423d e3a1d24
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Tue Jan 30 20:50:06 2024 -0800

    merge schema

commit 654423d597e019a5ebf1ab6568c9942fcb9181c5
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Tue Jan 30 20:49:34 2024 -0800

    merge confl.ict

commit 641724c11346319674fbb329e8e29b362117c242
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Tue Jan 30 20:47:34 2024 -0800

    model reload on create

commit eb4dfe8e3f39a0a98eab0fcf1affe7096c12f33b
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Tue Jan 30 17:00:03 2024 -0800

    delete models

commit 0bea9583fada102396c4e08fe6da971c94d404df
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Tue Jan 30 14:35:15 2024 -0800

    deploy volume uploader to have timeouts only be modal related
This commit is contained in:
bennykok 2024-01-31 14:29:36 +08:00
parent e3a1d24304
commit 8eb2ce3e10
10 changed files with 202 additions and 59 deletions

View File

@ -13,4 +13,4 @@ RUN mkdir builds
# CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "80", "--lifespan", "on"]
CMD ["python", "src/main.py"]
# If running behind a proxy like Nginx or Traefik add --proxy-headers
# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80", "--proxy-headers"]
# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80", "--proxy-headers"]

View File

@ -19,7 +19,7 @@ import requests
from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send
import modal
from concurrent.futures import ThreadPoolExecutor
@ -236,6 +236,14 @@ class UploadType(str, Enum):
checkpoint = "checkpoint"
lora = "lora"
embedding = "embedding"
clip = "clip"
clip_vision = "clip_vision"
configs = "configs"
controlnet = "controlnet"
upscale_models = "upscale_models"
vae = "vae"
ipadapter = "ipadapter"
other = "other"
class UploadBody(BaseModel):
download_url: str
@ -251,8 +259,46 @@ UPLOAD_TYPE_DIR_MAP = {
UploadType.checkpoint: "checkpoints",
UploadType.lora: "loras",
UploadType.embedding: "embeddings",
UploadType.clip: "clip",
UploadType.clip_vision: "clip_vision",
UploadType.configs: "configs",
UploadType.controlnet: "controlnet",
UploadType.upscale_models: "upscale_models",
UploadType.vae: "vae",
UploadType.ipadapter: "ipadapter",
UploadType.other: "",
}
class DeleteBody(BaseModel):
volume_name: str
path: str
file_name: str
@app.post("/delete-volume-model")
async def delete_model(body: DeleteBody):
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
full_path = f"{body.path.rstrip('/')}/{body.file_name}"
rm_process = await asyncio.subprocess.create_subprocess_exec("modal", "volume", "rm", body.volume_name, full_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,)
await rm_process.wait()
logger.info(f"Successfully deleted: {full_path} from volume: {body.volume_name}")
stdout, stderr = await rm_process.communicate()
if stdout:
logger.info(f"cp_process stdout: {stdout.decode()}")
if stderr:
logger.info(f"cp_process stderr: {stderr.decode()}")
if rm_process.returncode == 0:
return JSONResponse(status_code=200, content={"status":f"Successfully deleted {full_path} from volume {body.volume_name}"})
else:
return JSONResponse(status_code=500, content={"status": "error", "error": stderr.decode()})
@app.post("/upload-volume")
async def upload_model(body: UploadBody):
@ -267,12 +313,16 @@ async def upload_model(body: UploadBody):
async def upload_logic(body: UploadBody):
folder_path = f"/app/builds/{body.volume_id}"
folder_path = f"/app/builds/{body.volume_id}-{uuid4()}"
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume_builder", folder_path)
await cp_process.wait()
upload_path = UPLOAD_TYPE_DIR_MAP[body.upload_type]
if upload_path == "":
# TODO: deal with custom paths
pass
config = {
"volume_names": {
body.volume_name: {"download_url": body.download_url, "folder_path": upload_path}
@ -286,16 +336,22 @@ async def upload_logic(body: UploadBody):
"volume_id": body.volume_id,
"folder_path": upload_path,
},
"civitai_api_key": os.environ.get('CIVITAI_API_KEY')
"civitai_api_key": os.environ.get('CIVITAI_API_KEY'),
"app_name": f"vol_name_{uuid4()}"
}
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))
await asyncio.subprocess.create_subprocess_shell(
f"modal run app.py",
process = await asyncio.subprocess.create_subprocess_shell(
f"python runner.py",
cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"}
)
await process.wait()
# import modal
# modal.deploy_stub(stub)
# stub["download_model"].web_url
@app.post("/create")
async def create_machine(item: Item):

View File

@ -9,6 +9,8 @@ public:
loras: loras
upscale_models: upscale_models
vae: vae
ipadapter: ipadapter
private:
base_path: /private_models/
@ -21,3 +23,4 @@ private:
loras: loras
upscale_models: upscale_models
vae: vae
ipadapter: ipadapter

View File

@ -1,10 +1,18 @@
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
import modal
from config import config
import os
import subprocess
from pprint import pprint
stub = modal.Stub()
stub = modal.Stub(config["app_name"])
vol_name_to_links = config["volume_names"]
vol_name_to_path = config["volume_paths"]
callback_url = config["callback_url"]
callback_body = config["callback_body"]
civitai_key = config["civitai_api_key"]
web_app = FastAPI()
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
def is_valid_name(name: str) -> bool:
@ -21,12 +29,6 @@ def create_volumes(volume_names, paths):
return path_to_vol
vol_name_to_links = config["volume_names"]
vol_name_to_path = config["volume_paths"]
callback_url = config["callback_url"]
callback_body = config["callback_body"]
civitai_key = config["civitai_api_key"]
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
image = (
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
@ -45,7 +47,7 @@ def download_model(volume_name, download_config):
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key # civitai requires auth
print('downloading', modified_download_url)
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path, "-nv"])
subprocess.run(["ls", "-la", volume_base_path])
subprocess.run(["ls", "-la", model_store_path])
volumes[volume_base_path].commit()
@ -56,11 +58,12 @@ def download_model(volume_name, download_config):
print(f"finished! sending to {callback_url}")
pprint({**status, **callback_body})
@stub.local_entrypoint()
@stub.function(image=image)
# @modal.asgi_app()
def simple_download():
import requests
try:
list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
list(download_model.starmap([(vol_name, download_conf) for vol_name,download_conf in vol_name_to_links.items()]))
except modal.exception.FunctionTimeoutError as e:
status = {"status": "failed", "error_logs": f"{str(e)}", "timeout": timeout}
requests.post(callback_url, json={**status, **callback_body})
@ -71,4 +74,3 @@ def simple_download():
requests.post(callback_url, json={**status, **callback_body})
print(f"finished! sending to {callback_url}")
pprint({**status, **callback_body})

View File

@ -15,4 +15,5 @@ config = {
"folder_path": "checkpoints",
},
"civitai_api_key": "",
"app_name": "",
}

View File

@ -0,0 +1,12 @@
import modal
import requests
from app import stub
from config import config
modal.runner.deploy_stub(stub)
print("deployed stub")
# web_url = stub["simple_download"].web_url
f = modal.Function.lookup(config['app_name'], "simple_download")
f.spawn()
# print(f"web_url: {web_url}")
# requests.post(web_url)

View File

@ -353,4 +353,4 @@
"breakpoints": true
}
]
}
}

View File

@ -7,6 +7,13 @@ import { Checkbox } from "@/components/ui/checkbox";
import { InsertModal } from "./InsertModal";
import { Input } from "@/components/ui/input";
import { ScrollArea } from "@/components/ui/scroll-area";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
Table,
TableBody,
@ -30,9 +37,9 @@ import {
getSortedRowModel,
useReactTable,
} from "@tanstack/react-table";
import { ArrowUpDown } from "lucide-react";
import { ArrowUpDown, MoreHorizontal } from "lucide-react";
import * as React from "react";
import { addModel } from "@/server/curdModel";
import { addModel, deleteModel } from "@/server/curdModel";
import { downloadUrlModelSchema } from "@/server/addCivitaiModelSchema";
import { modelEnumType } from "@/db/schema";
@ -192,10 +199,16 @@ export const columns: ColumnDef<ModelItemList>[] = [
lora: "green",
embedding: "violet",
vae: "teal",
clip: "default",
clip_vision: "default",
configs: "default",
controlnet: "default",
upscale_models: "default",
ipadapter: "default",
};
function getBadgeColor(modelType: modelEnumType) {
return model_type_map[modelType] || "default";
return model_type_map[modelType]
}
const color = getBadgeColor(row.original.model_type);
@ -225,35 +238,35 @@ export const columns: ColumnDef<ModelItemList>[] = [
),
},
// TODO: deletion and editing for future sprint
// {
// id: "actions",
// enableHiding: false,
// cell: ({ row }) => {
// const checkpoint = row.original;
//
// return (
// <DropdownMenu>
// <DropdownMenuTrigger asChild>
// <Button variant="ghost" className="h-8 w-8 p-0">
// <span className="sr-only">Open menu</span>
// <MoreHorizontal className="h-4 w-4" />
// </Button>
// </DropdownMenuTrigger>
// <DropdownMenuContent align="end">
// <DropdownMenuLabel>Actions</DropdownMenuLabel>
// <DropdownMenuItem
// className="text-destructive"
// onClick={() => {
// deleteWorkflow(checkpoint.id);
// }}
// >
// Delete Workflow
// </DropdownMenuItem>
// </DropdownMenuContent>
// </DropdownMenu>
// );
// },
// },
{
id: "actions",
enableHiding: false,
cell: ({ row }) => {
const model = row.original;
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="ghost" className="h-8 w-8 p-0">
<span className="sr-only">Open menu</span>
<MoreHorizontal className="h-4 w-4" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuLabel>Actions</DropdownMenuLabel>
<DropdownMenuItem
className="text-destructive"
onClick={() => {
deleteModel(model.id);
}}
>
Delete Model
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
);
},
},
];
export function ModelList({ data }: { data: ModelItemList[] }) {

View File

@ -7,10 +7,10 @@ import {
jsonb,
pgEnum,
pgSchema,
real,
text,
timestamp,
uuid,
real,
} from "drizzle-orm/pg-core";
import { createInsertSchema, createSelectSchema } from "drizzle-zod";
import { TypeOf, z } from "zod";
@ -150,8 +150,9 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
onDelete: "set null",
},
),
workflow_inputs:
jsonb("workflow_inputs").$type<Record<string, string | number>>(),
workflow_inputs: jsonb("workflow_inputs").$type<
Record<string, string | number>
>(),
workflow_id: uuid("workflow_id")
.notNull()
.references(() => workflowTable.id, {
@ -298,8 +299,9 @@ export const deploymentsTable = dbSchema.table("deployments", {
.references(() => machinesTable.id),
share_slug: text("share_slug").unique(),
description: text("description"),
showcase_media:
jsonb("showcase_media").$type<z.infer<typeof showcaseMedia>>(),
showcase_media: jsonb("showcase_media").$type<
z.infer<typeof showcaseMedia>
>(),
environment: deploymentEnvironment("environment").notNull(),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
@ -389,7 +391,18 @@ export const modelUploadType = pgEnum("model_upload_type", [
]);
// https://www.answeroverflow.com/m/1125106227387584552
export const modelTypes = ["checkpoint", "lora", "embedding", "vae"] as const;
export const modelTypes = [
"checkpoint",
"lora",
"embedding",
"vae",
"clip",
"clip_vision",
"configs",
"controlnet",
"upscale_models",
"ipadapter",
] as const;
export const modelType = pgEnum("model_type", modelTypes);
export type modelEnumType = (typeof modelTypes)[number];

View File

@ -11,6 +11,7 @@ import {
import { withServerPromise } from "./withServerPromise";
import { db } from "@/db/db";
import type { z } from "zod";
import { revalidatePath } from "next/cache";
import { headers } from "next/headers";
import { downloadUrlModelSchema } from "./addCivitaiModelSchema";
import { and, eq, isNull } from "drizzle-orm";
@ -210,6 +211,47 @@ export const addModelDownloadUrl = withServerPromise(
},
);
export const deleteModel = withServerPromise(
async (modelId: string) => {
const model = await db.query.modelTable.findFirst({
where: eq(modelTable.id, modelId),
});
// If the model does not exist, throw an error or return a message
if (!model) {
throw new Error("Model not found");
// Or return { error: "Model not found" }; if you prefer to handle it without throwing
}
const volumes = await retrieveModelVolumes();
if (
model.status === "success" && !!model.folder_path && !!model.model_name
) {
const result = await fetch(
`${process.env.MODAL_BUILDER_URL!}/delete-volume-model`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
volume_name: volumes[0].volume_name,
path: model.folder_path,
file_name: model.model_name,
}),
},
);
if (!result.ok) {
const error_log = await result.text();
throw new Error(`Error: ${result.statusText} ${error_log}`);
}
}
await db.delete(modelTable).where(eq(modelTable.id, modelId));
revalidatePath("/storage");
return { message: "Model Deleted" };
},
);
export const getCivitaiModelRes = async (civitaiUrl: string) => {
const { url, modelVersionId } = getUrl(civitaiUrl);
const civitaiModelRes = await fetch(url)
@ -301,8 +343,8 @@ export const addCivitaiModel = withServerPromise(
model_name: selectedModelVersion.files[0].name,
civitai_id: civitaiModelRes.id.toString(),
civitai_version_id: selectedModelVersionId,
civitai_url: data.url, // TODO: need to confirm
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
civitai_url: data.url,
civitai_download_url: selectedModelVersion.files[0].downloadUrl, // there is an issue when a model hoster might put multiple different types of files i.e. their training data.
civitai_model_response: civitaiModelRes,
user_volume_id: volumes[0].id,
model_type,
@ -312,6 +354,7 @@ export const addCivitaiModel = withServerPromise(
const b = a[0];
await uploadModel(data, b, volumes[0]);
revalidatePath("/storage");
},
);