working
This commit is contained in:
parent
911cc8d16b
commit
db04d02d34
@ -47,6 +47,7 @@ machine_id_websocket_dict = {}
|
|||||||
machine_id_status = {}
|
machine_id_status = {}
|
||||||
|
|
||||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||||
|
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||||
|
|
||||||
|
|
||||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
|
class FlyReplayMiddleware(BaseHTTPMiddleware):
|
||||||
@ -272,7 +273,8 @@ async def upload_logic(body: UploadBody):
|
|||||||
"checkpoint_id": body.checkpoint_id,
|
"checkpoint_id": body.checkpoint_id,
|
||||||
"volume_id": body.volume_id,
|
"volume_id": body.volume_id,
|
||||||
"folder_path": upload_path,
|
"folder_path": upload_path,
|
||||||
}
|
},
|
||||||
|
"civitai_api_key": os.environ.get('CIVITAI_API_KEY')
|
||||||
}
|
}
|
||||||
with open(f"{folder_path}/config.py", "w") as f:
|
with open(f"{folder_path}/config.py", "w") as f:
|
||||||
f.write("config = " + json.dumps(config))
|
f.write("config = " + json.dumps(config))
|
||||||
|
@ -25,16 +25,13 @@ vol_name_to_links = config["volume_names"]
|
|||||||
vol_name_to_path = config["volume_paths"]
|
vol_name_to_path = config["volume_paths"]
|
||||||
callback_url = config["callback_url"]
|
callback_url = config["callback_url"]
|
||||||
callback_body = config["callback_body"]
|
callback_body = config["callback_body"]
|
||||||
|
civitai_key = config["civitai_api_key"]
|
||||||
|
|
||||||
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
|
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
|
||||||
image = (
|
image = (
|
||||||
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
|
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
|
||||||
)
|
)
|
||||||
|
|
||||||
print(vol_name_to_links)
|
|
||||||
print(vol_name_to_path)
|
|
||||||
print(volumes)
|
|
||||||
|
|
||||||
# download config { "download_url": "", "folder_path": ""}
|
# download config { "download_url": "", "folder_path": ""}
|
||||||
timeout=5000
|
timeout=5000
|
||||||
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
|
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
|
||||||
@ -42,20 +39,23 @@ def download_model(volume_name, download_config):
|
|||||||
import requests
|
import requests
|
||||||
download_url = download_config["download_url"]
|
download_url = download_config["download_url"]
|
||||||
folder_path = download_config["folder_path"]
|
folder_path = download_config["folder_path"]
|
||||||
|
|
||||||
volume_base_path = vol_name_to_path[volume_name]
|
volume_base_path = vol_name_to_path[volume_name]
|
||||||
model_store_path = os.path.join(volume_base_path, folder_path)
|
model_store_path = os.path.join(volume_base_path, folder_path)
|
||||||
|
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
|
||||||
|
print('downlodaing', modified_download_url)
|
||||||
|
|
||||||
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
|
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
|
||||||
subprocess.run(["ls", "-la", volume_base_path])
|
subprocess.run(["ls", "-la", volume_base_path])
|
||||||
subprocess.run(["ls", "-la", model_store_path])
|
subprocess.run(["ls", "-la", model_store_path])
|
||||||
volumes[volume_base_path].commit()
|
volumes[volume_base_path].commit()
|
||||||
|
|
||||||
|
|
||||||
status = {"status": "success"}
|
status = {"status": "success"}
|
||||||
requests.post(callback_url, json={**status, **callback_body})
|
requests.post(callback_url, json={**status, **callback_body})
|
||||||
print(f"finished! sending to {callback_url}")
|
print(f"finished! sending to {callback_url}")
|
||||||
pprint({**status, **callback_body})
|
pprint({**status, **callback_body})
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@stub.local_entrypoint()
|
||||||
def simple_download():
|
def simple_download():
|
||||||
import requests
|
import requests
|
||||||
|
@ -13,5 +13,6 @@ config = {
|
|||||||
"checkpoint_id": "",
|
"checkpoint_id": "",
|
||||||
"volume_id": "",
|
"volume_id": "",
|
||||||
"folder_path": "images",
|
"folder_path": "images",
|
||||||
}
|
},
|
||||||
|
"civitai_api_key": "",
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,8 @@ export async function POST(request: Request) {
|
|||||||
.update(checkpointTable)
|
.update(checkpointTable)
|
||||||
.set({
|
.set({
|
||||||
status: "success",
|
status: "success",
|
||||||
folder_path
|
folder_path,
|
||||||
|
updated_at: new Date(),
|
||||||
// build_log: build_log,
|
// build_log: build_log,
|
||||||
})
|
})
|
||||||
.where(eq(checkpointTable.id, checkpoint_id));
|
.where(eq(checkpointTable.id, checkpoint_id));
|
||||||
@ -37,6 +38,7 @@ export async function POST(request: Request) {
|
|||||||
.set({
|
.set({
|
||||||
status: "failed",
|
status: "failed",
|
||||||
error_log,
|
error_log,
|
||||||
|
updated_at: new Date(),
|
||||||
// status: "error",
|
// status: "error",
|
||||||
// build_log: build_log,
|
// build_log: build_log,
|
||||||
})
|
})
|
||||||
|
@ -116,6 +116,38 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
|
|||||||
{row.original.status}
|
{row.original.status}
|
||||||
</Badge>
|
</Badge>
|
||||||
);
|
);
|
||||||
|
// const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
|
||||||
|
// const lastUpdated = new Date(row.original.updated_at);
|
||||||
|
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||||
|
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||||
|
// cell: ({ row }) => {
|
||||||
|
// // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
|
||||||
|
// // const lastUpdated = new Date(row.original.updated_at);
|
||||||
|
// // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||||
|
// const canReDownload = true;
|
||||||
|
//
|
||||||
|
// return (
|
||||||
|
// <div className="flex items-center space-x-2">
|
||||||
|
// <Badge
|
||||||
|
// variant={row.original.status === "failed"
|
||||||
|
// ? "red"
|
||||||
|
// : row.original.status === "started"
|
||||||
|
// ? "yellow"
|
||||||
|
// : "green"}
|
||||||
|
// >
|
||||||
|
// {row.original.status}
|
||||||
|
// </Badge>
|
||||||
|
// {canReDownload && (
|
||||||
|
// <RefreshCcw
|
||||||
|
// onClick={() => {
|
||||||
|
// redownloadCheckpoint(row.original);
|
||||||
|
// }}
|
||||||
|
// className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
|
||||||
|
// />
|
||||||
|
// )}
|
||||||
|
// </div>
|
||||||
|
// );
|
||||||
|
// },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -406,13 +406,47 @@ export const checkpointTable = dbSchema.table("checkpoints", {
|
|||||||
|
|
||||||
is_public: boolean("is_public").notNull().default(false),
|
is_public: boolean("is_public").notNull().default(false),
|
||||||
status: resourceUpload("status").notNull().default("started"),
|
status: resourceUpload("status").notNull().default("started"),
|
||||||
upload_machine_id: text("upload_machine_id"),
|
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
|
||||||
upload_type: modelUploadType("upload_type").notNull(),
|
upload_type: modelUploadType("upload_type").notNull(),
|
||||||
error_log: text("error_log"),
|
error_log: text("error_log"),
|
||||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
||||||
|
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||||
|
user_id: text("user_id").references(() => usersTable.id, {
|
||||||
|
// onDelete: "cascade",
|
||||||
|
}),
|
||||||
|
org_id: text("org_id"),
|
||||||
|
volume_name: text("volume_name").notNull(),
|
||||||
|
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||||
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
|
disabled: boolean("disabled").default(false).notNull(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
||||||
|
user: one(usersTable, {
|
||||||
|
fields: [checkpointTable.user_id],
|
||||||
|
references: [usersTable.id],
|
||||||
|
}),
|
||||||
|
volume: one(checkpointVolumeTable, {
|
||||||
|
fields: [checkpointTable.checkpoint_volume_id],
|
||||||
|
references: [checkpointVolumeTable.id],
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
export const checkpointVolumeRelations = relations(
|
||||||
|
checkpointVolumeTable,
|
||||||
|
({ many, one }) => ({
|
||||||
|
checkpoint: many(checkpointTable),
|
||||||
|
user: one(usersTable, {
|
||||||
|
fields: [checkpointVolumeTable.user_id],
|
||||||
|
references: [usersTable.id],
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
export const subscriptionPlan = pgEnum("subscription_plan", [
|
export const subscriptionPlan = pgEnum("subscription_plan", [
|
||||||
"basic",
|
"basic",
|
||||||
"pro",
|
"pro",
|
||||||
@ -452,40 +486,6 @@ export const insertCivitaiCheckpointSchema = createInsertSchema(
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
|
||||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
|
||||||
user_id: text("user_id").references(() => usersTable.id, {
|
|
||||||
// onDelete: "cascade",
|
|
||||||
}),
|
|
||||||
org_id: text("org_id"),
|
|
||||||
volume_name: text("volume_name").notNull(),
|
|
||||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
|
||||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
|
||||||
disabled: boolean("disabled").default(false).notNull(),
|
|
||||||
});
|
|
||||||
|
|
||||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
|
||||||
user: one(usersTable, {
|
|
||||||
fields: [checkpointTable.user_id],
|
|
||||||
references: [usersTable.id],
|
|
||||||
}),
|
|
||||||
volume: one(checkpointVolumeTable, {
|
|
||||||
fields: [checkpointTable.checkpoint_volume_id],
|
|
||||||
references: [checkpointVolumeTable.id],
|
|
||||||
}),
|
|
||||||
}));
|
|
||||||
|
|
||||||
export const checkpointVolumeRelations = relations(
|
|
||||||
checkpointVolumeTable,
|
|
||||||
({ many, one }) => ({
|
|
||||||
checkpoint: many(checkpointTable),
|
|
||||||
user: one(usersTable, {
|
|
||||||
fields: [checkpointVolumeTable.user_id],
|
|
||||||
references: [usersTable.id],
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
export type UserType = InferSelectModel<typeof usersTable>;
|
export type UserType = InferSelectModel<typeof usersTable>;
|
||||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
||||||
export type MachineType = InferSelectModel<typeof machinesTable>;
|
export type MachineType = InferSelectModel<typeof machinesTable>;
|
||||||
|
@ -15,6 +15,7 @@ import { headers } from "next/headers";
|
|||||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
|
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
|
||||||
import { and, eq, isNull } from "drizzle-orm";
|
import { and, eq, isNull } from "drizzle-orm";
|
||||||
import { CivitaiModelResponse } from "@/types/civitai";
|
import { CivitaiModelResponse } from "@/types/civitai";
|
||||||
|
import { CheckpointItemList } from "@/components/CheckpointList";
|
||||||
|
|
||||||
export async function getCheckpoints() {
|
export async function getCheckpoints() {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
@ -74,12 +75,12 @@ export async function getCheckpointVolumes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function retrieveCheckpointVolumes() {
|
export async function retrieveCheckpointVolumes() {
|
||||||
let volumes = await getCheckpointVolumes()
|
let volumes = await getCheckpointVolumes();
|
||||||
if (volumes.length === 0) {
|
if (volumes.length === 0) {
|
||||||
// create volume if not already created
|
// create volume if not already created
|
||||||
volumes = await addCheckpointVolume()
|
volumes = await addCheckpointVolume();
|
||||||
}
|
}
|
||||||
return volumes
|
return volumes;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function addCheckpointVolume() {
|
export async function addCheckpointVolume() {
|
||||||
@ -97,7 +98,7 @@ export async function addCheckpointVolume() {
|
|||||||
disabled: false, // Default value
|
disabled: false, // Default value
|
||||||
})
|
})
|
||||||
.returning(); // Returns the inserted row
|
.returning(); // Returns the inserted row
|
||||||
return insertedVolume;
|
return insertedVolume;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getUrl(civitai_url: string) {
|
function getUrl(civitai_url: string) {
|
||||||
@ -115,30 +116,24 @@ function getUrl(civitai_url: string) {
|
|||||||
export const addCivitaiCheckpoint = withServerPromise(
|
export const addCivitaiCheckpoint = withServerPromise(
|
||||||
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
|
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
console.log("START")
|
|
||||||
console.log("1");
|
|
||||||
|
|
||||||
if (!data.civitai_url) return { error: "no civitai_url" };
|
if (!data.civitai_url) return { error: "no civitai_url" };
|
||||||
console.log("2");
|
|
||||||
if (!userId) return { error: "No user id" };
|
if (!userId) return { error: "No user id" };
|
||||||
console.log("3");
|
|
||||||
|
|
||||||
const { url, modelVersionId } = getUrl(data?.civitai_url);
|
const { url, modelVersionId } = getUrl(data?.civitai_url);
|
||||||
console.log("4", url, modelVersionId);
|
console.log("4", url, modelVersionId);
|
||||||
const civitaiModelRes = await fetch(url)
|
const civitaiModelRes = await fetch(url)
|
||||||
.then((x) => x.json())
|
.then((x) => x.json())
|
||||||
.then((a) => {
|
.then((a) => {
|
||||||
console.log(a)
|
console.log(a);
|
||||||
return CivitaiModelResponse.parse(a);
|
return CivitaiModelResponse.parse(a);
|
||||||
});
|
});
|
||||||
console.log("5");
|
console.log("5");
|
||||||
|
|
||||||
if (civitaiModelRes?.modelVersions?.length === 0) {
|
if (civitaiModelRes?.modelVersions?.length === 0) {
|
||||||
console.log("6");
|
|
||||||
return; // no versions to download
|
return; // no versions to download
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log("7");
|
|
||||||
let selectedModelVersion;
|
let selectedModelVersion;
|
||||||
let selectedModelVersionId: string | null = modelVersionId;
|
let selectedModelVersionId: string | null = modelVersionId;
|
||||||
if (!selectedModelVersionId) {
|
if (!selectedModelVersionId) {
|
||||||
@ -153,23 +148,16 @@ export const addCivitaiCheckpoint = withServerPromise(
|
|||||||
}
|
}
|
||||||
selectedModelVersionId = selectedModelVersion?.id.toString();
|
selectedModelVersionId = selectedModelVersion?.id.toString();
|
||||||
}
|
}
|
||||||
console.log("8");
|
|
||||||
|
|
||||||
const checkpointVolumes = await getCheckpointVolumes();
|
const checkpointVolumes = await getCheckpointVolumes();
|
||||||
console.log("9");
|
|
||||||
let cVolume;
|
let cVolume;
|
||||||
if (checkpointVolumes.length === 0) {
|
if (checkpointVolumes.length === 0) {
|
||||||
console.log("10");
|
|
||||||
const volume = await addCheckpointVolume();
|
const volume = await addCheckpointVolume();
|
||||||
console.log("11");
|
|
||||||
cVolume = volume[0];
|
cVolume = volume[0];
|
||||||
} else {
|
} else {
|
||||||
console.log("12");
|
|
||||||
cVolume = checkpointVolumes[0];
|
cVolume = checkpointVolumes[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log("13");
|
|
||||||
|
|
||||||
const a = await db
|
const a = await db
|
||||||
.insert(checkpointTable)
|
.insert(checkpointTable)
|
||||||
.values({
|
.values({
|
||||||
@ -183,18 +171,52 @@ export const addCivitaiCheckpoint = withServerPromise(
|
|||||||
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
||||||
civitai_model_response: civitaiModelRes,
|
civitai_model_response: civitaiModelRes,
|
||||||
checkpoint_volume_id: cVolume.id,
|
checkpoint_volume_id: cVolume.id,
|
||||||
|
updated_at: new Date(),
|
||||||
})
|
})
|
||||||
.returning();
|
.returning();
|
||||||
console.log("14");
|
|
||||||
|
|
||||||
const b = a[0];
|
const b = a[0];
|
||||||
|
|
||||||
await uploadCheckpoint(data, b, cVolume);
|
await uploadCheckpoint(data, b, cVolume);
|
||||||
console.log("15");
|
|
||||||
// redirect(`/checkpoints/${b.id}`);
|
// redirect(`/checkpoints/${b.id}`);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// export const redownloadCheckpoint = withServerPromise(
|
||||||
|
// async (data: CheckpointItemList) => {
|
||||||
|
// const { userId } = auth();
|
||||||
|
// if (!userId) return { error: "No user id" };
|
||||||
|
//
|
||||||
|
// const checkpointVolumes = await getCheckpointVolumes();
|
||||||
|
// let cVolume;
|
||||||
|
// if (checkpointVolumes.length === 0) {
|
||||||
|
// const volume = await addCheckpointVolume();
|
||||||
|
// cVolume = volume[0];
|
||||||
|
// } else {
|
||||||
|
// cVolume = checkpointVolumes[0];
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// console.log("data");
|
||||||
|
// console.log(data);
|
||||||
|
//
|
||||||
|
// const a = await db
|
||||||
|
// .update(checkpointTable)
|
||||||
|
// .set({
|
||||||
|
// // status: "started",
|
||||||
|
// // updated_at: new Date(),
|
||||||
|
// })
|
||||||
|
// .returning();
|
||||||
|
//
|
||||||
|
// const b = a[0];
|
||||||
|
//
|
||||||
|
// console.log("b");
|
||||||
|
// console.log(b);
|
||||||
|
//
|
||||||
|
// await uploadCheckpoint(data, b, cVolume);
|
||||||
|
// // redirect(`/checkpoints/${b.id}`);
|
||||||
|
// },
|
||||||
|
// );
|
||||||
|
|
||||||
async function uploadCheckpoint(
|
async function uploadCheckpoint(
|
||||||
data: z.infer<typeof addCivitaiCheckpointSchema>,
|
data: z.infer<typeof addCivitaiCheckpointSchema>,
|
||||||
c: CheckpointType,
|
c: CheckpointType,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user