Merge branch 'nickkao/checkpoint-volume'
# Conflicts: # web/bun.lockb
This commit is contained in:
		
						commit
						6f0499c657
					
				@ -8,6 +8,7 @@ from enum import Enum
 | 
			
		||||
import json
 | 
			
		||||
import subprocess
 | 
			
		||||
import time
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
import asyncio
 | 
			
		||||
import threading
 | 
			
		||||
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
 | 
			
		||||
from starlette.middleware.base import BaseHTTPMiddleware
 | 
			
		||||
from starlette.types import ASGIApp, Scope, Receive, Send
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
 | 
			
		||||
# executor = ThreadPoolExecutor(max_workers=5)
 | 
			
		||||
@ -45,6 +47,7 @@ machine_id_websocket_dict = {}
 | 
			
		||||
machine_id_status = {}
 | 
			
		||||
 | 
			
		||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
			
		||||
civitai_api_key = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
 | 
			
		||||
@ -174,6 +177,7 @@ class Item(BaseModel):
 | 
			
		||||
    snapshot: Snapshot
 | 
			
		||||
    models: List[Model]
 | 
			
		||||
    callback_url: str
 | 
			
		||||
    checkpoint_volume_name: str
 | 
			
		||||
    gpu: GPUType = Field(default=GPUType.T4)
 | 
			
		||||
 | 
			
		||||
    @field_validator('gpu')
 | 
			
		||||
@ -223,6 +227,103 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
			
		||||
 | 
			
		||||
#     return {"Hello": "World"}
 | 
			
		||||
 | 
			
		||||
class UploadType(str, Enum):
 | 
			
		||||
    checkpoint = "checkpoint"
 | 
			
		||||
 | 
			
		||||
class UploadBody(BaseModel):
 | 
			
		||||
    download_url: str
 | 
			
		||||
    volume_name: str
 | 
			
		||||
    volume_id: str
 | 
			
		||||
    checkpoint_id: str
 | 
			
		||||
    upload_type: UploadType
 | 
			
		||||
    callback_url: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
UPLOAD_TYPE_DIR_MAP = {
 | 
			
		||||
    UploadType.checkpoint: "checkpoints"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@app.post("/upload-volume")
 | 
			
		||||
async def upload_checkpoint(body: UploadBody):
 | 
			
		||||
    global last_activity_time
 | 
			
		||||
    last_activity_time = time.time()
 | 
			
		||||
    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
			
		||||
 | 
			
		||||
    asyncio.create_task(upload_logic(body))
 | 
			
		||||
 | 
			
		||||
    # check that this
 | 
			
		||||
    return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
 | 
			
		||||
 | 
			
		||||
async def upload_logic(body: UploadBody):
 | 
			
		||||
    folder_path = f"/app/builds/{body.volume_id}"
 | 
			
		||||
 | 
			
		||||
    cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
 | 
			
		||||
    await cp_process.wait()
 | 
			
		||||
 | 
			
		||||
    upload_path = UPLOAD_TYPE_DIR_MAP[body.upload_type]
 | 
			
		||||
    config = {
 | 
			
		||||
        "volume_names": {
 | 
			
		||||
            body.volume_name: {"download_url": body.download_url, "folder_path": upload_path}
 | 
			
		||||
        }, 
 | 
			
		||||
        "volume_paths": {
 | 
			
		||||
            body.volume_name: f'/volumes/{uuid4()}'
 | 
			
		||||
        }, 
 | 
			
		||||
        "callback_url": body.callback_url,
 | 
			
		||||
        "callback_body": {
 | 
			
		||||
            "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
            "volume_id": body.volume_id,
 | 
			
		||||
            "folder_path": upload_path,
 | 
			
		||||
        },
 | 
			
		||||
        "civitai_api_key": os.environ.get('CIVITAI_API_KEY')
 | 
			
		||||
    }
 | 
			
		||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
			
		||||
        f.write("config = " + json.dumps(config))
 | 
			
		||||
 | 
			
		||||
    process = await asyncio.subprocess.create_subprocess_shell(
 | 
			
		||||
        f"modal run app.py",
 | 
			
		||||
        # stdout=asyncio.subprocess.PIPE,
 | 
			
		||||
        # stderr=asyncio.subprocess.PIPE,
 | 
			
		||||
        cwd=folder_path,
 | 
			
		||||
        env={**os.environ, "COLUMNS": "10000"}
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    # error_logs = []
 | 
			
		||||
    # async def read_stream(stream):
 | 
			
		||||
    #     while True:
 | 
			
		||||
    #         line = await stream.readline()
 | 
			
		||||
    #         if line:
 | 
			
		||||
    #             l = line.decode('utf-8').strip()
 | 
			
		||||
    #             error_logs.append(l)
 | 
			
		||||
    #             logger.error(l)
 | 
			
		||||
    #             error_logs.append({
 | 
			
		||||
    #                 "logs": l,
 | 
			
		||||
    #                 "timestamp": time.time()
 | 
			
		||||
    #             })
 | 
			
		||||
    #         else:
 | 
			
		||||
    #             break
 | 
			
		||||
 | 
			
		||||
    # stderr_read_task = asyncio.create_task(read_stream(process.stderr))
 | 
			
		||||
    #
 | 
			
		||||
    # await asyncio.wait([stderr_read_task])
 | 
			
		||||
    # await process.wait()
 | 
			
		||||
 | 
			
		||||
    # if process.returncode != 0:
 | 
			
		||||
    #     error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
 | 
			
		||||
    #     # Error handling: send POST request to callback URL with error details
 | 
			
		||||
    #     requests.post(body.callback_url, json={
 | 
			
		||||
    #         "volume_id": body.volume_id, 
 | 
			
		||||
    #         "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
    #         "folder_path": upload_path,
 | 
			
		||||
    #         "error_logs": json.dumps(error_logs),
 | 
			
		||||
    #         "status": "failed"
 | 
			
		||||
    #     })
 | 
			
		||||
    #
 | 
			
		||||
    # requests.post(body.callback_url, json={
 | 
			
		||||
    #     "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
    #     "volume_id": body.volume_id,
 | 
			
		||||
    #     "folder_path": upload_path,
 | 
			
		||||
    #     "status": "success"
 | 
			
		||||
    # })
 | 
			
		||||
 | 
			
		||||
@app.post("/create")
 | 
			
		||||
async def create_machine(item: Item):
 | 
			
		||||
@ -312,7 +413,9 @@ async def build_logic(item: Item):
 | 
			
		||||
    config = {
 | 
			
		||||
        "name": item.name,
 | 
			
		||||
        "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
 | 
			
		||||
        "gpu": item.gpu
 | 
			
		||||
        "gpu": item.gpu,
 | 
			
		||||
        "public_checkpoint_volume": "model-store",
 | 
			
		||||
        "private_checkpoint_volume": item.checkpoint_volume_name
 | 
			
		||||
    }
 | 
			
		||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
			
		||||
        f.write("config = " + json.dumps(config))
 | 
			
		||||
 | 
			
		||||
@ -1,12 +1,13 @@
 | 
			
		||||
from config import config
 | 
			
		||||
import modal
 | 
			
		||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app
 | 
			
		||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app 
 | 
			
		||||
import json
 | 
			
		||||
import urllib.request
 | 
			
		||||
import urllib.parse
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from fastapi import FastAPI, Request
 | 
			
		||||
from fastapi.responses import HTMLResponse
 | 
			
		||||
from volume_setup import volumes
 | 
			
		||||
 | 
			
		||||
# deploy_test = False
 | 
			
		||||
 | 
			
		||||
@ -27,8 +28,8 @@ deploy_test = config["deploy_test"] == "True"
 | 
			
		||||
web_app = FastAPI()
 | 
			
		||||
print(config)
 | 
			
		||||
print("deploy_test ", deploy_test)
 | 
			
		||||
print('volumes', volumes)
 | 
			
		||||
stub = Stub(name=config["name"])
 | 
			
		||||
# print(stub.app_id)
 | 
			
		||||
 | 
			
		||||
if not deploy_test:
 | 
			
		||||
    # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
 | 
			
		||||
@ -52,11 +53,13 @@ if not deploy_test:
 | 
			
		||||
            "cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
 | 
			
		||||
            "cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
 | 
			
		||||
        )
 | 
			
		||||
        .run_commands(f"cat /comfyui/server.py")
 | 
			
		||||
        .run_commands(f"ls /comfyui/app")
 | 
			
		||||
        # .run_commands(
 | 
			
		||||
        #     # Install comfy deploy
 | 
			
		||||
        #     "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
 | 
			
		||||
        # )
 | 
			
		||||
        # .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
 | 
			
		||||
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
 | 
			
		||||
        .run_commands("chmod +x /start.sh")
 | 
			
		||||
@ -153,8 +156,9 @@ image = Image.debian_slim()
 | 
			
		||||
 | 
			
		||||
target_image = image if deploy_test else dockerfile_image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.function(image=target_image, gpu=config["gpu"])
 | 
			
		||||
@stub.function(image=target_image, gpu=config["gpu"]
 | 
			
		||||
   ,volumes=volumes 
 | 
			
		||||
)
 | 
			
		||||
def run(input: Input):
 | 
			
		||||
    import subprocess
 | 
			
		||||
    import time
 | 
			
		||||
@ -163,6 +167,7 @@ def run(input: Input):
 | 
			
		||||
 | 
			
		||||
    command = ["python", "main.py",
 | 
			
		||||
               "--disable-auto-launch", "--disable-metadata"]
 | 
			
		||||
 | 
			
		||||
    server_process = subprocess.Popen(command, cwd="/comfyui")
 | 
			
		||||
 | 
			
		||||
    check_server(
 | 
			
		||||
@ -235,7 +240,9 @@ async def bar(request_input: RequestInput):
 | 
			
		||||
    # pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.function(image=image)
 | 
			
		||||
@stub.function(image=image
 | 
			
		||||
   ,volumes=volumes
 | 
			
		||||
)
 | 
			
		||||
@asgi_app()
 | 
			
		||||
def comfyui_api():
 | 
			
		||||
    return web_app
 | 
			
		||||
@ -285,6 +292,7 @@ def spawn_comfyui_in_background():
 | 
			
		||||
    # to be on a single container.
 | 
			
		||||
    concurrency_limit=1,
 | 
			
		||||
    timeout=10 * 60,
 | 
			
		||||
    volumes=volumes,
 | 
			
		||||
)
 | 
			
		||||
@asgi_app()
 | 
			
		||||
def comfyui_app():
 | 
			
		||||
@ -303,4 +311,4 @@ def comfyui_app():
 | 
			
		||||
        },
 | 
			
		||||
    )()
 | 
			
		||||
 | 
			
		||||
    return make_simple_proxy_app(ProxyContext(config))
 | 
			
		||||
    return make_simple_proxy_app(ProxyContext(config))
 | 
			
		||||
 | 
			
		||||
@ -1 +1,7 @@
 | 
			
		||||
config = {"name": "my-app", "deploy_test": "True", "gpu": "T4"}
 | 
			
		||||
config = {
 | 
			
		||||
    "name": "my-app",
 | 
			
		||||
    "deploy_test": "True",
 | 
			
		||||
    "gpu": "T4", 
 | 
			
		||||
    "public_checkpoint_volume": "model-store",
 | 
			
		||||
    "private_checkpoint_volume": "private-model-store"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,15 @@
 | 
			
		||||
comfyui:
 | 
			
		||||
    base_path: /runpod-volume/ComfyUI/
 | 
			
		||||
    checkpoints: models/checkpoints/
 | 
			
		||||
    clip: models/clip/
 | 
			
		||||
    clip_vision: models/clip_vision/
 | 
			
		||||
    configs: models/configs/
 | 
			
		||||
    controlnet: models/controlnet/
 | 
			
		||||
    embeddings: models/embeddings/
 | 
			
		||||
    loras: models/loras/
 | 
			
		||||
    upscale_models: models/upscale_models/
 | 
			
		||||
    vae: models/vae/
 | 
			
		||||
public:
 | 
			
		||||
  base_path: /public_models/
 | 
			
		||||
  checkpoints: checkpoints
 | 
			
		||||
  clip: clip
 | 
			
		||||
  clip_vision: clip_vision
 | 
			
		||||
  configs: configs
 | 
			
		||||
  controlnet: controlnet
 | 
			
		||||
  embeddings: embeddings
 | 
			
		||||
  loras: loras
 | 
			
		||||
  upscale_models: upscale_models
 | 
			
		||||
  vae: vae
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  base_path: /private_models/
 | 
			
		||||
  checkpoints: checkpoints
 | 
			
		||||
 | 
			
		||||
@ -54,4 +54,4 @@ for model in models:
 | 
			
		||||
 | 
			
		||||
# Close the server
 | 
			
		||||
server_process.terminate()
 | 
			
		||||
print("Finished installing dependencies.")
 | 
			
		||||
print("Finished installing dependencies.")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								builder/modal-builder/src/template/volume_setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								builder/modal-builder/src/template/volume_setup.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
import modal
 | 
			
		||||
from config import config
 | 
			
		||||
 | 
			
		||||
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
 | 
			
		||||
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
 | 
			
		||||
 | 
			
		||||
PUBLIC_BASEMODEL_DIR = "/public_models"
 | 
			
		||||
PRIVATE_BASEMODEL_DIR = "/private_models"
 | 
			
		||||
volumes = {PUBLIC_BASEMODEL_DIR: public_model_volume, PRIVATE_BASEMODEL_DIR: private_volume}
 | 
			
		||||
							
								
								
									
										74
									
								
								builder/modal-builder/src/volume-builder/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								builder/modal-builder/src/volume-builder/app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,74 @@
 | 
			
		||||
import modal
 | 
			
		||||
from config import config
 | 
			
		||||
import os
 | 
			
		||||
import subprocess
 | 
			
		||||
from pprint import pprint
 | 
			
		||||
 | 
			
		||||
stub = modal.Stub()
 | 
			
		||||
 | 
			
		||||
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
 | 
			
		||||
def is_valid_name(name: str) -> bool:
 | 
			
		||||
    allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
 | 
			
		||||
    return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
 | 
			
		||||
 | 
			
		||||
def create_volumes(volume_names, paths):
 | 
			
		||||
    path_to_vol = {}
 | 
			
		||||
    for volume_name in volume_names.keys():
 | 
			
		||||
        if not is_valid_name(volume_name):
 | 
			
		||||
            pass
 | 
			
		||||
        modal_volume = modal.Volume.persisted(volume_name)
 | 
			
		||||
        path_to_vol[paths[volume_name]] = modal_volume
 | 
			
		||||
 
 | 
			
		||||
    return path_to_vol
 | 
			
		||||
 | 
			
		||||
vol_name_to_links = config["volume_names"]
 | 
			
		||||
vol_name_to_path = config["volume_paths"]
 | 
			
		||||
callback_url = config["callback_url"]
 | 
			
		||||
callback_body = config["callback_body"]
 | 
			
		||||
civitai_key = config["civitai_api_key"]
 | 
			
		||||
 | 
			
		||||
volumes = create_volumes(vol_name_to_links, vol_name_to_path)
 | 
			
		||||
image = ( 
 | 
			
		||||
   modal.Image.debian_slim().apt_install("wget").pip_install("requests")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# download config { "download_url": "", "folder_path": ""}
 | 
			
		||||
timeout=5000
 | 
			
		||||
@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
 | 
			
		||||
def download_model(volume_name, download_config):
 | 
			
		||||
    import requests
 | 
			
		||||
    download_url = download_config["download_url"]
 | 
			
		||||
    folder_path = download_config["folder_path"]
 | 
			
		||||
 | 
			
		||||
    volume_base_path = vol_name_to_path[volume_name]
 | 
			
		||||
    model_store_path = os.path.join(volume_base_path, folder_path)
 | 
			
		||||
    modified_download_url = download_url + ("&" if "?" in download_url else "?") + "token=" + civitai_key
 | 
			
		||||
    print('downloading', modified_download_url)
 | 
			
		||||
 | 
			
		||||
    subprocess.run(["wget", modified_download_url , "--content-disposition", "-P", model_store_path])
 | 
			
		||||
    subprocess.run(["ls", "-la", volume_base_path])
 | 
			
		||||
    subprocess.run(["ls", "-la", model_store_path])
 | 
			
		||||
    volumes[volume_base_path].commit()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    status =  {"status": "success"}
 | 
			
		||||
    requests.post(callback_url, json={**status, **callback_body})
 | 
			
		||||
    print(f"finished! sending to {callback_url}")
 | 
			
		||||
    pprint({**status, **callback_body})
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
def simple_download():
 | 
			
		||||
    import requests
 | 
			
		||||
    try:
 | 
			
		||||
        list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
 | 
			
		||||
    except modal.exception.FunctionTimeoutError as e:
 | 
			
		||||
        status =  {"status": "failed", "error_logs": f"{str(e)}", "timeout": timeout}
 | 
			
		||||
        requests.post(callback_url, json={**status, **callback_body})
 | 
			
		||||
        print(f"finished! sending to {callback_url}")
 | 
			
		||||
        pprint({**status, **callback_body})
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        status =  {"status": "failed", "error_logs": str(e)}
 | 
			
		||||
        requests.post(callback_url, json={**status, **callback_body})
 | 
			
		||||
        print(f"finished! sending to {callback_url}")
 | 
			
		||||
        pprint({**status, **callback_body})
 | 
			
		||||
        
 | 
			
		||||
							
								
								
									
										18
									
								
								builder/modal-builder/src/volume-builder/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								builder/modal-builder/src/volume-builder/config.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,18 @@
 | 
			
		||||
config = {
 | 
			
		||||
    "volume_names": {
 | 
			
		||||
        "test": {
 | 
			
		||||
            "download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png",
 | 
			
		||||
            "folder_path": "images"
 | 
			
		||||
        }
 | 
			
		||||
    }, 
 | 
			
		||||
    "volume_paths": {
 | 
			
		||||
        "test": "/volumes/something"
 | 
			
		||||
    },
 | 
			
		||||
    "callback_url": "",
 | 
			
		||||
    "callback_body": {
 | 
			
		||||
        "checkpoint_id": "",
 | 
			
		||||
        "volume_id": "",
 | 
			
		||||
        "folder_path": "images",
 | 
			
		||||
    }, 
 | 
			
		||||
    "civitai_api_key": "",
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										64
									
								
								web/drizzle/0042_windy_madelyne_pryor.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								web/drizzle/0042_windy_madelyne_pryor.sql
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 CREATE TYPE "model_upload_type" AS ENUM('civitai', 'huggingface', 'other');
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 CREATE TYPE "resource_upload" AS ENUM('started', 'success', 'failed');
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
 | 
			
		||||
	"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
 | 
			
		||||
	"user_id" text,
 | 
			
		||||
	"org_id" text,
 | 
			
		||||
	"description" text,
 | 
			
		||||
	"checkpoint_volume_id" uuid NOT NULL,
 | 
			
		||||
	"model_name" text,
 | 
			
		||||
	"folder_path" text,
 | 
			
		||||
	"civitai_id" text,
 | 
			
		||||
	"civitai_version_id" text,
 | 
			
		||||
	"civitai_url" text,
 | 
			
		||||
	"civitai_download_url" text,
 | 
			
		||||
	"civitai_model_response" jsonb,
 | 
			
		||||
	"hf_url" text,
 | 
			
		||||
	"s3_url" text,
 | 
			
		||||
	"client_url" text,
 | 
			
		||||
	"is_public" boolean DEFAULT false NOT NULL,
 | 
			
		||||
	"status" "resource_upload" DEFAULT 'started' NOT NULL,
 | 
			
		||||
	"upload_machine_id" text,
 | 
			
		||||
	"upload_type" "model_upload_type" NOT NULL,
 | 
			
		||||
	"error_log" text,
 | 
			
		||||
	"created_at" timestamp DEFAULT now() NOT NULL,
 | 
			
		||||
	"updated_at" timestamp DEFAULT now() NOT NULL
 | 
			
		||||
);
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
 | 
			
		||||
	"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
 | 
			
		||||
	"user_id" text,
 | 
			
		||||
	"org_id" text,
 | 
			
		||||
	"volume_name" text NOT NULL,
 | 
			
		||||
	"created_at" timestamp DEFAULT now() NOT NULL,
 | 
			
		||||
	"updated_at" timestamp DEFAULT now() NOT NULL,
 | 
			
		||||
	"disabled" boolean DEFAULT false NOT NULL
 | 
			
		||||
);
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk" FOREIGN KEY ("checkpoint_volume_id") REFERENCES "comfyui_deploy"."checkpoint_volume"("id") ON DELETE cascade ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
							
								
								
									
										1273
									
								
								web/drizzle/meta/0042_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1273
									
								
								web/drizzle/meta/0042_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -295,6 +295,13 @@
 | 
			
		||||
      "when": 1706111421524,
 | 
			
		||||
      "tag": "0041_thick_norrin_radd",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "idx": 42,
 | 
			
		||||
      "version": "5",
 | 
			
		||||
      "when": 1706164614659,
 | 
			
		||||
      "tag": "0042_windy_madelyne_pryor",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										55
									
								
								web/src/app/(app)/api/volume-upload/route.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								web/src/app/(app)/api/volume-upload/route.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,55 @@
 | 
			
		||||
import { parseDataSafe } from "../../../../lib/parseDataSafe";
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import { checkpointTable, machinesTable } from "@/db/schema";
 | 
			
		||||
import { eq } from "drizzle-orm";
 | 
			
		||||
import { NextResponse } from "next/server";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
 | 
			
		||||
const Request = z.object({
 | 
			
		||||
  volume_id: z.string(),
 | 
			
		||||
  checkpoint_id: z.string(),
 | 
			
		||||
  folder_path: z.string().optional(),
 | 
			
		||||
  status: z.enum(['success', 'failed']),
 | 
			
		||||
  error_log: z.string().optional(),
 | 
			
		||||
  timeout: z.number().optional(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export async function POST(request: Request) {
 | 
			
		||||
  const [data, error] = await parseDataSafe(Request, request);
 | 
			
		||||
  if (!data || error) return error;
 | 
			
		||||
 | 
			
		||||
  const { checkpoint_id, error_log, status, folder_path } = data;
 | 
			
		||||
  console.log( checkpoint_id, error_log, status, folder_path )
 | 
			
		||||
 | 
			
		||||
  if (status === "success") {
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        status: "success",
 | 
			
		||||
        folder_path,
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, checkpoint_id));
 | 
			
		||||
  } else {
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        status: "failed",
 | 
			
		||||
        error_log, 
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
        // status: "error",
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, checkpoint_id));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return NextResponse.json(
 | 
			
		||||
    {
 | 
			
		||||
      message: "success",
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      status: 200,
 | 
			
		||||
    }
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								web/src/app/(app)/storage/loading.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								web/src/app/(app)/storage/loading.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
"use client";
 | 
			
		||||
 | 
			
		||||
import { LoadingPageWrapper } from "@/components/LoadingWrapper";
 | 
			
		||||
import { usePathname } from "next/navigation";
 | 
			
		||||
 | 
			
		||||
export default function Loading() {
 | 
			
		||||
  const pathName = usePathname();
 | 
			
		||||
  return <LoadingPageWrapper className="h-full" tag={pathName.toLowerCase()} />;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										35
									
								
								web/src/app/(app)/storage/page.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								web/src/app/(app)/storage/page.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
import { setInitialUserData } from "../../../lib/setInitialUserData";
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import { clerkClient } from "@clerk/nextjs/server";
 | 
			
		||||
import { CheckpointList } from "@/components/CheckpointList";
 | 
			
		||||
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
			
		||||
 | 
			
		||||
export default function Page() {
 | 
			
		||||
  return <CheckpointListServer />;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async function CheckpointListServer() {
 | 
			
		||||
  const { userId } = auth();
 | 
			
		||||
 | 
			
		||||
  if (!userId) {
 | 
			
		||||
    return <div>No auth</div>;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const user = await clerkClient.users.getUser(userId);
 | 
			
		||||
 | 
			
		||||
  if (!user) {
 | 
			
		||||
    await setInitialUserData(userId);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const checkpoints = await getAllUserCheckpoints();
 | 
			
		||||
 | 
			
		||||
  if (!checkpoints) {
 | 
			
		||||
    return <div>No checkpoints found</div>;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <div className="w-full">
 | 
			
		||||
      <CheckpointList data={checkpoints} />
 | 
			
		||||
    </div>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										373
									
								
								web/src/components/CheckpointList.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										373
									
								
								web/src/components/CheckpointList.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,373 @@
 | 
			
		||||
"use client";
 | 
			
		||||
 | 
			
		||||
import { getRelativeTime } from "../lib/getRelativeTime";
 | 
			
		||||
import { Badge } from "@/components/ui/badge";
 | 
			
		||||
import { Button } from "@/components/ui/button";
 | 
			
		||||
import { Checkbox } from "@/components/ui/checkbox";
 | 
			
		||||
import { InsertModal, UpdateModal } from "./InsertModal";
 | 
			
		||||
import { Input } from "@/components/ui/input";
 | 
			
		||||
import { ScrollArea } from "@/components/ui/scroll-area";
 | 
			
		||||
import {
 | 
			
		||||
  Table,
 | 
			
		||||
  TableBody,
 | 
			
		||||
  TableCell,
 | 
			
		||||
  TableHead,
 | 
			
		||||
  TableHeader,
 | 
			
		||||
  TableRow,
 | 
			
		||||
} from "@/components/ui/table";
 | 
			
		||||
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
			
		||||
import type {
 | 
			
		||||
  ColumnDef,
 | 
			
		||||
  ColumnFiltersState,
 | 
			
		||||
  SortingState,
 | 
			
		||||
  VisibilityState,
 | 
			
		||||
} from "@tanstack/react-table";
 | 
			
		||||
import {
 | 
			
		||||
  flexRender,
 | 
			
		||||
  getCoreRowModel,
 | 
			
		||||
  getFilteredRowModel,
 | 
			
		||||
  getPaginationRowModel,
 | 
			
		||||
  getSortedRowModel,
 | 
			
		||||
  useReactTable,
 | 
			
		||||
} from "@tanstack/react-table";
 | 
			
		||||
import { ArrowUpDown } from "lucide-react";
 | 
			
		||||
import * as React from "react";
 | 
			
		||||
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
 | 
			
		||||
 | 
			
		||||
export type CheckpointItemList = NonNullable<
 | 
			
		||||
  Awaited<ReturnType<typeof getAllUserCheckpoints>>
 | 
			
		||||
>[0];
 | 
			
		||||
 | 
			
		||||
export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "id",
 | 
			
		||||
    id: "select",
 | 
			
		||||
    header: ({ table }) => (
 | 
			
		||||
      <Checkbox
 | 
			
		||||
        checked={
 | 
			
		||||
          table.getIsAllPageRowsSelected() ||
 | 
			
		||||
          (table.getIsSomePageRowsSelected() && "indeterminate")
 | 
			
		||||
        }
 | 
			
		||||
        onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
 | 
			
		||||
        aria-label="Select all"
 | 
			
		||||
      />
 | 
			
		||||
    ),
 | 
			
		||||
    cell: ({ row }) => (
 | 
			
		||||
      <Checkbox
 | 
			
		||||
        checked={row.getIsSelected()}
 | 
			
		||||
        onCheckedChange={(value) => row.toggleSelected(!!value)}
 | 
			
		||||
        aria-label="Select row"
 | 
			
		||||
      />
 | 
			
		||||
    ),
 | 
			
		||||
    enableSorting: false,
 | 
			
		||||
    enableHiding: false,
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "model_name",
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="flex items-center hover:underline"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Model Name
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      const checkpoint = row.original;
 | 
			
		||||
      return (
 | 
			
		||||
        <a
 | 
			
		||||
          className="hover:underline flex gap-2"
 | 
			
		||||
          href={`/storage/${checkpoint.id}`} // TODO
 | 
			
		||||
        >
 | 
			
		||||
          <span className="truncate max-w-[200px]">
 | 
			
		||||
            {row.original.model_name}
 | 
			
		||||
          </span>
 | 
			
		||||
 | 
			
		||||
          {checkpoint.is_public ? (
 | 
			
		||||
            <Badge variant="green">Public</Badge>
 | 
			
		||||
          ) : (
 | 
			
		||||
            <Badge variant="orange">Private</Badge>
 | 
			
		||||
          )}
 | 
			
		||||
        </a>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "status",
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="flex items-center hover:underline"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Status
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <Badge variant={row.original.status === "failed" ? "red" : (row.original.status === "started" ? "yellow" : "green")}>
 | 
			
		||||
          {row.original.status}
 | 
			
		||||
        </Badge>
 | 
			
		||||
      );
 | 
			
		||||
      // NOTE: retry downloads on failures
 | 
			
		||||
      // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
 | 
			
		||||
      // const lastUpdated = new Date(row.original.updated_at);
 | 
			
		||||
      // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
 | 
			
		||||
      // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
 | 
			
		||||
      // cell: ({ row }) => {
 | 
			
		||||
      //   // const oneHourAgo = new Date(new Date().getTime() - (60 * 60 * 1000));
 | 
			
		||||
      //   // const lastUpdated = new Date(row.original.updated_at);
 | 
			
		||||
      //   // const canRefresh = row.original.status === "failed" && lastUpdated < oneHourAgo;
 | 
			
		||||
      //   const canReDownload = true;
 | 
			
		||||
      //
 | 
			
		||||
      //   return (
 | 
			
		||||
      //     <div className="flex items-center space-x-2">
 | 
			
		||||
      //       <Badge
 | 
			
		||||
      //         variant={row.original.status === "failed"
 | 
			
		||||
      //           ? "red"
 | 
			
		||||
      //           : row.original.status === "started"
 | 
			
		||||
      //           ? "yellow"
 | 
			
		||||
      //           : "green"}
 | 
			
		||||
      //       >
 | 
			
		||||
      //         {row.original.status}
 | 
			
		||||
      //       </Badge>
 | 
			
		||||
      //       {canReDownload && (
 | 
			
		||||
      //         <RefreshCcw
 | 
			
		||||
      //           onClick={() => {
 | 
			
		||||
      //             redownloadCheckpoint(row.original);
 | 
			
		||||
      //           }}
 | 
			
		||||
      //           className="h-4 w-4 cursor-pointer" // Adjust the size with h-x and w-x classes
 | 
			
		||||
      //         />
 | 
			
		||||
      //       )}
 | 
			
		||||
      //     </div>
 | 
			
		||||
      //   );
 | 
			
		||||
      // },
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "upload_type",
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="flex items-center hover:underline"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Source
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      return <Badge variant="cyan">{row.original.upload_type}</Badge>;
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "date",
 | 
			
		||||
    sortingFn: "datetime",
 | 
			
		||||
    enableSorting: true,
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="w-full flex items-center justify-end hover:underline truncate"
 | 
			
		||||
          // variant="ghost"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Update Date
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => (
 | 
			
		||||
      <div className="w-full capitalize text-right truncate">
 | 
			
		||||
        {getRelativeTime(row.original.updated_at)}
 | 
			
		||||
      </div>
 | 
			
		||||
    ),
 | 
			
		||||
  },
 | 
			
		||||
  // TODO: deletion and editing for future sprint
 | 
			
		||||
  // {
 | 
			
		||||
  //   id: "actions",
 | 
			
		||||
  //   enableHiding: false,
 | 
			
		||||
  //   cell: ({ row }) => {
 | 
			
		||||
  //     const checkpoint = row.original;
 | 
			
		||||
  //
 | 
			
		||||
  //     return (
 | 
			
		||||
  //       <DropdownMenu>
 | 
			
		||||
  //         <DropdownMenuTrigger asChild>
 | 
			
		||||
  //           <Button variant="ghost" className="h-8 w-8 p-0">
 | 
			
		||||
  //             <span className="sr-only">Open menu</span>
 | 
			
		||||
  //             <MoreHorizontal className="h-4 w-4" />
 | 
			
		||||
  //           </Button>
 | 
			
		||||
  //         </DropdownMenuTrigger>
 | 
			
		||||
  //         <DropdownMenuContent align="end">
 | 
			
		||||
  //           <DropdownMenuLabel>Actions</DropdownMenuLabel>
 | 
			
		||||
  //           <DropdownMenuItem
 | 
			
		||||
  //             className="text-destructive"
 | 
			
		||||
  //             onClick={() => {
 | 
			
		||||
  //               deleteWorkflow(checkpoint.id);
 | 
			
		||||
  //             }}
 | 
			
		||||
  //           >
 | 
			
		||||
  //             Delete Workflow
 | 
			
		||||
  //           </DropdownMenuItem>
 | 
			
		||||
  //         </DropdownMenuContent>
 | 
			
		||||
  //       </DropdownMenu>
 | 
			
		||||
  //     );
 | 
			
		||||
  //   },
 | 
			
		||||
  // },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
  const [sorting, setSorting] = React.useState<SortingState>([]);
 | 
			
		||||
  const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
 | 
			
		||||
    []
 | 
			
		||||
  );
 | 
			
		||||
  const [columnVisibility, setColumnVisibility] =
 | 
			
		||||
    React.useState<VisibilityState>({});
 | 
			
		||||
  const [rowSelection, setRowSelection] = React.useState({});
 | 
			
		||||
 | 
			
		||||
  const table = useReactTable({
 | 
			
		||||
    data,
 | 
			
		||||
    columns,
 | 
			
		||||
    onSortingChange: setSorting,
 | 
			
		||||
    onColumnFiltersChange: setColumnFilters,
 | 
			
		||||
    getCoreRowModel: getCoreRowModel(),
 | 
			
		||||
    getPaginationRowModel: getPaginationRowModel(),
 | 
			
		||||
    getSortedRowModel: getSortedRowModel(),
 | 
			
		||||
    getFilteredRowModel: getFilteredRowModel(),
 | 
			
		||||
    onColumnVisibilityChange: setColumnVisibility,
 | 
			
		||||
    onRowSelectionChange: setRowSelection,
 | 
			
		||||
    state: {
 | 
			
		||||
      sorting,
 | 
			
		||||
      columnFilters,
 | 
			
		||||
      columnVisibility,
 | 
			
		||||
      rowSelection,
 | 
			
		||||
    },
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <div className="grid grid-rows-[auto,1fr,auto] h-full">
 | 
			
		||||
      <div className="flex flex-row w-full items-center py-4">
 | 
			
		||||
        <Input
 | 
			
		||||
          placeholder="Filter workflows..."
 | 
			
		||||
          value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
 | 
			
		||||
          onChange={(event) =>
 | 
			
		||||
            table.getColumn("name")?.setFilterValue(event.target.value)
 | 
			
		||||
          }
 | 
			
		||||
          className="max-w-sm"
 | 
			
		||||
        />
 | 
			
		||||
        <div className="ml-auto flex gap-2">
 | 
			
		||||
          <InsertModal
 | 
			
		||||
            dialogClassName="sm:max-w-[600px]"
 | 
			
		||||
            disabled={
 | 
			
		||||
              false
 | 
			
		||||
              // TODO: limitations based on plan
 | 
			
		||||
            }
 | 
			
		||||
            tooltip={"Add models using their civitai url!"}
 | 
			
		||||
            title="Civitai Checkpoint"
 | 
			
		||||
            description="Pick a model from civitai"
 | 
			
		||||
            serverAction={addCivitaiCheckpoint}
 | 
			
		||||
            formSchema={addCivitaiCheckpointSchema}
 | 
			
		||||
            fieldConfig={{
 | 
			
		||||
              civitai_url: {
 | 
			
		||||
                fieldType: "fallback",
 | 
			
		||||
                inputProps: { required: true },
 | 
			
		||||
                description: (
 | 
			
		||||
                  <>
 | 
			
		||||
                    Pick a checkpoint from{" "}
 | 
			
		||||
                    <a
 | 
			
		||||
                      href="https://www.civitai.com/models"
 | 
			
		||||
                      target="_blank"
 | 
			
		||||
                      className="underline text-blue-600 hover:text-blue-800 visited:text-purple-600"
 | 
			
		||||
                    >
 | 
			
		||||
                      civitai.com
 | 
			
		||||
                    </a>{" "}
 | 
			
		||||
                    and place it's url here
 | 
			
		||||
                  </>
 | 
			
		||||
                ),
 | 
			
		||||
              },
 | 
			
		||||
            }}
 | 
			
		||||
          />
 | 
			
		||||
        </div>
 | 
			
		||||
      </div>
 | 
			
		||||
      <ScrollArea className="h-full w-full rounded-md border">
 | 
			
		||||
        <Table>
 | 
			
		||||
          <TableHeader className="bg-background top-0 sticky">
 | 
			
		||||
            {table.getHeaderGroups().map((headerGroup) => (
 | 
			
		||||
              <TableRow key={headerGroup.id}>
 | 
			
		||||
                {headerGroup.headers.map((header) => {
 | 
			
		||||
                  return (
 | 
			
		||||
                    <TableHead key={header.id}>
 | 
			
		||||
                      {header.isPlaceholder
 | 
			
		||||
                        ? null
 | 
			
		||||
                        : flexRender(
 | 
			
		||||
                            header.column.columnDef.header,
 | 
			
		||||
                            header.getContext()
 | 
			
		||||
                          )}
 | 
			
		||||
                    </TableHead>
 | 
			
		||||
                  );
 | 
			
		||||
                })}
 | 
			
		||||
              </TableRow>
 | 
			
		||||
            ))}
 | 
			
		||||
          </TableHeader>
 | 
			
		||||
          <TableBody>
 | 
			
		||||
            {table.getRowModel().rows?.length ? (
 | 
			
		||||
              table.getRowModel().rows.map((row) => (
 | 
			
		||||
                <TableRow
 | 
			
		||||
                  key={row.id}
 | 
			
		||||
                  data-state={row.getIsSelected() && "selected"}
 | 
			
		||||
                >
 | 
			
		||||
                  {row.getVisibleCells().map((cell) => (
 | 
			
		||||
                    <TableCell key={cell.id}>
 | 
			
		||||
                      {flexRender(
 | 
			
		||||
                        cell.column.columnDef.cell,
 | 
			
		||||
                        cell.getContext()
 | 
			
		||||
                      )}
 | 
			
		||||
                    </TableCell>
 | 
			
		||||
                  ))}
 | 
			
		||||
                </TableRow>
 | 
			
		||||
              ))
 | 
			
		||||
            ) : (
 | 
			
		||||
              <TableRow>
 | 
			
		||||
                <TableCell
 | 
			
		||||
                  colSpan={columns.length}
 | 
			
		||||
                  className="h-24 text-center"
 | 
			
		||||
                >
 | 
			
		||||
                  No results.
 | 
			
		||||
                </TableCell>
 | 
			
		||||
              </TableRow>
 | 
			
		||||
            )}
 | 
			
		||||
          </TableBody>
 | 
			
		||||
        </Table>
 | 
			
		||||
      </ScrollArea>
 | 
			
		||||
      <div className="flex flex-row items-center justify-end space-x-2 py-4">
 | 
			
		||||
        <div className="flex-1 text-sm text-muted-foreground">
 | 
			
		||||
          {table.getFilteredSelectedRowModel().rows.length} of{" "}
 | 
			
		||||
          {table.getFilteredRowModel().rows.length} row(s) selected.
 | 
			
		||||
        </div>
 | 
			
		||||
        <div className="space-x-2">
 | 
			
		||||
          <Button
 | 
			
		||||
            variant="outline"
 | 
			
		||||
            size="sm"
 | 
			
		||||
            onClick={() => table.previousPage()}
 | 
			
		||||
            disabled={!table.getCanPreviousPage()}
 | 
			
		||||
          >
 | 
			
		||||
            Previous
 | 
			
		||||
          </Button>
 | 
			
		||||
          <Button
 | 
			
		||||
            variant="outline"
 | 
			
		||||
            size="sm"
 | 
			
		||||
            onClick={() => table.nextPage()}
 | 
			
		||||
            disabled={!table.getCanNextPage()}
 | 
			
		||||
          >
 | 
			
		||||
            Next
 | 
			
		||||
          </Button>
 | 
			
		||||
        </div>
 | 
			
		||||
      </div>
 | 
			
		||||
    </div>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
@ -34,6 +34,10 @@ export function NavbarMenu({ className }: { className?: string }) {
 | 
			
		||||
      name: "API Keys",
 | 
			
		||||
      path: "/api-keys",
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      name: "Storage",
 | 
			
		||||
      path: "/storage",
 | 
			
		||||
    },
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
@ -42,9 +46,9 @@ export function NavbarMenu({ className }: { className?: string }) {
 | 
			
		||||
      {isDesktop && (
 | 
			
		||||
        <Tabs
 | 
			
		||||
          defaultValue={pathname}
 | 
			
		||||
          className="w-[300px] flex pointer-events-auto"
 | 
			
		||||
          className="w-[400px] flex pointer-events-auto"
 | 
			
		||||
        >
 | 
			
		||||
          <TabsList className="grid w-full grid-cols-3">
 | 
			
		||||
          <TabsList className="grid w-full grid-cols-4">
 | 
			
		||||
            {pages.map((page) => (
 | 
			
		||||
              <TabsTrigger
 | 
			
		||||
                key={page.name}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										86
									
								
								web/src/components/custom-form/checkpoint-input.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								web/src/components/custom-form/checkpoint-input.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,86 @@
 | 
			
		||||
// NOTE: this is WIP for doing client side validation for civitai model downloading
 | 
			
		||||
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
 | 
			
		||||
import { FormControl, FormItem, FormLabel } from "../ui/form";
 | 
			
		||||
import { LoadingIcon } from "@/components/LoadingIcon";
 | 
			
		||||
import * as React from "react";
 | 
			
		||||
import AutoFormInput from "../ui/auto-form/fields/input";
 | 
			
		||||
import { useDebouncedCallback } from "use-debounce";
 | 
			
		||||
import { CivitaiModelResponse } from "@/types/civitai";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
function getUrl(civitai_url: string) {
 | 
			
		||||
  // expect to be a URL to be https://civitai.com/models/36520
 | 
			
		||||
  // possiblity with slugged name and query-param modelVersionId
 | 
			
		||||
 | 
			
		||||
  const baseUrl = "https://civitai.com/api/v1/models/";
 | 
			
		||||
  const url = new URL(civitai_url);
 | 
			
		||||
  const pathSegments = url.pathname.split("/");
 | 
			
		||||
  const modelId = pathSegments[pathSegments.indexOf("models") + 1];
 | 
			
		||||
  const modelVersionId = url.searchParams.get("modelVersionId");
 | 
			
		||||
 | 
			
		||||
  return { url: baseUrl + modelId, modelVersionId };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export default function AutoFormCheckpointInput(
 | 
			
		||||
  props: AutoFormInputComponentProps
 | 
			
		||||
) {
 | 
			
		||||
  const [loading, setLoading] = React.useState(false);
 | 
			
		||||
  const [modelRes, setModelRes] =
 | 
			
		||||
    React.useState<z.infer<typeof CivitaiModelResponse>>();
 | 
			
		||||
  const [modelVersionid, setModelVersionId] = React.useState<string | null>();
 | 
			
		||||
  const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
 | 
			
		||||
 | 
			
		||||
  const handleSearch = useDebouncedCallback((search) => {
 | 
			
		||||
    const validationResult =
 | 
			
		||||
      insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search);
 | 
			
		||||
    if (!validationResult.success) {
 | 
			
		||||
      console.error(validationResult.error);
 | 
			
		||||
      // Optionally set an error state here
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
 | 
			
		||||
    const controller = new AbortController();
 | 
			
		||||
    const { url, modelVersionId: versionId } = getUrl(search);
 | 
			
		||||
    setModelVersionId(versionId);
 | 
			
		||||
    fetch(url, {
 | 
			
		||||
      signal: controller.signal,
 | 
			
		||||
    })
 | 
			
		||||
      .then((x) => x.json())
 | 
			
		||||
      .then((a) => {
 | 
			
		||||
        const res = CivitaiModelResponse.parse(a);
 | 
			
		||||
        console.log(a);
 | 
			
		||||
        console.log(res);
 | 
			
		||||
        setModelRes(res);
 | 
			
		||||
        setLoading(false);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
    return () => {
 | 
			
		||||
      controller.abort();
 | 
			
		||||
      setLoading(false);
 | 
			
		||||
    };
 | 
			
		||||
  }, 300);
 | 
			
		||||
 | 
			
		||||
  const modifiedField = {
 | 
			
		||||
    ...fieldProps,
 | 
			
		||||
    // onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
 | 
			
		||||
    //   handleSearch(event.target.value);
 | 
			
		||||
    // },
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <FormItem>
 | 
			
		||||
      {fieldConfigItem.inputProps?.showLabel && (
 | 
			
		||||
        <FormLabel>
 | 
			
		||||
          {label}
 | 
			
		||||
          {isRequired && <span className="text-destructive">*</span>}
 | 
			
		||||
        </FormLabel>
 | 
			
		||||
      )}
 | 
			
		||||
      <FormControl>
 | 
			
		||||
        <AutoFormInput {...props} fieldProps={modifiedField} />
 | 
			
		||||
      </FormControl>
 | 
			
		||||
    </FormItem>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
import { CivitaiModelResponse } from "@/types/civitai";
 | 
			
		||||
import { type InferSelectModel, relations } from "drizzle-orm";
 | 
			
		||||
import {
 | 
			
		||||
  boolean,
 | 
			
		||||
@ -92,7 +93,7 @@ export const workflowVersionRelations = relations(
 | 
			
		||||
      fields: [workflowVersionTable.workflow_id],
 | 
			
		||||
      references: [workflowTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const workflowRunStatus = pgEnum("workflow_run_status", [
 | 
			
		||||
@ -141,7 +142,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
 | 
			
		||||
    () => workflowVersionTable.id,
 | 
			
		||||
    {
 | 
			
		||||
      onDelete: "set null",
 | 
			
		||||
    },
 | 
			
		||||
    }
 | 
			
		||||
  ),
 | 
			
		||||
  workflow_inputs:
 | 
			
		||||
    jsonb("workflow_inputs").$type<Record<string, string | number>>(),
 | 
			
		||||
@ -181,7 +182,7 @@ export const workflowRunRelations = relations(
 | 
			
		||||
      fields: [workflowRunsTable.workflow_id],
 | 
			
		||||
      references: [workflowTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// We still want to keep the workflow run record.
 | 
			
		||||
@ -205,7 +206,7 @@ export const workflowOutputRelations = relations(
 | 
			
		||||
      fields: [workflowRunOutputs.run_id],
 | 
			
		||||
      references: [workflowRunsTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// when user delete, also delete all the workflow versions
 | 
			
		||||
@ -238,7 +239,7 @@ export const snapshotType = z.object({
 | 
			
		||||
    z.object({
 | 
			
		||||
      hash: z.string(),
 | 
			
		||||
      disabled: z.boolean(),
 | 
			
		||||
    }),
 | 
			
		||||
    })
 | 
			
		||||
  ),
 | 
			
		||||
  file_custom_nodes: z.array(z.any()),
 | 
			
		||||
});
 | 
			
		||||
@ -253,7 +254,7 @@ export const showcaseMedia = z.array(
 | 
			
		||||
  z.object({
 | 
			
		||||
    url: z.string(),
 | 
			
		||||
    isCover: z.boolean().default(false),
 | 
			
		||||
  }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const showcaseMediaNullable = z
 | 
			
		||||
@ -261,7 +262,7 @@ export const showcaseMediaNullable = z
 | 
			
		||||
    z.object({
 | 
			
		||||
      url: z.string(),
 | 
			
		||||
      isCover: z.boolean().default(false),
 | 
			
		||||
    }),
 | 
			
		||||
    })
 | 
			
		||||
  )
 | 
			
		||||
  .nullable();
 | 
			
		||||
 | 
			
		||||
@ -363,6 +364,89 @@ export const authRequestsTable = dbSchema.table("auth_requests", {
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const resourceUpload = pgEnum("resource_upload", [
 | 
			
		||||
  "started",
 | 
			
		||||
  "success",
 | 
			
		||||
  "failed",
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
export const modelUploadType = pgEnum("model_upload_type", [
 | 
			
		||||
  "civitai",
 | 
			
		||||
  "huggingface",
 | 
			
		||||
  "other",
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
export const checkpointTable = dbSchema.table("checkpoints", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  description: text("description"),
 | 
			
		||||
 | 
			
		||||
  checkpoint_volume_id: uuid("checkpoint_volume_id")
 | 
			
		||||
    .notNull()
 | 
			
		||||
    .references(() => checkpointVolumeTable.id, {
 | 
			
		||||
      onDelete: "cascade",
 | 
			
		||||
    })
 | 
			
		||||
    .notNull(),
 | 
			
		||||
 | 
			
		||||
  model_name: text("model_name"),
 | 
			
		||||
  folder_path: text("folder_path"), // in volume
 | 
			
		||||
 | 
			
		||||
  civitai_id: text("civitai_id"),
 | 
			
		||||
  civitai_version_id: text("civitai_version_id"),
 | 
			
		||||
  civitai_url: text("civitai_url"),
 | 
			
		||||
  civitai_download_url: text("civitai_download_url"),
 | 
			
		||||
  civitai_model_response: jsonb("civitai_model_response").$type<
 | 
			
		||||
    z.infer<typeof CivitaiModelResponse>
 | 
			
		||||
  >(),
 | 
			
		||||
 | 
			
		||||
  hf_url: text("hf_url"),
 | 
			
		||||
  s3_url: text("s3_url"),
 | 
			
		||||
  user_url: text("client_url"),
 | 
			
		||||
 | 
			
		||||
  is_public: boolean("is_public").notNull().default(false),
 | 
			
		||||
  status: resourceUpload("status").notNull().default("started"),
 | 
			
		||||
  upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
 | 
			
		||||
  upload_type: modelUploadType("upload_type").notNull(),
 | 
			
		||||
  error_log: text("error_log"),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {
 | 
			
		||||
    // onDelete: "cascade",
 | 
			
		||||
  }),
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  volume_name: text("volume_name").notNull(),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
  disabled: boolean("disabled").default(false).notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
 | 
			
		||||
  user: one(usersTable, {
 | 
			
		||||
    fields: [checkpointTable.user_id],
 | 
			
		||||
    references: [usersTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
  volume: one(checkpointVolumeTable, {
 | 
			
		||||
    fields: [checkpointTable.checkpoint_volume_id],
 | 
			
		||||
    references: [checkpointVolumeTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
}));
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeRelations = relations(
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  ({ many, one }) => ({
 | 
			
		||||
    checkpoint: many(checkpointTable),
 | 
			
		||||
    user: one(usersTable, {
 | 
			
		||||
      fields: [checkpointVolumeTable.user_id],
 | 
			
		||||
      references: [usersTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const subscriptionPlan = pgEnum("subscription_plan", [
 | 
			
		||||
  "basic",
 | 
			
		||||
  "pro",
 | 
			
		||||
@ -389,9 +473,26 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const insertCivitaiCheckpointSchema = createInsertSchema(
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
  {
 | 
			
		||||
    civitai_url: (schema) =>
 | 
			
		||||
      schema.civitai_url
 | 
			
		||||
        .trim()
 | 
			
		||||
        .url({ message: "URL required" })
 | 
			
		||||
        .includes("civitai.com/models", {
 | 
			
		||||
          message: "civitai.com/models link required",
 | 
			
		||||
        }),
 | 
			
		||||
  }
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export type UserType = InferSelectModel<typeof usersTable>;
 | 
			
		||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
 | 
			
		||||
export type MachineType = InferSelectModel<typeof machinesTable>;
 | 
			
		||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
 | 
			
		||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
 | 
			
		||||
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
 | 
			
		||||
export type CheckpointVolumeType = InferSelectModel<
 | 
			
		||||
  typeof checkpointVolumeTable
 | 
			
		||||
>;
 | 
			
		||||
export type UserUsageType = InferSelectModel<typeof userUsageTable>;
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								web/src/server/addCheckpointSchema.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								web/src/server/addCheckpointSchema.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
 | 
			
		||||
  civitai_url: true,
 | 
			
		||||
});
 | 
			
		||||
							
								
								
									
										271
									
								
								web/src/server/curdCheckpoint.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								web/src/server/curdCheckpoint.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,271 @@
 | 
			
		||||
"use server";
 | 
			
		||||
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import {
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
  CheckpointType,
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  CheckpointVolumeType,
 | 
			
		||||
} from "@/db/schema";
 | 
			
		||||
import { withServerPromise } from "./withServerPromise";
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import type { z } from "zod";
 | 
			
		||||
import { headers } from "next/headers";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
 | 
			
		||||
import { and, eq, isNull } from "drizzle-orm";
 | 
			
		||||
import { CivitaiModelResponse } from "@/types/civitai";
 | 
			
		||||
 | 
			
		||||
export async function getCheckpoints() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpoints = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      orgId
 | 
			
		||||
        ? eq(checkpointTable.org_id, orgId)
 | 
			
		||||
        // make sure org_id is null
 | 
			
		||||
        : and(
 | 
			
		||||
          eq(checkpointTable.user_id, userId),
 | 
			
		||||
          isNull(checkpointTable.org_id),
 | 
			
		||||
        ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpoints;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function getCheckpointById(id: string) {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpoint = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      and(
 | 
			
		||||
        orgId ? eq(checkpointTable.org_id, orgId) : and(
 | 
			
		||||
          eq(checkpointTable.user_id, userId),
 | 
			
		||||
          isNull(checkpointTable.org_id),
 | 
			
		||||
        ),
 | 
			
		||||
        eq(checkpointTable.id, id),
 | 
			
		||||
      ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpoint[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function getCheckpointVolumes() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const volume = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointVolumeTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      and(
 | 
			
		||||
        orgId
 | 
			
		||||
          ? eq(checkpointVolumeTable.org_id, orgId)
 | 
			
		||||
          // make sure org_id is null
 | 
			
		||||
          : and(
 | 
			
		||||
            eq(checkpointVolumeTable.user_id, userId),
 | 
			
		||||
            isNull(checkpointVolumeTable.org_id),
 | 
			
		||||
          ),
 | 
			
		||||
        eq(checkpointVolumeTable.disabled, false),
 | 
			
		||||
      ),
 | 
			
		||||
    );
 | 
			
		||||
  return volume;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function retrieveCheckpointVolumes() {
 | 
			
		||||
  let volumes = await getCheckpointVolumes();
 | 
			
		||||
  if (volumes.length === 0) {
 | 
			
		||||
    // create volume if not already created
 | 
			
		||||
    volumes = await addCheckpointVolume();
 | 
			
		||||
  }
 | 
			
		||||
  return volumes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function addCheckpointVolume() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
 | 
			
		||||
  // Insert the new checkpointVolume into the checkpointVolumeTable
 | 
			
		||||
  const insertedVolume = await db
 | 
			
		||||
    .insert(checkpointVolumeTable)
 | 
			
		||||
    .values({
 | 
			
		||||
      user_id: userId,
 | 
			
		||||
      org_id: orgId,
 | 
			
		||||
      volume_name: `checkpoints_${userId}`,
 | 
			
		||||
      // created_at and updated_at will be set to current timestamp by default
 | 
			
		||||
      disabled: false, // Default value
 | 
			
		||||
    })
 | 
			
		||||
    .returning(); // Returns the inserted row
 | 
			
		||||
  return insertedVolume;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function getUrl(civitai_url: string) {
 | 
			
		||||
  // expect to be a URL to be https://civitai.com/models/36520
 | 
			
		||||
  // possiblity with slugged name and query-param modelVersionId
 | 
			
		||||
  const baseUrl = "https://civitai.com/api/v1/models/";
 | 
			
		||||
  const url = new URL(civitai_url);
 | 
			
		||||
  const pathSegments = url.pathname.split("/");
 | 
			
		||||
  const modelId = pathSegments[pathSegments.indexOf("models") + 1];
 | 
			
		||||
  const modelVersionId = url.searchParams.get("modelVersionId");
 | 
			
		||||
 | 
			
		||||
  return { url: baseUrl + modelId, modelVersionId };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
  async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
 | 
			
		||||
    const { userId, orgId } = auth();
 | 
			
		||||
 | 
			
		||||
    if (!data.civitai_url) return { error: "no civitai_url" };
 | 
			
		||||
    if (!userId) return { error: "No user id" };
 | 
			
		||||
 | 
			
		||||
    const { url, modelVersionId } = getUrl(data?.civitai_url);
 | 
			
		||||
    const civitaiModelRes = await fetch(url)
 | 
			
		||||
      .then((x) => x.json())
 | 
			
		||||
      .then((a) => {
 | 
			
		||||
        console.log(a);
 | 
			
		||||
        return CivitaiModelResponse.parse(a);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
    if (civitaiModelRes?.modelVersions?.length === 0) {
 | 
			
		||||
      return; // no versions to download
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let selectedModelVersion;
 | 
			
		||||
    let selectedModelVersionId: string | null = modelVersionId;
 | 
			
		||||
    if (!selectedModelVersionId) {
 | 
			
		||||
      selectedModelVersion = civitaiModelRes.modelVersions[0];
 | 
			
		||||
      selectedModelVersionId = civitaiModelRes.modelVersions[0].id.toString();
 | 
			
		||||
    } else {
 | 
			
		||||
      selectedModelVersion = civitaiModelRes.modelVersions.find((version) =>
 | 
			
		||||
        version.id.toString() === selectedModelVersionId
 | 
			
		||||
      );
 | 
			
		||||
      if (!selectedModelVersion) {
 | 
			
		||||
        return; // version id is wrong
 | 
			
		||||
      }
 | 
			
		||||
      selectedModelVersionId = selectedModelVersion?.id.toString();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const checkpointVolumes = await getCheckpointVolumes();
 | 
			
		||||
    let cVolume;
 | 
			
		||||
    if (checkpointVolumes.length === 0) {
 | 
			
		||||
      const volume = await addCheckpointVolume();
 | 
			
		||||
      cVolume = volume[0];
 | 
			
		||||
    } else {
 | 
			
		||||
      cVolume = checkpointVolumes[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const a = await db
 | 
			
		||||
      .insert(checkpointTable)
 | 
			
		||||
      .values({
 | 
			
		||||
        user_id: userId,
 | 
			
		||||
        org_id: orgId,
 | 
			
		||||
        upload_type: "civitai",
 | 
			
		||||
        model_name: selectedModelVersion.files[0].name,
 | 
			
		||||
        civitai_id: civitaiModelRes.id.toString(),
 | 
			
		||||
        civitai_version_id: selectedModelVersionId,
 | 
			
		||||
        civitai_url: data.civitai_url,
 | 
			
		||||
        civitai_download_url: selectedModelVersion.files[0].downloadUrl,
 | 
			
		||||
        civitai_model_response: civitaiModelRes,
 | 
			
		||||
        checkpoint_volume_id: cVolume.id,
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
      })
 | 
			
		||||
      .returning();
 | 
			
		||||
 | 
			
		||||
    const b = a[0];
 | 
			
		||||
 | 
			
		||||
    await uploadCheckpoint(data, b, cVolume);
 | 
			
		||||
    // redirect(`/checkpoints/${b.id}`);
 | 
			
		||||
  },
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// export const redownloadCheckpoint = withServerPromise(
 | 
			
		||||
//   async (data: CheckpointItemList) => {
 | 
			
		||||
//     const { userId } = auth();
 | 
			
		||||
//     if (!userId) return { error: "No user id" };
 | 
			
		||||
//
 | 
			
		||||
//     const checkpointVolumes = await getCheckpointVolumes();
 | 
			
		||||
//     let cVolume;
 | 
			
		||||
//     if (checkpointVolumes.length === 0) {
 | 
			
		||||
//       const volume = await addCheckpointVolume();
 | 
			
		||||
//       cVolume = volume[0];
 | 
			
		||||
//     } else {
 | 
			
		||||
//       cVolume = checkpointVolumes[0];
 | 
			
		||||
//     }
 | 
			
		||||
//
 | 
			
		||||
//     console.log("data");
 | 
			
		||||
//     console.log(data);
 | 
			
		||||
//
 | 
			
		||||
//     const a = await db
 | 
			
		||||
//       .update(checkpointTable)
 | 
			
		||||
//       .set({
 | 
			
		||||
//         // status: "started",
 | 
			
		||||
//         // updated_at: new Date(),
 | 
			
		||||
//       })
 | 
			
		||||
//       .returning();
 | 
			
		||||
//
 | 
			
		||||
//     const b = a[0];
 | 
			
		||||
//
 | 
			
		||||
//     console.log("b");
 | 
			
		||||
//     console.log(b);
 | 
			
		||||
//
 | 
			
		||||
//     await uploadCheckpoint(data, b, cVolume);
 | 
			
		||||
//     // redirect(`/checkpoints/${b.id}`);
 | 
			
		||||
//   },
 | 
			
		||||
// );
 | 
			
		||||
 | 
			
		||||
async function uploadCheckpoint(
 | 
			
		||||
  data: z.infer<typeof addCivitaiCheckpointSchema>,
 | 
			
		||||
  c: CheckpointType,
 | 
			
		||||
  v: CheckpointVolumeType,
 | 
			
		||||
) {
 | 
			
		||||
  const headersList = headers();
 | 
			
		||||
 | 
			
		||||
  const domain = headersList.get("x-forwarded-host") || "";
 | 
			
		||||
  const protocol = headersList.get("x-forwarded-proto") || "";
 | 
			
		||||
 | 
			
		||||
  if (domain === "") {
 | 
			
		||||
    throw new Error("No domain");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Call remote builder
 | 
			
		||||
  const result = await fetch(
 | 
			
		||||
    `${process.env.MODAL_BUILDER_URL!}/upload-volume`,
 | 
			
		||||
    {
 | 
			
		||||
      method: "POST",
 | 
			
		||||
      headers: {
 | 
			
		||||
        "Content-Type": "application/json",
 | 
			
		||||
      },
 | 
			
		||||
      body: JSON.stringify({
 | 
			
		||||
        download_url: c.civitai_download_url,
 | 
			
		||||
        volume_name: v.volume_name,
 | 
			
		||||
        volume_id: v.id,
 | 
			
		||||
        checkpoint_id: c.id,
 | 
			
		||||
        callback_url: `${protocol}://${domain}/api/volume-upload`,
 | 
			
		||||
        upload_type: "checkpoint"
 | 
			
		||||
      }),
 | 
			
		||||
    },
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  if (!result.ok) {
 | 
			
		||||
    const error_log = await result.text();
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        ...data,
 | 
			
		||||
        status: "failed",
 | 
			
		||||
        error_log: error_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, c.id));
 | 
			
		||||
    throw new Error(`Error: ${result.statusText} ${error_log}`);
 | 
			
		||||
  } else {
 | 
			
		||||
    // setting the build machine id
 | 
			
		||||
    const json = await result.json();
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        ...data,
 | 
			
		||||
        upload_machine_id: json.build_machine_instance_id,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, c.id));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -15,6 +15,7 @@ import { headers } from "next/headers";
 | 
			
		||||
import { redirect } from "next/navigation";
 | 
			
		||||
import "server-only";
 | 
			
		||||
import type { z } from "zod";
 | 
			
		||||
import { retrieveCheckpointVolumes } from "./curdCheckpoint";
 | 
			
		||||
 | 
			
		||||
export async function getMachines() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
@ -189,6 +190,7 @@ async function _buildMachine(
 | 
			
		||||
    throw new Error("No domain");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const volumes = await retrieveCheckpointVolumes();
 | 
			
		||||
  // Call remote builder
 | 
			
		||||
  const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
 | 
			
		||||
    method: "POST",
 | 
			
		||||
@ -202,6 +204,7 @@ async function _buildMachine(
 | 
			
		||||
      callback_url: `${protocol}://${domain}/api/machine-built`,
 | 
			
		||||
      models: data.models, //JSON.parse(data.models as string),
 | 
			
		||||
      gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
 | 
			
		||||
      checkpoint_volume_name: volumes[0].volume_name,
 | 
			
		||||
    }),
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										41
									
								
								web/src/server/getAllUserCheckpoints.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								web/src/server/getAllUserCheckpoints.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import {
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
} from "@/db/schema";
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import { and, desc, eq, isNull } from "drizzle-orm";
 | 
			
		||||
 | 
			
		||||
export async function getAllUserCheckpoints() {
 | 
			
		||||
  const { userId, orgId } = await auth();
 | 
			
		||||
 | 
			
		||||
  if (!userId) {
 | 
			
		||||
    return null;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const checkpoints = await db.query.checkpointTable.findMany({
 | 
			
		||||
    with: {
 | 
			
		||||
      user: {
 | 
			
		||||
        columns: {
 | 
			
		||||
          name: true,
 | 
			
		||||
        },
 | 
			
		||||
      },
 | 
			
		||||
    },
 | 
			
		||||
    columns: {
 | 
			
		||||
      id: true,
 | 
			
		||||
      updated_at: true,
 | 
			
		||||
      model_name: true,
 | 
			
		||||
      civitai_url: true,
 | 
			
		||||
      civitai_model_response: true,
 | 
			
		||||
      is_public: true,
 | 
			
		||||
      upload_type: true,
 | 
			
		||||
      status: true,
 | 
			
		||||
    },
 | 
			
		||||
    orderBy: desc(checkpointTable.updated_at),
 | 
			
		||||
    where: 
 | 
			
		||||
      orgId != undefined
 | 
			
		||||
        ? eq(checkpointTable.org_id, orgId)
 | 
			
		||||
        : and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return checkpoints;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										129
									
								
								web/src/types/civitai.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								web/src/types/civitai.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,129 @@
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
 | 
			
		||||
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
 | 
			
		||||
 | 
			
		||||
export const creatorSchema = z.object({
 | 
			
		||||
  username: z.string().nullish(),
 | 
			
		||||
  image: z.string().url().nullish(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const fileMetadataSchema = z.object({
 | 
			
		||||
  fp: z.string().nullish(),
 | 
			
		||||
  size: z.string().nullish(),
 | 
			
		||||
  format: z.string().nullish(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const fileSchema = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  sizeKB: z.number().nullish(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  type: z.string().nullish(),
 | 
			
		||||
  metadata: fileMetadataSchema.nullish(),
 | 
			
		||||
  pickleScanResult: z.string().nullish(),
 | 
			
		||||
  pickleScanMessage: z.string().nullable(),
 | 
			
		||||
  virusScanResult: z.string().nullish(),
 | 
			
		||||
  virusScanMessage: z.string().nullable(),
 | 
			
		||||
  scannedAt: z.string().nullish(),
 | 
			
		||||
  hashes: z.record(z.string()).nullish(),
 | 
			
		||||
  downloadUrl: z.string().url(),
 | 
			
		||||
  primary: z.boolean().nullish(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const imageMetadataSchema = z.object({
 | 
			
		||||
  hash: z.string(),
 | 
			
		||||
  width: z.number(),
 | 
			
		||||
  height: z.number(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const imageMetaSchema = z.object({
 | 
			
		||||
  ENSD: z.string().nullish(),
 | 
			
		||||
  Size: z.string().nullish(),
 | 
			
		||||
  seed: z.number().nullish(),
 | 
			
		||||
  Model: z.string().nullish(),
 | 
			
		||||
  steps: z.number().nullish(),
 | 
			
		||||
  hashes: z.record(z.string()).nullish(),
 | 
			
		||||
  prompt: z.string().nullish(),
 | 
			
		||||
  sampler: z.string().nullish(),
 | 
			
		||||
  cfgScale: z.number().nullish(),
 | 
			
		||||
  ClipSkip: z.number().nullish(),
 | 
			
		||||
  resources: z.array(
 | 
			
		||||
    z.object({
 | 
			
		||||
      hash: z.string().nullish(),
 | 
			
		||||
      name: z.string(),
 | 
			
		||||
      type: z.string(),
 | 
			
		||||
      weight: z.number().nullish(),
 | 
			
		||||
    }),
 | 
			
		||||
  ).nullish(),
 | 
			
		||||
  ModelHash: z.string().nullish(),
 | 
			
		||||
  HiresSteps: z.string().nullish(),
 | 
			
		||||
  HiresUpscale: z.string().nullish(),
 | 
			
		||||
  HiresUpscaler: z.string().nullish(),
 | 
			
		||||
  negativePrompt: z.string(),
 | 
			
		||||
  DenoisingStrength: z.number().nullish(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
// NOTE: this definition is all over the place
 | 
			
		||||
// export const imageSchema = z.object({
 | 
			
		||||
//   url: z.string().url().nullish(),
 | 
			
		||||
//   nsfw: z.enum(["None", "Soft", "Mature"]).nullish(),
 | 
			
		||||
//   width: z.number().nullish(),
 | 
			
		||||
//   height: z.number().nullish(),
 | 
			
		||||
//   hash: z.string().nullish(),
 | 
			
		||||
//   type: z.string().nullish(),
 | 
			
		||||
//   metadata: imageMetadataSchema.nullish(),
 | 
			
		||||
//   meta: imageMetaSchema.nullish(),
 | 
			
		||||
// });
 | 
			
		||||
 | 
			
		||||
export const modelVersionSchema = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  modelId: z.number(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  createdAt: z.string().nullish(),
 | 
			
		||||
  updatedAt: z.string().nullish(),
 | 
			
		||||
  // status: z.enum(["Published", "Unpublished"]).nullish(),
 | 
			
		||||
  status: z.string().nullish(),
 | 
			
		||||
  publishedAt: z.string().nullish(),
 | 
			
		||||
  trainedWords: z.array(z.string()).nullable(),
 | 
			
		||||
  trainingStatus: z.string().nullable(),
 | 
			
		||||
  trainingDetails: z.string().nullable(),
 | 
			
		||||
  baseModel: z.string().nullish(),
 | 
			
		||||
  baseModelType: z.string().nullish(),
 | 
			
		||||
  earlyAccessTimeFrame: z.number().nullish(),
 | 
			
		||||
  description: z.string().nullable(),
 | 
			
		||||
  vaeId: z.number().nullable(),
 | 
			
		||||
  stats: z.object({
 | 
			
		||||
    downloadCount: z.number(),
 | 
			
		||||
    ratingCount: z.number(),
 | 
			
		||||
    rating: z.number(),
 | 
			
		||||
  }).nullish(),
 | 
			
		||||
  files: z.array(fileSchema),
 | 
			
		||||
  images: z.array(z.any()).nullish(),
 | 
			
		||||
  downloadUrl: z.string().url(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const statsSchema = z.object({
 | 
			
		||||
  downloadCount: z.number(),
 | 
			
		||||
  favoriteCount: z.number(),
 | 
			
		||||
  commentCount: z.number(),
 | 
			
		||||
  ratingCount: z.number(),
 | 
			
		||||
  rating: z.number(),
 | 
			
		||||
  tippedAmountCount: z.number(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const CivitaiModelResponse = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  name: z.string().nullish(),
 | 
			
		||||
  description: z.string().nullish(),
 | 
			
		||||
  // type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know
 | 
			
		||||
  type: z.string(),
 | 
			
		||||
  poi: z.boolean().nullish(),
 | 
			
		||||
  nsfw: z.boolean().nullish(),
 | 
			
		||||
  allowNoCredit: z.boolean().nullish(),
 | 
			
		||||
  allowCommercialUse: z.string().nullish(),
 | 
			
		||||
  allowDerivatives: z.boolean().nullish(),
 | 
			
		||||
  allowDifferentLicense: z.boolean().nullish(),
 | 
			
		||||
  stats: statsSchema.nullish(),
 | 
			
		||||
  creator: creatorSchema.nullish(),
 | 
			
		||||
  tags: z.array(z.string()).nullish(),
 | 
			
		||||
  modelVersions: z.array(modelVersionSchema),
 | 
			
		||||
});
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user