Compare commits
6 Commits
main
...
nickkao/vo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed659ebdb8 | ||
|
|
f6a1b88dda | ||
|
|
fed7b380b6 | ||
|
|
90cec6b778 | ||
|
|
7b61fea849 | ||
|
|
60471a8e01 |
@ -8,6 +8,7 @@ from enum import Enum
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
import threading
|
||||
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp, Scope, Receive, Send
|
||||
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# executor = ThreadPoolExecutor(max_workers=5)
|
||||
@ -224,6 +226,52 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
|
||||
# return {"Hello": "World"}
|
||||
|
||||
|
||||
class UploadBody(BaseModel):
|
||||
download_url: str
|
||||
volume_name: str
|
||||
volume_id: str
|
||||
# callback_url: str
|
||||
|
||||
@app.post("/upload_volume")
|
||||
async def upload_checkpoint(body: UploadBody):
|
||||
global last_activity_time
|
||||
last_activity_time = time.time()
|
||||
logger.info(f"Extended inactivity time to {global_timeout}")
|
||||
|
||||
download_url = body.download_url
|
||||
volume_name = body.volume_name
|
||||
# callback_url = body.callback_url
|
||||
|
||||
folder_path = f"/app/builds/{body.volume_id}"
|
||||
|
||||
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
|
||||
await cp_process.wait()
|
||||
|
||||
# Write the config file
|
||||
config = {
|
||||
"volume_names": {
|
||||
volume_name: download_url
|
||||
},
|
||||
"paths": {
|
||||
volume_name: f'/volumes/{uuid4()}'
|
||||
},
|
||||
}
|
||||
|
||||
await asyncio.subprocess.create_subprocess_shell(
|
||||
f"modal run app.py",
|
||||
# stdout=asyncio.subprocess.PIPE,
|
||||
# stderr=asyncio.subprocess.PIPE,
|
||||
cwd=folder_path,
|
||||
env={**os.environ, "COLUMNS": "10000"}
|
||||
)
|
||||
|
||||
with open(f"{folder_path}/config.py", "w") as f:
|
||||
f.write("config = " + json.dumps(config))
|
||||
|
||||
# check that thi
|
||||
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
|
||||
|
||||
|
||||
@app.post("/create")
|
||||
async def create_machine(item: Item):
|
||||
global last_activity_time
|
||||
@ -312,7 +360,9 @@ async def build_logic(item: Item):
|
||||
config = {
|
||||
"name": item.name,
|
||||
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
|
||||
"gpu": item.gpu
|
||||
"gpu": item.gpu,
|
||||
"public_checkpoint_volume": "model-store",
|
||||
"private_checkpoint_volume": "private-model-store"
|
||||
}
|
||||
with open(f"{folder_path}/config.py", "w") as f:
|
||||
f.write("config = " + json.dumps(config))
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from config import config
|
||||
import modal
|
||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app
|
||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app, Volume
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from volume import volumes
|
||||
|
||||
# deploy_test = False
|
||||
|
||||
@ -28,7 +29,6 @@ web_app = FastAPI()
|
||||
print(config)
|
||||
print("deploy_test ", deploy_test)
|
||||
stub = Stub(name=config["name"])
|
||||
# print(stub.app_id)
|
||||
|
||||
if not deploy_test:
|
||||
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
||||
@ -56,7 +56,7 @@ if not deploy_test:
|
||||
# # Install comfy deploy
|
||||
# "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
|
||||
# )
|
||||
# .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
|
||||
.copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
|
||||
|
||||
.copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
|
||||
.run_commands("chmod +x /start.sh")
|
||||
@ -153,8 +153,9 @@ image = Image.debian_slim()
|
||||
|
||||
target_image = image if deploy_test else dockerfile_image
|
||||
|
||||
|
||||
@stub.function(image=target_image, gpu=config["gpu"])
|
||||
@stub.function(image=target_image, gpu=config["gpu"]
|
||||
,volumes=volumes
|
||||
)
|
||||
def run(input: Input):
|
||||
import subprocess
|
||||
import time
|
||||
@ -163,6 +164,7 @@ def run(input: Input):
|
||||
|
||||
command = ["python", "main.py",
|
||||
"--disable-auto-launch", "--disable-metadata"]
|
||||
|
||||
server_process = subprocess.Popen(command, cwd="/comfyui")
|
||||
|
||||
check_server(
|
||||
@ -235,7 +237,9 @@ async def bar(request_input: RequestInput):
|
||||
# pass
|
||||
|
||||
|
||||
@stub.function(image=image)
|
||||
@stub.function(image=image
|
||||
,volumes=volumes
|
||||
)
|
||||
@asgi_app()
|
||||
def comfyui_api():
|
||||
return web_app
|
||||
@ -285,6 +289,7 @@ def spawn_comfyui_in_background():
|
||||
# to be on a single container.
|
||||
concurrency_limit=1,
|
||||
timeout=10 * 60,
|
||||
volumes=volumes,
|
||||
)
|
||||
@asgi_app()
|
||||
def comfyui_app():
|
||||
@ -303,4 +308,4 @@ def comfyui_app():
|
||||
},
|
||||
)()
|
||||
|
||||
return make_simple_proxy_app(ProxyContext(config))
|
||||
return make_simple_proxy_app(ProxyContext(config))
|
||||
|
||||
@ -1 +1,7 @@
|
||||
config = {"name": "my-app", "deploy_test": "True", "gpu": "T4"}
|
||||
config = {
|
||||
"name": "my-app",
|
||||
"deploy_test": "True",
|
||||
"gpu": "T4",
|
||||
"public_checkpoint_volume": "model-store",
|
||||
"private_checkpoint_volume": "private-model-store"
|
||||
}
|
||||
|
||||
@ -1,11 +1,30 @@
|
||||
comfyui:
|
||||
base_path: /runpod-volume/ComfyUI/
|
||||
checkpoints: models/checkpoints/
|
||||
clip: models/clip/
|
||||
clip_vision: models/clip_vision/
|
||||
configs: models/configs/
|
||||
controlnet: models/controlnet/
|
||||
embeddings: models/embeddings/
|
||||
loras: models/loras/
|
||||
upscale_models: models/upscale_models/
|
||||
vae: models/vae/
|
||||
base_path: /extra_models/
|
||||
checkpoints: |
|
||||
checkpoints
|
||||
private_checkpoints
|
||||
clip: |
|
||||
clip
|
||||
private_clip
|
||||
clip_vision: |
|
||||
clip_vision
|
||||
private_clip_vision
|
||||
configs: |
|
||||
configs
|
||||
private_configs
|
||||
controlnet: |
|
||||
controlnet
|
||||
private_controlnet
|
||||
embeddings: |
|
||||
embeddings
|
||||
private_embeddings
|
||||
loras: |
|
||||
loras
|
||||
private_loras
|
||||
upscale_models: |
|
||||
upscale_models
|
||||
private_upscale_models
|
||||
vae: |
|
||||
vae
|
||||
private_vae
|
||||
|
||||
|
||||
101
builder/modal-builder/src/template/data/insert_models.py
Normal file
101
builder/modal-builder/src/template/data/insert_models.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""
|
||||
This is a standalone script to download models into a modal Volume using civitai
|
||||
|
||||
Example Usage
|
||||
`modal run insert_models::insert_model --civitai-url https://civitai.com/models/36520/ghostmix`
|
||||
This inserts an individual model from a civitai url
|
||||
|
||||
`modal run insert_models::insert_models_civitai_api`
|
||||
This inserts a bunch of models based on the models retrieved by civitai
|
||||
|
||||
civitai's API reference https://github.com/civitai/civitai/wiki/REST-API-Reference
|
||||
"""
|
||||
import modal
|
||||
import subprocess
|
||||
import requests
|
||||
import json
|
||||
|
||||
stub = modal.Stub()
|
||||
|
||||
# NOTE: volume name can be variable
|
||||
volume = modal.Volume.persisted("rah")
|
||||
model_store_path = "/vol/models"
|
||||
MODEL_ROUTE = "models"
|
||||
image = (
|
||||
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
|
||||
)
|
||||
|
||||
@stub.function(volumes={model_store_path: volume}, image=image, timeout=50000, gpu=None)
|
||||
def download_model(download_url):
|
||||
print(download_url)
|
||||
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
|
||||
subprocess.run(["ls", "-la", model_store_path])
|
||||
volume.commit()
|
||||
|
||||
# file is raw output from Civitai API https://github.com/civitai/civitai/wiki/REST-API-Reference
|
||||
|
||||
@stub.function()
|
||||
def get_civitai_models(model_type: str, sort: str = "Highest Rated", page: int = 1):
|
||||
"""Fetch models from CivitAI API based on type."""
|
||||
try:
|
||||
response = requests.get(f"https://civitai.com/api/v1/models", params={"types": model_type, "page": page, "sort": sort})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching models: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@stub.function()
|
||||
def get_civitai_model_url(civitai_url: str):
|
||||
# Validate the URL
|
||||
|
||||
if civitai_url.startswith("https://civitai.com/api/"):
|
||||
api_url = civitai_url
|
||||
elif civitai_url.startswith("https://civitai.com/models/"):
|
||||
try:
|
||||
model_id = civitai_url.split("/")[4]
|
||||
int(model_id)
|
||||
except (IndexError, ValueError):
|
||||
return None
|
||||
api_url = f"https://civitai.com/api/v1/models/{model_id}"
|
||||
else:
|
||||
return "Error: URL must be from civitai.com and contain /models/"
|
||||
|
||||
response = requests.get(api_url)
|
||||
# Check for successful response
|
||||
if response.status_code != 200:
|
||||
return f"Error: Unable to fetch data from {api_url}"
|
||||
# Return the response data
|
||||
return response.json()
|
||||
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def insert_models_civitai_api(type: str = "Checkpoint", sort = "Highest Rated", page: int = 1):
|
||||
civitai_models = get_civitai_models.local(type, sort, page)
|
||||
if civitai_models:
|
||||
for _ in download_model.map(map(lambda model: model['modelVersions'][0]['downloadUrl'], civitai_models['items'])):
|
||||
pass
|
||||
else:
|
||||
print("Failed to retrieve models.")
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def insert_model(civitai_url: str):
|
||||
if civitai_url.startswith("'https://civitai.com/api/download/models/"):
|
||||
download_url = civitai_url
|
||||
else:
|
||||
civitai_model = get_civitai_model_url.local(civitai_url)
|
||||
if civitai_model:
|
||||
download_url = civitai_model['modelVersions'][0]['downloadUrl']
|
||||
else:
|
||||
return "invalid URL"
|
||||
|
||||
download_model.remote(download_url)
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def simple_download():
|
||||
download_urls = ['https://civitai.com/api/download/models/119057', 'https://civitai.com/api/download/models/130090', 'https://civitai.com/api/download/models/31859', 'https://civitai.com/api/download/models/128713', 'https://civitai.com/api/download/models/179657', 'https://civitai.com/api/download/models/143906', 'https://civitai.com/api/download/models/9208', 'https://civitai.com/api/download/models/136078', 'https://civitai.com/api/download/models/134065', 'https://civitai.com/api/download/models/288775', 'https://civitai.com/api/download/models/95263', 'https://civitai.com/api/download/models/288982', 'https://civitai.com/api/download/models/87153', 'https://civitai.com/api/download/models/10638', 'https://civitai.com/api/download/models/263809', 'https://civitai.com/api/download/models/130072', 'https://civitai.com/api/download/models/117019', 'https://civitai.com/api/download/models/95256', 'https://civitai.com/api/download/models/197181', 'https://civitai.com/api/download/models/256915', 'https://civitai.com/api/download/models/118945', 'https://civitai.com/api/download/models/125843', 'https://civitai.com/api/download/models/179015', 'https://civitai.com/api/download/models/245598', 'https://civitai.com/api/download/models/223670', 'https://civitai.com/api/download/models/90072', 'https://civitai.com/api/download/models/290817', 'https://civitai.com/api/download/models/154097', 'https://civitai.com/api/download/models/143497', 'https://civitai.com/api/download/models/5637']
|
||||
|
||||
for _ in download_model.map(download_urls):
|
||||
pass
|
||||
@ -45,13 +45,13 @@ for package in packages:
|
||||
response = requests.request("POST", f"{root_url}/customnode/install", json=package, headers=headers)
|
||||
print(response.text)
|
||||
|
||||
with open('models.json') as f:
|
||||
models = json.load(f)
|
||||
|
||||
for model in models:
|
||||
response = requests.request("POST", f"{root_url}/model/install", json=model, headers=headers)
|
||||
print(response.text)
|
||||
# with open('models.json') as f:
|
||||
# models = json.load(f)
|
||||
#
|
||||
# for model in models:
|
||||
# response = requests.request("POST", f"{root_url}/model/install", json=model, headers=headers)
|
||||
# print(response.text)
|
||||
|
||||
# Close the server
|
||||
server_process.terminate()
|
||||
print("Finished installing dependencies.")
|
||||
print("Finished installing dependencies.")
|
||||
|
||||
10
builder/modal-builder/src/template/volume.py
Normal file
10
builder/modal-builder/src/template/volume.py
Normal file
@ -0,0 +1,10 @@
|
||||
import modal
|
||||
from config import config
|
||||
|
||||
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
|
||||
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
|
||||
|
||||
BASEMODEL_DIR = "/extra_models/"
|
||||
MODEL_DIR = BASEMODEL_DIR + "checkpoints"
|
||||
PRIVATE_MODEL_DIR = BASEMODEL_DIR + "private_checkpoints"
|
||||
volumes = {MODEL_DIR: public_model_volume, PRIVATE_MODEL_DIR: private_volume}
|
||||
45
builder/modal-builder/src/volume-builder/app.py
Normal file
45
builder/modal-builder/src/volume-builder/app.py
Normal file
@ -0,0 +1,45 @@
|
||||
import modal
|
||||
from config import config
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
stub = modal.Stub()
|
||||
|
||||
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
|
||||
def is_valid_name(name: str) -> bool:
|
||||
allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
|
||||
return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
|
||||
|
||||
def create_volumes(volume_names, paths):
|
||||
path_to_vol = {}
|
||||
for volume_name in volume_names.keys():
|
||||
if not is_valid_name(volume_name):
|
||||
pass
|
||||
modal_volume = modal.Volume.persisted(volume_name)
|
||||
path_to_vol[paths[volume_name]] = modal_volume
|
||||
|
||||
return path_to_vol
|
||||
|
||||
vol_name_to_links = config["volume_names"]
|
||||
vol_name_to_path = config["paths"]
|
||||
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
|
||||
image = (
|
||||
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
|
||||
)
|
||||
|
||||
print(vol_name_to_links)
|
||||
print(vol_name_to_path)
|
||||
print(volumes)
|
||||
|
||||
@stub.function(volumes=volumes, image=image, timeout=5000, gpu=None)
|
||||
def download_model(volume_name, download_url):
|
||||
model_store_path = vol_name_to_path[volume_name]
|
||||
subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
|
||||
subprocess.run(["ls", "-la", model_store_path])
|
||||
volumes[model_store_path].commit()
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def simple_download():
|
||||
print(vol_name_to_links)
|
||||
print([(vol_name, link) for vol_name,link in vol_name_to_links.items()])
|
||||
list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
|
||||
8
builder/modal-builder/src/volume-builder/config.py
Normal file
8
builder/modal-builder/src/volume-builder/config.py
Normal file
@ -0,0 +1,8 @@
|
||||
config = {
|
||||
"volume_names": {
|
||||
"test": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png"
|
||||
},
|
||||
"paths": {
|
||||
"test": "/volumes/something"
|
||||
}
|
||||
}
|
||||
62
web/drizzle/0031_safe_multiple_man.sql
Normal file
62
web/drizzle/0031_safe_multiple_man.sql
Normal file
@ -0,0 +1,62 @@
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE "model_upload_type" AS ENUM('civitai', 'huggingface', 'other');
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE "resource_upload" AS ENUM('started', 'failed', 'succeded');
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
|
||||
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
|
||||
"user_id" text,
|
||||
"org_id" text,
|
||||
"description" text,
|
||||
"checkpoint_volume_id" uuid NOT NULL,
|
||||
"model_name" text,
|
||||
"civitai_id" text,
|
||||
"civitai_version_id" text,
|
||||
"civitai_url" text,
|
||||
"civitai_download_url" text,
|
||||
"civitai_model_response" jsonb,
|
||||
"hf_url" text,
|
||||
"s3_url" text,
|
||||
"client_url" text,
|
||||
"is_public" boolean DEFAULT false NOT NULL,
|
||||
"status" "resource_upload" DEFAULT 'started' NOT NULL,
|
||||
"upload_machine_id" text,
|
||||
"upload_type" "model_upload_type" NOT NULL,
|
||||
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
|
||||
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
|
||||
"user_id" text,
|
||||
"org_id" text,
|
||||
"volume_name" text NOT NULL,
|
||||
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp DEFAULT now() NOT NULL,
|
||||
"disabled" boolean DEFAULT false NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_checkpoint_volume_id_workflow_runs_id_fk" FOREIGN KEY ("checkpoint_volume_id") REFERENCES "comfyui_deploy"."workflow_runs"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
2
web/drizzle/0032_material_wallflower.sql
Normal file
2
web/drizzle/0032_material_wallflower.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TYPE "resource_upload" ADD VALUE 'error';--> statement-breakpoint
|
||||
ALTER TABLE "comfyui_deploy"."checkpoints" ADD COLUMN "build_log" text;
|
||||
1004
web/drizzle/meta/0031_snapshot.json
Normal file
1004
web/drizzle/meta/0031_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
1010
web/drizzle/meta/0032_snapshot.json
Normal file
1010
web/drizzle/meta/0032_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -218,6 +218,20 @@
|
||||
"when": 1705716303820,
|
||||
"tag": "0030_kind_doorman",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 31,
|
||||
"version": "5",
|
||||
"when": 1705975916818,
|
||||
"tag": "0031_safe_multiple_man",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 32,
|
||||
"version": "5",
|
||||
"when": 1705979098372,
|
||||
"tag": "0032_material_wallflower",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
50
web/src/app/(app)/api/volume-updated/route.ts
Normal file
50
web/src/app/(app)/api/volume-updated/route.ts
Normal file
@ -0,0 +1,50 @@
|
||||
import { parseDataSafe } from "../../../../lib/parseDataSafe";
|
||||
import { db } from "@/db/db";
|
||||
import { checkpointTable, machinesTable } from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { NextResponse } from "next/server";
|
||||
import { z } from "zod";
|
||||
|
||||
const Request = z.object({
|
||||
machine_id: z.string(),
|
||||
endpoint: z.string().optional(),
|
||||
build_log: z.string().optional(),
|
||||
});
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const [data, error] = await parseDataSafe(Request, request);
|
||||
if (!data || error) return error;
|
||||
|
||||
// console.log(data);
|
||||
|
||||
const { machine_id, endpoint, build_log } = data;
|
||||
|
||||
if (endpoint) {
|
||||
await db
|
||||
.update(checkpointTable)
|
||||
.set({
|
||||
// status: "ready",
|
||||
// endpoint: endpoint,
|
||||
// build_log: build_log,
|
||||
})
|
||||
.where(eq(machinesTable.id, machine_id));
|
||||
} else {
|
||||
// console.log(data);
|
||||
await db
|
||||
.update(machinesTable)
|
||||
.set({
|
||||
// status: "error",
|
||||
// build_log: build_log,
|
||||
})
|
||||
.where(eq(machinesTable.id, machine_id));
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
message: "success",
|
||||
},
|
||||
{
|
||||
status: 200,
|
||||
}
|
||||
);
|
||||
}
|
||||
9
web/src/app/(app)/storage/loading.tsx
Normal file
9
web/src/app/(app)/storage/loading.tsx
Normal file
@ -0,0 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingPageWrapper } from "@/components/LoadingWrapper";
|
||||
import { usePathname } from "next/navigation";
|
||||
|
||||
export default function Loading() {
|
||||
const pathName = usePathname();
|
||||
return <LoadingPageWrapper className="h-full" tag={pathName.toLowerCase()} />;
|
||||
}
|
||||
35
web/src/app/(app)/storage/page.tsx
Normal file
35
web/src/app/(app)/storage/page.tsx
Normal file
@ -0,0 +1,35 @@
|
||||
import { setInitialUserData } from "../../../lib/setInitialUserData";
|
||||
import { auth } from "@clerk/nextjs";
|
||||
import { clerkClient } from "@clerk/nextjs/server";
|
||||
import { CheckpointList } from "@/components/CheckpointList"
|
||||
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
|
||||
|
||||
export default function Page() {
|
||||
return <CheckpointListServer />;
|
||||
}
|
||||
|
||||
async function CheckpointListServer() {
|
||||
const { userId } = auth();
|
||||
|
||||
if (!userId) {
|
||||
return <div>No auth</div>;
|
||||
}
|
||||
|
||||
const user = await clerkClient.users.getUser(userId);
|
||||
|
||||
if (!user) {
|
||||
await setInitialUserData(userId);
|
||||
}
|
||||
|
||||
const checkpoints = await getAllUserCheckpoints()
|
||||
|
||||
if (!checkpoints) {
|
||||
return <div>No checkpoints found</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<CheckpointList data={checkpoints}/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
315
web/src/components/CheckpointList.tsx
Normal file
315
web/src/components/CheckpointList.tsx
Normal file
@ -0,0 +1,315 @@
|
||||
"use client";
|
||||
|
||||
import { getRelativeTime } from "../lib/getRelativeTime";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { InsertModal, UpdateModal } from "./InsertModal";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
|
||||
import type {
|
||||
ColumnDef,
|
||||
ColumnFiltersState,
|
||||
SortingState,
|
||||
VisibilityState,
|
||||
} from "@tanstack/react-table";
|
||||
import {
|
||||
flexRender,
|
||||
getCoreRowModel,
|
||||
getFilteredRowModel,
|
||||
getPaginationRowModel,
|
||||
getSortedRowModel,
|
||||
useReactTable,
|
||||
} from "@tanstack/react-table";
|
||||
import { ArrowUpDown, MoreHorizontal } from "lucide-react";
|
||||
import * as React from "react";
|
||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
||||
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
|
||||
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
|
||||
|
||||
export type CheckpointItemList = NonNullable<
|
||||
Awaited<ReturnType<typeof getAllUserCheckpoints>>
|
||||
>[0];
|
||||
|
||||
export const columns: ColumnDef<CheckpointItemList>[] = [
|
||||
{
|
||||
accessorKey: "id",
|
||||
id: "select",
|
||||
header: ({ table }) => (
|
||||
<Checkbox
|
||||
checked={table.getIsAllPageRowsSelected() ||
|
||||
(table.getIsSomePageRowsSelected() && "indeterminate")}
|
||||
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
|
||||
aria-label="Select all"
|
||||
/>
|
||||
),
|
||||
cell: ({ row }) => (
|
||||
<Checkbox
|
||||
checked={row.getIsSelected()}
|
||||
onCheckedChange={(value) => row.toggleSelected(!!value)}
|
||||
aria-label="Select row"
|
||||
/>
|
||||
),
|
||||
enableSorting: false,
|
||||
enableHiding: false,
|
||||
},
|
||||
{
|
||||
accessorKey: "name",
|
||||
header: ({ column }) => {
|
||||
return (
|
||||
<button
|
||||
className="flex items-center hover:underline"
|
||||
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||
>
|
||||
Name
|
||||
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||
</button>
|
||||
);
|
||||
},
|
||||
cell: ({ row }) => {
|
||||
const checkpoint = row.original;
|
||||
return (
|
||||
<a
|
||||
className="hover:underline flex gap-2"
|
||||
href={`/storage/${checkpoint.id}`} // TODO
|
||||
>
|
||||
<span className="truncate max-w-[200px]">{row.original.model_name}</span>
|
||||
|
||||
<Badge variant="default">{}</Badge>
|
||||
{checkpoint.is_public
|
||||
? <Badge variant="success">Public</Badge>
|
||||
: <Badge variant="teal">Private</Badge>}
|
||||
</a>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
accessorKey: "creator",
|
||||
header: ({ column }) => {
|
||||
return (
|
||||
<button
|
||||
className="flex items-center hover:underline"
|
||||
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||
>
|
||||
Creator
|
||||
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||
</button>
|
||||
);
|
||||
},
|
||||
cell: ({ row }) => {
|
||||
// return <Badge variant="cyan">{row?.original?.user?.name ? row.original.user.name : "Public"}</Badge>;
|
||||
},
|
||||
},
|
||||
{
|
||||
accessorKey: "date",
|
||||
sortingFn: "datetime",
|
||||
enableSorting: true,
|
||||
header: ({ column }) => {
|
||||
return (
|
||||
<button
|
||||
className="w-full flex items-center justify-end hover:underline truncate"
|
||||
// variant="ghost"
|
||||
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||
>
|
||||
Update Date
|
||||
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||
</button>
|
||||
);
|
||||
},
|
||||
cell: ({ row }) => (
|
||||
<div className="w-full capitalize text-right truncate">
|
||||
{getRelativeTime(row.original.updated_at)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
// {
|
||||
// id: "actions",
|
||||
// enableHiding: false,
|
||||
// cell: ({ row }) => {
|
||||
// const checkpoint = row.original;
|
||||
//
|
||||
// return (
|
||||
// <DropdownMenu>
|
||||
// <DropdownMenuTrigger asChild>
|
||||
// <Button variant="ghost" className="h-8 w-8 p-0">
|
||||
// <span className="sr-only">Open menu</span>
|
||||
// <MoreHorizontal className="h-4 w-4" />
|
||||
// </Button>
|
||||
// </DropdownMenuTrigger>
|
||||
// <DropdownMenuContent align="end">
|
||||
// <DropdownMenuLabel>Actions</DropdownMenuLabel>
|
||||
// <DropdownMenuItem
|
||||
// className="text-destructive"
|
||||
// onClick={() => {
|
||||
// deleteWorkflow(checkpoint.id);
|
||||
// }}
|
||||
// >
|
||||
// Delete Workflow
|
||||
// </DropdownMenuItem>
|
||||
// </DropdownMenuContent>
|
||||
// </DropdownMenu>
|
||||
// );
|
||||
// },
|
||||
// },
|
||||
];
|
||||
|
||||
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
||||
const [sorting, setSorting] = React.useState<SortingState>([]);
|
||||
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
||||
[],
|
||||
);
|
||||
const [columnVisibility, setColumnVisibility] = React.useState<
|
||||
VisibilityState
|
||||
>({});
|
||||
const [rowSelection, setRowSelection] = React.useState({});
|
||||
|
||||
const table = useReactTable({
|
||||
data,
|
||||
columns,
|
||||
onSortingChange: setSorting,
|
||||
onColumnFiltersChange: setColumnFilters,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
getPaginationRowModel: getPaginationRowModel(),
|
||||
getSortedRowModel: getSortedRowModel(),
|
||||
getFilteredRowModel: getFilteredRowModel(),
|
||||
onColumnVisibilityChange: setColumnVisibility,
|
||||
onRowSelectionChange: setRowSelection,
|
||||
state: {
|
||||
sorting,
|
||||
columnFilters,
|
||||
columnVisibility,
|
||||
rowSelection,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="grid grid-rows-[auto,1fr,auto] h-full">
|
||||
<div className="flex flex-row w-full items-center py-4">
|
||||
<Input
|
||||
placeholder="Filter workflows..."
|
||||
value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
|
||||
onChange={(event) =>
|
||||
table.getColumn("name")?.setFilterValue(event.target.value)}
|
||||
className="max-w-sm"
|
||||
/>
|
||||
<div className="ml-auto flex gap-2">
|
||||
<InsertModal
|
||||
dialogClassName="sm:max-w-[600px]"
|
||||
disabled={
|
||||
false
|
||||
// TODO: limitations based on plan
|
||||
}
|
||||
tooltip={"Add models using their civitai url!"}
|
||||
title="Civitai Checkpoint"
|
||||
description="Pick a model from civitai"
|
||||
serverAction={addCivitaiCheckpoint}
|
||||
formSchema={addCivitaiCheckpointSchema}
|
||||
fieldConfig={{
|
||||
civitai_url: {
|
||||
fieldType: "fallback",
|
||||
// fieldType: "fallback",
|
||||
inputProps: { required: true },
|
||||
description: (
|
||||
<>
|
||||
Pick a checkpoint from{" "}
|
||||
<a
|
||||
href="https://www.civitai.com/models"
|
||||
target="_blank"
|
||||
className="underline text-blue-600 hover:text-blue-800 visited:text-purple-600"
|
||||
>
|
||||
civitai.com
|
||||
</a>{" "}
|
||||
and place it's url here
|
||||
</>
|
||||
),
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<ScrollArea className="h-full w-full rounded-md border">
|
||||
<Table>
|
||||
<TableHeader className="bg-background top-0 sticky">
|
||||
{table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header) => {
|
||||
return (
|
||||
<TableHead key={header.id}>
|
||||
{header.isPlaceholder ? null : flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext(),
|
||||
)}
|
||||
</TableHead>
|
||||
);
|
||||
})}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{table.getRowModel().rows?.length
|
||||
? (
|
||||
table.getRowModel().rows.map((row) => (
|
||||
<TableRow
|
||||
key={row.id}
|
||||
data-state={row.getIsSelected() && "selected"}
|
||||
>
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<TableCell key={cell.id}>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext(),
|
||||
)}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))
|
||||
)
|
||||
: (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={columns.length}
|
||||
className="h-24 text-center"
|
||||
>
|
||||
No results.
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</ScrollArea>
|
||||
<div className="flex flex-row items-center justify-end space-x-2 py-4">
|
||||
<div className="flex-1 text-sm text-muted-foreground">
|
||||
{table.getFilteredSelectedRowModel().rows.length} of{" "}
|
||||
{table.getFilteredRowModel().rows.length} row(s) selected.
|
||||
</div>
|
||||
<div className="space-x-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => table.previousPage()}
|
||||
disabled={!table.getCanPreviousPage()}
|
||||
>
|
||||
Previous
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => table.nextPage()}
|
||||
disabled={!table.getCanNextPage()}
|
||||
>
|
||||
Next
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -34,6 +34,10 @@ export function NavbarMenu({ className }: { className?: string }) {
|
||||
name: "API Keys",
|
||||
path: "/api-keys",
|
||||
},
|
||||
{
|
||||
name: "Storage",
|
||||
path: "/storage",
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
@ -42,9 +46,9 @@ export function NavbarMenu({ className }: { className?: string }) {
|
||||
{isDesktop && (
|
||||
<Tabs
|
||||
defaultValue={pathname}
|
||||
className="w-[300px] flex pointer-events-auto"
|
||||
className="w-[400px] flex pointer-events-auto"
|
||||
>
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsList className="grid w-full grid-cols-4">
|
||||
{pages.map((page) => (
|
||||
<TabsTrigger
|
||||
key={page.name}
|
||||
|
||||
@ -42,105 +42,105 @@ const Model = z.object({
|
||||
url: z.string(),
|
||||
});
|
||||
|
||||
export const CivitalModelSchema = z.object({
|
||||
items: z.array(
|
||||
export const CivitaiModel = z.object({
|
||||
id: z.number(),
|
||||
name: z.string(),
|
||||
description: z.string(),
|
||||
type: z.string(),
|
||||
// poi: z.boolean(),
|
||||
// nsfw: z.boolean(),
|
||||
// allowNoCredit: z.boolean(),
|
||||
// allowCommercialUse: z.string(),
|
||||
// allowDerivatives: z.boolean(),
|
||||
// allowDifferentLicense: z.boolean(),
|
||||
// stats: z.object({
|
||||
// downloadCount: z.number(),
|
||||
// favoriteCount: z.number(),
|
||||
// commentCount: z.number(),
|
||||
// ratingCount: z.number(),
|
||||
// rating: z.number(),
|
||||
// tippedAmountCount: z.number(),
|
||||
// }),
|
||||
creator: z
|
||||
.object({
|
||||
username: z.string().nullable(),
|
||||
image: z.string().nullable().default(null),
|
||||
})
|
||||
.nullable(),
|
||||
tags: z.array(z.string()),
|
||||
modelVersions: z.array(
|
||||
z.object({
|
||||
id: z.number(),
|
||||
modelId: z.number(),
|
||||
name: z.string(),
|
||||
description: z.string(),
|
||||
type: z.string(),
|
||||
// poi: z.boolean(),
|
||||
// nsfw: z.boolean(),
|
||||
// allowNoCredit: z.boolean(),
|
||||
// allowCommercialUse: z.string(),
|
||||
// allowDerivatives: z.boolean(),
|
||||
// allowDifferentLicense: z.boolean(),
|
||||
// stats: z.object({
|
||||
// downloadCount: z.number(),
|
||||
// favoriteCount: z.number(),
|
||||
// commentCount: z.number(),
|
||||
// ratingCount: z.number(),
|
||||
// rating: z.number(),
|
||||
// tippedAmountCount: z.number(),
|
||||
// }),
|
||||
creator: z
|
||||
.object({
|
||||
username: z.string().nullable(),
|
||||
image: z.string().nullable().default(null),
|
||||
})
|
||||
.nullable(),
|
||||
tags: z.array(z.string()),
|
||||
modelVersions: z.array(
|
||||
createdAt: z.string(),
|
||||
updatedAt: z.string(),
|
||||
status: z.string(),
|
||||
publishedAt: z.string(),
|
||||
trainedWords: z.array(z.unknown()),
|
||||
trainingStatus: z.string().nullable(),
|
||||
trainingDetails: z.string().nullable(),
|
||||
baseModel: z.string(),
|
||||
baseModelType: z.string().nullable(),
|
||||
earlyAccessTimeFrame: z.number(),
|
||||
description: z.string().nullable(),
|
||||
vaeId: z.number().nullable(),
|
||||
stats: z.object({
|
||||
downloadCount: z.number(),
|
||||
ratingCount: z.number(),
|
||||
rating: z.number(),
|
||||
}),
|
||||
files: z.array(
|
||||
z.object({
|
||||
id: z.number(),
|
||||
modelId: z.number(),
|
||||
sizeKB: z.number(),
|
||||
name: z.string(),
|
||||
createdAt: z.string(),
|
||||
updatedAt: z.string(),
|
||||
status: z.string(),
|
||||
publishedAt: z.string(),
|
||||
trainedWords: z.array(z.unknown()),
|
||||
trainingStatus: z.string().nullable(),
|
||||
trainingDetails: z.string().nullable(),
|
||||
baseModel: z.string(),
|
||||
baseModelType: z.string().nullable(),
|
||||
earlyAccessTimeFrame: z.number(),
|
||||
description: z.string().nullable(),
|
||||
vaeId: z.number().nullable(),
|
||||
stats: z.object({
|
||||
downloadCount: z.number(),
|
||||
ratingCount: z.number(),
|
||||
rating: z.number(),
|
||||
}),
|
||||
files: z.array(
|
||||
z.object({
|
||||
id: z.number(),
|
||||
sizeKB: z.number(),
|
||||
name: z.string(),
|
||||
type: z.string(),
|
||||
// metadata: z.object({
|
||||
// fp: z.string().nullable().optional(),
|
||||
// size: z.string().nullable().optional(),
|
||||
// format: z.string().nullable().optional(),
|
||||
// }),
|
||||
// pickleScanResult: z.string(),
|
||||
// pickleScanMessage: z.string(),
|
||||
// virusScanResult: z.string(),
|
||||
// virusScanMessage: z.string().nullable(),
|
||||
// scannedAt: z.string(),
|
||||
// hashes: z.object({
|
||||
// AutoV1: z.string().nullable().optional(),
|
||||
// AutoV2: z.string().nullable().optional(),
|
||||
// SHA256: z.string().nullable().optional(),
|
||||
// CRC32: z.string().nullable().optional(),
|
||||
// BLAKE3: z.string().nullable().optional(),
|
||||
// }),
|
||||
downloadUrl: z.string(),
|
||||
// primary: z.boolean().default(false),
|
||||
})
|
||||
),
|
||||
images: z.array(
|
||||
z.object({
|
||||
id: z.number(),
|
||||
url: z.string(),
|
||||
nsfw: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
hash: z.string(),
|
||||
type: z.string(),
|
||||
metadata: z.object({
|
||||
hash: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
}),
|
||||
meta: z.any(),
|
||||
})
|
||||
),
|
||||
type: z.string(),
|
||||
// metadata: z.object({
|
||||
// fp: z.string().nullable().optional(),
|
||||
// size: z.string().nullable().optional(),
|
||||
// format: z.string().nullable().optional(),
|
||||
// }),
|
||||
// pickleScanResult: z.string(),
|
||||
// pickleScanMessage: z.string(),
|
||||
// virusScanResult: z.string(),
|
||||
// virusScanMessage: z.string().nullable(),
|
||||
// scannedAt: z.string(),
|
||||
// hashes: z.object({
|
||||
// AutoV1: z.string().nullable().optional(),
|
||||
// AutoV2: z.string().nullable().optional(),
|
||||
// SHA256: z.string().nullable().optional(),
|
||||
// CRC32: z.string().nullable().optional(),
|
||||
// BLAKE3: z.string().nullable().optional(),
|
||||
// }),
|
||||
downloadUrl: z.string(),
|
||||
})
|
||||
// primary: z.boolean().default(false),
|
||||
}),
|
||||
),
|
||||
})
|
||||
images: z.array(
|
||||
z.object({
|
||||
id: z.number(),
|
||||
url: z.string(),
|
||||
nsfw: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
hash: z.string(),
|
||||
type: z.string(),
|
||||
metadata: z.object({
|
||||
hash: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
}),
|
||||
meta: z.any(),
|
||||
}),
|
||||
),
|
||||
downloadUrl: z.string(),
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export const CivitalModelSchema = z.object({
|
||||
items: z.array(CivitaiModel),
|
||||
metadata: z.object({
|
||||
totalItems: z.number(),
|
||||
currentPage: z.number(),
|
||||
@ -197,7 +197,7 @@ function mapType(type: string) {
|
||||
}
|
||||
|
||||
function mapModelsList(
|
||||
models: z.infer<typeof CivitalModelSchema>
|
||||
models: z.infer<typeof CivitalModelSchema>,
|
||||
): z.infer<typeof ModelListWrapper> {
|
||||
return {
|
||||
models: models.items.flatMap((item) => {
|
||||
@ -241,8 +241,9 @@ function getUrl(search?: string) {
|
||||
export function CivitaiModelRegistry({
|
||||
field,
|
||||
}: Pick<AutoFormInputComponentProps, "field">) {
|
||||
const [modelList, setModelList] =
|
||||
React.useState<z.infer<typeof ModelListWrapper>>();
|
||||
const [modelList, setModelList] = React.useState<
|
||||
z.infer<typeof ModelListWrapper>
|
||||
>();
|
||||
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
|
||||
@ -301,8 +302,9 @@ export function CivitaiModelRegistry({
|
||||
export function ComfyUIManagerModelRegistry({
|
||||
field,
|
||||
}: Pick<AutoFormInputComponentProps, "field">) {
|
||||
const [modelList, setModelList] =
|
||||
React.useState<z.infer<typeof ModelListWrapper>>();
|
||||
const [modelList, setModelList] = React.useState<
|
||||
z.infer<typeof ModelListWrapper>
|
||||
>();
|
||||
|
||||
React.useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
@ -310,7 +312,7 @@ export function ComfyUIManagerModelRegistry({
|
||||
"https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/model-list.json",
|
||||
{
|
||||
signal: controller.signal,
|
||||
}
|
||||
},
|
||||
)
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
@ -353,14 +355,14 @@ export function ModelSelector({
|
||||
if (
|
||||
prevSelectedModels.some(
|
||||
(selectedModel) =>
|
||||
selectedModel.url + selectedModel.name === model.url + model.name
|
||||
selectedModel.url + selectedModel.name === model.url + model.name,
|
||||
)
|
||||
) {
|
||||
field.onChange(
|
||||
prevSelectedModels.filter(
|
||||
(selectedModel) =>
|
||||
selectedModel.url + selectedModel.name !== model.url + model.name
|
||||
)
|
||||
selectedModel.url + selectedModel.name !== model.url + model.name,
|
||||
),
|
||||
);
|
||||
} else {
|
||||
field.onChange([...prevSelectedModels, model]);
|
||||
@ -408,10 +410,10 @@ export function ModelSelector({
|
||||
className={cn(
|
||||
"ml-auto h-4 w-4",
|
||||
value.some(
|
||||
(selectedModel) => selectedModel.url === model.url
|
||||
)
|
||||
(selectedModel) => selectedModel.url === model.url,
|
||||
)
|
||||
? "opacity-100"
|
||||
: "opacity-0"
|
||||
: "opacity-0",
|
||||
)}
|
||||
/>
|
||||
</CommandItem>
|
||||
|
||||
89
web/src/components/custom-form/checkpoint-input.tsx
Normal file
89
web/src/components/custom-form/checkpoint-input.tsx
Normal file
@ -0,0 +1,89 @@
|
||||
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
|
||||
import { FormControl, FormItem, FormLabel } from "../ui/form";
|
||||
import { LoadingIcon } from "@/components/LoadingIcon";
|
||||
import * as React from "react";
|
||||
import AutoFormInput from "../ui/auto-form/fields/input";
|
||||
import { useDebouncedCallback } from "use-debounce";
|
||||
import { CivitaiModel } from "./ModelPickerView";
|
||||
import { z } from "zod";
|
||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
||||
|
||||
function getUrl(civitai_url: string) {
|
||||
// expect to be a URL to be https://civitai.com/models/36520
|
||||
// possiblity with slugged name and query-param modelVersionId
|
||||
|
||||
const baseUrl = "https://civitai.com/api/v1/models/";
|
||||
const url = new URL(civitai_url);
|
||||
const pathSegments = url.pathname.split("/");
|
||||
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
|
||||
const modelVersionId = url.searchParams.get("modelVersionId");
|
||||
|
||||
return { url: baseUrl + modelId, modelVersionId };
|
||||
}
|
||||
|
||||
export default function AutoFormCheckpointInput(
|
||||
props: AutoFormInputComponentProps,
|
||||
) {
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [modelRes, setModelRes] = React.useState<
|
||||
z.infer<typeof CivitaiModel>
|
||||
>();
|
||||
const [modelVersionid, setModelVersionId] = React.useState<string | null>();
|
||||
const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
|
||||
|
||||
const handleSearch = useDebouncedCallback((search) => {
|
||||
const validationResult = insertCivitaiCheckpointSchema.shape.civitai_url
|
||||
.safeParse(search);
|
||||
if (!validationResult.success) {
|
||||
console.error(validationResult.error);
|
||||
// Optionally set an error state here
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
|
||||
const controller = new AbortController();
|
||||
const { url, modelVersionId: versionId } = getUrl(search);
|
||||
setModelVersionId(versionId);
|
||||
fetch(url, {
|
||||
signal: controller.signal,
|
||||
})
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
const res = CivitaiModel.parse(a);
|
||||
console.log(a);
|
||||
console.log(res);
|
||||
setModelRes(res);
|
||||
setLoading(false);
|
||||
});
|
||||
|
||||
return () => {
|
||||
controller.abort();
|
||||
setLoading(false);
|
||||
};
|
||||
}, 300);
|
||||
|
||||
const modifiedField = {
|
||||
...fieldProps,
|
||||
// onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
// handleSearch(event.target.value);
|
||||
// },
|
||||
};
|
||||
|
||||
return (
|
||||
<FormItem>
|
||||
{fieldConfigItem.inputProps?.showLabel && (
|
||||
<FormLabel>
|
||||
{label}
|
||||
{isRequired && <span className="text-destructive">*</span>}
|
||||
</FormLabel>
|
||||
)}
|
||||
<FormControl>
|
||||
<AutoFormInput
|
||||
{...props}
|
||||
fieldProps={modifiedField}
|
||||
/>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
);
|
||||
}
|
||||
@ -8,6 +8,7 @@ import AutoFormSwitch from "./fields/switch";
|
||||
import AutoFormTextarea from "./fields/textarea";
|
||||
import AutoFormModelsPicker from "@/components/custom-form/model-picker";
|
||||
import AutoFormSnapshotPicker from "@/components/custom-form/snapshot-picker";
|
||||
import AutoFormCheckpointInput from "@/components/custom-form/checkpoint-input";
|
||||
|
||||
export const INPUT_COMPONENTS = {
|
||||
checkbox: AutoFormCheckbox,
|
||||
@ -22,6 +23,7 @@ export const INPUT_COMPONENTS = {
|
||||
// Customs
|
||||
snapshot: AutoFormSnapshotPicker,
|
||||
models: AutoFormModelsPicker,
|
||||
checkpoints: AutoFormCheckpointInput,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import { relations, type InferSelectModel } from "drizzle-orm";
|
||||
import { CivitaiModelResponse } from "@/types/civitai";
|
||||
import { type InferSelectModel, relations } from "drizzle-orm";
|
||||
import {
|
||||
text,
|
||||
pgSchema,
|
||||
uuid,
|
||||
boolean,
|
||||
integer,
|
||||
timestamp,
|
||||
jsonb,
|
||||
pgEnum,
|
||||
boolean,
|
||||
pgSchema,
|
||||
text,
|
||||
timestamp,
|
||||
uuid,
|
||||
} from "drizzle-orm/pg-core";
|
||||
import { createInsertSchema } from "drizzle-zod";
|
||||
import { z } from "zod";
|
||||
@ -87,7 +88,7 @@ export const workflowVersionRelations = relations(
|
||||
fields: [workflowVersionTable.workflow_id],
|
||||
references: [workflowTable.id],
|
||||
}),
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
export const workflowRunStatus = pgEnum("workflow_run_status", [
|
||||
@ -136,10 +137,11 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
|
||||
() => workflowVersionTable.id,
|
||||
{
|
||||
onDelete: "set null",
|
||||
}
|
||||
},
|
||||
),
|
||||
workflow_inputs:
|
||||
jsonb("workflow_inputs").$type<Record<string, string | number>>(),
|
||||
workflow_inputs: jsonb("workflow_inputs").$type<
|
||||
Record<string, string | number>
|
||||
>(),
|
||||
workflow_id: uuid("workflow_id")
|
||||
.notNull()
|
||||
.references(() => workflowTable.id, {
|
||||
@ -171,7 +173,7 @@ export const workflowRunRelations = relations(
|
||||
fields: [workflowRunsTable.workflow_id],
|
||||
references: [workflowTable.id],
|
||||
}),
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
// We still want to keep the workflow run record.
|
||||
@ -195,7 +197,7 @@ export const workflowOutputRelations = relations(
|
||||
fields: [workflowRunOutputs.run_id],
|
||||
references: [workflowRunsTable.id],
|
||||
}),
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
// when user delete, also delete all the workflow versions
|
||||
@ -228,7 +230,7 @@ export const snapshotType = z.object({
|
||||
z.object({
|
||||
hash: z.string(),
|
||||
disabled: z.boolean(),
|
||||
})
|
||||
}),
|
||||
),
|
||||
file_custom_nodes: z.array(z.any()),
|
||||
});
|
||||
@ -243,7 +245,7 @@ export const showcaseMedia = z.array(
|
||||
z.object({
|
||||
url: z.string(),
|
||||
isCover: z.boolean().default(false),
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
export const showcaseMediaNullable = z
|
||||
@ -251,7 +253,7 @@ export const showcaseMediaNullable = z
|
||||
z.object({
|
||||
url: z.string(),
|
||||
isCover: z.boolean().default(false),
|
||||
})
|
||||
}),
|
||||
)
|
||||
.nullable();
|
||||
|
||||
@ -275,8 +277,9 @@ export const deploymentsTable = dbSchema.table("deployments", {
|
||||
.notNull()
|
||||
.references(() => machinesTable.id),
|
||||
description: text("description"),
|
||||
showcase_media:
|
||||
jsonb("showcase_media").$type<z.infer<typeof showcaseMedia>>(),
|
||||
showcase_media: jsonb("showcase_media").$type<
|
||||
z.infer<typeof showcaseMedia>
|
||||
>(),
|
||||
environment: deploymentEnvironment("environment").notNull(),
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
@ -329,8 +332,107 @@ export const apiKeyTable = dbSchema.table("api_keys", {
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
});
|
||||
|
||||
export const resourceUpload = pgEnum("resource_upload", [
|
||||
"started",
|
||||
"error",
|
||||
"succeded",
|
||||
]);
|
||||
|
||||
export const modelUploadType = pgEnum("model_upload_type", [
|
||||
"civitai",
|
||||
"huggingface",
|
||||
"other",
|
||||
]);
|
||||
|
||||
export const checkpointTable = dbSchema.table("checkpoints", {
|
||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||
user_id: text("user_id")
|
||||
.references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
|
||||
org_id: text("org_id"),
|
||||
description: text("description"),
|
||||
|
||||
checkpoint_volume_id: uuid("checkpoint_volume_id")
|
||||
.notNull()
|
||||
.references(() => workflowRunsTable.id, {
|
||||
onDelete: "cascade",
|
||||
}).notNull(),
|
||||
|
||||
model_name: text("model_name"),
|
||||
|
||||
civitai_id: text("civitai_id"),
|
||||
civitai_version_id: text("civitai_version_id"),
|
||||
civitai_url: text("civitai_url"),
|
||||
civitai_download_url: text("civitai_download_url"),
|
||||
civitai_model_response: jsonb("civitai_model_response").$type<
|
||||
z.infer<typeof CivitaiModelResponse>
|
||||
>(),
|
||||
|
||||
hf_url: text("hf_url"),
|
||||
s3_url: text("s3_url"),
|
||||
user_url: text("client_url"),
|
||||
|
||||
is_public: boolean("is_public").notNull().default(false),
|
||||
status: resourceUpload("status").notNull().default("started"),
|
||||
upload_machine_id: text("upload_machine_id"),
|
||||
upload_type: modelUploadType("upload_type").notNull(),
|
||||
build_log: text("build_log"),
|
||||
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
});
|
||||
|
||||
export const insertCivitaiCheckpointSchema = createInsertSchema(
|
||||
checkpointTable,
|
||||
{
|
||||
civitai_url: (schema) =>
|
||||
schema.civitai_url.trim().url({ message: "URL required" }).includes(
|
||||
"civitai.com/models",
|
||||
{ message: "civitai.com/models link required" },
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
||||
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||
user_id: text("user_id")
|
||||
.references(() => usersTable.id, {
|
||||
// onDelete: "cascade",
|
||||
}),
|
||||
org_id: text("org_id"),
|
||||
volume_name: text("volume_name").notNull(),
|
||||
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||
disabled: boolean("disabled").default(false).notNull(),
|
||||
});
|
||||
|
||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
||||
user: one(usersTable, {
|
||||
fields: [checkpointTable.user_id],
|
||||
references: [usersTable.id],
|
||||
}),
|
||||
volume: one(checkpointVolumeTable, {
|
||||
fields: [checkpointTable.checkpoint_volume_id],
|
||||
references: [checkpointVolumeTable.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const checkpointVolumeRelations = relations(
|
||||
checkpointVolumeTable,
|
||||
({ many, one }) => ({
|
||||
checkpoint: many(checkpointTable),
|
||||
user: one(usersTable, {
|
||||
fields: [checkpointVolumeTable.user_id],
|
||||
references: [usersTable.id],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
export type UserType = InferSelectModel<typeof usersTable>;
|
||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
||||
export type MachineType = InferSelectModel<typeof machinesTable>;
|
||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
|
||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
|
||||
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
|
||||
export type CheckpointVolumeType = InferSelectModel<
|
||||
typeof checkpointVolumeTable
|
||||
>;
|
||||
|
||||
5
web/src/server/addCheckpointSchema.tsx
Normal file
5
web/src/server/addCheckpointSchema.tsx
Normal file
@ -0,0 +1,5 @@
|
||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
||||
|
||||
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
|
||||
civitai_url: true,
|
||||
});
|
||||
224
web/src/server/curdCheckpoint.ts
Normal file
224
web/src/server/curdCheckpoint.ts
Normal file
@ -0,0 +1,224 @@
|
||||
"use server";
|
||||
|
||||
import { auth } from "@clerk/nextjs";
|
||||
import {
|
||||
checkpointTable,
|
||||
CheckpointType,
|
||||
volumeTable,
|
||||
CheckpointVolumeType,
|
||||
} from "@/db/schema";
|
||||
import { withServerPromise } from "./withServerPromise";
|
||||
import { redirect } from "next/navigation";
|
||||
import { db } from "@/db/db";
|
||||
import type { z } from "zod";
|
||||
import { headers } from "next/headers";
|
||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
import { CivitaiModelResponse } from "@/types/civitai";
|
||||
|
||||
export async function getCheckpoints() {
|
||||
const { userId, orgId } = auth();
|
||||
if (!userId) throw new Error("No user id");
|
||||
const checkpoints = await db
|
||||
.select()
|
||||
.from(checkpointTable)
|
||||
.where(
|
||||
orgId
|
||||
? eq(checkpointTable.org_id, orgId)
|
||||
// make sure org_id is null
|
||||
: and(
|
||||
eq(checkpointTable.user_id, userId),
|
||||
isNull(checkpointTable.org_id),
|
||||
),
|
||||
);
|
||||
return checkpoints;
|
||||
}
|
||||
|
||||
export async function getCheckpointById(id: string) {
|
||||
const { userId, orgId } = auth();
|
||||
if (!userId) throw new Error("No user id");
|
||||
const checkpoint = await db
|
||||
.select()
|
||||
.from(checkpointTable)
|
||||
.where(
|
||||
and(
|
||||
orgId ? eq(checkpointTable.org_id, orgId) : and(
|
||||
eq(checkpointTable.user_id, userId),
|
||||
isNull(checkpointTable.org_id),
|
||||
),
|
||||
eq(checkpointTable.id, id),
|
||||
),
|
||||
);
|
||||
return checkpoint[0];
|
||||
}
|
||||
|
||||
export async function getCheckpointVolumes() {
|
||||
const { userId, orgId } = auth();
|
||||
if (!userId) throw new Error("No user id");
|
||||
const checkpointVolume = await db
|
||||
.select()
|
||||
.from(volumeTable)
|
||||
.where(
|
||||
and(
|
||||
orgId
|
||||
? eq(volumeTable.org_id, orgId)
|
||||
// make sure org_id is null
|
||||
: and(
|
||||
eq(volumeTable.user_id, userId),
|
||||
isNull(volumeTable.org_id),
|
||||
),
|
||||
eq(volumeTable.disabled, false),
|
||||
),
|
||||
);
|
||||
return checkpointVolume;
|
||||
}
|
||||
|
||||
export async function addCheckpointVolume() {
|
||||
const { userId, orgId } = auth();
|
||||
if (!userId) throw new Error("No user id");
|
||||
|
||||
// Insert the new volume into the checkpointVolumeTable
|
||||
const insertedVolume = await db
|
||||
.insert(volumeTable)
|
||||
.values({
|
||||
user_id: userId,
|
||||
org_id: orgId,
|
||||
volume_name: `checkpoints_${userId}`,
|
||||
// created_at and updated_at will be set to current timestamp by default
|
||||
disabled: false, // Default value
|
||||
})
|
||||
.returning(); // Returns the inserted row
|
||||
|
||||
return insertedVolume;
|
||||
}
|
||||
|
||||
function getUrl(civitai_url: string) {
|
||||
// expect to be a URL to be https://civitai.com/models/36520
|
||||
// possiblity with slugged name and query-param modelVersionId
|
||||
const baseUrl = "https://civitai.com/api/v1/models/";
|
||||
const url = new URL(civitai_url);
|
||||
const pathSegments = url.pathname.split("/");
|
||||
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
|
||||
const modelVersionId = url.searchParams.get("modelVersionId");
|
||||
|
||||
return { url: baseUrl + modelId, modelVersionId };
|
||||
}
|
||||
|
||||
export const addCivitaiCheckpoint = withServerPromise(
|
||||
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
|
||||
const { userId, orgId } = auth();
|
||||
|
||||
if (!data.civitai_url) return { error: "no civitai_url" };
|
||||
if (!userId) return { error: "No user id" };
|
||||
|
||||
const { url, modelVersionId } = getUrl(data?.civitai_url);
|
||||
const civitaiModelRes = await fetch(url)
|
||||
.then((x) => x.json())
|
||||
.then((a) => {
|
||||
return CivitaiModelResponse.parse(a);
|
||||
});
|
||||
|
||||
if (civitaiModelRes?.modelVersions?.length === 0) {
|
||||
return; // no versions to download
|
||||
}
|
||||
|
||||
let selectedModelVersion;
|
||||
let selectedModelVersionId: string | null = modelVersionId;
|
||||
if (!selectedModelVersionId) {
|
||||
selectedModelVersion = civitaiModelRes.modelVersions[0];
|
||||
selectedModelVersionId = civitaiModelRes.modelVersions[0].id.toString();
|
||||
} else {
|
||||
selectedModelVersion = civitaiModelRes.modelVersions.find((version) =>
|
||||
version.id.toString() === selectedModelVersionId
|
||||
);
|
||||
if (!selectedModelVersion) {
|
||||
return; // version id is wrong
|
||||
}
|
||||
selectedModelVersionId = selectedModelVersion?.id.toString();
|
||||
}
|
||||
|
||||
const checkpointVolumes = await getCheckpointVolumes();
|
||||
let cVolume;
|
||||
if (checkpointVolumes.length === 0) {
|
||||
const volume = await addCheckpointVolume();
|
||||
cVolume = volume[0];
|
||||
} else {
|
||||
cVolume = checkpointVolumes[0];
|
||||
}
|
||||
|
||||
const a = await db
|
||||
.insert(checkpointTable)
|
||||
.values({
|
||||
user_id: userId,
|
||||
org_id: orgId,
|
||||
upload_type: "civitai",
|
||||
civitai_id: civitaiModelRes.id.toString(),
|
||||
civitai_version_id: selectedModelVersionId,
|
||||
civitai_url: data.civitai_url,
|
||||
civitai_download_url: selectedModelVersion.downloadUrl,
|
||||
civitai_model_response: civitaiModelRes,
|
||||
checkpoint_volume_id: cVolume.id,
|
||||
})
|
||||
.returning();
|
||||
|
||||
const b = a[0];
|
||||
|
||||
await uploadCheckpoint(data, b, cVolume);
|
||||
redirect(`/checkpoints/${b.id}`);
|
||||
},
|
||||
);
|
||||
|
||||
async function uploadCheckpoint(
|
||||
data: z.infer<typeof addCivitaiCheckpointSchema>,
|
||||
b: CheckpointType,
|
||||
v: CheckpointVolumeType,
|
||||
) {
|
||||
const headersList = headers();
|
||||
|
||||
const domain = headersList.get("x-forwarded-host") || "";
|
||||
const protocol = headersList.get("x-forwarded-proto") || "";
|
||||
|
||||
if (domain === "") {
|
||||
throw new Error("No domain");
|
||||
}
|
||||
|
||||
// Call remote builder
|
||||
const result = await fetch(
|
||||
`${process.env.MODAL_BUILDER_URL!}/upload_volume`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
download_url: data.civitai_url,
|
||||
volume_name: v.volume_name,
|
||||
volume_id: v.id,
|
||||
callback_url: `${protocol}://${domain}/api/volume-updated`,
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
if (!result.ok) {
|
||||
const error_log = await result.text();
|
||||
await db
|
||||
.update(checkpointTable)
|
||||
.set({
|
||||
...data,
|
||||
status: "error",
|
||||
build_log: error_log,
|
||||
})
|
||||
.where(eq(checkpointTable.id, b.id));
|
||||
throw new Error(`Error: ${result.statusText} ${error_log}`);
|
||||
} else {
|
||||
// setting the build machine id
|
||||
const json = await result.json();
|
||||
await db
|
||||
.update(checkpointTable)
|
||||
.set({
|
||||
...data,
|
||||
upload_machine_id: json.build_machine_instance_id,
|
||||
})
|
||||
.where(eq(checkpointTable.id, b.id));
|
||||
}
|
||||
}
|
||||
39
web/src/server/getAllUserCheckpoints.tsx
Normal file
39
web/src/server/getAllUserCheckpoints.tsx
Normal file
@ -0,0 +1,39 @@
|
||||
import { db } from "@/db/db";
|
||||
import {
|
||||
checkpointTable,
|
||||
} from "@/db/schema";
|
||||
import { auth } from "@clerk/nextjs";
|
||||
import { and, desc, eq, isNull } from "drizzle-orm";
|
||||
|
||||
export async function getAllUserCheckpoints() {
|
||||
const { userId, orgId } = await auth();
|
||||
|
||||
if (!userId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const checkpoints = await db.query.checkpointTable.findMany({
|
||||
with: {
|
||||
user: {
|
||||
columns: {
|
||||
name: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
columns: {
|
||||
id: true,
|
||||
updated_at: true,
|
||||
name: true,
|
||||
civitai_url: true,
|
||||
civitai_model_response: true,
|
||||
is_public: true,
|
||||
},
|
||||
orderBy: desc(checkpointTable.updated_at),
|
||||
where:
|
||||
orgId != undefined
|
||||
? eq(checkpointTable.org_id, orgId)
|
||||
: and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
|
||||
});
|
||||
|
||||
return checkpoints;
|
||||
}
|
||||
126
web/src/types/civitai.ts
Normal file
126
web/src/types/civitai.ts
Normal file
@ -0,0 +1,126 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
|
||||
|
||||
export const creatorSchema = z.object({
|
||||
username: z.string().optional(),
|
||||
image: z.string().url().optional(),
|
||||
});
|
||||
|
||||
export const fileMetadataSchema = z.object({
|
||||
fp: z.string().optional(),
|
||||
size: z.string().optional(),
|
||||
format: z.string().optional(),
|
||||
});
|
||||
|
||||
export const fileSchema = z.object({
|
||||
id: z.number(),
|
||||
sizeKB: z.number().optional(),
|
||||
name: z.string(),
|
||||
type: z.string().optional(),
|
||||
metadata: fileMetadataSchema.optional(),
|
||||
pickleScanResult: z.string().optional(),
|
||||
pickleScanMessage: z.string().nullable(),
|
||||
virusScanResult: z.string().optional(),
|
||||
virusScanMessage: z.string().nullable(),
|
||||
scannedAt: z.string().optional(),
|
||||
hashes: z.record(z.string()).optional(),
|
||||
downloadUrl: z.string().url(),
|
||||
primary: z.boolean().optional().optional(),
|
||||
});
|
||||
|
||||
export const imageMetadataSchema = z.object({
|
||||
hash: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
});
|
||||
|
||||
export const imageMetaSchema = z.object({
|
||||
ENSD: z.string().optional(),
|
||||
Size: z.string().optional(),
|
||||
seed: z.number().optional(),
|
||||
Model: z.string().optional(),
|
||||
steps: z.number().optional(),
|
||||
hashes: z.record(z.string()).optional(),
|
||||
prompt: z.string().optional(),
|
||||
sampler: z.string().optional(),
|
||||
cfgScale: z.number().optional(),
|
||||
ClipSkip: z.number().optional(),
|
||||
resources: z.array(
|
||||
z.object({
|
||||
hash: z.string().optional(),
|
||||
name: z.string(),
|
||||
type: z.string(),
|
||||
weight: z.number().optional(),
|
||||
})
|
||||
).optional(),
|
||||
ModelHash: z.string().optional(),
|
||||
HiresSteps: z.string().optional(),
|
||||
HiresUpscale: z.string().optional(),
|
||||
HiresUpscaler: z.string().optional(),
|
||||
negativePrompt: z.string(),
|
||||
DenoisingStrength: z.number().optional(),
|
||||
});
|
||||
|
||||
export const imageSchema = z.object({
|
||||
url: z.string().url().optional(),
|
||||
nsfw: z.enum(["None", "Soft", "Mature"]).optional(),
|
||||
width: z.number().optional(),
|
||||
height: z.number().optional(),
|
||||
hash: z.string().optional(),
|
||||
type: z.string().optional(),
|
||||
metadata: imageMetadataSchema.optional(),
|
||||
meta: imageMetaSchema.optional(),
|
||||
});
|
||||
|
||||
export const modelVersionSchema = z.object({
|
||||
id: z.number(),
|
||||
modelId: z.number(),
|
||||
name: z.string(),
|
||||
createdAt: z.string().optional(),
|
||||
updatedAt: z.string().optional(),
|
||||
status: z.enum(["Published", "Unpublished"]).optional(),
|
||||
publishedAt: z.string().optional(),
|
||||
trainedWords: z.array(z.string()).nullable(),
|
||||
trainingStatus: z.string().nullable(),
|
||||
trainingDetails: z.string().nullable(),
|
||||
baseModel: z.string().optional(),
|
||||
baseModelType: z.string().optional(),
|
||||
earlyAccessTimeFrame: z.number().optional(),
|
||||
description: z.string().nullable(),
|
||||
vaeId: z.string().nullable(),
|
||||
stats: z.object({
|
||||
downloadCount: z.number(),
|
||||
ratingCount: z.number(),
|
||||
rating: z.number(),
|
||||
}).optional(),
|
||||
files: z.array(fileSchema),
|
||||
images: z.array(imageSchema),
|
||||
downloadUrl: z.string().url(),
|
||||
});
|
||||
|
||||
export const statsSchema = z.object({
|
||||
downloadCount: z.number(),
|
||||
favoriteCount: z.number(),
|
||||
commentCount: z.number(),
|
||||
ratingCount: z.number(),
|
||||
rating: z.number(),
|
||||
tippedAmountCount: z.number(),
|
||||
});
|
||||
|
||||
export const CivitaiModelResponse = z.object({
|
||||
id: z.number(),
|
||||
name: z.string().optional(),
|
||||
description: z.string().optional(),
|
||||
type: z.enum(["Checkpoint", "Lora"]),
|
||||
poi: z.boolean().optional(),
|
||||
nsfw: z.boolean().optional(),
|
||||
allowNoCredit: z.boolean().optional(),
|
||||
allowCommercialUse: z.enum(["Rent"]).optional(),
|
||||
allowDerivatives: z.boolean().optional(),
|
||||
allowDifferentLicense: z.boolean().optional(),
|
||||
stats: statsSchema.optional(),
|
||||
creator: creatorSchema.optional(),
|
||||
tags: z.array(z.string()).optional(),
|
||||
modelVersions: z.array(modelVersionSchema),
|
||||
});
|
||||
Loading…
x
Reference in New Issue
Block a user