From db04d02d34e081165b5ff00df84b13734b9d4a75 Mon Sep 17 00:00:00 2001 From: Nicholas Koben Kao Date: Wed, 24 Jan 2024 22:15:12 -0800 Subject: [PATCH] working --- builder/modal-builder/src/main.py | 4 +- .../modal-builder/src/volume-builder/app.py | 12 ++-- .../src/volume-builder/config.py | 3 +- web/src/app/(app)/api/volume-upload/route.ts | 4 +- web/src/components/CheckpointList.tsx | 32 +++++++++ web/src/db/schema.ts | 70 +++++++++---------- web/src/server/curdCheckpoint.ts | 64 +++++++++++------ 7 files changed, 124 insertions(+), 65 deletions(-) diff --git a/builder/modal-builder/src/main.py b/builder/modal-builder/src/main.py index b907fc4..46787de 100644 --- a/builder/modal-builder/src/main.py +++ b/builder/modal-builder/src/main.py @@ -47,6 +47,7 @@ machine_id_websocket_dict = {} machine_id_status = {} fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] +civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0] class FlyReplayMiddleware(BaseHTTPMiddleware): @@ -272,7 +273,8 @@ async def upload_logic(body: UploadBody): "checkpoint_id": body.checkpoint_id, "volume_id": body.volume_id, "folder_path": upload_path, - } + }, + "civitai_api_key": os.environ.get('CIVITAI_API_KEY') } with open(f"{folder_path}/config.py", "w") as f: f.write("config = " + json.dumps(config)) diff --git a/builder/modal-builder/src/volume-builder/app.py b/builder/modal-builder/src/volume-builder/app.py index 3eee974..06daf86 100644 --- a/builder/modal-builder/src/volume-builder/app.py +++ b/builder/modal-builder/src/volume-builder/app.py @@ -25,16 +25,13 @@ vol_name_to_links = config["volume_names"] vol_name_to_path = config["volume_paths"] callback_url = config["callback_url"] callback_body = config["callback_body"] +civitai_key = config["civitai_api_key"] volumes = create_volumes(vol_name_to_links, vol_name_to_path) image = ( modal.Image.debian_slim().apt_install("wget").pip_install("requests") ) -print(vol_name_to_links) -print(vol_name_to_path) -print(volumes) - # download config { "download_url": "", "folder_path": ""} timeout=5000 @stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None) @@ -42,20 +39,23 @@ def download_model(volume_name, download_config): import requests download_url = download_config["download_url"] folder_path = download_config["folder_path"] + volume_base_path = vol_name_to_path[volume_name] model_store_path = os.path.join(volume_base_path, folder_path) + modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key + print('downlodaing', modified_download_url) - subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path]) + subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path]) subprocess.run(["ls", "-la", volume_base_path]) subprocess.run(["ls", "-la", model_store_path]) volumes[volume_base_path].commit() + status = {"status": "success"} requests.post(callback_url, json={**status, **callback_body}) print(f"finished! sending to {callback_url}") pprint({**status, **callback_body}) - @stub.local_entrypoint() def simple_download(): import requests diff --git a/builder/modal-builder/src/volume-builder/config.py b/builder/modal-builder/src/volume-builder/config.py index 2bf9032..3abfeac 100644 --- a/builder/modal-builder/src/volume-builder/config.py +++ b/builder/modal-builder/src/volume-builder/config.py @@ -13,5 +13,6 @@ config = { "checkpoint_id": "", "volume_id": "", "folder_path": "images", - } + }, + "civitai_api_key": "", } diff --git a/web/src/app/(app)/api/volume-upload/route.ts b/web/src/app/(app)/api/volume-upload/route.ts index 25eb2d4..70edf27 100644 --- a/web/src/app/(app)/api/volume-upload/route.ts +++ b/web/src/app/(app)/api/volume-upload/route.ts @@ -26,7 +26,8 @@ export async function POST(request: Request) { .update(checkpointTable) .set({ status: "success", - folder_path + folder_path, + updated_at: new Date(), // build_log: build_log, }) .where(eq(checkpointTable.id, checkpoint_id)); @@ -37,6 +38,7 @@ export async function POST(request: Request) { .set({ status: "failed", error_log, + updated_at: new Date(), // status: "error", // build_log: build_log, }) diff --git a/web/src/components/CheckpointList.tsx b/web/src/components/CheckpointList.tsx index 01afbf7..faf014b 100644 --- a/web/src/components/CheckpointList.tsx +++ b/web/src/components/CheckpointList.tsx @@ -116,6 +116,38 @@ export const columns: ColumnDef[] = [ {row.original.status} ); + // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000)); + // const lastUpdated = new Date(row.original.updated_at); + // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo; + // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo; + // cell: ({ row }) => { + // // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000)); + // // const lastUpdated = new Date(row.original.updated_at); + // // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo; + // const canReDownload = true; + // + // return ( + //
+ // + // {row.original.status} + // + // {canReDownload && ( + // { + // redownloadCheckpoint(row.original); + // }} + // className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes + // /> + // )} + //
+ // ); + // }, }, }, { diff --git a/web/src/db/schema.ts b/web/src/db/schema.ts index c6255ba..da3433c 100644 --- a/web/src/db/schema.ts +++ b/web/src/db/schema.ts @@ -406,13 +406,47 @@ export const checkpointTable = dbSchema.table("checkpoints", { is_public: boolean("is_public").notNull().default(false), status: resourceUpload("status").notNull().default("started"), - upload_machine_id: text("upload_machine_id"), + upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed upload_type: modelUploadType("upload_type").notNull(), error_log: text("error_log"), created_at: timestamp("created_at").defaultNow().notNull(), updated_at: timestamp("updated_at").defaultNow().notNull(), }); +export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", { + id: uuid("id").primaryKey().defaultRandom().notNull(), + user_id: text("user_id").references(() => usersTable.id, { + // onDelete: "cascade", + }), + org_id: text("org_id"), + volume_name: text("volume_name").notNull(), + created_at: timestamp("created_at").defaultNow().notNull(), + updated_at: timestamp("updated_at").defaultNow().notNull(), + disabled: boolean("disabled").default(false).notNull(), +}); + +export const checkpointRelations = relations(checkpointTable, ({ one }) => ({ + user: one(usersTable, { + fields: [checkpointTable.user_id], + references: [usersTable.id], + }), + volume: one(checkpointVolumeTable, { + fields: [checkpointTable.checkpoint_volume_id], + references: [checkpointVolumeTable.id], + }), +})); + +export const checkpointVolumeRelations = relations( + checkpointVolumeTable, + ({ many, one }) => ({ + checkpoint: many(checkpointTable), + user: one(usersTable, { + fields: [checkpointVolumeTable.user_id], + references: [usersTable.id], + }), + }) +); + export const subscriptionPlan = pgEnum("subscription_plan", [ "basic", "pro", @@ -452,40 +486,6 @@ export const insertCivitaiCheckpointSchema = createInsertSchema( } ); -export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", { - id: uuid("id").primaryKey().defaultRandom().notNull(), - user_id: text("user_id").references(() => usersTable.id, { - // onDelete: "cascade", - }), - org_id: text("org_id"), - volume_name: text("volume_name").notNull(), - created_at: timestamp("created_at").defaultNow().notNull(), - updated_at: timestamp("updated_at").defaultNow().notNull(), - disabled: boolean("disabled").default(false).notNull(), -}); - -export const checkpointRelations = relations(checkpointTable, ({ one }) => ({ - user: one(usersTable, { - fields: [checkpointTable.user_id], - references: [usersTable.id], - }), - volume: one(checkpointVolumeTable, { - fields: [checkpointTable.checkpoint_volume_id], - references: [checkpointVolumeTable.id], - }), -})); - -export const checkpointVolumeRelations = relations( - checkpointVolumeTable, - ({ many, one }) => ({ - checkpoint: many(checkpointTable), - user: one(usersTable, { - fields: [checkpointVolumeTable.user_id], - references: [usersTable.id], - }), - }) -); - export type UserType = InferSelectModel; export type WorkflowType = InferSelectModel; export type MachineType = InferSelectModel; diff --git a/web/src/server/curdCheckpoint.ts b/web/src/server/curdCheckpoint.ts index acbacc9..7dc1046 100644 --- a/web/src/server/curdCheckpoint.ts +++ b/web/src/server/curdCheckpoint.ts @@ -15,6 +15,7 @@ import { headers } from "next/headers"; import { addCivitaiCheckpointSchema } from "./addCheckpointSchema"; import { and, eq, isNull } from "drizzle-orm"; import { CivitaiModelResponse } from "@/types/civitai"; +import { CheckpointItemList } from "@/components/CheckpointList"; export async function getCheckpoints() { const { userId, orgId } = auth(); @@ -74,12 +75,12 @@ export async function getCheckpointVolumes() { } export async function retrieveCheckpointVolumes() { - let volumes = await getCheckpointVolumes() + let volumes = await getCheckpointVolumes(); if (volumes.length === 0) { // create volume if not already created - volumes = await addCheckpointVolume() - } - return volumes + volumes = await addCheckpointVolume(); + } + return volumes; } export async function addCheckpointVolume() { @@ -97,7 +98,7 @@ export async function addCheckpointVolume() { disabled: false, // Default value }) .returning(); // Returns the inserted row -return insertedVolume; + return insertedVolume; } function getUrl(civitai_url: string) { @@ -115,30 +116,24 @@ function getUrl(civitai_url: string) { export const addCivitaiCheckpoint = withServerPromise( async (data: z.infer) => { const { userId, orgId } = auth(); - console.log("START") - console.log("1"); if (!data.civitai_url) return { error: "no civitai_url" }; - console.log("2"); if (!userId) return { error: "No user id" }; - console.log("3"); const { url, modelVersionId } = getUrl(data?.civitai_url); console.log("4", url, modelVersionId); const civitaiModelRes = await fetch(url) .then((x) => x.json()) .then((a) => { - console.log(a) + console.log(a); return CivitaiModelResponse.parse(a); }); console.log("5"); if (civitaiModelRes?.modelVersions?.length === 0) { - console.log("6"); return; // no versions to download } - console.log("7"); let selectedModelVersion; let selectedModelVersionId: string | null = modelVersionId; if (!selectedModelVersionId) { @@ -153,23 +148,16 @@ export const addCivitaiCheckpoint = withServerPromise( } selectedModelVersionId = selectedModelVersion?.id.toString(); } - console.log("8"); const checkpointVolumes = await getCheckpointVolumes(); - console.log("9"); let cVolume; if (checkpointVolumes.length === 0) { - console.log("10"); const volume = await addCheckpointVolume(); - console.log("11"); cVolume = volume[0]; } else { - console.log("12"); cVolume = checkpointVolumes[0]; } - console.log("13"); - const a = await db .insert(checkpointTable) .values({ @@ -183,18 +171,52 @@ export const addCivitaiCheckpoint = withServerPromise( civitai_download_url: selectedModelVersion.files[0].downloadUrl, civitai_model_response: civitaiModelRes, checkpoint_volume_id: cVolume.id, + updated_at: new Date(), }) .returning(); - console.log("14"); const b = a[0]; await uploadCheckpoint(data, b, cVolume); - console.log("15"); // redirect(`/checkpoints/${b.id}`); }, ); +// export const redownloadCheckpoint = withServerPromise( +// async (data: CheckpointItemList) => { +// const { userId } = auth(); +// if (!userId) return { error: "No user id" }; +// +// const checkpointVolumes = await getCheckpointVolumes(); +// let cVolume; +// if (checkpointVolumes.length === 0) { +// const volume = await addCheckpointVolume(); +// cVolume = volume[0]; +// } else { +// cVolume = checkpointVolumes[0]; +// } +// +// console.log("data"); +// console.log(data); +// +// const a = await db +// .update(checkpointTable) +// .set({ +// // status: "started", +// // updated_at: new Date(), +// }) +// .returning(); +// +// const b = a[0]; +// +// console.log("b"); +// console.log(b); +// +// await uploadCheckpoint(data, b, cVolume); +// // redirect(`/checkpoints/${b.id}`); +// }, +// ); + async function uploadCheckpoint( data: z.infer, c: CheckpointType,