This commit is contained in:
Nicholas Koben Kao 2024-01-24 22:15:12 -08:00
parent 911cc8d16b
commit db04d02d34
7 changed files with 124 additions and 65 deletions

View File

@ -47,6 +47,7 @@ machine_id_websocket_dict = {}
machine_id_status = {}
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
class FlyReplayMiddleware(BaseHTTPMiddleware):
@ -272,7 +273,8 @@ async def upload_logic(body: UploadBody):
"checkpoint_id": body.checkpoint_id,
"volume_id": body.volume_id,
"folder_path": upload_path,
}
},
"civitai_api_key": os.environ.get('CIVITAI_API_KEY')
}
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))

View File

@ -25,16 +25,13 @@ vol_name_to_links = config["volume_names"]
vol_name_to_path = config["volume_paths"]
callback_url = config["callback_url"]
callback_body = config["callback_body"]
civitai_key = config["civitai_api_key"]
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
image = (
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
)
print(vol_name_to_links)
print(vol_name_to_path)
print(volumes)
# download config { "download_url": "", "folder_path": ""}
timeout=5000
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
@ -42,20 +39,23 @@ def download_model(volume_name, download_config):
import requests
download_url = download_config["download_url"]
folder_path = download_config["folder_path"]
volume_base_path = vol_name_to_path[volume_name]
model_store_path = os.path.join(volume_base_path, folder_path)
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
print('downlodaing', modified_download_url)
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
subprocess.run(["ls", "-la", volume_base_path])
subprocess.run(["ls", "-la", model_store_path])
volumes[volume_base_path].commit()
status = {"status": "success"}
requests.post(callback_url, json={**status, **callback_body})
print(f"finished! sending to {callback_url}")
pprint({**status, **callback_body})
@stub.local_entrypoint()
def simple_download():
import requests

View File

@ -13,5 +13,6 @@ config = {
"checkpoint_id": "",
"volume_id": "",
"folder_path": "images",
}
},
"civitai_api_key": "",
}

View File

@ -26,7 +26,8 @@ export async function POST(request: Request) {
.update(checkpointTable)
.set({
status: "success",
folder_path
folder_path,
updated_at: new Date(),
// build_log: build_log,
})
.where(eq(checkpointTable.id, checkpoint_id));
@ -37,6 +38,7 @@ export async function POST(request: Request) {
.set({
status: "failed",
error_log,
updated_at: new Date(),
// status: "error",
// build_log: build_log,
})

View File

@ -116,6 +116,38 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
{row.original.status}
</Badge>
);
// const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
// const lastUpdated = new Date(row.original.updated_at);
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
// cell: ({ row }) => {
// // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
// // const lastUpdated = new Date(row.original.updated_at);
// // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
// const canReDownload = true;
//
// return (
// <div className="flex items-center space-x-2">
// <Badge
// variant={row.original.status === "failed"
// ? "red"
// : row.original.status === "started"
// ? "yellow"
// : "green"}
// >
// {row.original.status}
// </Badge>
// {canReDownload && (
// <RefreshCcw
// onClick={() => {
// redownloadCheckpoint(row.original);
// }}
// className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
// />
// )}
// </div>
// );
// },
},
},
{

View File

@ -406,13 +406,47 @@ export const checkpointTable = dbSchema.table("checkpoints", {
is_public: boolean("is_public").notNull().default(false),
status: resourceUpload("status").notNull().default("started"),
upload_machine_id: text("upload_machine_id"),
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
upload_type: modelUploadType("upload_type").notNull(),
error_log: text("error_log"),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
});
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, {
// onDelete: "cascade",
}),
org_id: text("org_id"),
volume_name: text("volume_name").notNull(),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
disabled: boolean("disabled").default(false).notNull(),
});
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
user: one(usersTable, {
fields: [checkpointTable.user_id],
references: [usersTable.id],
}),
volume: one(checkpointVolumeTable, {
fields: [checkpointTable.checkpoint_volume_id],
references: [checkpointVolumeTable.id],
}),
}));
export const checkpointVolumeRelations = relations(
checkpointVolumeTable,
({ many, one }) => ({
checkpoint: many(checkpointTable),
user: one(usersTable, {
fields: [checkpointVolumeTable.user_id],
references: [usersTable.id],
}),
})
);
export const subscriptionPlan = pgEnum("subscription_plan", [
"basic",
"pro",
@ -452,40 +486,6 @@ export const insertCivitaiCheckpointSchema = createInsertSchema(
}
);
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id").references(() => usersTable.id, {
// onDelete: "cascade",
}),
org_id: text("org_id"),
volume_name: text("volume_name").notNull(),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
disabled: boolean("disabled").default(false).notNull(),
});
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
user: one(usersTable, {
fields: [checkpointTable.user_id],
references: [usersTable.id],
}),
volume: one(checkpointVolumeTable, {
fields: [checkpointTable.checkpoint_volume_id],
references: [checkpointVolumeTable.id],
}),
}));
export const checkpointVolumeRelations = relations(
checkpointVolumeTable,
({ many, one }) => ({
checkpoint: many(checkpointTable),
user: one(usersTable, {
fields: [checkpointVolumeTable.user_id],
references: [usersTable.id],
}),
})
);
export type UserType = InferSelectModel<typeof usersTable>;
export type WorkflowType = InferSelectModel<typeof workflowTable>;
export type MachineType = InferSelectModel<typeof machinesTable>;

View File

@ -15,6 +15,7 @@ import { headers } from "next/headers";
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
import { and, eq, isNull } from "drizzle-orm";
import { CivitaiModelResponse } from "@/types/civitai";
import { CheckpointItemList } from "@/components/CheckpointList";
export async function getCheckpoints() {
const { userId, orgId } = auth();
@ -74,12 +75,12 @@ export async function getCheckpointVolumes() {
}
export async function retrieveCheckpointVolumes() {
let volumes = await getCheckpointVolumes()
let volumes = await getCheckpointVolumes();
if (volumes.length === 0) {
// create volume if not already created
volumes = await addCheckpointVolume()
}
return volumes
volumes = await addCheckpointVolume();
}
return volumes;
}
export async function addCheckpointVolume() {
@ -97,7 +98,7 @@ export async function addCheckpointVolume() {
disabled: false, // Default value
})
.returning(); // Returns the inserted row
return insertedVolume;
return insertedVolume;
}
function getUrl(civitai_url: string) {
@ -115,30 +116,24 @@ function getUrl(civitai_url: string) {
export const addCivitaiCheckpoint = withServerPromise(
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
const { userId, orgId } = auth();
console.log("START")
console.log("1");
if (!data.civitai_url) return { error: "no civitai_url" };
console.log("2");
if (!userId) return { error: "No user id" };
console.log("3");
const { url, modelVersionId } = getUrl(data?.civitai_url);
console.log("4", url, modelVersionId);
const civitaiModelRes = await fetch(url)
.then((x) => x.json())
.then((a) => {
console.log(a)
console.log(a);
return CivitaiModelResponse.parse(a);
});
console.log("5");
if (civitaiModelRes?.modelVersions?.length === 0) {
console.log("6");
return; // no versions to download
}
console.log("7");
let selectedModelVersion;
let selectedModelVersionId: string | null = modelVersionId;
if (!selectedModelVersionId) {
@ -153,23 +148,16 @@ export const addCivitaiCheckpoint = withServerPromise(
}
selectedModelVersionId = selectedModelVersion?.id.toString();
}
console.log("8");
const checkpointVolumes = await getCheckpointVolumes();
console.log("9");
let cVolume;
if (checkpointVolumes.length === 0) {
console.log("10");
const volume = await addCheckpointVolume();
console.log("11");
cVolume = volume[0];
} else {
console.log("12");
cVolume = checkpointVolumes[0];
}
console.log("13");
const a = await db
.insert(checkpointTable)
.values({
@ -183,18 +171,52 @@ export const addCivitaiCheckpoint = withServerPromise(
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
civitai_model_response: civitaiModelRes,
checkpoint_volume_id: cVolume.id,
updated_at: new Date(),
})
.returning();
console.log("14");
const b = a[0];
await uploadCheckpoint(data, b, cVolume);
console.log("15");
// redirect(`/checkpoints/${b.id}`);
},
);
// export const redownloadCheckpoint = withServerPromise(
// async (data: CheckpointItemList) => {
// const { userId } = auth();
// if (!userId) return { error: "No user id" };
//
// const checkpointVolumes = await getCheckpointVolumes();
// let cVolume;
// if (checkpointVolumes.length === 0) {
// const volume = await addCheckpointVolume();
// cVolume = volume[0];
// } else {
// cVolume = checkpointVolumes[0];
// }
//
// console.log("data");
// console.log(data);
//
// const a = await db
// .update(checkpointTable)
// .set({
// // status: "started",
// // updated_at: new Date(),
// })
// .returning();
//
// const b = a[0];
//
// console.log("b");
// console.log(b);
//
// await uploadCheckpoint(data, b, cVolume);
// // redirect(`/checkpoints/${b.id}`);
// },
// );
async function uploadCheckpoint(
data: z.infer<typeof addCivitaiCheckpointSchema>,
c: CheckpointType,