working
This commit is contained in:
parent
911cc8d16b
commit
db04d02d34
@ -47,6 +47,7 @@ machine_id_websocket_dict = {}
|
||||
machine_id_status = {}
|
||||
|
||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||
|
||||
|
||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
|
||||
@ -272,7 +273,8 @@ async def upload_logic(body: UploadBody):
|
||||
"checkpoint_id": body.checkpoint_id,
|
||||
"volume_id": body.volume_id,
|
||||
"folder_path": upload_path,
|
||||
}
|
||||
},
|
||||
"civitai_api_key": os.environ.get('CIVITAI_API_KEY')
|
||||
}
|
||||
with open(f"{folder_path}/config.py", "w") as f:
|
||||
f.write("config = " + json.dumps(config))
|
||||
|
@ -25,16 +25,13 @@ vol_name_to_links = config["volume_names"]
|
||||
vol_name_to_path = config["volume_paths"]
|
||||
callback_url = config["callback_url"]
|
||||
callback_body = config["callback_body"]
|
||||
civitai_key = config["civitai_api_key"]
|
||||
|
||||
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
|
||||
image = (
|
||||
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
|
||||
)
|
||||
|
||||
print(vol_name_to_links)
|
||||
print(vol_name_to_path)
|
||||
print(volumes)
|
||||
|
||||
# download config { "download_url": "", "folder_path": ""}
|
||||
timeout=5000
|
||||
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
|
||||
@ -42,20 +39,23 @@ def download_model(volume_name, download_config):
|
||||
import requests
|
||||
download_url = download_config["download_url"]
|
||||
folder_path = download_config["folder_path"]
|
||||
|
||||
volume_base_path = vol_name_to_path[volume_name]
|
||||
model_store_path = os.path.join(volume_base_path, folder_path)
|
||||
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
|
||||
print('downlodaing', modified_download_url)
|
||||
|
||||
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
|
||||
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
|
||||
subprocess.run(["ls", "-la", volume_base_path])
|
||||
subprocess.run(["ls", "-la", model_store_path])
|
||||
volumes[volume_base_path].commit()
|
||||
|
||||
|
||||
status = {"status": "success"}
|
||||
requests.post(callback_url, json={**status, **callback_body})
|
||||
print(f"finished! sending to {callback_url}")
|
||||
pprint({**status, **callback_body})
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def simple_download():
|
||||
import requests
|
||||
|
@ -13,5 +13,6 @@ config = {
|
||||
"checkpoint_id": "",
|
||||
"volume_id": "",
|
||||
"folder_path": "images",
|
||||
}
|
||||
},
|
||||
"civitai_api_key": "",
|
||||
}
|
||||
|
@ -26,7 +26,8 @@ export async function POST(request: Request) {
|
||||
.update(checkpointTable)
|
||||
.set({
|
||||
status: "success",
|
||||
folder_path
|
||||
folder_path,
|
||||
updated_at: new Date(),
|
||||
// build_log: build_log,
|
||||
})
|
||||
.where(eq(checkpointTable.id, checkpoint_id));
|
||||
@ -37,6 +38,7 @@ export async function POST(request: Request) {
|
||||
.set({
|
||||
status: "failed",
|
||||
error_log,
|
||||
updated_at: new Date(),
|
||||
// status: "error",
|
||||
// build_log: build_log,
|
||||
})
|
||||
|
@ -116,6 +116,38 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
|
||||
{row.original.status}
|
||||
</Badge>
|
||||
);
|
||||
// const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
|
||||
// const lastUpdated = new Date(row.original.updated_at);
|
||||
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||
// cell: ({ row }) => {
|
||||
// // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
|
||||
// // const lastUpdated = new Date(row.original.updated_at);
|
||||
// // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||
// const canReDownload = true;
|
||||
//
|
||||
// return (
|
||||
// <div className="flex items-center space-x-2">
|
||||
// <Badge
|
||||
// variant={row.original.status === "failed"
|
||||
// ? "red"
|
||||
// : row.original.status === "started"
|
||||
// ? "yellow"
|
||||
// : "green"}
|
||||
// >
|
||||
// {row.original.status}
|
||||
// </Badge>
|
||||
// {canReDownload && (
|
||||
// <RefreshCcw
|
||||
// onClick={() => {
|
||||
// redownloadCheckpoint(row.original);
|
||||
// }}
|
||||
// className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
|
||||
// />
|
||||
// )}
|
||||
// </div>
|
||||
// );
|
||||
// },
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -406,13 +406,47 @@ export const checkpointTable = dbSchema.table("checkpoints", {
|
||||
|
||||
is_public: boolean("is_public").notNull().default(false),
|
||||
status: resourceUpload("status").notNull().default("started"),
|
||||
upload_machine_id: text("upload_machine_id"),
|
||||
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
|
||||
upload_type: modelUploadType("upload_type").notNull(),
|
||||
error_log: text("error_log"),
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
});
|
||||
|
||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||
user_id: text("user_id").references(() => usersTable.id, {
|
||||
// onDelete: "cascade",
|
||||
}),
|
||||
org_id: text("org_id"),
|
||||
volume_name: text("volume_name").notNull(),
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
disabled: boolean("disabled").default(false).notNull(),
|
||||
});
|
||||
|
||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
||||
user: one(usersTable, {
|
||||
fields: [checkpointTable.user_id],
|
||||
references: [usersTable.id],
|
||||
}),
|
||||
volume: one(checkpointVolumeTable, {
|
||||
fields: [checkpointTable.checkpoint_volume_id],
|
||||
references: [checkpointVolumeTable.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const checkpointVolumeRelations = relations(
|
||||
checkpointVolumeTable,
|
||||
({ many, one }) => ({
|
||||
checkpoint: many(checkpointTable),
|
||||
user: one(usersTable, {
|
||||
fields: [checkpointVolumeTable.user_id],
|
||||
references: [usersTable.id],
|
||||
}),
|
||||
})
|
||||
);
|
||||
|
||||
export const subscriptionPlan = pgEnum("subscription_plan", [
|
||||
"basic",
|
||||
"pro",
|
||||
@ -452,40 +486,6 @@ export const insertCivitaiCheckpointSchema = createInsertSchema(
|
||||
}
|
||||
);
|
||||
|
||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||
user_id: text("user_id").references(() => usersTable.id, {
|
||||
// onDelete: "cascade",
|
||||
}),
|
||||
org_id: text("org_id"),
|
||||
volume_name: text("volume_name").notNull(),
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
disabled: boolean("disabled").default(false).notNull(),
|
||||
});
|
||||
|
||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
||||
user: one(usersTable, {
|
||||
fields: [checkpointTable.user_id],
|
||||
references: [usersTable.id],
|
||||
}),
|
||||
volume: one(checkpointVolumeTable, {
|
||||
fields: [checkpointTable.checkpoint_volume_id],
|
||||
references: [checkpointVolumeTable.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const checkpointVolumeRelations = relations(
|
||||
checkpointVolumeTable,
|
||||
({ many, one }) => ({
|
||||
checkpoint: many(checkpointTable),
|
||||
user: one(usersTable, {
|
||||
fields: [checkpointVolumeTable.user_id],
|
||||
references: [usersTable.id],
|
||||
}),
|
||||
})
|
||||
);
|
||||
|
||||
export type UserType = InferSelectModel<typeof usersTable>;
|
||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
||||
export type MachineType = InferSelectModel<typeof machinesTable>;
|
||||
|
@ -15,6 +15,7 @@ import { headers } from "next/headers";
|
||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
import { CivitaiModelResponse } from "@/types/civitai";
|
||||
import { CheckpointItemList } from "@/components/CheckpointList";
|
||||
|
||||
export async function getCheckpoints() {
|
||||
const { userId, orgId } = auth();
|
||||
@ -74,12 +75,12 @@ export async function getCheckpointVolumes() {
|
||||
}
|
||||
|
||||
export async function retrieveCheckpointVolumes() {
|
||||
let volumes = await getCheckpointVolumes()
|
||||
let volumes = await getCheckpointVolumes();
|
||||
if (volumes.length === 0) {
|
||||
// create volume if not already created
|
||||
volumes = await addCheckpointVolume()
|
||||
}
|
||||
return volumes
|
||||
volumes = await addCheckpointVolume();
|
||||
}
|
||||
return volumes;
|
||||
}
|
||||
|
||||
export async function addCheckpointVolume() {
|
||||
@ -97,7 +98,7 @@ export async function addCheckpointVolume() {
|
||||
disabled: false, // Default value
|
||||
})
|
||||
.returning(); // Returns the inserted row
|
||||
return insertedVolume;
|
||||
return insertedVolume;
|
||||
}
|
||||
|
||||
function getUrl(civitai_url: string) {
|
||||
@ -115,30 +116,24 @@ function getUrl(civitai_url: string) {
|
||||
export const addCivitaiCheckpoint = withServerPromise(
|
||||
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
|
||||
const { userId, orgId } = auth();
|
||||
console.log("START")
|
||||
console.log("1");
|
||||
|
||||
if (!data.civitai_url) return { error: "no civitai_url" };
|
||||
console.log("2");
|
||||
if (!userId) return { error: "No user id" };
|
||||
console.log("3");
|
||||
|
||||
const { url, modelVersionId } = getUrl(data?.civitai_url);
|
||||
console.log("4", url, modelVersionId);
|
||||
const civitaiModelRes = await fetch(url)
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
console.log(a)
|
||||
console.log(a);
|
||||
return CivitaiModelResponse.parse(a);
|
||||
});
|
||||
console.log("5");
|
||||
|
||||
if (civitaiModelRes?.modelVersions?.length === 0) {
|
||||
console.log("6");
|
||||
return; // no versions to download
|
||||
}
|
||||
|
||||
console.log("7");
|
||||
let selectedModelVersion;
|
||||
let selectedModelVersionId: string | null = modelVersionId;
|
||||
if (!selectedModelVersionId) {
|
||||
@ -153,23 +148,16 @@ export const addCivitaiCheckpoint = withServerPromise(
|
||||
}
|
||||
selectedModelVersionId = selectedModelVersion?.id.toString();
|
||||
}
|
||||
console.log("8");
|
||||
|
||||
const checkpointVolumes = await getCheckpointVolumes();
|
||||
console.log("9");
|
||||
let cVolume;
|
||||
if (checkpointVolumes.length === 0) {
|
||||
console.log("10");
|
||||
const volume = await addCheckpointVolume();
|
||||
console.log("11");
|
||||
cVolume = volume[0];
|
||||
} else {
|
||||
console.log("12");
|
||||
cVolume = checkpointVolumes[0];
|
||||
}
|
||||
|
||||
console.log("13");
|
||||
|
||||
const a = await db
|
||||
.insert(checkpointTable)
|
||||
.values({
|
||||
@ -183,18 +171,52 @@ export const addCivitaiCheckpoint = withServerPromise(
|
||||
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
||||
civitai_model_response: civitaiModelRes,
|
||||
checkpoint_volume_id: cVolume.id,
|
||||
updated_at: new Date(),
|
||||
})
|
||||
.returning();
|
||||
console.log("14");
|
||||
|
||||
const b = a[0];
|
||||
|
||||
await uploadCheckpoint(data, b, cVolume);
|
||||
console.log("15");
|
||||
// redirect(`/checkpoints/${b.id}`);
|
||||
},
|
||||
);
|
||||
|
||||
// export const redownloadCheckpoint = withServerPromise(
|
||||
// async (data: CheckpointItemList) => {
|
||||
// const { userId } = auth();
|
||||
// if (!userId) return { error: "No user id" };
|
||||
//
|
||||
// const checkpointVolumes = await getCheckpointVolumes();
|
||||
// let cVolume;
|
||||
// if (checkpointVolumes.length === 0) {
|
||||
// const volume = await addCheckpointVolume();
|
||||
// cVolume = volume[0];
|
||||
// } else {
|
||||
// cVolume = checkpointVolumes[0];
|
||||
// }
|
||||
//
|
||||
// console.log("data");
|
||||
// console.log(data);
|
||||
//
|
||||
// const a = await db
|
||||
// .update(checkpointTable)
|
||||
// .set({
|
||||
// // status: "started",
|
||||
// // updated_at: new Date(),
|
||||
// })
|
||||
// .returning();
|
||||
//
|
||||
// const b = a[0];
|
||||
//
|
||||
// console.log("b");
|
||||
// console.log(b);
|
||||
//
|
||||
// await uploadCheckpoint(data, b, cVolume);
|
||||
// // redirect(`/checkpoints/${b.id}`);
|
||||
// },
|
||||
// );
|
||||
|
||||
async function uploadCheckpoint(
|
||||
data: z.infer<typeof addCivitaiCheckpointSchema>,
|
||||
c: CheckpointType,
|
||||
|
Loading…
x
Reference in New Issue
Block a user