This commit is contained in:
Nicholas Koben Kao 2024-01-22 19:05:18 -08:00
parent fed7b380b6
commit f6a1b88dda
9 changed files with 201 additions and 283 deletions

View File

@ -8,6 +8,7 @@ from enum import Enum
import json
import subprocess
import time
from uuid import uuid4
from contextlib import asynccontextmanager
import asyncio
import threading
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send
from concurrent.futures import ThreadPoolExecutor
# executor = ThreadPoolExecutor(max_workers=5)
@ -227,15 +229,47 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
class UploadBody(BaseModel):
download_url: str
volume_name: str
volume_id: str
# callback_url: str
@app.post("/upload_volume")
async def upload_checkpoint(body: UploadBody):
global last_activity_time
last_activity_time = time.time()
logger.info(f"Extended inactivity time to {global_timeout}")
download_url = body.download_url
volume_name = body.download_url
volume_name = body.volume_name
# callback_url = body.callback_url
folder_path = f"/app/builds/{body.volume_id}"
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
await cp_process.wait()
# Write the config file
config = {
"volume_names": {
volume_name: download_url
},
"paths": {
volume_name: f'/volumes/{uuid4()}'
},
}
await asyncio.subprocess.create_subprocess_shell(
f"modal run app.py",
# stdout=asyncio.subprocess.PIPE,
# stderr=asyncio.subprocess.PIPE,
cwd=folder_path,
env={**os.environ, "COLUMNS": "10000"}
)
with open(f"{folder_path}/config.py", "w") as f:
f.write("config = " + json.dumps(config))
# check that thi
return
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
@app.post("/create")

View File

@ -1,51 +1,42 @@
import modal
from config import config
import os
import uuid
import subprocess
stub = modal.Stub()
base_path = "/volumes"
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
def is_valid_name(name: str) -> bool:
allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
def create_volumes(volume_names):
def create_volumes(volume_names, paths):
path_to_vol = {}
vol_to_path = {}
for volume_name in volume_names.keys():
if not is_valid_name(volume_name):
pass
modal_volume = modal.Volume.persisted(volume_name)
volume_path = create_volume_path(base_path)
path_to_vol[volume_path] = modal_volume
vol_to_path[volume_name] = volume_path
path_to_vol[paths[volume_name]] = modal_volume
return (path_to_vol, vol_to_path)
def create_volume_path(base_path: str):
random_path = str(uuid.uuid4())
return os.path.join(base_path, random_path)
return path_to_vol
vol_name_to_links = config["volume_names"]
(path_to_vol, vol_name_to_path) = create_volumes(vol_name_to_links)
vol_name_to_path = config["paths"]
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
image = (
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
)
print(vol_name_to_links)
print(path_to_vol)
print(vol_name_to_path)
print(volumes)
@stub.function(volumes=path_to_vol, image=image, timeout=5000, gpu=None)
@stub.function(volumes=volumes, image=image, timeout=5000, gpu=None)
def download_model(volume_name, download_url):
model_store_path = vol_name_to_path[volume_name]
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
subprocess.run(["ls", "-la", model_store_path])
path_to_vol[model_store_path].commit()
volumes[model_store_path].commit()
@stub.local_entrypoint()
def simple_download():

View File

@ -1,5 +1,8 @@
config = {
"volume_names": {
"eg1": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png"
"test": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png"
},
"paths": {
"test": "/volumes/something"
}
}

View File

@ -33,7 +33,7 @@ CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
"updated_at" timestamp DEFAULT now() NOT NULL
);
--> statement-breakpoint
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpointVolumeTable" (
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
"user_id" text,
"org_id" text,
@ -56,7 +56,7 @@ EXCEPTION
END $$;
--> statement-breakpoint
DO $$ BEGIN
ALTER TABLE "comfyui_deploy"."checkpointVolumeTable" ADD CONSTRAINT "checkpointVolumeTable_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN null;
END $$;

View File

@ -1,5 +1,5 @@
{
"id": "66dbc84a-6cd8-4692-9d24-fdcac227b23c",
"id": "4d5b29d0-848f-4c2e-a2cd-2932f1fa38c6",
"prevId": "db06ea66-92c2-4ebe-93c1-6cb8a90ccd8b",
"version": "5",
"dialect": "pg",
@ -250,8 +250,8 @@
"compositePrimaryKeys": {},
"uniqueConstraints": {}
},
"checkpointVolumeTable": {
"name": "checkpointVolumeTable",
"checkpoint_volume": {
"name": "checkpoint_volume",
"schema": "comfyui_deploy",
"columns": {
"id": {
@ -303,9 +303,9 @@
},
"indexes": {},
"foreignKeys": {
"checkpointVolumeTable_user_id_users_id_fk": {
"name": "checkpointVolumeTable_user_id_users_id_fk",
"tableFrom": "checkpointVolumeTable",
"checkpoint_volume_user_id_users_id_fk": {
"name": "checkpoint_volume_user_id_users_id_fk",
"tableFrom": "checkpoint_volume",
"tableTo": "users",
"columnsFrom": [
"user_id"

View File

@ -222,8 +222,8 @@
{
"idx": 31,
"version": "5",
"when": 1705963548821,
"tag": "0031_common_deathbird",
"when": 1705975916818,
"tag": "0031_safe_multiple_man",
"breakpoints": true
}
]

View File

@ -1,3 +1,4 @@
import { CivitaiModelResponse } from "@/types/civitai";
import { type InferSelectModel, relations } from "drizzle-orm";
import {
boolean,
@ -331,198 +332,9 @@ export const apiKeyTable = dbSchema.table("api_keys", {
updated_at: timestamp("updated_at").defaultNow().notNull(),
});
// const civitaiModelVersion = z.object({
// id: z.number(),
// modelId: z.number(),
// name: z.string(),
// createdAt: z.string(),
// updatedAt: z.string(),
// status: z.string(),
// publishedAt: z.string(),
// trainedWords: z.array(z.string()).optional(),
// trainingStatus: z.string().optional(),
// trainingDetails: z.string().optional(),
// baseModel: z.string(),
// baseModelType: z.string(),
// earlyAccessTimeFrame: z.number(),
// description: z.string().optional(),
// vaeId: z.string().optional(),
// stats: z.object({
// downloadCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// }),
// files: z.array(z.object({
// id: z.number(),
// sizeKB: z.number(),
// name: z.string(),
// type: z.string(),
// metadata: z.object({
// fp: z.string(),
// size: z.string(),
// format: z.string(),
// }),
// pickleScanResult: z.string(),
// pickleScanMessage: z.string().optional(),
// virusScanResult: z.string(),
// virusScanMessage: z.string().optional(),
// scannedAt: z.string(),
// hashes: z.object({
// AutoV1: z.string(),
// AutoV2: z.string(),
// SHA256: z.string(),
// CRC32: z.string(),
// BLAKE3: z.string(),
// AutoV3: z.string(),
// }),
// downloadUrl: z.string(),
// primary: z.boolean(),
// })),
// images: z.array(z.object({
// url: z.string(),
// nsfw: z.string(),
// width: z.number(),
// height: z.number(),
// hash: z.string(),
// type: z.string(),
// metadata: z.object({
// hash: z.string(),
// size: z.number(),
// width: z.number(),
// height: z.number(),
// }),
// meta: z.any(),
// })),
// downloadUrl: z.string(),
// });
//
// const civitaiModelResponseType = z.object({
// id: z.number(),
// name: z.string(),
// description: z.string().optional(),
// type: z.string(),
// poi: z.boolean(),
// nsfw: z.boolean(),
// allowNoCredit: z.boolean(),
// allowCommercialUse: z.string(),
// allowDerivatives: z.boolean(),
// allowDifferentLicense: z.boolean(),
// stats: z.object({
// downloadCount: z.number(),
// favoriteCount: z.number(),
// commentCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// tippedAmountCount: z.number(),
// }),
// creator: z.object({
// username: z.string(),
// image: z.string(),
// }),
// tags: z.array(z.string()),
// modelVersions: z.array(civitaiModelVersion),
// });
export const CivitaiModel = z.object({
id: z.number(),
name: z.string(),
description: z.string(),
type: z.string(),
// poi: z.boolean(),
// nsfw: z.boolean(),
// allowNoCredit: z.boolean(),
// allowCommercialUse: z.string(),
// allowDerivatives: z.boolean(),
// allowDifferentLicense: z.boolean(),
// stats: z.object({
// downloadCount: z.number(),
// favoriteCount: z.number(),
// commentCount: z.number(),
// ratingCount: z.number(),
// rating: z.number(),
// tippedAmountCount: z.number(),
// }),
creator: z
.object({
username: z.string().nullable(),
image: z.string().nullable().default(null),
})
.nullable(),
tags: z.array(z.string()),
modelVersions: z.array(
z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string(),
updatedAt: z.string(),
status: z.string(),
publishedAt: z.string(),
trainedWords: z.array(z.unknown()),
trainingStatus: z.string().nullable(),
trainingDetails: z.string().nullable(),
baseModel: z.string(),
baseModelType: z.string().nullable(),
earlyAccessTimeFrame: z.number(),
description: z.string().nullable(),
vaeId: z.number().nullable(),
stats: z.object({
downloadCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
}),
files: z.array(
z.object({
id: z.number(),
sizeKB: z.number(),
name: z.string(),
type: z.string(),
// metadata: z.object({
// fp: z.string().nullable().optional(),
// size: z.string().nullable().optional(),
// format: z.string().nullable().optional(),
// }),
// pickleScanResult: z.string(),
// pickleScanMessage: z.string(),
// virusScanResult: z.string(),
// virusScanMessage: z.string().nullable(),
// scannedAt: z.string(),
// hashes: z.object({
// AutoV1: z.string().nullable().optional(),
// AutoV2: z.string().nullable().optional(),
// SHA256: z.string().nullable().optional(),
// CRC32: z.string().nullable().optional(),
// BLAKE3: z.string().nullable().optional(),
// }),
downloadUrl: z.string(),
// primary: z.boolean().default(false),
}),
),
images: z.array(
z.object({
id: z.number(),
url: z.string(),
nsfw: z.string(),
width: z.number(),
height: z.number(),
hash: z.string(),
type: z.string(),
metadata: z.object({
hash: z.string(),
width: z.number(),
height: z.number(),
}),
meta: z.any(),
}),
),
downloadUrl: z.string(),
}),
),
});
export const resourceUpload = pgEnum("resource_upload", [
"started",
"failed",
"error",
"succeded",
]);
@ -552,7 +364,7 @@ export const checkpointTable = dbSchema.table("checkpoints", {
civitai_url: text("civitai_url"),
civitai_download_url: text("civitai_download_url"),
civitai_model_response: jsonb("civitai_model_response").$type<
z.infer<typeof CivitaiModel>
z.infer<typeof CivitaiModelResponse>
>(),
hf_url: text("hf_url"),
@ -563,6 +375,7 @@ export const checkpointTable = dbSchema.table("checkpoints", {
status: resourceUpload("status").notNull().default("started"),
upload_machine_id: text("upload_machine_id"),
upload_type: modelUploadType("upload_type").notNull(),
build_log: text("build_log"),
created_at: timestamp("created_at").defaultNow().notNull(),
updated_at: timestamp("updated_at").defaultNow().notNull(),
@ -579,7 +392,7 @@ export const insertCivitaiCheckpointSchema = createInsertSchema(
},
);
export const checkpointVolumeTable = dbSchema.table("checkpointVolumeTable", {
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
id: uuid("id").primaryKey().defaultRandom().notNull(),
user_id: text("user_id")
.references(() => usersTable.id, {
@ -605,8 +418,12 @@ export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
export const checkpointVolumeRelations = relations(
checkpointVolumeTable,
({ many }) => ({
({ many, one }) => ({
checkpoint: many(checkpointTable),
user: one(usersTable, {
fields: [checkpointVolumeTable.user_id],
references: [usersTable.id],
}),
}),
);

View File

@ -4,9 +4,8 @@ import { auth } from "@clerk/nextjs";
import {
checkpointTable,
CheckpointType,
checkpointVolumeTable,
volumeTable,
CheckpointVolumeType,
CivitaiModel,
} from "@/db/schema";
import { withServerPromise } from "./withServerPromise";
import { redirect } from "next/navigation";
@ -15,6 +14,7 @@ import type { z } from "zod";
import { headers } from "next/headers";
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
import { and, eq, isNull } from "drizzle-orm";
import { CivitaiModelResponse } from "@/types/civitai";
export async function getCheckpoints() {
const { userId, orgId } = auth();
@ -57,17 +57,17 @@ export async function getCheckpointVolumes() {
if (!userId) throw new Error("No user id");
const checkpointVolume = await db
.select()
.from(checkpointVolumeTable)
.from(volumeTable)
.where(
and(
orgId
? eq(checkpointVolumeTable.org_id, orgId)
? eq(volumeTable.org_id, orgId)
// make sure org_id is null
: and(
eq(checkpointVolumeTable.user_id, userId),
isNull(checkpointVolumeTable.org_id),
eq(volumeTable.user_id, userId),
isNull(volumeTable.org_id),
),
eq(checkpointVolumeTable.disabled, false),
eq(volumeTable.disabled, false),
),
);
return checkpointVolume;
@ -79,7 +79,7 @@ export async function addCheckpointVolume() {
// Insert the new volume into the checkpointVolumeTable
const insertedVolume = await db
.insert(checkpointVolumeTable)
.insert(volumeTable)
.values({
user_id: userId,
org_id: orgId,
@ -115,11 +115,10 @@ export const addCivitaiCheckpoint = withServerPromise(
const civitaiModelRes = await fetch(url)
.then((x) => x.json())
.then((a) => {
console.log(a)
return CivitaiModel.parse(a);
return CivitaiModelResponse.parse(a);
});
if (civitaiModelRes.modelVersions?.length === 0) {
if (civitaiModelRes?.modelVersions?.length === 0) {
return; // no versions to download
}
@ -155,6 +154,8 @@ export const addCivitaiCheckpoint = withServerPromise(
upload_type: "civitai",
civitai_id: civitaiModelRes.id.toString(),
civitai_version_id: selectedModelVersionId,
civitai_url: data.civitai_url,
civitai_download_url: selectedModelVersion.downloadUrl,
civitai_model_response: civitaiModelRes,
checkpoint_volume_id: cVolume.id,
})
@ -192,22 +193,23 @@ async function uploadCheckpoint(
body: JSON.stringify({
download_url: data.civitai_url,
volume_name: v.volume_name,
volume_id: v.id,
callback_url: `${protocol}://${domain}/api/volume-updated`,
}),
},
);
if (!result.ok) {
// const error_log = await result.text();
// await db
// .update(checkpointTable)
// .set({
// ...data,
// status: "error",
// build_log: error_log,
// })
// .where(eq(machinesTable.id, b.id));
// throw new Error(`Error: ${result.statusText} ${error_log}`);
const error_log = await result.text();
await db
.update(checkpointTable)
.set({
...data,
status: "error",
build_log: error_log,
})
.where(eq(checkpointTable.id, b.id));
throw new Error(`Error: ${result.statusText} ${error_log}`);
} else {
// setting the build machine id
const json = await result.json();
@ -215,8 +217,8 @@ async function uploadCheckpoint(
.update(checkpointTable)
.set({
...data,
// build_machine_instance_id: json.build_machine_instance_id,
upload_machine_id: json.build_machine_instance_id,
})
.where(eq(machinesTable.id, b.id));
.where(eq(checkpointTable.id, b.id));
}
}

View File

@ -1,8 +1,102 @@
import { z } from 'zod';
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
export const creatorSchema = z.object({
username: z.string(),
image: z.string(),
username: z.string().optional(),
image: z.string().url().optional(),
});
export const fileMetadataSchema = z.object({
fp: z.string().optional(),
size: z.string().optional(),
format: z.string().optional(),
});
export const fileSchema = z.object({
id: z.number(),
sizeKB: z.number().optional(),
name: z.string(),
type: z.string().optional(),
metadata: fileMetadataSchema.optional(),
pickleScanResult: z.string().optional(),
pickleScanMessage: z.string().nullable(),
virusScanResult: z.string().optional(),
virusScanMessage: z.string().nullable(),
scannedAt: z.string().optional(),
hashes: z.record(z.string()).optional(),
downloadUrl: z.string().url(),
primary: z.boolean().optional().optional(),
});
export const imageMetadataSchema = z.object({
hash: z.string(),
width: z.number(),
height: z.number(),
});
export const imageMetaSchema = z.object({
ENSD: z.string().optional(),
Size: z.string().optional(),
seed: z.number().optional(),
Model: z.string().optional(),
steps: z.number().optional(),
hashes: z.record(z.string()).optional(),
prompt: z.string().optional(),
sampler: z.string().optional(),
cfgScale: z.number().optional(),
ClipSkip: z.number().optional(),
resources: z.array(
z.object({
hash: z.string().optional(),
name: z.string(),
type: z.string(),
weight: z.number().optional(),
})
).optional(),
ModelHash: z.string().optional(),
HiresSteps: z.string().optional(),
HiresUpscale: z.string().optional(),
HiresUpscaler: z.string().optional(),
negativePrompt: z.string(),
DenoisingStrength: z.number().optional(),
});
export const imageSchema = z.object({
url: z.string().url().optional(),
nsfw: z.enum(["None", "Soft", "Mature"]).optional(),
width: z.number().optional(),
height: z.number().optional(),
hash: z.string().optional(),
type: z.string().optional(),
metadata: imageMetadataSchema.optional(),
meta: imageMetaSchema.optional(),
});
export const modelVersionSchema = z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string().optional(),
updatedAt: z.string().optional(),
status: z.enum(["Published", "Unpublished"]).optional(),
publishedAt: z.string().optional(),
trainedWords: z.array(z.string()).nullable(),
trainingStatus: z.string().nullable(),
trainingDetails: z.string().nullable(),
baseModel: z.string().optional(),
baseModelType: z.string().optional(),
earlyAccessTimeFrame: z.number().optional(),
description: z.string().nullable(),
vaeId: z.string().nullable(),
stats: z.object({
downloadCount: z.number(),
ratingCount: z.number(),
rating: z.number(),
}).optional(),
files: z.array(fileSchema),
images: z.array(imageSchema),
downloadUrl: z.string().url(),
});
export const statsSchema = z.object({
@ -14,42 +108,19 @@ export const statsSchema = z.object({
tippedAmountCount: z.number(),
});
export const modelVersionSchema = z.object({
export const CivitaiModelResponse = z.object({
id: z.number(),
modelId: z.number(),
name: z.string(),
createdAt: z.string(),
updatedAt: z.string(),
status: z.string(),
publishedAt: z.string(),
trainedWords: z.array(z.any()), // Replace with more specific type if known
trainingStatus: z.any().optional(),
trainingDetails: z.any().optional(),
baseModel: z.string(),
baseModelType: z.string(),
earlyAccessTimeFrame: z.number(),
name: z.string().optional(),
description: z.string().optional(),
vaeId: z.any().optional(), // Replace with more specific type if known
stats: statsSchema.optional(), // If stats structure is known, replace with specific type
files: z.array(z.any()), // Replace with more specific type if known
images: z.array(z.any()), // Replace with more specific type if known
downloadUrl: z.string(),
});
export const CivitaiModel = z.object({
id: z.number(),
name: z.string(),
description: z.string(),
type: z.string(),
poi: z.boolean(),
nsfw: z.boolean(),
allowNoCredit: z.boolean(),
allowCommercialUse: z.string(),
allowDerivatives: z.boolean(),
allowDifferentLicense: z.boolean(),
stats: statsSchema,
creator: creatorSchema,
tags: z.array(z.string()),
type: z.enum(["Checkpoint", "Lora"]),
poi: z.boolean().optional(),
nsfw: z.boolean().optional(),
allowNoCredit: z.boolean().optional(),
allowCommercialUse: z.enum(["Rent"]).optional(),
allowDerivatives: z.boolean().optional(),
allowDifferentLicense: z.boolean().optional(),
stats: statsSchema.optional(),
creator: creatorSchema.optional(),
tags: z.array(z.string()).optional(),
modelVersions: z.array(modelVersionSchema),
});