Merge branch 'nickkao/checkpoint-volume'
# Conflicts: # web/bun.lockb
This commit is contained in:
commit
6f0499c657
@ -8,6 +8,7 @@ from enum import Enum
|
|||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
from uuid import uuid4
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
|
|||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.types import ASGIApp, Scope, Receive, Send
|
from starlette.types import ASGIApp, Scope, Receive, Send
|
||||||
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
# executor = ThreadPoolExecutor(max_workers=5)
|
# executor = ThreadPoolExecutor(max_workers=5)
|
||||||
@ -45,6 +47,7 @@ machine_id_websocket_dict = {}
|
|||||||
machine_id_status = {}
|
machine_id_status = {}
|
||||||
|
|
||||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||||
|
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||||
|
|
||||||
|
|
||||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
|
class FlyReplayMiddleware(BaseHTTPMiddleware):
|
||||||
@ -174,6 +177,7 @@ class Item(BaseModel):
|
|||||||
snapshot: Snapshot
|
snapshot: Snapshot
|
||||||
models: List[Model]
|
models: List[Model]
|
||||||
callback_url: str
|
callback_url: str
|
||||||
|
checkpoint_volume_name: str
|
||||||
gpu: GPUType = Field(default=GPUType.T4)
|
gpu: GPUType = Field(default=GPUType.T4)
|
||||||
|
|
||||||
@field_validator('gpu')
|
@field_validator('gpu')
|
||||||
@ -223,6 +227,103 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
|
|||||||
|
|
||||||
# return {"Hello": "World"}
|
# return {"Hello": "World"}
|
||||||
|
|
||||||
|
class UploadType(str, Enum):
|
||||||
|
checkpoint = "checkpoint"
|
||||||
|
|
||||||
|
class UploadBody(BaseModel):
|
||||||
|
download_url: str
|
||||||
|
volume_name: str
|
||||||
|
volume_id: str
|
||||||
|
checkpoint_id: str
|
||||||
|
upload_type: UploadType
|
||||||
|
callback_url: str
|
||||||
|
|
||||||
|
|
||||||
|
UPLOAD_TYPE_DIR_MAP = {
|
||||||
|
UploadType.checkpoint: "checkpoints"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/upload-volume")
|
||||||
|
async def upload_checkpoint(body: UploadBody):
|
||||||
|
global last_activity_time
|
||||||
|
last_activity_time = time.time()
|
||||||
|
logger.info(f"Extended inactivity time to {global_timeout}")
|
||||||
|
|
||||||
|
asyncio.create_task(upload_logic(body))
|
||||||
|
|
||||||
|
# check that this
|
||||||
|
return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
|
||||||
|
|
||||||
|
async def upload_logic(body: UploadBody):
|
||||||
|
folder_path = f"/app/builds/{body.volume_id}"
|
||||||
|
|
||||||
|
cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
|
||||||
|
await cp_process.wait()
|
||||||
|
|
||||||
|
upload_path = UPLOAD_TYPE_DIR_MAP[body.upload_type]
|
||||||
|
config = {
|
||||||
|
"volume_names": {
|
||||||
|
body.volume_name: {"download_url": body.download_url, "folder_path": upload_path}
|
||||||
|
},
|
||||||
|
"volume_paths": {
|
||||||
|
body.volume_name: f'/volumes/{uuid4()}'
|
||||||
|
},
|
||||||
|
"callback_url": body.callback_url,
|
||||||
|
"callback_body": {
|
||||||
|
"checkpoint_id": body.checkpoint_id,
|
||||||
|
"volume_id": body.volume_id,
|
||||||
|
"folder_path": upload_path,
|
||||||
|
},
|
||||||
|
"civitai_api_key": os.environ.get('CIVITAI_API_KEY')
|
||||||
|
}
|
||||||
|
with open(f"{folder_path}/config.py", "w") as f:
|
||||||
|
f.write("config = " + json.dumps(config))
|
||||||
|
|
||||||
|
process = await asyncio.subprocess.create_subprocess_shell(
|
||||||
|
f"modal run app.py",
|
||||||
|
# stdout=asyncio.subprocess.PIPE,
|
||||||
|
# stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=folder_path,
|
||||||
|
env={**os.environ, "COLUMNS": "10000"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# error_logs = []
|
||||||
|
# async def read_stream(stream):
|
||||||
|
# while True:
|
||||||
|
# line = await stream.readline()
|
||||||
|
# if line:
|
||||||
|
# l = line.decode('utf-8').strip()
|
||||||
|
# error_logs.append(l)
|
||||||
|
# logger.error(l)
|
||||||
|
# error_logs.append({
|
||||||
|
# "logs": l,
|
||||||
|
# "timestamp": time.time()
|
||||||
|
# })
|
||||||
|
# else:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# stderr_read_task = asyncio.create_task(read_stream(process.stderr))
|
||||||
|
#
|
||||||
|
# await asyncio.wait([stderr_read_task])
|
||||||
|
# await process.wait()
|
||||||
|
|
||||||
|
# if process.returncode != 0:
|
||||||
|
# error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
|
||||||
|
# # Error handling: send POST request to callback URL with error details
|
||||||
|
# requests.post(body.callback_url, json={
|
||||||
|
# "volume_id": body.volume_id,
|
||||||
|
# "checkpoint_id": body.checkpoint_id,
|
||||||
|
# "folder_path": upload_path,
|
||||||
|
# "error_logs": json.dumps(error_logs),
|
||||||
|
# "status": "failed"
|
||||||
|
# })
|
||||||
|
#
|
||||||
|
# requests.post(body.callback_url, json={
|
||||||
|
# "checkpoint_id": body.checkpoint_id,
|
||||||
|
# "volume_id": body.volume_id,
|
||||||
|
# "folder_path": upload_path,
|
||||||
|
# "status": "success"
|
||||||
|
# })
|
||||||
|
|
||||||
@app.post("/create")
|
@app.post("/create")
|
||||||
async def create_machine(item: Item):
|
async def create_machine(item: Item):
|
||||||
@ -312,7 +413,9 @@ async def build_logic(item: Item):
|
|||||||
config = {
|
config = {
|
||||||
"name": item.name,
|
"name": item.name,
|
||||||
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
|
"deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
|
||||||
"gpu": item.gpu
|
"gpu": item.gpu,
|
||||||
|
"public_checkpoint_volume": "model-store",
|
||||||
|
"private_checkpoint_volume": item.checkpoint_volume_name
|
||||||
}
|
}
|
||||||
with open(f"{folder_path}/config.py", "w") as f:
|
with open(f"{folder_path}/config.py", "w") as f:
|
||||||
f.write("config = " + json.dumps(config))
|
f.write("config = " + json.dumps(config))
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
from config import config
|
from config import config
|
||||||
import modal
|
import modal
|
||||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app
|
from modal import Image, Mount, web_endpoint, Stub, asgi_app
|
||||||
import json
|
import json
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
from volume_setup import volumes
|
||||||
|
|
||||||
# deploy_test = False
|
# deploy_test = False
|
||||||
|
|
||||||
@ -27,8 +28,8 @@ deploy_test = config["deploy_test"] == "True"
|
|||||||
web_app = FastAPI()
|
web_app = FastAPI()
|
||||||
print(config)
|
print(config)
|
||||||
print("deploy_test ", deploy_test)
|
print("deploy_test ", deploy_test)
|
||||||
|
print('volumes', volumes)
|
||||||
stub = Stub(name=config["name"])
|
stub = Stub(name=config["name"])
|
||||||
# print(stub.app_id)
|
|
||||||
|
|
||||||
if not deploy_test:
|
if not deploy_test:
|
||||||
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
||||||
@ -52,11 +53,13 @@ if not deploy_test:
|
|||||||
"cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
|
"cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
|
||||||
"cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
|
"cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
|
||||||
)
|
)
|
||||||
|
.run_commands(f"cat /comfyui/server.py")
|
||||||
|
.run_commands(f"ls /comfyui/app")
|
||||||
# .run_commands(
|
# .run_commands(
|
||||||
# # Install comfy deploy
|
# # Install comfy deploy
|
||||||
# "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
|
# "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
|
||||||
# )
|
# )
|
||||||
# .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
|
.copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
|
||||||
|
|
||||||
.copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
|
.copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
|
||||||
.run_commands("chmod +x /start.sh")
|
.run_commands("chmod +x /start.sh")
|
||||||
@ -153,8 +156,9 @@ image = Image.debian_slim()
|
|||||||
|
|
||||||
target_image = image if deploy_test else dockerfile_image
|
target_image = image if deploy_test else dockerfile_image
|
||||||
|
|
||||||
|
@stub.function(image=target_image, gpu=config["gpu"]
|
||||||
@stub.function(image=target_image, gpu=config["gpu"])
|
,volumes=volumes
|
||||||
|
)
|
||||||
def run(input: Input):
|
def run(input: Input):
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@ -163,6 +167,7 @@ def run(input: Input):
|
|||||||
|
|
||||||
command = ["python", "main.py",
|
command = ["python", "main.py",
|
||||||
"--disable-auto-launch", "--disable-metadata"]
|
"--disable-auto-launch", "--disable-metadata"]
|
||||||
|
|
||||||
server_process = subprocess.Popen(command, cwd="/comfyui")
|
server_process = subprocess.Popen(command, cwd="/comfyui")
|
||||||
|
|
||||||
check_server(
|
check_server(
|
||||||
@ -235,7 +240,9 @@ async def bar(request_input: RequestInput):
|
|||||||
# pass
|
# pass
|
||||||
|
|
||||||
|
|
||||||
@stub.function(image=image)
|
@stub.function(image=image
|
||||||
|
,volumes=volumes
|
||||||
|
)
|
||||||
@asgi_app()
|
@asgi_app()
|
||||||
def comfyui_api():
|
def comfyui_api():
|
||||||
return web_app
|
return web_app
|
||||||
@ -285,6 +292,7 @@ def spawn_comfyui_in_background():
|
|||||||
# to be on a single container.
|
# to be on a single container.
|
||||||
concurrency_limit=1,
|
concurrency_limit=1,
|
||||||
timeout=10 * 60,
|
timeout=10 * 60,
|
||||||
|
volumes=volumes,
|
||||||
)
|
)
|
||||||
@asgi_app()
|
@asgi_app()
|
||||||
def comfyui_app():
|
def comfyui_app():
|
||||||
@ -303,4 +311,4 @@ def comfyui_app():
|
|||||||
},
|
},
|
||||||
)()
|
)()
|
||||||
|
|
||||||
return make_simple_proxy_app(ProxyContext(config))
|
return make_simple_proxy_app(ProxyContext(config))
|
||||||
|
@ -1 +1,7 @@
|
|||||||
config = {"name": "my-app", "deploy_test": "True", "gpu": "T4"}
|
config = {
|
||||||
|
"name": "my-app",
|
||||||
|
"deploy_test": "True",
|
||||||
|
"gpu": "T4",
|
||||||
|
"public_checkpoint_volume": "model-store",
|
||||||
|
"private_checkpoint_volume": "private-model-store"
|
||||||
|
}
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
comfyui:
|
public:
|
||||||
base_path: /runpod-volume/ComfyUI/
|
base_path: /public_models/
|
||||||
checkpoints: models/checkpoints/
|
checkpoints: checkpoints
|
||||||
clip: models/clip/
|
clip: clip
|
||||||
clip_vision: models/clip_vision/
|
clip_vision: clip_vision
|
||||||
configs: models/configs/
|
configs: configs
|
||||||
controlnet: models/controlnet/
|
controlnet: controlnet
|
||||||
embeddings: models/embeddings/
|
embeddings: embeddings
|
||||||
loras: models/loras/
|
loras: loras
|
||||||
upscale_models: models/upscale_models/
|
upscale_models: upscale_models
|
||||||
vae: models/vae/
|
vae: vae
|
||||||
|
|
||||||
|
private:
|
||||||
|
base_path: /private_models/
|
||||||
|
checkpoints: checkpoints
|
||||||
|
@ -54,4 +54,4 @@ for model in models:
|
|||||||
|
|
||||||
# Close the server
|
# Close the server
|
||||||
server_process.terminate()
|
server_process.terminate()
|
||||||
print("Finished installing dependencies.")
|
print("Finished installing dependencies.")
|
||||||
|
9
builder/modal-builder/src/template/volume_setup.py
Normal file
9
builder/modal-builder/src/template/volume_setup.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import modal
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
|
||||||
|
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
|
||||||
|
|
||||||
|
PUBLIC_BASEMODEL_DIR = "/public_models"
|
||||||
|
PRIVATE_BASEMODEL_DIR = "/private_models"
|
||||||
|
volumes = {PUBLIC_BASEMODEL_DIR: public_model_volume, PRIVATE_BASEMODEL_DIR: private_volume}
|
74
builder/modal-builder/src/volume-builder/app.py
Normal file
74
builder/modal-builder/src/volume-builder/app.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import modal
|
||||||
|
from config import config
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
stub = modal.Stub()
|
||||||
|
|
||||||
|
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
|
||||||
|
def is_valid_name(name: str) -> bool:
|
||||||
|
allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
|
||||||
|
return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
|
||||||
|
|
||||||
|
def create_volumes(volume_names, paths):
|
||||||
|
path_to_vol = {}
|
||||||
|
for volume_name in volume_names.keys():
|
||||||
|
if not is_valid_name(volume_name):
|
||||||
|
pass
|
||||||
|
modal_volume = modal.Volume.persisted(volume_name)
|
||||||
|
path_to_vol[paths[volume_name]] = modal_volume
|
||||||
|
|
||||||
|
return path_to_vol
|
||||||
|
|
||||||
|
vol_name_to_links = config["volume_names"]
|
||||||
|
vol_name_to_path = config["volume_paths"]
|
||||||
|
callback_url = config["callback_url"]
|
||||||
|
callback_body = config["callback_body"]
|
||||||
|
civitai_key = config["civitai_api_key"]
|
||||||
|
|
||||||
|
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
|
||||||
|
image = (
|
||||||
|
modal.Image.debian_slim().apt_install("wget").pip_install("requests")
|
||||||
|
)
|
||||||
|
|
||||||
|
# download config { "download_url": "", "folder_path": ""}
|
||||||
|
timeout=5000
|
||||||
|
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
|
||||||
|
def download_model(volume_name, download_config):
|
||||||
|
import requests
|
||||||
|
download_url = download_config["download_url"]
|
||||||
|
folder_path = download_config["folder_path"]
|
||||||
|
|
||||||
|
volume_base_path = vol_name_to_path[volume_name]
|
||||||
|
model_store_path = os.path.join(volume_base_path, folder_path)
|
||||||
|
modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
|
||||||
|
print('downloading', modified_download_url)
|
||||||
|
|
||||||
|
subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
|
||||||
|
subprocess.run(["ls", "-la", volume_base_path])
|
||||||
|
subprocess.run(["ls", "-la", model_store_path])
|
||||||
|
volumes[volume_base_path].commit()
|
||||||
|
|
||||||
|
|
||||||
|
status = {"status": "success"}
|
||||||
|
requests.post(callback_url, json={**status, **callback_body})
|
||||||
|
print(f"finished! sending to {callback_url}")
|
||||||
|
pprint({**status, **callback_body})
|
||||||
|
|
||||||
|
@stub.local_entrypoint()
|
||||||
|
def simple_download():
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
|
||||||
|
except modal.exception.FunctionTimeoutError as e:
|
||||||
|
status = {"status": "failed", "error_logs": f"{str(e)}", "timeout": timeout}
|
||||||
|
requests.post(callback_url, json={**status, **callback_body})
|
||||||
|
print(f"finished! sending to {callback_url}")
|
||||||
|
pprint({**status, **callback_body})
|
||||||
|
except Exception as e:
|
||||||
|
status = {"status": "failed", "error_logs": str(e)}
|
||||||
|
requests.post(callback_url, json={**status, **callback_body})
|
||||||
|
print(f"finished! sending to {callback_url}")
|
||||||
|
pprint({**status, **callback_body})
|
||||||
|
|
18
builder/modal-builder/src/volume-builder/config.py
Normal file
18
builder/modal-builder/src/volume-builder/config.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
config = {
|
||||||
|
"volume_names": {
|
||||||
|
"test": {
|
||||||
|
"download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png",
|
||||||
|
"folder_path": "images"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"volume_paths": {
|
||||||
|
"test": "/volumes/something"
|
||||||
|
},
|
||||||
|
"callback_url": "",
|
||||||
|
"callback_body": {
|
||||||
|
"checkpoint_id": "",
|
||||||
|
"volume_id": "",
|
||||||
|
"folder_path": "images",
|
||||||
|
},
|
||||||
|
"civitai_api_key": "",
|
||||||
|
}
|
64
web/drizzle/0042_windy_madelyne_pryor.sql
Normal file
64
web/drizzle/0042_windy_madelyne_pryor.sql
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE "model_upload_type" AS ENUM('civitai', 'huggingface', 'other');
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE "resource_upload" AS ENUM('started', 'success', 'failed');
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
|
||||||
|
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
"user_id" text,
|
||||||
|
"org_id" text,
|
||||||
|
"description" text,
|
||||||
|
"checkpoint_volume_id" uuid NOT NULL,
|
||||||
|
"model_name" text,
|
||||||
|
"folder_path" text,
|
||||||
|
"civitai_id" text,
|
||||||
|
"civitai_version_id" text,
|
||||||
|
"civitai_url" text,
|
||||||
|
"civitai_download_url" text,
|
||||||
|
"civitai_model_response" jsonb,
|
||||||
|
"hf_url" text,
|
||||||
|
"s3_url" text,
|
||||||
|
"client_url" text,
|
||||||
|
"is_public" boolean DEFAULT false NOT NULL,
|
||||||
|
"status" "resource_upload" DEFAULT 'started' NOT NULL,
|
||||||
|
"upload_machine_id" text,
|
||||||
|
"upload_type" "model_upload_type" NOT NULL,
|
||||||
|
"error_log" text,
|
||||||
|
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||||
|
"updated_at" timestamp DEFAULT now() NOT NULL
|
||||||
|
);
|
||||||
|
--> statement-breakpoint
|
||||||
|
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
|
||||||
|
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
"user_id" text,
|
||||||
|
"org_id" text,
|
||||||
|
"volume_name" text NOT NULL,
|
||||||
|
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||||
|
"updated_at" timestamp DEFAULT now() NOT NULL,
|
||||||
|
"disabled" boolean DEFAULT false NOT NULL
|
||||||
|
);
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk" FOREIGN KEY ("checkpoint_volume_id") REFERENCES "comfyui_deploy"."checkpoint_volume"("id") ON DELETE cascade ON UPDATE no action;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
||||||
|
--> statement-breakpoint
|
||||||
|
DO $$ BEGIN
|
||||||
|
ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_object THEN null;
|
||||||
|
END $$;
|
1273
web/drizzle/meta/0042_snapshot.json
Normal file
1273
web/drizzle/meta/0042_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -295,6 +295,13 @@
|
|||||||
"when": 1706111421524,
|
"when": 1706111421524,
|
||||||
"tag": "0041_thick_norrin_radd",
|
"tag": "0041_thick_norrin_radd",
|
||||||
"breakpoints": true
|
"breakpoints": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"idx": 42,
|
||||||
|
"version": "5",
|
||||||
|
"when": 1706164614659,
|
||||||
|
"tag": "0042_windy_madelyne_pryor",
|
||||||
|
"breakpoints": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
55
web/src/app/(app)/api/volume-upload/route.ts
Normal file
55
web/src/app/(app)/api/volume-upload/route.ts
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import { parseDataSafe } from "../../../../lib/parseDataSafe";
|
||||||
|
import { db } from "@/db/db";
|
||||||
|
import { checkpointTable, machinesTable } from "@/db/schema";
|
||||||
|
import { eq } from "drizzle-orm";
|
||||||
|
import { NextResponse } from "next/server";
|
||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
const Request = z.object({
|
||||||
|
volume_id: z.string(),
|
||||||
|
checkpoint_id: z.string(),
|
||||||
|
folder_path: z.string().optional(),
|
||||||
|
status: z.enum(['success', 'failed']),
|
||||||
|
error_log: z.string().optional(),
|
||||||
|
timeout: z.number().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export async function POST(request: Request) {
|
||||||
|
const [data, error] = await parseDataSafe(Request, request);
|
||||||
|
if (!data || error) return error;
|
||||||
|
|
||||||
|
const { checkpoint_id, error_log, status, folder_path } = data;
|
||||||
|
console.log( checkpoint_id, error_log, status, folder_path )
|
||||||
|
|
||||||
|
if (status === "success") {
|
||||||
|
await db
|
||||||
|
.update(checkpointTable)
|
||||||
|
.set({
|
||||||
|
status: "success",
|
||||||
|
folder_path,
|
||||||
|
updated_at: new Date(),
|
||||||
|
// build_log: build_log,
|
||||||
|
})
|
||||||
|
.where(eq(checkpointTable.id, checkpoint_id));
|
||||||
|
} else {
|
||||||
|
await db
|
||||||
|
.update(checkpointTable)
|
||||||
|
.set({
|
||||||
|
status: "failed",
|
||||||
|
error_log,
|
||||||
|
updated_at: new Date(),
|
||||||
|
// status: "error",
|
||||||
|
// build_log: build_log,
|
||||||
|
})
|
||||||
|
.where(eq(checkpointTable.id, checkpoint_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
return NextResponse.json(
|
||||||
|
{
|
||||||
|
message: "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
status: 200,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
9
web/src/app/(app)/storage/loading.tsx
Normal file
9
web/src/app/(app)/storage/loading.tsx
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { LoadingPageWrapper } from "@/components/LoadingWrapper";
|
||||||
|
import { usePathname } from "next/navigation";
|
||||||
|
|
||||||
|
export default function Loading() {
|
||||||
|
const pathName = usePathname();
|
||||||
|
return <LoadingPageWrapper className="h-full" tag={pathName.toLowerCase()} />;
|
||||||
|
}
|
35
web/src/app/(app)/storage/page.tsx
Normal file
35
web/src/app/(app)/storage/page.tsx
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import { setInitialUserData } from "../../../lib/setInitialUserData";
|
||||||
|
import { auth } from "@clerk/nextjs";
|
||||||
|
import { clerkClient } from "@clerk/nextjs/server";
|
||||||
|
import { CheckpointList } from "@/components/CheckpointList";
|
||||||
|
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
|
||||||
|
|
||||||
|
export default function Page() {
|
||||||
|
return <CheckpointListServer />;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function CheckpointListServer() {
|
||||||
|
const { userId } = auth();
|
||||||
|
|
||||||
|
if (!userId) {
|
||||||
|
return <div>No auth</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const user = await clerkClient.users.getUser(userId);
|
||||||
|
|
||||||
|
if (!user) {
|
||||||
|
await setInitialUserData(userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
const checkpoints = await getAllUserCheckpoints();
|
||||||
|
|
||||||
|
if (!checkpoints) {
|
||||||
|
return <div>No checkpoints found</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full">
|
||||||
|
<CheckpointList data={checkpoints} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
373
web/src/components/CheckpointList.tsx
Normal file
373
web/src/components/CheckpointList.tsx
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { getRelativeTime } from "../lib/getRelativeTime";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Checkbox } from "@/components/ui/checkbox";
|
||||||
|
import { InsertModal, UpdateModal } from "./InsertModal";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from "@/components/ui/table";
|
||||||
|
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
|
||||||
|
import type {
|
||||||
|
ColumnDef,
|
||||||
|
ColumnFiltersState,
|
||||||
|
SortingState,
|
||||||
|
VisibilityState,
|
||||||
|
} from "@tanstack/react-table";
|
||||||
|
import {
|
||||||
|
flexRender,
|
||||||
|
getCoreRowModel,
|
||||||
|
getFilteredRowModel,
|
||||||
|
getPaginationRowModel,
|
||||||
|
getSortedRowModel,
|
||||||
|
useReactTable,
|
||||||
|
} from "@tanstack/react-table";
|
||||||
|
import { ArrowUpDown } from "lucide-react";
|
||||||
|
import * as React from "react";
|
||||||
|
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
|
||||||
|
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
|
||||||
|
|
||||||
|
export type CheckpointItemList = NonNullable<
|
||||||
|
Awaited<ReturnType<typeof getAllUserCheckpoints>>
|
||||||
|
>[0];
|
||||||
|
|
||||||
|
export const columns: ColumnDef<CheckpointItemList>[] = [
|
||||||
|
{
|
||||||
|
accessorKey: "id",
|
||||||
|
id: "select",
|
||||||
|
header: ({ table }) => (
|
||||||
|
<Checkbox
|
||||||
|
checked={
|
||||||
|
table.getIsAllPageRowsSelected() ||
|
||||||
|
(table.getIsSomePageRowsSelected() && "indeterminate")
|
||||||
|
}
|
||||||
|
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
|
||||||
|
aria-label="Select all"
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<Checkbox
|
||||||
|
checked={row.getIsSelected()}
|
||||||
|
onCheckedChange={(value) => row.toggleSelected(!!value)}
|
||||||
|
aria-label="Select row"
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
enableSorting: false,
|
||||||
|
enableHiding: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "model_name",
|
||||||
|
header: ({ column }) => {
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className="flex items-center hover:underline"
|
||||||
|
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||||
|
>
|
||||||
|
Model Name
|
||||||
|
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const checkpoint = row.original;
|
||||||
|
return (
|
||||||
|
<a
|
||||||
|
className="hover:underline flex gap-2"
|
||||||
|
href={`/storage/${checkpoint.id}`} // TODO
|
||||||
|
>
|
||||||
|
<span className="truncate max-w-[200px]">
|
||||||
|
{row.original.model_name}
|
||||||
|
</span>
|
||||||
|
|
||||||
|
{checkpoint.is_public ? (
|
||||||
|
<Badge variant="green">Public</Badge>
|
||||||
|
) : (
|
||||||
|
<Badge variant="orange">Private</Badge>
|
||||||
|
)}
|
||||||
|
</a>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "status",
|
||||||
|
header: ({ column }) => {
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className="flex items-center hover:underline"
|
||||||
|
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||||
|
>
|
||||||
|
Status
|
||||||
|
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
cell: ({ row }) => {
|
||||||
|
return (
|
||||||
|
<Badge variant={row.original.status === "failed" ? "red" : (row.original.status === "started" ? "yellow" : "green")}>
|
||||||
|
{row.original.status}
|
||||||
|
</Badge>
|
||||||
|
);
|
||||||
|
// NOTE: retry downloads on failures
|
||||||
|
// const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
|
||||||
|
// const lastUpdated = new Date(row.original.updated_at);
|
||||||
|
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||||
|
// const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||||
|
// cell: ({ row }) => {
|
||||||
|
// // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
|
||||||
|
// // const lastUpdated = new Date(row.original.updated_at);
|
||||||
|
// // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
|
||||||
|
// const canReDownload = true;
|
||||||
|
//
|
||||||
|
// return (
|
||||||
|
// <div className="flex items-center space-x-2">
|
||||||
|
// <Badge
|
||||||
|
// variant={row.original.status === "failed"
|
||||||
|
// ? "red"
|
||||||
|
// : row.original.status === "started"
|
||||||
|
// ? "yellow"
|
||||||
|
// : "green"}
|
||||||
|
// >
|
||||||
|
// {row.original.status}
|
||||||
|
// </Badge>
|
||||||
|
// {canReDownload && (
|
||||||
|
// <RefreshCcw
|
||||||
|
// onClick={() => {
|
||||||
|
// redownloadCheckpoint(row.original);
|
||||||
|
// }}
|
||||||
|
// className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
|
||||||
|
// />
|
||||||
|
// )}
|
||||||
|
// </div>
|
||||||
|
// );
|
||||||
|
// },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "upload_type",
|
||||||
|
header: ({ column }) => {
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className="flex items-center hover:underline"
|
||||||
|
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||||
|
>
|
||||||
|
Source
|
||||||
|
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
cell: ({ row }) => {
|
||||||
|
return <Badge variant="cyan">{row.original.upload_type}</Badge>;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "date",
|
||||||
|
sortingFn: "datetime",
|
||||||
|
enableSorting: true,
|
||||||
|
header: ({ column }) => {
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className="w-full flex items-center justify-end hover:underline truncate"
|
||||||
|
// variant="ghost"
|
||||||
|
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||||
|
>
|
||||||
|
Update Date
|
||||||
|
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="w-full capitalize text-right truncate">
|
||||||
|
{getRelativeTime(row.original.updated_at)}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
// TODO: deletion and editing for future sprint
|
||||||
|
// {
|
||||||
|
// id: "actions",
|
||||||
|
// enableHiding: false,
|
||||||
|
// cell: ({ row }) => {
|
||||||
|
// const checkpoint = row.original;
|
||||||
|
//
|
||||||
|
// return (
|
||||||
|
// <DropdownMenu>
|
||||||
|
// <DropdownMenuTrigger asChild>
|
||||||
|
// <Button variant="ghost" className="h-8 w-8 p-0">
|
||||||
|
// <span className="sr-only">Open menu</span>
|
||||||
|
// <MoreHorizontal className="h-4 w-4" />
|
||||||
|
// </Button>
|
||||||
|
// </DropdownMenuTrigger>
|
||||||
|
// <DropdownMenuContent align="end">
|
||||||
|
// <DropdownMenuLabel>Actions</DropdownMenuLabel>
|
||||||
|
// <DropdownMenuItem
|
||||||
|
// className="text-destructive"
|
||||||
|
// onClick={() => {
|
||||||
|
// deleteWorkflow(checkpoint.id);
|
||||||
|
// }}
|
||||||
|
// >
|
||||||
|
// Delete Workflow
|
||||||
|
// </DropdownMenuItem>
|
||||||
|
// </DropdownMenuContent>
|
||||||
|
// </DropdownMenu>
|
||||||
|
// );
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
];
|
||||||
|
|
||||||
|
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
|
||||||
|
const [sorting, setSorting] = React.useState<SortingState>([]);
|
||||||
|
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
const [columnVisibility, setColumnVisibility] =
|
||||||
|
React.useState<VisibilityState>({});
|
||||||
|
const [rowSelection, setRowSelection] = React.useState({});
|
||||||
|
|
||||||
|
const table = useReactTable({
|
||||||
|
data,
|
||||||
|
columns,
|
||||||
|
onSortingChange: setSorting,
|
||||||
|
onColumnFiltersChange: setColumnFilters,
|
||||||
|
getCoreRowModel: getCoreRowModel(),
|
||||||
|
getPaginationRowModel: getPaginationRowModel(),
|
||||||
|
getSortedRowModel: getSortedRowModel(),
|
||||||
|
getFilteredRowModel: getFilteredRowModel(),
|
||||||
|
onColumnVisibilityChange: setColumnVisibility,
|
||||||
|
onRowSelectionChange: setRowSelection,
|
||||||
|
state: {
|
||||||
|
sorting,
|
||||||
|
columnFilters,
|
||||||
|
columnVisibility,
|
||||||
|
rowSelection,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="grid grid-rows-[auto,1fr,auto] h-full">
|
||||||
|
<div className="flex flex-row w-full items-center py-4">
|
||||||
|
<Input
|
||||||
|
placeholder="Filter workflows..."
|
||||||
|
value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
|
||||||
|
onChange={(event) =>
|
||||||
|
table.getColumn("name")?.setFilterValue(event.target.value)
|
||||||
|
}
|
||||||
|
className="max-w-sm"
|
||||||
|
/>
|
||||||
|
<div className="ml-auto flex gap-2">
|
||||||
|
<InsertModal
|
||||||
|
dialogClassName="sm:max-w-[600px]"
|
||||||
|
disabled={
|
||||||
|
false
|
||||||
|
// TODO: limitations based on plan
|
||||||
|
}
|
||||||
|
tooltip={"Add models using their civitai url!"}
|
||||||
|
title="Civitai Checkpoint"
|
||||||
|
description="Pick a model from civitai"
|
||||||
|
serverAction={addCivitaiCheckpoint}
|
||||||
|
formSchema={addCivitaiCheckpointSchema}
|
||||||
|
fieldConfig={{
|
||||||
|
civitai_url: {
|
||||||
|
fieldType: "fallback",
|
||||||
|
inputProps: { required: true },
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
Pick a checkpoint from{" "}
|
||||||
|
<a
|
||||||
|
href="https://www.civitai.com/models"
|
||||||
|
target="_blank"
|
||||||
|
className="underline text-blue-600 hover:text-blue-800 visited:text-purple-600"
|
||||||
|
>
|
||||||
|
civitai.com
|
||||||
|
</a>{" "}
|
||||||
|
and place it's url here
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<ScrollArea className="h-full w-full rounded-md border">
|
||||||
|
<Table>
|
||||||
|
<TableHeader className="bg-background top-0 sticky">
|
||||||
|
{table.getHeaderGroups().map((headerGroup) => (
|
||||||
|
<TableRow key={headerGroup.id}>
|
||||||
|
{headerGroup.headers.map((header) => {
|
||||||
|
return (
|
||||||
|
<TableHead key={header.id}>
|
||||||
|
{header.isPlaceholder
|
||||||
|
? null
|
||||||
|
: flexRender(
|
||||||
|
header.column.columnDef.header,
|
||||||
|
header.getContext()
|
||||||
|
)}
|
||||||
|
</TableHead>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{table.getRowModel().rows?.length ? (
|
||||||
|
table.getRowModel().rows.map((row) => (
|
||||||
|
<TableRow
|
||||||
|
key={row.id}
|
||||||
|
data-state={row.getIsSelected() && "selected"}
|
||||||
|
>
|
||||||
|
{row.getVisibleCells().map((cell) => (
|
||||||
|
<TableCell key={cell.id}>
|
||||||
|
{flexRender(
|
||||||
|
cell.column.columnDef.cell,
|
||||||
|
cell.getContext()
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
|
))}
|
||||||
|
</TableRow>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<TableRow>
|
||||||
|
<TableCell
|
||||||
|
colSpan={columns.length}
|
||||||
|
className="h-24 text-center"
|
||||||
|
>
|
||||||
|
No results.
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
)}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</ScrollArea>
|
||||||
|
<div className="flex flex-row items-center justify-end space-x-2 py-4">
|
||||||
|
<div className="flex-1 text-sm text-muted-foreground">
|
||||||
|
{table.getFilteredSelectedRowModel().rows.length} of{" "}
|
||||||
|
{table.getFilteredRowModel().rows.length} row(s) selected.
|
||||||
|
</div>
|
||||||
|
<div className="space-x-2">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => table.previousPage()}
|
||||||
|
disabled={!table.getCanPreviousPage()}
|
||||||
|
>
|
||||||
|
Previous
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => table.nextPage()}
|
||||||
|
disabled={!table.getCanNextPage()}
|
||||||
|
>
|
||||||
|
Next
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -34,6 +34,10 @@ export function NavbarMenu({ className }: { className?: string }) {
|
|||||||
name: "API Keys",
|
name: "API Keys",
|
||||||
path: "/api-keys",
|
path: "/api-keys",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Storage",
|
||||||
|
path: "/storage",
|
||||||
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -42,9 +46,9 @@ export function NavbarMenu({ className }: { className?: string }) {
|
|||||||
{isDesktop && (
|
{isDesktop && (
|
||||||
<Tabs
|
<Tabs
|
||||||
defaultValue={pathname}
|
defaultValue={pathname}
|
||||||
className="w-[300px] flex pointer-events-auto"
|
className="w-[400px] flex pointer-events-auto"
|
||||||
>
|
>
|
||||||
<TabsList className="grid w-full grid-cols-3">
|
<TabsList className="grid w-full grid-cols-4">
|
||||||
{pages.map((page) => (
|
{pages.map((page) => (
|
||||||
<TabsTrigger
|
<TabsTrigger
|
||||||
key={page.name}
|
key={page.name}
|
||||||
|
86
web/src/components/custom-form/checkpoint-input.tsx
Normal file
86
web/src/components/custom-form/checkpoint-input.tsx
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
// NOTE: this is WIP for doing client side validation for civitai model downloading
|
||||||
|
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
|
||||||
|
import { FormControl, FormItem, FormLabel } from "../ui/form";
|
||||||
|
import { LoadingIcon } from "@/components/LoadingIcon";
|
||||||
|
import * as React from "react";
|
||||||
|
import AutoFormInput from "../ui/auto-form/fields/input";
|
||||||
|
import { useDebouncedCallback } from "use-debounce";
|
||||||
|
import { CivitaiModelResponse } from "@/types/civitai";
|
||||||
|
import { z } from "zod";
|
||||||
|
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
||||||
|
|
||||||
|
function getUrl(civitai_url: string) {
|
||||||
|
// expect to be a URL to be https://civitai.com/models/36520
|
||||||
|
// possiblity with slugged name and query-param modelVersionId
|
||||||
|
|
||||||
|
const baseUrl = "https://civitai.com/api/v1/models/";
|
||||||
|
const url = new URL(civitai_url);
|
||||||
|
const pathSegments = url.pathname.split("/");
|
||||||
|
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
|
||||||
|
const modelVersionId = url.searchParams.get("modelVersionId");
|
||||||
|
|
||||||
|
return { url: baseUrl + modelId, modelVersionId };
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function AutoFormCheckpointInput(
|
||||||
|
props: AutoFormInputComponentProps
|
||||||
|
) {
|
||||||
|
const [loading, setLoading] = React.useState(false);
|
||||||
|
const [modelRes, setModelRes] =
|
||||||
|
React.useState<z.infer<typeof CivitaiModelResponse>>();
|
||||||
|
const [modelVersionid, setModelVersionId] = React.useState<string | null>();
|
||||||
|
const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
|
||||||
|
|
||||||
|
const handleSearch = useDebouncedCallback((search) => {
|
||||||
|
const validationResult =
|
||||||
|
insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search);
|
||||||
|
if (!validationResult.success) {
|
||||||
|
console.error(validationResult.error);
|
||||||
|
// Optionally set an error state here
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
|
const controller = new AbortController();
|
||||||
|
const { url, modelVersionId: versionId } = getUrl(search);
|
||||||
|
setModelVersionId(versionId);
|
||||||
|
fetch(url, {
|
||||||
|
signal: controller.signal,
|
||||||
|
})
|
||||||
|
.then((x) => x.json())
|
||||||
|
.then((a) => {
|
||||||
|
const res = CivitaiModelResponse.parse(a);
|
||||||
|
console.log(a);
|
||||||
|
console.log(res);
|
||||||
|
setModelRes(res);
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
controller.abort();
|
||||||
|
setLoading(false);
|
||||||
|
};
|
||||||
|
}, 300);
|
||||||
|
|
||||||
|
const modifiedField = {
|
||||||
|
...fieldProps,
|
||||||
|
// onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
// handleSearch(event.target.value);
|
||||||
|
// },
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormItem>
|
||||||
|
{fieldConfigItem.inputProps?.showLabel && (
|
||||||
|
<FormLabel>
|
||||||
|
{label}
|
||||||
|
{isRequired && <span className="text-destructive">*</span>}
|
||||||
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<FormControl>
|
||||||
|
<AutoFormInput {...props} fieldProps={modifiedField} />
|
||||||
|
</FormControl>
|
||||||
|
</FormItem>
|
||||||
|
);
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
import { CivitaiModelResponse } from "@/types/civitai";
|
||||||
import { type InferSelectModel, relations } from "drizzle-orm";
|
import { type InferSelectModel, relations } from "drizzle-orm";
|
||||||
import {
|
import {
|
||||||
boolean,
|
boolean,
|
||||||
@ -92,7 +93,7 @@ export const workflowVersionRelations = relations(
|
|||||||
fields: [workflowVersionTable.workflow_id],
|
fields: [workflowVersionTable.workflow_id],
|
||||||
references: [workflowTable.id],
|
references: [workflowTable.id],
|
||||||
}),
|
}),
|
||||||
}),
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
export const workflowRunStatus = pgEnum("workflow_run_status", [
|
export const workflowRunStatus = pgEnum("workflow_run_status", [
|
||||||
@ -141,7 +142,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
|
|||||||
() => workflowVersionTable.id,
|
() => workflowVersionTable.id,
|
||||||
{
|
{
|
||||||
onDelete: "set null",
|
onDelete: "set null",
|
||||||
},
|
}
|
||||||
),
|
),
|
||||||
workflow_inputs:
|
workflow_inputs:
|
||||||
jsonb("workflow_inputs").$type<Record<string, string | number>>(),
|
jsonb("workflow_inputs").$type<Record<string, string | number>>(),
|
||||||
@ -181,7 +182,7 @@ export const workflowRunRelations = relations(
|
|||||||
fields: [workflowRunsTable.workflow_id],
|
fields: [workflowRunsTable.workflow_id],
|
||||||
references: [workflowTable.id],
|
references: [workflowTable.id],
|
||||||
}),
|
}),
|
||||||
}),
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
// We still want to keep the workflow run record.
|
// We still want to keep the workflow run record.
|
||||||
@ -205,7 +206,7 @@ export const workflowOutputRelations = relations(
|
|||||||
fields: [workflowRunOutputs.run_id],
|
fields: [workflowRunOutputs.run_id],
|
||||||
references: [workflowRunsTable.id],
|
references: [workflowRunsTable.id],
|
||||||
}),
|
}),
|
||||||
}),
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
// when user delete, also delete all the workflow versions
|
// when user delete, also delete all the workflow versions
|
||||||
@ -238,7 +239,7 @@ export const snapshotType = z.object({
|
|||||||
z.object({
|
z.object({
|
||||||
hash: z.string(),
|
hash: z.string(),
|
||||||
disabled: z.boolean(),
|
disabled: z.boolean(),
|
||||||
}),
|
})
|
||||||
),
|
),
|
||||||
file_custom_nodes: z.array(z.any()),
|
file_custom_nodes: z.array(z.any()),
|
||||||
});
|
});
|
||||||
@ -253,7 +254,7 @@ export const showcaseMedia = z.array(
|
|||||||
z.object({
|
z.object({
|
||||||
url: z.string(),
|
url: z.string(),
|
||||||
isCover: z.boolean().default(false),
|
isCover: z.boolean().default(false),
|
||||||
}),
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
export const showcaseMediaNullable = z
|
export const showcaseMediaNullable = z
|
||||||
@ -261,7 +262,7 @@ export const showcaseMediaNullable = z
|
|||||||
z.object({
|
z.object({
|
||||||
url: z.string(),
|
url: z.string(),
|
||||||
isCover: z.boolean().default(false),
|
isCover: z.boolean().default(false),
|
||||||
}),
|
})
|
||||||
)
|
)
|
||||||
.nullable();
|
.nullable();
|
||||||
|
|
||||||
@ -363,6 +364,89 @@ export const authRequestsTable = dbSchema.table("auth_requests", {
|
|||||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const resourceUpload = pgEnum("resource_upload", [
|
||||||
|
"started",
|
||||||
|
"success",
|
||||||
|
"failed",
|
||||||
|
]);
|
||||||
|
|
||||||
|
export const modelUploadType = pgEnum("model_upload_type", [
|
||||||
|
"civitai",
|
||||||
|
"huggingface",
|
||||||
|
"other",
|
||||||
|
]);
|
||||||
|
|
||||||
|
export const checkpointTable = dbSchema.table("checkpoints", {
|
||||||
|
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||||
|
user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
|
||||||
|
org_id: text("org_id"),
|
||||||
|
description: text("description"),
|
||||||
|
|
||||||
|
checkpoint_volume_id: uuid("checkpoint_volume_id")
|
||||||
|
.notNull()
|
||||||
|
.references(() => checkpointVolumeTable.id, {
|
||||||
|
onDelete: "cascade",
|
||||||
|
})
|
||||||
|
.notNull(),
|
||||||
|
|
||||||
|
model_name: text("model_name"),
|
||||||
|
folder_path: text("folder_path"), // in volume
|
||||||
|
|
||||||
|
civitai_id: text("civitai_id"),
|
||||||
|
civitai_version_id: text("civitai_version_id"),
|
||||||
|
civitai_url: text("civitai_url"),
|
||||||
|
civitai_download_url: text("civitai_download_url"),
|
||||||
|
civitai_model_response: jsonb("civitai_model_response").$type<
|
||||||
|
z.infer<typeof CivitaiModelResponse>
|
||||||
|
>(),
|
||||||
|
|
||||||
|
hf_url: text("hf_url"),
|
||||||
|
s3_url: text("s3_url"),
|
||||||
|
user_url: text("client_url"),
|
||||||
|
|
||||||
|
is_public: boolean("is_public").notNull().default(false),
|
||||||
|
status: resourceUpload("status").notNull().default("started"),
|
||||||
|
upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
|
||||||
|
upload_type: modelUploadType("upload_type").notNull(),
|
||||||
|
error_log: text("error_log"),
|
||||||
|
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||||
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
|
||||||
|
id: uuid("id").primaryKey().defaultRandom().notNull(),
|
||||||
|
user_id: text("user_id").references(() => usersTable.id, {
|
||||||
|
// onDelete: "cascade",
|
||||||
|
}),
|
||||||
|
org_id: text("org_id"),
|
||||||
|
volume_name: text("volume_name").notNull(),
|
||||||
|
created_at: timestamp("created_at").defaultNow().notNull(),
|
||||||
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
|
disabled: boolean("disabled").default(false).notNull(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
|
||||||
|
user: one(usersTable, {
|
||||||
|
fields: [checkpointTable.user_id],
|
||||||
|
references: [usersTable.id],
|
||||||
|
}),
|
||||||
|
volume: one(checkpointVolumeTable, {
|
||||||
|
fields: [checkpointTable.checkpoint_volume_id],
|
||||||
|
references: [checkpointVolumeTable.id],
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
export const checkpointVolumeRelations = relations(
|
||||||
|
checkpointVolumeTable,
|
||||||
|
({ many, one }) => ({
|
||||||
|
checkpoint: many(checkpointTable),
|
||||||
|
user: one(usersTable, {
|
||||||
|
fields: [checkpointVolumeTable.user_id],
|
||||||
|
references: [usersTable.id],
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
export const subscriptionPlan = pgEnum("subscription_plan", [
|
export const subscriptionPlan = pgEnum("subscription_plan", [
|
||||||
"basic",
|
"basic",
|
||||||
"pro",
|
"pro",
|
||||||
@ -389,9 +473,26 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
|
|||||||
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
updated_at: timestamp("updated_at").defaultNow().notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const insertCivitaiCheckpointSchema = createInsertSchema(
|
||||||
|
checkpointTable,
|
||||||
|
{
|
||||||
|
civitai_url: (schema) =>
|
||||||
|
schema.civitai_url
|
||||||
|
.trim()
|
||||||
|
.url({ message: "URL required" })
|
||||||
|
.includes("civitai.com/models", {
|
||||||
|
message: "civitai.com/models link required",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
export type UserType = InferSelectModel<typeof usersTable>;
|
export type UserType = InferSelectModel<typeof usersTable>;
|
||||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
export type WorkflowType = InferSelectModel<typeof workflowTable>;
|
||||||
export type MachineType = InferSelectModel<typeof machinesTable>;
|
export type MachineType = InferSelectModel<typeof machinesTable>;
|
||||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
|
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
|
||||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
|
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
|
||||||
|
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
|
||||||
|
export type CheckpointVolumeType = InferSelectModel<
|
||||||
|
typeof checkpointVolumeTable
|
||||||
|
>;
|
||||||
export type UserUsageType = InferSelectModel<typeof userUsageTable>;
|
export type UserUsageType = InferSelectModel<typeof userUsageTable>;
|
||||||
|
5
web/src/server/addCheckpointSchema.tsx
Normal file
5
web/src/server/addCheckpointSchema.tsx
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import { insertCivitaiCheckpointSchema } from "@/db/schema";
|
||||||
|
|
||||||
|
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
|
||||||
|
civitai_url: true,
|
||||||
|
});
|
271
web/src/server/curdCheckpoint.ts
Normal file
271
web/src/server/curdCheckpoint.ts
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
"use server";
|
||||||
|
|
||||||
|
import { auth } from "@clerk/nextjs";
|
||||||
|
import {
|
||||||
|
checkpointTable,
|
||||||
|
CheckpointType,
|
||||||
|
checkpointVolumeTable,
|
||||||
|
CheckpointVolumeType,
|
||||||
|
} from "@/db/schema";
|
||||||
|
import { withServerPromise } from "./withServerPromise";
|
||||||
|
import { db } from "@/db/db";
|
||||||
|
import type { z } from "zod";
|
||||||
|
import { headers } from "next/headers";
|
||||||
|
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
|
||||||
|
import { and, eq, isNull } from "drizzle-orm";
|
||||||
|
import { CivitaiModelResponse } from "@/types/civitai";
|
||||||
|
|
||||||
|
export async function getCheckpoints() {
|
||||||
|
const { userId, orgId } = auth();
|
||||||
|
if (!userId) throw new Error("No user id");
|
||||||
|
const checkpoints = await db
|
||||||
|
.select()
|
||||||
|
.from(checkpointTable)
|
||||||
|
.where(
|
||||||
|
orgId
|
||||||
|
? eq(checkpointTable.org_id, orgId)
|
||||||
|
// make sure org_id is null
|
||||||
|
: and(
|
||||||
|
eq(checkpointTable.user_id, userId),
|
||||||
|
isNull(checkpointTable.org_id),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return checkpoints;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getCheckpointById(id: string) {
|
||||||
|
const { userId, orgId } = auth();
|
||||||
|
if (!userId) throw new Error("No user id");
|
||||||
|
const checkpoint = await db
|
||||||
|
.select()
|
||||||
|
.from(checkpointTable)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
orgId ? eq(checkpointTable.org_id, orgId) : and(
|
||||||
|
eq(checkpointTable.user_id, userId),
|
||||||
|
isNull(checkpointTable.org_id),
|
||||||
|
),
|
||||||
|
eq(checkpointTable.id, id),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return checkpoint[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getCheckpointVolumes() {
|
||||||
|
const { userId, orgId } = auth();
|
||||||
|
if (!userId) throw new Error("No user id");
|
||||||
|
const volume = await db
|
||||||
|
.select()
|
||||||
|
.from(checkpointVolumeTable)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
orgId
|
||||||
|
? eq(checkpointVolumeTable.org_id, orgId)
|
||||||
|
// make sure org_id is null
|
||||||
|
: and(
|
||||||
|
eq(checkpointVolumeTable.user_id, userId),
|
||||||
|
isNull(checkpointVolumeTable.org_id),
|
||||||
|
),
|
||||||
|
eq(checkpointVolumeTable.disabled, false),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return volume;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function retrieveCheckpointVolumes() {
|
||||||
|
let volumes = await getCheckpointVolumes();
|
||||||
|
if (volumes.length === 0) {
|
||||||
|
// create volume if not already created
|
||||||
|
volumes = await addCheckpointVolume();
|
||||||
|
}
|
||||||
|
return volumes;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function addCheckpointVolume() {
|
||||||
|
const { userId, orgId } = auth();
|
||||||
|
if (!userId) throw new Error("No user id");
|
||||||
|
|
||||||
|
// Insert the new checkpointVolume into the checkpointVolumeTable
|
||||||
|
const insertedVolume = await db
|
||||||
|
.insert(checkpointVolumeTable)
|
||||||
|
.values({
|
||||||
|
user_id: userId,
|
||||||
|
org_id: orgId,
|
||||||
|
volume_name: `checkpoints_${userId}`,
|
||||||
|
// created_at and updated_at will be set to current timestamp by default
|
||||||
|
disabled: false, // Default value
|
||||||
|
})
|
||||||
|
.returning(); // Returns the inserted row
|
||||||
|
return insertedVolume;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getUrl(civitai_url: string) {
|
||||||
|
// expect to be a URL to be https://civitai.com/models/36520
|
||||||
|
// possiblity with slugged name and query-param modelVersionId
|
||||||
|
const baseUrl = "https://civitai.com/api/v1/models/";
|
||||||
|
const url = new URL(civitai_url);
|
||||||
|
const pathSegments = url.pathname.split("/");
|
||||||
|
const modelId = pathSegments[pathSegments.indexOf("models") + 1];
|
||||||
|
const modelVersionId = url.searchParams.get("modelVersionId");
|
||||||
|
|
||||||
|
return { url: baseUrl + modelId, modelVersionId };
|
||||||
|
}
|
||||||
|
|
||||||
|
export const addCivitaiCheckpoint = withServerPromise(
|
||||||
|
async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
|
||||||
|
const { userId, orgId } = auth();
|
||||||
|
|
||||||
|
if (!data.civitai_url) return { error: "no civitai_url" };
|
||||||
|
if (!userId) return { error: "No user id" };
|
||||||
|
|
||||||
|
const { url, modelVersionId } = getUrl(data?.civitai_url);
|
||||||
|
const civitaiModelRes = await fetch(url)
|
||||||
|
.then((x) => x.json())
|
||||||
|
.then((a) => {
|
||||||
|
console.log(a);
|
||||||
|
return CivitaiModelResponse.parse(a);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (civitaiModelRes?.modelVersions?.length === 0) {
|
||||||
|
return; // no versions to download
|
||||||
|
}
|
||||||
|
|
||||||
|
let selectedModelVersion;
|
||||||
|
let selectedModelVersionId: string | null = modelVersionId;
|
||||||
|
if (!selectedModelVersionId) {
|
||||||
|
selectedModelVersion = civitaiModelRes.modelVersions[0];
|
||||||
|
selectedModelVersionId = civitaiModelRes.modelVersions[0].id.toString();
|
||||||
|
} else {
|
||||||
|
selectedModelVersion = civitaiModelRes.modelVersions.find((version) =>
|
||||||
|
version.id.toString() === selectedModelVersionId
|
||||||
|
);
|
||||||
|
if (!selectedModelVersion) {
|
||||||
|
return; // version id is wrong
|
||||||
|
}
|
||||||
|
selectedModelVersionId = selectedModelVersion?.id.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
const checkpointVolumes = await getCheckpointVolumes();
|
||||||
|
let cVolume;
|
||||||
|
if (checkpointVolumes.length === 0) {
|
||||||
|
const volume = await addCheckpointVolume();
|
||||||
|
cVolume = volume[0];
|
||||||
|
} else {
|
||||||
|
cVolume = checkpointVolumes[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
const a = await db
|
||||||
|
.insert(checkpointTable)
|
||||||
|
.values({
|
||||||
|
user_id: userId,
|
||||||
|
org_id: orgId,
|
||||||
|
upload_type: "civitai",
|
||||||
|
model_name: selectedModelVersion.files[0].name,
|
||||||
|
civitai_id: civitaiModelRes.id.toString(),
|
||||||
|
civitai_version_id: selectedModelVersionId,
|
||||||
|
civitai_url: data.civitai_url,
|
||||||
|
civitai_download_url: selectedModelVersion.files[0].downloadUrl,
|
||||||
|
civitai_model_response: civitaiModelRes,
|
||||||
|
checkpoint_volume_id: cVolume.id,
|
||||||
|
updated_at: new Date(),
|
||||||
|
})
|
||||||
|
.returning();
|
||||||
|
|
||||||
|
const b = a[0];
|
||||||
|
|
||||||
|
await uploadCheckpoint(data, b, cVolume);
|
||||||
|
// redirect(`/checkpoints/${b.id}`);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// export const redownloadCheckpoint = withServerPromise(
|
||||||
|
// async (data: CheckpointItemList) => {
|
||||||
|
// const { userId } = auth();
|
||||||
|
// if (!userId) return { error: "No user id" };
|
||||||
|
//
|
||||||
|
// const checkpointVolumes = await getCheckpointVolumes();
|
||||||
|
// let cVolume;
|
||||||
|
// if (checkpointVolumes.length === 0) {
|
||||||
|
// const volume = await addCheckpointVolume();
|
||||||
|
// cVolume = volume[0];
|
||||||
|
// } else {
|
||||||
|
// cVolume = checkpointVolumes[0];
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// console.log("data");
|
||||||
|
// console.log(data);
|
||||||
|
//
|
||||||
|
// const a = await db
|
||||||
|
// .update(checkpointTable)
|
||||||
|
// .set({
|
||||||
|
// // status: "started",
|
||||||
|
// // updated_at: new Date(),
|
||||||
|
// })
|
||||||
|
// .returning();
|
||||||
|
//
|
||||||
|
// const b = a[0];
|
||||||
|
//
|
||||||
|
// console.log("b");
|
||||||
|
// console.log(b);
|
||||||
|
//
|
||||||
|
// await uploadCheckpoint(data, b, cVolume);
|
||||||
|
// // redirect(`/checkpoints/${b.id}`);
|
||||||
|
// },
|
||||||
|
// );
|
||||||
|
|
||||||
|
async function uploadCheckpoint(
|
||||||
|
data: z.infer<typeof addCivitaiCheckpointSchema>,
|
||||||
|
c: CheckpointType,
|
||||||
|
v: CheckpointVolumeType,
|
||||||
|
) {
|
||||||
|
const headersList = headers();
|
||||||
|
|
||||||
|
const domain = headersList.get("x-forwarded-host") || "";
|
||||||
|
const protocol = headersList.get("x-forwarded-proto") || "";
|
||||||
|
|
||||||
|
if (domain === "") {
|
||||||
|
throw new Error("No domain");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call remote builder
|
||||||
|
const result = await fetch(
|
||||||
|
`${process.env.MODAL_BUILDER_URL!}/upload-volume`,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
download_url: c.civitai_download_url,
|
||||||
|
volume_name: v.volume_name,
|
||||||
|
volume_id: v.id,
|
||||||
|
checkpoint_id: c.id,
|
||||||
|
callback_url: `${protocol}://${domain}/api/volume-upload`,
|
||||||
|
upload_type: "checkpoint"
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!result.ok) {
|
||||||
|
const error_log = await result.text();
|
||||||
|
await db
|
||||||
|
.update(checkpointTable)
|
||||||
|
.set({
|
||||||
|
...data,
|
||||||
|
status: "failed",
|
||||||
|
error_log: error_log,
|
||||||
|
})
|
||||||
|
.where(eq(checkpointTable.id, c.id));
|
||||||
|
throw new Error(`Error: ${result.statusText} ${error_log}`);
|
||||||
|
} else {
|
||||||
|
// setting the build machine id
|
||||||
|
const json = await result.json();
|
||||||
|
await db
|
||||||
|
.update(checkpointTable)
|
||||||
|
.set({
|
||||||
|
...data,
|
||||||
|
upload_machine_id: json.build_machine_instance_id,
|
||||||
|
})
|
||||||
|
.where(eq(checkpointTable.id, c.id));
|
||||||
|
}
|
||||||
|
}
|
@ -15,6 +15,7 @@ import { headers } from "next/headers";
|
|||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
import "server-only";
|
import "server-only";
|
||||||
import type { z } from "zod";
|
import type { z } from "zod";
|
||||||
|
import { retrieveCheckpointVolumes } from "./curdCheckpoint";
|
||||||
|
|
||||||
export async function getMachines() {
|
export async function getMachines() {
|
||||||
const { userId, orgId } = auth();
|
const { userId, orgId } = auth();
|
||||||
@ -189,6 +190,7 @@ async function _buildMachine(
|
|||||||
throw new Error("No domain");
|
throw new Error("No domain");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const volumes = await retrieveCheckpointVolumes();
|
||||||
// Call remote builder
|
// Call remote builder
|
||||||
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
|
const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@ -202,6 +204,7 @@ async function _buildMachine(
|
|||||||
callback_url: `${protocol}://${domain}/api/machine-built`,
|
callback_url: `${protocol}://${domain}/api/machine-built`,
|
||||||
models: data.models, //JSON.parse(data.models as string),
|
models: data.models, //JSON.parse(data.models as string),
|
||||||
gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
|
gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
|
||||||
|
checkpoint_volume_name: volumes[0].volume_name,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
41
web/src/server/getAllUserCheckpoints.tsx
Normal file
41
web/src/server/getAllUserCheckpoints.tsx
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import { db } from "@/db/db";
|
||||||
|
import {
|
||||||
|
checkpointTable,
|
||||||
|
} from "@/db/schema";
|
||||||
|
import { auth } from "@clerk/nextjs";
|
||||||
|
import { and, desc, eq, isNull } from "drizzle-orm";
|
||||||
|
|
||||||
|
export async function getAllUserCheckpoints() {
|
||||||
|
const { userId, orgId } = await auth();
|
||||||
|
|
||||||
|
if (!userId) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const checkpoints = await db.query.checkpointTable.findMany({
|
||||||
|
with: {
|
||||||
|
user: {
|
||||||
|
columns: {
|
||||||
|
name: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
columns: {
|
||||||
|
id: true,
|
||||||
|
updated_at: true,
|
||||||
|
model_name: true,
|
||||||
|
civitai_url: true,
|
||||||
|
civitai_model_response: true,
|
||||||
|
is_public: true,
|
||||||
|
upload_type: true,
|
||||||
|
status: true,
|
||||||
|
},
|
||||||
|
orderBy: desc(checkpointTable.updated_at),
|
||||||
|
where:
|
||||||
|
orgId != undefined
|
||||||
|
? eq(checkpointTable.org_id, orgId)
|
||||||
|
: and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
|
||||||
|
});
|
||||||
|
|
||||||
|
return checkpoints;
|
||||||
|
}
|
129
web/src/types/civitai.ts
Normal file
129
web/src/types/civitai.ts
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
|
||||||
|
|
||||||
|
export const creatorSchema = z.object({
|
||||||
|
username: z.string().nullish(),
|
||||||
|
image: z.string().url().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const fileMetadataSchema = z.object({
|
||||||
|
fp: z.string().nullish(),
|
||||||
|
size: z.string().nullish(),
|
||||||
|
format: z.string().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const fileSchema = z.object({
|
||||||
|
id: z.number(),
|
||||||
|
sizeKB: z.number().nullish(),
|
||||||
|
name: z.string(),
|
||||||
|
type: z.string().nullish(),
|
||||||
|
metadata: fileMetadataSchema.nullish(),
|
||||||
|
pickleScanResult: z.string().nullish(),
|
||||||
|
pickleScanMessage: z.string().nullable(),
|
||||||
|
virusScanResult: z.string().nullish(),
|
||||||
|
virusScanMessage: z.string().nullable(),
|
||||||
|
scannedAt: z.string().nullish(),
|
||||||
|
hashes: z.record(z.string()).nullish(),
|
||||||
|
downloadUrl: z.string().url(),
|
||||||
|
primary: z.boolean().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const imageMetadataSchema = z.object({
|
||||||
|
hash: z.string(),
|
||||||
|
width: z.number(),
|
||||||
|
height: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const imageMetaSchema = z.object({
|
||||||
|
ENSD: z.string().nullish(),
|
||||||
|
Size: z.string().nullish(),
|
||||||
|
seed: z.number().nullish(),
|
||||||
|
Model: z.string().nullish(),
|
||||||
|
steps: z.number().nullish(),
|
||||||
|
hashes: z.record(z.string()).nullish(),
|
||||||
|
prompt: z.string().nullish(),
|
||||||
|
sampler: z.string().nullish(),
|
||||||
|
cfgScale: z.number().nullish(),
|
||||||
|
ClipSkip: z.number().nullish(),
|
||||||
|
resources: z.array(
|
||||||
|
z.object({
|
||||||
|
hash: z.string().nullish(),
|
||||||
|
name: z.string(),
|
||||||
|
type: z.string(),
|
||||||
|
weight: z.number().nullish(),
|
||||||
|
}),
|
||||||
|
).nullish(),
|
||||||
|
ModelHash: z.string().nullish(),
|
||||||
|
HiresSteps: z.string().nullish(),
|
||||||
|
HiresUpscale: z.string().nullish(),
|
||||||
|
HiresUpscaler: z.string().nullish(),
|
||||||
|
negativePrompt: z.string(),
|
||||||
|
DenoisingStrength: z.number().nullish(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// NOTE: this definition is all over the place
|
||||||
|
// export const imageSchema = z.object({
|
||||||
|
// url: z.string().url().nullish(),
|
||||||
|
// nsfw: z.enum(["None", "Soft", "Mature"]).nullish(),
|
||||||
|
// width: z.number().nullish(),
|
||||||
|
// height: z.number().nullish(),
|
||||||
|
// hash: z.string().nullish(),
|
||||||
|
// type: z.string().nullish(),
|
||||||
|
// metadata: imageMetadataSchema.nullish(),
|
||||||
|
// meta: imageMetaSchema.nullish(),
|
||||||
|
// });
|
||||||
|
|
||||||
|
export const modelVersionSchema = z.object({
|
||||||
|
id: z.number(),
|
||||||
|
modelId: z.number(),
|
||||||
|
name: z.string(),
|
||||||
|
createdAt: z.string().nullish(),
|
||||||
|
updatedAt: z.string().nullish(),
|
||||||
|
// status: z.enum(["Published", "Unpublished"]).nullish(),
|
||||||
|
status: z.string().nullish(),
|
||||||
|
publishedAt: z.string().nullish(),
|
||||||
|
trainedWords: z.array(z.string()).nullable(),
|
||||||
|
trainingStatus: z.string().nullable(),
|
||||||
|
trainingDetails: z.string().nullable(),
|
||||||
|
baseModel: z.string().nullish(),
|
||||||
|
baseModelType: z.string().nullish(),
|
||||||
|
earlyAccessTimeFrame: z.number().nullish(),
|
||||||
|
description: z.string().nullable(),
|
||||||
|
vaeId: z.number().nullable(),
|
||||||
|
stats: z.object({
|
||||||
|
downloadCount: z.number(),
|
||||||
|
ratingCount: z.number(),
|
||||||
|
rating: z.number(),
|
||||||
|
}).nullish(),
|
||||||
|
files: z.array(fileSchema),
|
||||||
|
images: z.array(z.any()).nullish(),
|
||||||
|
downloadUrl: z.string().url(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const statsSchema = z.object({
|
||||||
|
downloadCount: z.number(),
|
||||||
|
favoriteCount: z.number(),
|
||||||
|
commentCount: z.number(),
|
||||||
|
ratingCount: z.number(),
|
||||||
|
rating: z.number(),
|
||||||
|
tippedAmountCount: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const CivitaiModelResponse = z.object({
|
||||||
|
id: z.number(),
|
||||||
|
name: z.string().nullish(),
|
||||||
|
description: z.string().nullish(),
|
||||||
|
// type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know
|
||||||
|
type: z.string(),
|
||||||
|
poi: z.boolean().nullish(),
|
||||||
|
nsfw: z.boolean().nullish(),
|
||||||
|
allowNoCredit: z.boolean().nullish(),
|
||||||
|
allowCommercialUse: z.string().nullish(),
|
||||||
|
allowDerivatives: z.boolean().nullish(),
|
||||||
|
allowDifferentLicense: z.boolean().nullish(),
|
||||||
|
stats: statsSchema.nullish(),
|
||||||
|
creator: creatorSchema.nullish(),
|
||||||
|
tags: z.array(z.string()).nullish(),
|
||||||
|
modelVersions: z.array(modelVersionSchema),
|
||||||
|
});
|
Loading…
x
Reference in New Issue
Block a user