blah
This commit is contained in:
		
							parent
							
								
									6437de4def
								
							
						
					
					
						commit
						b1e9bcc4e6
					
				@ -8,6 +8,7 @@ from enum import Enum
 | 
				
			|||||||
import json
 | 
					import json
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					from uuid import uuid4
 | 
				
			||||||
from contextlib import asynccontextmanager
 | 
					from contextlib import asynccontextmanager
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
@ -19,6 +20,7 @@ from urllib.parse import parse_qs
 | 
				
			|||||||
from starlette.middleware.base import BaseHTTPMiddleware
 | 
					from starlette.middleware.base import BaseHTTPMiddleware
 | 
				
			||||||
from starlette.types import ASGIApp, Scope, Receive, Send
 | 
					from starlette.types import ASGIApp, Scope, Receive, Send
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from concurrent.futures import ThreadPoolExecutor
 | 
					from concurrent.futures import ThreadPoolExecutor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# executor = ThreadPoolExecutor(max_workers=5)
 | 
					# executor = ThreadPoolExecutor(max_workers=5)
 | 
				
			||||||
@ -174,6 +176,7 @@ class Item(BaseModel):
 | 
				
			|||||||
    snapshot: Snapshot
 | 
					    snapshot: Snapshot
 | 
				
			||||||
    models: List[Model]
 | 
					    models: List[Model]
 | 
				
			||||||
    callback_url: str
 | 
					    callback_url: str
 | 
				
			||||||
 | 
					    checkpoint_volume_name: str
 | 
				
			||||||
    gpu: GPUType = Field(default=GPUType.T4)
 | 
					    gpu: GPUType = Field(default=GPUType.T4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @field_validator('gpu')
 | 
					    @field_validator('gpu')
 | 
				
			||||||
@ -223,6 +226,102 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#     return {"Hello": "World"}
 | 
					#     return {"Hello": "World"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UploadType(str, Enum):
 | 
				
			||||||
 | 
					    checkpoint = "checkpoint"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UploadBody(BaseModel):
 | 
				
			||||||
 | 
					    download_url: str
 | 
				
			||||||
 | 
					    volume_name: str
 | 
				
			||||||
 | 
					    volume_id: str
 | 
				
			||||||
 | 
					    checkpoint_id: str
 | 
				
			||||||
 | 
					    upload_type: UploadType
 | 
				
			||||||
 | 
					    callback_url: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					UPLOAD_TYPE_DIR_MAP = {
 | 
				
			||||||
 | 
					    UploadType.checkpoint: "checkpoints"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/upload-volume")
 | 
				
			||||||
 | 
					async def upload_checkpoint(body: UploadBody):
 | 
				
			||||||
 | 
					    global last_activity_time
 | 
				
			||||||
 | 
					    last_activity_time = time.time()
 | 
				
			||||||
 | 
					    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    asyncio.create_task(upload_logic(body))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # check that this
 | 
				
			||||||
 | 
					    return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def upload_logic(body: UploadBody):
 | 
				
			||||||
 | 
					    folder_path = f"/app/builds/{body.volume_id}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cp_process = await asyncio.subprocess.create_subprocess_exec("cp", "-r", "/app/src/volume-builder", folder_path)
 | 
				
			||||||
 | 
					    await cp_process.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    upload_path = UPLOAD_TYPE_DIR_MAP[body.upload_type]
 | 
				
			||||||
 | 
					    config = {
 | 
				
			||||||
 | 
					        "volume_names": {
 | 
				
			||||||
 | 
					            body.volume_name: {"download_url": body.download_url, "folder_path": upload_path}
 | 
				
			||||||
 | 
					        }, 
 | 
				
			||||||
 | 
					        "volume_paths": {
 | 
				
			||||||
 | 
					            body.volume_name: f'/volumes/{uuid4()}'
 | 
				
			||||||
 | 
					        }, 
 | 
				
			||||||
 | 
					        "callback_url": body.callback_url,
 | 
				
			||||||
 | 
					        "callback_body": {
 | 
				
			||||||
 | 
					            "checkpoint_id": body.checkpoint_id,
 | 
				
			||||||
 | 
					            "volume_id": body.volume_id,
 | 
				
			||||||
 | 
					            "folder_path": upload_path,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    with open(f"{folder_path}/config.py", "w") as f:
 | 
				
			||||||
 | 
					        f.write("config = " + json.dumps(config))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    process = await asyncio.subprocess.create_subprocess_shell(
 | 
				
			||||||
 | 
					        f"modal run app.py",
 | 
				
			||||||
 | 
					        # stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        # stderr=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        cwd=folder_path,
 | 
				
			||||||
 | 
					        env={**os.environ, "COLUMNS": "10000"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # error_logs = []
 | 
				
			||||||
 | 
					    # async def read_stream(stream):
 | 
				
			||||||
 | 
					    #     while True:
 | 
				
			||||||
 | 
					    #         line = await stream.readline()
 | 
				
			||||||
 | 
					    #         if line:
 | 
				
			||||||
 | 
					    #             l = line.decode('utf-8').strip()
 | 
				
			||||||
 | 
					    #             error_logs.append(l)
 | 
				
			||||||
 | 
					    #             logger.error(l)
 | 
				
			||||||
 | 
					    #             error_logs.append({
 | 
				
			||||||
 | 
					    #                 "logs": l,
 | 
				
			||||||
 | 
					    #                 "timestamp": time.time()
 | 
				
			||||||
 | 
					    #             })
 | 
				
			||||||
 | 
					    #         else:
 | 
				
			||||||
 | 
					    #             break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # stderr_read_task = asyncio.create_task(read_stream(process.stderr))
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # await asyncio.wait([stderr_read_task])
 | 
				
			||||||
 | 
					    # await process.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if process.returncode != 0:
 | 
				
			||||||
 | 
					    #     error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
 | 
				
			||||||
 | 
					    #     # Error handling: send POST request to callback URL with error details
 | 
				
			||||||
 | 
					    #     requests.post(body.callback_url, json={
 | 
				
			||||||
 | 
					    #         "volume_id": body.volume_id, 
 | 
				
			||||||
 | 
					    #         "checkpoint_id": body.checkpoint_id,
 | 
				
			||||||
 | 
					    #         "folder_path": upload_path,
 | 
				
			||||||
 | 
					    #         "error_logs": json.dumps(error_logs),
 | 
				
			||||||
 | 
					    #         "status": "failed"
 | 
				
			||||||
 | 
					    #     })
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # requests.post(body.callback_url, json={
 | 
				
			||||||
 | 
					    #     "checkpoint_id": body.checkpoint_id,
 | 
				
			||||||
 | 
					    #     "volume_id": body.volume_id,
 | 
				
			||||||
 | 
					    #     "folder_path": upload_path,
 | 
				
			||||||
 | 
					    #     "status": "success"
 | 
				
			||||||
 | 
					    # })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/create")
 | 
					@app.post("/create")
 | 
				
			||||||
async def create_machine(item: Item):
 | 
					async def create_machine(item: Item):
 | 
				
			||||||
@ -312,7 +411,9 @@ async def build_logic(item: Item):
 | 
				
			|||||||
    config = {
 | 
					    config = {
 | 
				
			||||||
        "name": item.name,
 | 
					        "name": item.name,
 | 
				
			||||||
        "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
 | 
					        "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
 | 
				
			||||||
        "gpu": item.gpu
 | 
					        "gpu": item.gpu,
 | 
				
			||||||
 | 
					        "public_checkpoint_volume": "model-store",
 | 
				
			||||||
 | 
					        "private_checkpoint_volume": item.checkpoint_volume_name
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
					    with open(f"{folder_path}/config.py", "w") as f:
 | 
				
			||||||
        f.write("config = " + json.dumps(config))
 | 
					        f.write("config = " + json.dumps(config))
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,7 @@ import urllib.parse
 | 
				
			|||||||
from pydantic import BaseModel
 | 
					from pydantic import BaseModel
 | 
				
			||||||
from fastapi import FastAPI, Request
 | 
					from fastapi import FastAPI, Request
 | 
				
			||||||
from fastapi.responses import HTMLResponse
 | 
					from fastapi.responses import HTMLResponse
 | 
				
			||||||
 | 
					from volume_setup import volumes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# deploy_test = False
 | 
					# deploy_test = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,6 +28,7 @@ deploy_test = config["deploy_test"] == "True"
 | 
				
			|||||||
web_app = FastAPI()
 | 
					web_app = FastAPI()
 | 
				
			||||||
print(config)
 | 
					print(config)
 | 
				
			||||||
print("deploy_test ", deploy_test)
 | 
					print("deploy_test ", deploy_test)
 | 
				
			||||||
 | 
					print('volumes', volumes)
 | 
				
			||||||
stub = Stub(name=config["name"])
 | 
					stub = Stub(name=config["name"])
 | 
				
			||||||
# print(stub.app_id)
 | 
					# print(stub.app_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -56,7 +58,7 @@ if not deploy_test:
 | 
				
			|||||||
        #     # Install comfy deploy
 | 
					        #     # Install comfy deploy
 | 
				
			||||||
        #     "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
 | 
					        #     "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
 | 
				
			||||||
        # )
 | 
					        # )
 | 
				
			||||||
        # .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
 | 
					        .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        .copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
 | 
					        .copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
 | 
				
			||||||
        .run_commands("chmod +x /start.sh")
 | 
					        .run_commands("chmod +x /start.sh")
 | 
				
			||||||
@ -154,7 +156,7 @@ image = Image.debian_slim()
 | 
				
			|||||||
target_image = image if deploy_test else dockerfile_image
 | 
					target_image = image if deploy_test else dockerfile_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@stub.function(image=target_image, gpu=config["gpu"])
 | 
					@stub.function(image=target_image, gpu=config["gpu"], volumes=volumes)
 | 
				
			||||||
def run(input: Input):
 | 
					def run(input: Input):
 | 
				
			||||||
    import subprocess
 | 
					    import subprocess
 | 
				
			||||||
    import time
 | 
					    import time
 | 
				
			||||||
@ -235,7 +237,7 @@ async def bar(request_input: RequestInput):
 | 
				
			|||||||
    # pass
 | 
					    # pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@stub.function(image=image)
 | 
					@stub.function(image=image, volumes=volumes)
 | 
				
			||||||
@asgi_app()
 | 
					@asgi_app()
 | 
				
			||||||
def comfyui_api():
 | 
					def comfyui_api():
 | 
				
			||||||
    return web_app
 | 
					    return web_app
 | 
				
			||||||
@ -284,6 +286,7 @@ def spawn_comfyui_in_background():
 | 
				
			|||||||
    # Restrict to 1 container because we want to our ComfyUI session state
 | 
					    # Restrict to 1 container because we want to our ComfyUI session state
 | 
				
			||||||
    # to be on a single container.
 | 
					    # to be on a single container.
 | 
				
			||||||
    concurrency_limit=1,
 | 
					    concurrency_limit=1,
 | 
				
			||||||
 | 
					    volumes=volumes,
 | 
				
			||||||
    timeout=10 * 60,
 | 
					    timeout=10 * 60,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@asgi_app()
 | 
					@asgi_app()
 | 
				
			||||||
 | 
				
			|||||||
@ -1 +1,7 @@
 | 
				
			|||||||
config = {"name": "my-app", "deploy_test": "True", "gpu": "T4"}
 | 
					config = {
 | 
				
			||||||
 | 
					    "name": "my-app",
 | 
				
			||||||
 | 
					    "deploy_test": "True",
 | 
				
			||||||
 | 
					    "gpu": "T4", 
 | 
				
			||||||
 | 
					    "public_checkpoint_volume": "model-store",
 | 
				
			||||||
 | 
					    "private_checkpoint_volume": "private-model-store"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,15 @@
 | 
				
			|||||||
comfyui:
 | 
					public:
 | 
				
			||||||
    base_path: /runpod-volume/ComfyUI/
 | 
					  base_path: /public_models/
 | 
				
			||||||
    checkpoints: models/checkpoints/
 | 
					  checkpoints: checkpoints
 | 
				
			||||||
    clip: models/clip/
 | 
					  clip: clip
 | 
				
			||||||
    clip_vision: models/clip_vision/
 | 
					  clip_vision: clip_vision
 | 
				
			||||||
    configs: models/configs/
 | 
					  configs: configs
 | 
				
			||||||
    controlnet: models/controlnet/
 | 
					  controlnet: controlnet
 | 
				
			||||||
    embeddings: models/embeddings/
 | 
					  embeddings: embeddings
 | 
				
			||||||
    loras: models/loras/
 | 
					  loras: loras
 | 
				
			||||||
    upscale_models: models/upscale_models/
 | 
					  upscale_models: upscale_models
 | 
				
			||||||
    vae: models/vae/
 | 
					  vae: vae
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					private:
 | 
				
			||||||
 | 
					  base_path: /private_models/
 | 
				
			||||||
 | 
					  checkpoints: checkpoints
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										101
									
								
								builder/modal-builder/src/template/data/insert_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								builder/modal-builder/src/template/data/insert_models.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,101 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					This is a standalone script to download models into a modal Volume using civitai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Example Usage
 | 
				
			||||||
 | 
					`modal run insert_models::insert_model --civitai-url https://civitai.com/models/36520/ghostmix`
 | 
				
			||||||
 | 
					This inserts an individual model from a civitai url 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					`modal run insert_models::insert_models_civitai_api` 
 | 
				
			||||||
 | 
					This inserts a bunch of models based on the models retrieved by civitai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					civitai's API reference https://github.com/civitai/civitai/wiki/REST-API-Reference
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import modal
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					stub = modal.Stub()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# NOTE: volume name can be variable
 | 
				
			||||||
 | 
					volume = modal.Volume.persisted("rah")
 | 
				
			||||||
 | 
					model_store_path = "/vol/models"
 | 
				
			||||||
 | 
					MODEL_ROUTE = "models"
 | 
				
			||||||
 | 
					image = (
 | 
				
			||||||
 | 
					    modal.Image.debian_slim().apt_install("wget").pip_install("requests")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.function(volumes={model_store_path: volume}, image=image, timeout=50000, gpu=None)
 | 
				
			||||||
 | 
					def download_model(download_url):
 | 
				
			||||||
 | 
					    print(download_url)
 | 
				
			||||||
 | 
					    subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
 | 
				
			||||||
 | 
					    subprocess.run(["ls", "-la", model_store_path])
 | 
				
			||||||
 | 
					    volume.commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# file is raw output from Civitai API https://github.com/civitai/civitai/wiki/REST-API-Reference
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.function()
 | 
				
			||||||
 | 
					def get_civitai_models(model_type: str, sort: str = "Highest Rated", page: int = 1):
 | 
				
			||||||
 | 
					    """Fetch models from CivitAI API based on type."""
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        response = requests.get(f"https://civitai.com/api/v1/models", params={"types": model_type, "page": page, "sort": sort})
 | 
				
			||||||
 | 
					        response.raise_for_status()
 | 
				
			||||||
 | 
					        return response.json()
 | 
				
			||||||
 | 
					    except requests.RequestException as e:
 | 
				
			||||||
 | 
					        print(f"Error fetching models: {e}")
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.function()
 | 
				
			||||||
 | 
					def get_civitai_model_url(civitai_url: str):
 | 
				
			||||||
 | 
					    # Validate the URL
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if civitai_url.startswith("https://civitai.com/api/"):
 | 
				
			||||||
 | 
					        api_url = civitai_url
 | 
				
			||||||
 | 
					    elif civitai_url.startswith("https://civitai.com/models/"):  
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            model_id = civitai_url.split("/")[4]
 | 
				
			||||||
 | 
					            int(model_id) 
 | 
				
			||||||
 | 
					        except (IndexError, ValueError):
 | 
				
			||||||
 | 
					            return None 
 | 
				
			||||||
 | 
					        api_url = f"https://civitai.com/api/v1/models/{model_id}"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return "Error: URL must be from civitai.com and contain /models/"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = requests.get(api_url)
 | 
				
			||||||
 | 
					    # Check for successful response
 | 
				
			||||||
 | 
					    if response.status_code != 200:
 | 
				
			||||||
 | 
					        return f"Error: Unable to fetch data from {api_url}"
 | 
				
			||||||
 | 
					    # Return the response data
 | 
				
			||||||
 | 
					    return response.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.local_entrypoint()
 | 
				
			||||||
 | 
					def insert_models_civitai_api(type: str = "Checkpoint", sort = "Highest Rated", page: int = 1):
 | 
				
			||||||
 | 
					    civitai_models = get_civitai_models.local(type, sort, page)
 | 
				
			||||||
 | 
					    if civitai_models:
 | 
				
			||||||
 | 
					        for _ in download_model.map(map(lambda model: model['modelVersions'][0]['downloadUrl'], civitai_models['items'])):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        print("Failed to retrieve models.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.local_entrypoint()
 | 
				
			||||||
 | 
					def insert_model(civitai_url: str):
 | 
				
			||||||
 | 
					    if civitai_url.startswith("'https://civitai.com/api/download/models/"):
 | 
				
			||||||
 | 
					        download_url = civitai_url
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        civitai_model = get_civitai_model_url.local(civitai_url)
 | 
				
			||||||
 | 
					        if civitai_model:
 | 
				
			||||||
 | 
					            download_url = civitai_model['modelVersions'][0]['downloadUrl']
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return "invalid URL"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    download_model.remote(download_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.local_entrypoint()
 | 
				
			||||||
 | 
					def simple_download():
 | 
				
			||||||
 | 
					    download_urls = ['https://civitai.com/api/download/models/119057', 'https://civitai.com/api/download/models/130090', 'https://civitai.com/api/download/models/31859', 'https://civitai.com/api/download/models/128713', 'https://civitai.com/api/download/models/179657', 'https://civitai.com/api/download/models/143906', 'https://civitai.com/api/download/models/9208', 'https://civitai.com/api/download/models/136078', 'https://civitai.com/api/download/models/134065', 'https://civitai.com/api/download/models/288775', 'https://civitai.com/api/download/models/95263', 'https://civitai.com/api/download/models/288982', 'https://civitai.com/api/download/models/87153', 'https://civitai.com/api/download/models/10638', 'https://civitai.com/api/download/models/263809', 'https://civitai.com/api/download/models/130072', 'https://civitai.com/api/download/models/117019', 'https://civitai.com/api/download/models/95256', 'https://civitai.com/api/download/models/197181', 'https://civitai.com/api/download/models/256915', 'https://civitai.com/api/download/models/118945', 'https://civitai.com/api/download/models/125843', 'https://civitai.com/api/download/models/179015', 'https://civitai.com/api/download/models/245598', 'https://civitai.com/api/download/models/223670', 'https://civitai.com/api/download/models/90072', 'https://civitai.com/api/download/models/290817', 'https://civitai.com/api/download/models/154097', 'https://civitai.com/api/download/models/143497', 'https://civitai.com/api/download/models/5637']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for _ in download_model.map(download_urls):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
@ -45,12 +45,12 @@ for package in packages:
 | 
				
			|||||||
    response = requests.request("POST", f"{root_url}/customnode/install", json=package, headers=headers)
 | 
					    response = requests.request("POST", f"{root_url}/customnode/install", json=package, headers=headers)
 | 
				
			||||||
    print(response.text)
 | 
					    print(response.text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
with open('models.json') as f:
 | 
					# with open('models.json') as f:
 | 
				
			||||||
    models = json.load(f)
 | 
					#     models = json.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
for model in models:
 | 
					# for model in models:
 | 
				
			||||||
    response = requests.request("POST", f"{root_url}/model/install", json=model, headers=headers)
 | 
					#     response = requests.request("POST", f"{root_url}/model/install", json=model, headers=headers)
 | 
				
			||||||
    print(response.text)
 | 
					#     print(response.text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Close the server
 | 
					# Close the server
 | 
				
			||||||
server_process.terminate()
 | 
					server_process.terminate()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										9
									
								
								builder/modal-builder/src/template/volume_setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								builder/modal-builder/src/template/volume_setup.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					import modal
 | 
				
			||||||
 | 
					from config import config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
 | 
				
			||||||
 | 
					private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PUBLIC_BASEMODEL_DIR = "/public_models"
 | 
				
			||||||
 | 
					PRIVATE_BASEMODEL_DIR = "/private_models"
 | 
				
			||||||
 | 
					volumes = {PUBLIC_BASEMODEL_DIR: public_model_volume, PRIVATE_BASEMODEL_DIR: private_volume}
 | 
				
			||||||
							
								
								
									
										69
									
								
								builder/modal-builder/src/volume-builder/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								builder/modal-builder/src/volume-builder/app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					import modal
 | 
				
			||||||
 | 
					from config import config
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					stub = modal.Stub()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
 | 
				
			||||||
 | 
					def is_valid_name(name: str) -> bool:
 | 
				
			||||||
 | 
					    allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
 | 
				
			||||||
 | 
					    return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_volumes(volume_names, paths):
 | 
				
			||||||
 | 
					    path_to_vol = {}
 | 
				
			||||||
 | 
					    for volume_name in volume_names.keys():
 | 
				
			||||||
 | 
					        if not is_valid_name(volume_name):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        modal_volume = modal.Volume.persisted(volume_name)
 | 
				
			||||||
 | 
					        path_to_vol[paths[volume_name]] = modal_volume
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					    return path_to_vol
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					vol_name_to_links = config["volume_names"]
 | 
				
			||||||
 | 
					vol_name_to_path = config["volume_paths"]
 | 
				
			||||||
 | 
					callback_url = config["callback_url"]
 | 
				
			||||||
 | 
					callback_body = config["callback_body"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					volumes = create_volumes(vol_name_to_links, vol_name_to_path)
 | 
				
			||||||
 | 
					image = ( 
 | 
				
			||||||
 | 
					   modal.Image.debian_slim().apt_install("wget").pip_install("requests")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(vol_name_to_links)
 | 
				
			||||||
 | 
					print(vol_name_to_path)
 | 
				
			||||||
 | 
					print(volumes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# download config { "download_url": "", "folder_path": ""}
 | 
				
			||||||
 | 
					timeout=5000
 | 
				
			||||||
 | 
					@stub.function(volumes=volumes, image=image, timeout=timeout, gpu=None)
 | 
				
			||||||
 | 
					def download_model(volume_name, download_config):
 | 
				
			||||||
 | 
					    import requests
 | 
				
			||||||
 | 
					    download_url = download_config["download_url"]
 | 
				
			||||||
 | 
					    folder_path = download_config["folder_path"]
 | 
				
			||||||
 | 
					    volume_base_path = vol_name_to_path[volume_name]
 | 
				
			||||||
 | 
					    model_store_path = os.path.join(volume_base_path, folder_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
 | 
				
			||||||
 | 
					    subprocess.run(["ls", "-la", volume_base_path])
 | 
				
			||||||
 | 
					    subprocess.run(["ls", "-la", model_store_path])
 | 
				
			||||||
 | 
					    volumes[volume_base_path].commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    status =  {"status": "success"}
 | 
				
			||||||
 | 
					    requests.post(callback_url, json={**status, **callback_body})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.local_entrypoint()
 | 
				
			||||||
 | 
					def simple_download():
 | 
				
			||||||
 | 
					    import requests
 | 
				
			||||||
 | 
					    print(vol_name_to_links)
 | 
				
			||||||
 | 
					    print([(vol_name, link) for vol_name,link in vol_name_to_links.items()])
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
 | 
				
			||||||
 | 
					    except modal.exception.FunctionTimeoutError as e:
 | 
				
			||||||
 | 
					        status =  {"status": "failed", "error_logs": f"{str(e)}", "timeout": timeout}
 | 
				
			||||||
 | 
					        requests.post(callback_url, json={**status, **callback_body})
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        status =  {"status": "failed", "error_logs": str(e)}
 | 
				
			||||||
 | 
					        requests.post(callback_url, json={**status, **callback_body})
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
							
								
								
									
										17
									
								
								builder/modal-builder/src/volume-builder/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								builder/modal-builder/src/volume-builder/config.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					config = {
 | 
				
			||||||
 | 
					    "volume_names": {
 | 
				
			||||||
 | 
					        "test": {
 | 
				
			||||||
 | 
					            "download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png",
 | 
				
			||||||
 | 
					            "folder_path": "images"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }, 
 | 
				
			||||||
 | 
					    "volume_paths": {
 | 
				
			||||||
 | 
					        "test": "/volumes/something"
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "callback_url": "",
 | 
				
			||||||
 | 
					    "callback_body": {
 | 
				
			||||||
 | 
					        "checkpoint_id": "",
 | 
				
			||||||
 | 
					        "volume_id": "",
 | 
				
			||||||
 | 
					        "folder_path": "images",
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								web/bun.lockb
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								web/bun.lockb
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										64
									
								
								web/drizzle/0035_known_skin.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								web/drizzle/0035_known_skin.sql
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					DO $$ BEGIN
 | 
				
			||||||
 | 
					 CREATE TYPE "model_upload_type" AS ENUM('civitai', 'huggingface', 'other');
 | 
				
			||||||
 | 
					EXCEPTION
 | 
				
			||||||
 | 
					 WHEN duplicate_object THEN null;
 | 
				
			||||||
 | 
					END $$;
 | 
				
			||||||
 | 
					--> statement-breakpoint
 | 
				
			||||||
 | 
					DO $$ BEGIN
 | 
				
			||||||
 | 
					 CREATE TYPE "resource_upload" AS ENUM('started', 'success', 'failed');
 | 
				
			||||||
 | 
					EXCEPTION
 | 
				
			||||||
 | 
					 WHEN duplicate_object THEN null;
 | 
				
			||||||
 | 
					END $$;
 | 
				
			||||||
 | 
					--> statement-breakpoint
 | 
				
			||||||
 | 
					CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
 | 
				
			||||||
 | 
						"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
 | 
				
			||||||
 | 
						"user_id" text,
 | 
				
			||||||
 | 
						"org_id" text,
 | 
				
			||||||
 | 
						"description" text,
 | 
				
			||||||
 | 
						"checkpoint_volume_id" uuid NOT NULL,
 | 
				
			||||||
 | 
						"model_name" text,
 | 
				
			||||||
 | 
						"folder_path" text,
 | 
				
			||||||
 | 
						"civitai_id" text,
 | 
				
			||||||
 | 
						"civitai_version_id" text,
 | 
				
			||||||
 | 
						"civitai_url" text,
 | 
				
			||||||
 | 
						"civitai_download_url" text,
 | 
				
			||||||
 | 
						"civitai_model_response" jsonb,
 | 
				
			||||||
 | 
						"hf_url" text,
 | 
				
			||||||
 | 
						"s3_url" text,
 | 
				
			||||||
 | 
						"client_url" text,
 | 
				
			||||||
 | 
						"is_public" boolean DEFAULT false NOT NULL,
 | 
				
			||||||
 | 
						"status" "resource_upload" DEFAULT 'started' NOT NULL,
 | 
				
			||||||
 | 
						"upload_machine_id" text,
 | 
				
			||||||
 | 
						"upload_type" "model_upload_type" NOT NULL,
 | 
				
			||||||
 | 
						"error_log" text,
 | 
				
			||||||
 | 
						"created_at" timestamp DEFAULT now() NOT NULL,
 | 
				
			||||||
 | 
						"updated_at" timestamp DEFAULT now() NOT NULL
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					--> statement-breakpoint
 | 
				
			||||||
 | 
					CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoint_volume" (
 | 
				
			||||||
 | 
						"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
 | 
				
			||||||
 | 
						"user_id" text,
 | 
				
			||||||
 | 
						"org_id" text,
 | 
				
			||||||
 | 
						"volume_name" text NOT NULL,
 | 
				
			||||||
 | 
						"created_at" timestamp DEFAULT now() NOT NULL,
 | 
				
			||||||
 | 
						"updated_at" timestamp DEFAULT now() NOT NULL,
 | 
				
			||||||
 | 
						"disabled" boolean DEFAULT false NOT NULL
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					--> statement-breakpoint
 | 
				
			||||||
 | 
					DO $$ BEGIN
 | 
				
			||||||
 | 
					 ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
				
			||||||
 | 
					EXCEPTION
 | 
				
			||||||
 | 
					 WHEN duplicate_object THEN null;
 | 
				
			||||||
 | 
					END $$;
 | 
				
			||||||
 | 
					--> statement-breakpoint
 | 
				
			||||||
 | 
					DO $$ BEGIN
 | 
				
			||||||
 | 
					 ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk" FOREIGN KEY ("checkpoint_volume_id") REFERENCES "comfyui_deploy"."checkpoint_volume"("id") ON DELETE cascade ON UPDATE no action;
 | 
				
			||||||
 | 
					EXCEPTION
 | 
				
			||||||
 | 
					 WHEN duplicate_object THEN null;
 | 
				
			||||||
 | 
					END $$;
 | 
				
			||||||
 | 
					--> statement-breakpoint
 | 
				
			||||||
 | 
					DO $$ BEGIN
 | 
				
			||||||
 | 
					 ALTER TABLE "comfyui_deploy"."checkpoint_volume" ADD CONSTRAINT "checkpoint_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
				
			||||||
 | 
					EXCEPTION
 | 
				
			||||||
 | 
					 WHEN duplicate_object THEN null;
 | 
				
			||||||
 | 
					END $$;
 | 
				
			||||||
							
								
								
									
										1090
									
								
								web/drizzle/meta/0035_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1090
									
								
								web/drizzle/meta/0035_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -246,6 +246,13 @@
 | 
				
			|||||||
      "when": 1705902960991,
 | 
					      "when": 1705902960991,
 | 
				
			||||||
      "tag": "0034_even_lady_ursula",
 | 
					      "tag": "0034_even_lady_ursula",
 | 
				
			||||||
      "breakpoints": true
 | 
					      "breakpoints": true
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      "idx": 35,
 | 
				
			||||||
 | 
					      "version": "5",
 | 
				
			||||||
 | 
					      "when": 1706085876992,
 | 
				
			||||||
 | 
					      "tag": "0035_known_skin",
 | 
				
			||||||
 | 
					      "breakpoints": true
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  ]
 | 
					  ]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
							
								
								
									
										53
									
								
								web/src/app/(app)/api/volume-upload/route.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								web/src/app/(app)/api/volume-upload/route.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,53 @@
 | 
				
			|||||||
 | 
					import { parseDataSafe } from "../../../../lib/parseDataSafe";
 | 
				
			||||||
 | 
					import { db } from "@/db/db";
 | 
				
			||||||
 | 
					import { checkpointTable, machinesTable } from "@/db/schema";
 | 
				
			||||||
 | 
					import { eq } from "drizzle-orm";
 | 
				
			||||||
 | 
					import { NextResponse } from "next/server";
 | 
				
			||||||
 | 
					import { z } from "zod";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const Request = z.object({
 | 
				
			||||||
 | 
					  volume_id: z.string(),
 | 
				
			||||||
 | 
					  checkpoint_id: z.string(),
 | 
				
			||||||
 | 
					  folder_path: z.string().optional(),
 | 
				
			||||||
 | 
					  status: z.enum(['success', 'failed']),
 | 
				
			||||||
 | 
					  error_log: z.string().optional(),
 | 
				
			||||||
 | 
					  timeout: z.number().optional(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function POST(request: Request) {
 | 
				
			||||||
 | 
					  const [data, error] = await parseDataSafe(Request, request);
 | 
				
			||||||
 | 
					  if (!data || error) return error;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const { checkpoint_id, error_log, status, folder_path } = data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (status === "success") {
 | 
				
			||||||
 | 
					    await db
 | 
				
			||||||
 | 
					      .update(checkpointTable)
 | 
				
			||||||
 | 
					      .set({
 | 
				
			||||||
 | 
					        status: "success",
 | 
				
			||||||
 | 
					        folder_path 
 | 
				
			||||||
 | 
					        // build_log: build_log,
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					      .where(eq(checkpointTable.id, checkpoint_id));
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // console.log(data);
 | 
				
			||||||
 | 
					    await db
 | 
				
			||||||
 | 
					      .update(checkpointTable)
 | 
				
			||||||
 | 
					      .set({
 | 
				
			||||||
 | 
					        status: "failed",
 | 
				
			||||||
 | 
					        error_log, 
 | 
				
			||||||
 | 
					        // status: "error",
 | 
				
			||||||
 | 
					        // build_log: build_log,
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					      .where(eq(checkpointTable.id, checkpoint_id));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return NextResponse.json(
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      message: "success",
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      status: 200,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										9
									
								
								web/src/app/(app)/storage/loading.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								web/src/app/(app)/storage/loading.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					"use client";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import { LoadingPageWrapper } from "@/components/LoadingWrapper";
 | 
				
			||||||
 | 
					import { usePathname } from "next/navigation";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default function Loading() {
 | 
				
			||||||
 | 
					  const pathName = usePathname();
 | 
				
			||||||
 | 
					  return <LoadingPageWrapper className="h-full" tag={pathName.toLowerCase()} />;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										35
									
								
								web/src/app/(app)/storage/page.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								web/src/app/(app)/storage/page.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					import { setInitialUserData } from "../../../lib/setInitialUserData";
 | 
				
			||||||
 | 
					import { auth } from "@clerk/nextjs";
 | 
				
			||||||
 | 
					import { clerkClient } from "@clerk/nextjs/server";
 | 
				
			||||||
 | 
					import { CheckpointList } from "@/components/CheckpointList";
 | 
				
			||||||
 | 
					import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default function Page() {
 | 
				
			||||||
 | 
					  return <CheckpointListServer />;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async function CheckpointListServer() {
 | 
				
			||||||
 | 
					  const { userId } = auth();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!userId) {
 | 
				
			||||||
 | 
					    return <div>No auth</div>;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const user = await clerkClient.users.getUser(userId);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!user) {
 | 
				
			||||||
 | 
					    await setInitialUserData(userId);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const checkpoints = await getAllUserCheckpoints();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!checkpoints) {
 | 
				
			||||||
 | 
					    return <div>No checkpoints found</div>;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <div className="w-full">
 | 
				
			||||||
 | 
					      <CheckpointList data={checkpoints} />
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										341
									
								
								web/src/components/CheckpointList.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										341
									
								
								web/src/components/CheckpointList.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,341 @@
 | 
				
			|||||||
 | 
					"use client";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import { getRelativeTime } from "../lib/getRelativeTime";
 | 
				
			||||||
 | 
					import { Badge } from "@/components/ui/badge";
 | 
				
			||||||
 | 
					import { Button } from "@/components/ui/button";
 | 
				
			||||||
 | 
					import { Checkbox } from "@/components/ui/checkbox";
 | 
				
			||||||
 | 
					import { InsertModal, UpdateModal } from "./InsertModal";
 | 
				
			||||||
 | 
					import { Input } from "@/components/ui/input";
 | 
				
			||||||
 | 
					import { ScrollArea } from "@/components/ui/scroll-area";
 | 
				
			||||||
 | 
					import {
 | 
				
			||||||
 | 
					  Table,
 | 
				
			||||||
 | 
					  TableBody,
 | 
				
			||||||
 | 
					  TableCell,
 | 
				
			||||||
 | 
					  TableHead,
 | 
				
			||||||
 | 
					  TableHeader,
 | 
				
			||||||
 | 
					  TableRow,
 | 
				
			||||||
 | 
					} from "@/components/ui/table";
 | 
				
			||||||
 | 
					import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
				
			||||||
 | 
					import type {
 | 
				
			||||||
 | 
					  ColumnDef,
 | 
				
			||||||
 | 
					  ColumnFiltersState,
 | 
				
			||||||
 | 
					  SortingState,
 | 
				
			||||||
 | 
					  VisibilityState,
 | 
				
			||||||
 | 
					} from "@tanstack/react-table";
 | 
				
			||||||
 | 
					import {
 | 
				
			||||||
 | 
					  flexRender,
 | 
				
			||||||
 | 
					  getCoreRowModel,
 | 
				
			||||||
 | 
					  getFilteredRowModel,
 | 
				
			||||||
 | 
					  getPaginationRowModel,
 | 
				
			||||||
 | 
					  getSortedRowModel,
 | 
				
			||||||
 | 
					  useReactTable,
 | 
				
			||||||
 | 
					} from "@tanstack/react-table";
 | 
				
			||||||
 | 
					import { ArrowUpDown } from "lucide-react";
 | 
				
			||||||
 | 
					import * as React from "react";
 | 
				
			||||||
 | 
					import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
 | 
				
			||||||
 | 
					import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export type CheckpointItemList = NonNullable<
 | 
				
			||||||
 | 
					  Awaited<ReturnType<typeof getAllUserCheckpoints>>
 | 
				
			||||||
 | 
					>[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    accessorKey: "id",
 | 
				
			||||||
 | 
					    id: "select",
 | 
				
			||||||
 | 
					    header: ({ table }) => (
 | 
				
			||||||
 | 
					      <Checkbox
 | 
				
			||||||
 | 
					        checked={
 | 
				
			||||||
 | 
					          table.getIsAllPageRowsSelected() ||
 | 
				
			||||||
 | 
					          (table.getIsSomePageRowsSelected() && "indeterminate")
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
 | 
				
			||||||
 | 
					        aria-label="Select all"
 | 
				
			||||||
 | 
					      />
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    cell: ({ row }) => (
 | 
				
			||||||
 | 
					      <Checkbox
 | 
				
			||||||
 | 
					        checked={row.getIsSelected()}
 | 
				
			||||||
 | 
					        onCheckedChange={(value) => row.toggleSelected(!!value)}
 | 
				
			||||||
 | 
					        aria-label="Select row"
 | 
				
			||||||
 | 
					      />
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    enableSorting: false,
 | 
				
			||||||
 | 
					    enableHiding: false,
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    accessorKey: "model_name",
 | 
				
			||||||
 | 
					    header: ({ column }) => {
 | 
				
			||||||
 | 
					      return (
 | 
				
			||||||
 | 
					        <button
 | 
				
			||||||
 | 
					          className="flex items-center hover:underline"
 | 
				
			||||||
 | 
					          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
				
			||||||
 | 
					        >
 | 
				
			||||||
 | 
					          Model Name
 | 
				
			||||||
 | 
					          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
				
			||||||
 | 
					        </button>
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    cell: ({ row }) => {
 | 
				
			||||||
 | 
					      const checkpoint = row.original;
 | 
				
			||||||
 | 
					      console.log(checkpoint);
 | 
				
			||||||
 | 
					      return (
 | 
				
			||||||
 | 
					        <a
 | 
				
			||||||
 | 
					          className="hover:underline flex gap-2"
 | 
				
			||||||
 | 
					          href={`/storage/${checkpoint.id}`} // TODO
 | 
				
			||||||
 | 
					        >
 | 
				
			||||||
 | 
					          <span className="truncate max-w-[200px]">
 | 
				
			||||||
 | 
					            {row.original.model_name}
 | 
				
			||||||
 | 
					          </span>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          {checkpoint.is_public ? (
 | 
				
			||||||
 | 
					            <Badge variant="orange">Public</Badge>
 | 
				
			||||||
 | 
					          ) : (
 | 
				
			||||||
 | 
					            <Badge variant="orange">Private</Badge>
 | 
				
			||||||
 | 
					          )}
 | 
				
			||||||
 | 
					        </a>
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    accessorKey: "status",
 | 
				
			||||||
 | 
					    header: ({ column }) => {
 | 
				
			||||||
 | 
					      return (
 | 
				
			||||||
 | 
					        <button
 | 
				
			||||||
 | 
					          className="flex items-center hover:underline"
 | 
				
			||||||
 | 
					          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
				
			||||||
 | 
					        >
 | 
				
			||||||
 | 
					          Status
 | 
				
			||||||
 | 
					          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
				
			||||||
 | 
					        </button>
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    cell: ({ row }) => {
 | 
				
			||||||
 | 
					      return (
 | 
				
			||||||
 | 
					        <Badge variant={row.original.status === "failed" ? "red" : "green"}>
 | 
				
			||||||
 | 
					          {row.original.status}
 | 
				
			||||||
 | 
					        </Badge>
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    accessorKey: "upload_type",
 | 
				
			||||||
 | 
					    header: ({ column }) => {
 | 
				
			||||||
 | 
					      return (
 | 
				
			||||||
 | 
					        <button
 | 
				
			||||||
 | 
					          className="flex items-center hover:underline"
 | 
				
			||||||
 | 
					          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
				
			||||||
 | 
					        >
 | 
				
			||||||
 | 
					          Source
 | 
				
			||||||
 | 
					          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
				
			||||||
 | 
					        </button>
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    cell: ({ row }) => {
 | 
				
			||||||
 | 
					      return <Badge variant="cyan">{row.original.upload_type}</Badge>;
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    accessorKey: "date",
 | 
				
			||||||
 | 
					    sortingFn: "datetime",
 | 
				
			||||||
 | 
					    enableSorting: true,
 | 
				
			||||||
 | 
					    header: ({ column }) => {
 | 
				
			||||||
 | 
					      return (
 | 
				
			||||||
 | 
					        <button
 | 
				
			||||||
 | 
					          className="w-full flex items-center justify-end hover:underline truncate"
 | 
				
			||||||
 | 
					          // variant="ghost"
 | 
				
			||||||
 | 
					          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
				
			||||||
 | 
					        >
 | 
				
			||||||
 | 
					          Update Date
 | 
				
			||||||
 | 
					          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
				
			||||||
 | 
					        </button>
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    cell: ({ row }) => (
 | 
				
			||||||
 | 
					      <div className="w-full capitalize text-right truncate">
 | 
				
			||||||
 | 
					        {getRelativeTime(row.original.updated_at)}
 | 
				
			||||||
 | 
					      </div>
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  // {
 | 
				
			||||||
 | 
					  //   id: "actions",
 | 
				
			||||||
 | 
					  //   enableHiding: false,
 | 
				
			||||||
 | 
					  //   cell: ({ row }) => {
 | 
				
			||||||
 | 
					  //     const checkpoint = row.original;
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  //     return (
 | 
				
			||||||
 | 
					  //       <DropdownMenu>
 | 
				
			||||||
 | 
					  //         <DropdownMenuTrigger asChild>
 | 
				
			||||||
 | 
					  //           <Button variant="ghost" className="h-8 w-8 p-0">
 | 
				
			||||||
 | 
					  //             <span className="sr-only">Open menu</span>
 | 
				
			||||||
 | 
					  //             <MoreHorizontal className="h-4 w-4" />
 | 
				
			||||||
 | 
					  //           </Button>
 | 
				
			||||||
 | 
					  //         </DropdownMenuTrigger>
 | 
				
			||||||
 | 
					  //         <DropdownMenuContent align="end">
 | 
				
			||||||
 | 
					  //           <DropdownMenuLabel>Actions</DropdownMenuLabel>
 | 
				
			||||||
 | 
					  //           <DropdownMenuItem
 | 
				
			||||||
 | 
					  //             className="text-destructive"
 | 
				
			||||||
 | 
					  //             onClick={() => {
 | 
				
			||||||
 | 
					  //               deleteWorkflow(checkpoint.id);
 | 
				
			||||||
 | 
					  //             }}
 | 
				
			||||||
 | 
					  //           >
 | 
				
			||||||
 | 
					  //             Delete Workflow
 | 
				
			||||||
 | 
					  //           </DropdownMenuItem>
 | 
				
			||||||
 | 
					  //         </DropdownMenuContent>
 | 
				
			||||||
 | 
					  //       </DropdownMenu>
 | 
				
			||||||
 | 
					  //     );
 | 
				
			||||||
 | 
					  //   },
 | 
				
			||||||
 | 
					  // },
 | 
				
			||||||
 | 
					];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
				
			||||||
 | 
					  const [sorting, setSorting] = React.useState<SortingState>([]);
 | 
				
			||||||
 | 
					  const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
 | 
				
			||||||
 | 
					    []
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					  const [columnVisibility, setColumnVisibility] =
 | 
				
			||||||
 | 
					    React.useState<VisibilityState>({});
 | 
				
			||||||
 | 
					  const [rowSelection, setRowSelection] = React.useState({});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const table = useReactTable({
 | 
				
			||||||
 | 
					    data,
 | 
				
			||||||
 | 
					    columns,
 | 
				
			||||||
 | 
					    onSortingChange: setSorting,
 | 
				
			||||||
 | 
					    onColumnFiltersChange: setColumnFilters,
 | 
				
			||||||
 | 
					    getCoreRowModel: getCoreRowModel(),
 | 
				
			||||||
 | 
					    getPaginationRowModel: getPaginationRowModel(),
 | 
				
			||||||
 | 
					    getSortedRowModel: getSortedRowModel(),
 | 
				
			||||||
 | 
					    getFilteredRowModel: getFilteredRowModel(),
 | 
				
			||||||
 | 
					    onColumnVisibilityChange: setColumnVisibility,
 | 
				
			||||||
 | 
					    onRowSelectionChange: setRowSelection,
 | 
				
			||||||
 | 
					    state: {
 | 
				
			||||||
 | 
					      sorting,
 | 
				
			||||||
 | 
					      columnFilters,
 | 
				
			||||||
 | 
					      columnVisibility,
 | 
				
			||||||
 | 
					      rowSelection,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <div className="grid grid-rows-[auto,1fr,auto] h-full">
 | 
				
			||||||
 | 
					      <div className="flex flex-row w-full items-center py-4">
 | 
				
			||||||
 | 
					        <Input
 | 
				
			||||||
 | 
					          placeholder="Filter workflows..."
 | 
				
			||||||
 | 
					          value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
 | 
				
			||||||
 | 
					          onChange={(event) =>
 | 
				
			||||||
 | 
					            table.getColumn("name")?.setFilterValue(event.target.value)
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          className="max-w-sm"
 | 
				
			||||||
 | 
					        />
 | 
				
			||||||
 | 
					        <div className="ml-auto flex gap-2">
 | 
				
			||||||
 | 
					          <InsertModal
 | 
				
			||||||
 | 
					            dialogClassName="sm:max-w-[600px]"
 | 
				
			||||||
 | 
					            disabled={
 | 
				
			||||||
 | 
					              false
 | 
				
			||||||
 | 
					              // TODO: limitations based on plan
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            tooltip={"Add models using their civitai url!"}
 | 
				
			||||||
 | 
					            title="Civitai Checkpoint"
 | 
				
			||||||
 | 
					            description="Pick a model from civitai"
 | 
				
			||||||
 | 
					            serverAction={addCivitaiCheckpoint}
 | 
				
			||||||
 | 
					            formSchema={addCivitaiCheckpointSchema}
 | 
				
			||||||
 | 
					            fieldConfig={{
 | 
				
			||||||
 | 
					              civitai_url: {
 | 
				
			||||||
 | 
					                fieldType: "fallback",
 | 
				
			||||||
 | 
					                // fieldType: "fallback",
 | 
				
			||||||
 | 
					                inputProps: { required: true },
 | 
				
			||||||
 | 
					                description: (
 | 
				
			||||||
 | 
					                  <>
 | 
				
			||||||
 | 
					                    Pick a checkpoint from{" "}
 | 
				
			||||||
 | 
					                    <a
 | 
				
			||||||
 | 
					                      href="https://www.civitai.com/models"
 | 
				
			||||||
 | 
					                      target="_blank"
 | 
				
			||||||
 | 
					                      className="underline text-blue-600 hover:text-blue-800 visited:text-purple-600"
 | 
				
			||||||
 | 
					                    >
 | 
				
			||||||
 | 
					                      civitai.com
 | 
				
			||||||
 | 
					                    </a>{" "}
 | 
				
			||||||
 | 
					                    and place it's url here
 | 
				
			||||||
 | 
					                  </>
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					              },
 | 
				
			||||||
 | 
					            }}
 | 
				
			||||||
 | 
					          />
 | 
				
			||||||
 | 
					        </div>
 | 
				
			||||||
 | 
					      </div>
 | 
				
			||||||
 | 
					      <ScrollArea className="h-full w-full rounded-md border">
 | 
				
			||||||
 | 
					        <Table>
 | 
				
			||||||
 | 
					          <TableHeader className="bg-background top-0 sticky">
 | 
				
			||||||
 | 
					            {table.getHeaderGroups().map((headerGroup) => (
 | 
				
			||||||
 | 
					              <TableRow key={headerGroup.id}>
 | 
				
			||||||
 | 
					                {headerGroup.headers.map((header) => {
 | 
				
			||||||
 | 
					                  return (
 | 
				
			||||||
 | 
					                    <TableHead key={header.id}>
 | 
				
			||||||
 | 
					                      {header.isPlaceholder
 | 
				
			||||||
 | 
					                        ? null
 | 
				
			||||||
 | 
					                        : flexRender(
 | 
				
			||||||
 | 
					                            header.column.columnDef.header,
 | 
				
			||||||
 | 
					                            header.getContext()
 | 
				
			||||||
 | 
					                          )}
 | 
				
			||||||
 | 
					                    </TableHead>
 | 
				
			||||||
 | 
					                  );
 | 
				
			||||||
 | 
					                })}
 | 
				
			||||||
 | 
					              </TableRow>
 | 
				
			||||||
 | 
					            ))}
 | 
				
			||||||
 | 
					          </TableHeader>
 | 
				
			||||||
 | 
					          <TableBody>
 | 
				
			||||||
 | 
					            {table.getRowModel().rows?.length ? (
 | 
				
			||||||
 | 
					              table.getRowModel().rows.map((row) => (
 | 
				
			||||||
 | 
					                <TableRow
 | 
				
			||||||
 | 
					                  key={row.id}
 | 
				
			||||||
 | 
					                  data-state={row.getIsSelected() && "selected"}
 | 
				
			||||||
 | 
					                >
 | 
				
			||||||
 | 
					                  {row.getVisibleCells().map((cell) => (
 | 
				
			||||||
 | 
					                    <TableCell key={cell.id}>
 | 
				
			||||||
 | 
					                      {flexRender(
 | 
				
			||||||
 | 
					                        cell.column.columnDef.cell,
 | 
				
			||||||
 | 
					                        cell.getContext()
 | 
				
			||||||
 | 
					                      )}
 | 
				
			||||||
 | 
					                    </TableCell>
 | 
				
			||||||
 | 
					                  ))}
 | 
				
			||||||
 | 
					                </TableRow>
 | 
				
			||||||
 | 
					              ))
 | 
				
			||||||
 | 
					            ) : (
 | 
				
			||||||
 | 
					              <TableRow>
 | 
				
			||||||
 | 
					                <TableCell
 | 
				
			||||||
 | 
					                  colSpan={columns.length}
 | 
				
			||||||
 | 
					                  className="h-24 text-center"
 | 
				
			||||||
 | 
					                >
 | 
				
			||||||
 | 
					                  No results.
 | 
				
			||||||
 | 
					                </TableCell>
 | 
				
			||||||
 | 
					              </TableRow>
 | 
				
			||||||
 | 
					            )}
 | 
				
			||||||
 | 
					          </TableBody>
 | 
				
			||||||
 | 
					        </Table>
 | 
				
			||||||
 | 
					      </ScrollArea>
 | 
				
			||||||
 | 
					      <div className="flex flex-row items-center justify-end space-x-2 py-4">
 | 
				
			||||||
 | 
					        <div className="flex-1 text-sm text-muted-foreground">
 | 
				
			||||||
 | 
					          {table.getFilteredSelectedRowModel().rows.length} of{" "}
 | 
				
			||||||
 | 
					          {table.getFilteredRowModel().rows.length} row(s) selected.
 | 
				
			||||||
 | 
					        </div>
 | 
				
			||||||
 | 
					        <div className="space-x-2">
 | 
				
			||||||
 | 
					          <Button
 | 
				
			||||||
 | 
					            variant="outline"
 | 
				
			||||||
 | 
					            size="sm"
 | 
				
			||||||
 | 
					            onClick={() => table.previousPage()}
 | 
				
			||||||
 | 
					            disabled={!table.getCanPreviousPage()}
 | 
				
			||||||
 | 
					          >
 | 
				
			||||||
 | 
					            Previous
 | 
				
			||||||
 | 
					          </Button>
 | 
				
			||||||
 | 
					          <Button
 | 
				
			||||||
 | 
					            variant="outline"
 | 
				
			||||||
 | 
					            size="sm"
 | 
				
			||||||
 | 
					            onClick={() => table.nextPage()}
 | 
				
			||||||
 | 
					            disabled={!table.getCanNextPage()}
 | 
				
			||||||
 | 
					          >
 | 
				
			||||||
 | 
					            Next
 | 
				
			||||||
 | 
					          </Button>
 | 
				
			||||||
 | 
					        </div>
 | 
				
			||||||
 | 
					      </div>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -34,6 +34,10 @@ export function NavbarMenu({ className }: { className?: string }) {
 | 
				
			|||||||
      name: "API Keys",
 | 
					      name: "API Keys",
 | 
				
			||||||
      path: "/api-keys",
 | 
					      path: "/api-keys",
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      name: "Storage",
 | 
				
			||||||
 | 
					      path: "/storage",
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
  ];
 | 
					  ];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return (
 | 
					  return (
 | 
				
			||||||
@ -42,9 +46,9 @@ export function NavbarMenu({ className }: { className?: string }) {
 | 
				
			|||||||
      {isDesktop && (
 | 
					      {isDesktop && (
 | 
				
			||||||
        <Tabs
 | 
					        <Tabs
 | 
				
			||||||
          defaultValue={pathname}
 | 
					          defaultValue={pathname}
 | 
				
			||||||
          className="w-[300px] flex pointer-events-auto"
 | 
					          className="w-[400px] flex pointer-events-auto"
 | 
				
			||||||
        >
 | 
					        >
 | 
				
			||||||
          <TabsList className="grid w-full grid-cols-3">
 | 
					          <TabsList className="grid w-full grid-cols-4">
 | 
				
			||||||
            {pages.map((page) => (
 | 
					            {pages.map((page) => (
 | 
				
			||||||
              <TabsTrigger
 | 
					              <TabsTrigger
 | 
				
			||||||
                key={page.name}
 | 
					                key={page.name}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										86
									
								
								web/src/components/custom-form/checkpoint-input.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								web/src/components/custom-form/checkpoint-input.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					// NOTE: this is work in progress
 | 
				
			||||||
 | 
					import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
 | 
				
			||||||
 | 
					import { FormControl, FormItem, FormLabel } from "../ui/form";
 | 
				
			||||||
 | 
					import { LoadingIcon } from "@/components/LoadingIcon";
 | 
				
			||||||
 | 
					import * as React from "react";
 | 
				
			||||||
 | 
					import AutoFormInput from "../ui/auto-form/fields/input";
 | 
				
			||||||
 | 
					import { useDebouncedCallback } from "use-debounce";
 | 
				
			||||||
 | 
					import { CivitaiModelResponse } from "@/types/civitai";
 | 
				
			||||||
 | 
					import { z } from "zod";
 | 
				
			||||||
 | 
					import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					function getUrl(civitai_url: string) {
 | 
				
			||||||
 | 
					  // expect to be a URL to be https://civitai.com/models/36520
 | 
				
			||||||
 | 
					  // possiblity with slugged name and query-param modelVersionId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const baseUrl = "https://civitai.com/api/v1/models/";
 | 
				
			||||||
 | 
					  const url = new URL(civitai_url);
 | 
				
			||||||
 | 
					  const pathSegments = url.pathname.split("/");
 | 
				
			||||||
 | 
					  const modelId = pathSegments[pathSegments.indexOf("models") + 1];
 | 
				
			||||||
 | 
					  const modelVersionId = url.searchParams.get("modelVersionId");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return { url: baseUrl + modelId, modelVersionId };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default function AutoFormCheckpointInput(
 | 
				
			||||||
 | 
					  props: AutoFormInputComponentProps
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					  const [loading, setLoading] = React.useState(false);
 | 
				
			||||||
 | 
					  const [modelRes, setModelRes] =
 | 
				
			||||||
 | 
					    React.useState<z.infer<typeof CivitaiModelResponse>>();
 | 
				
			||||||
 | 
					  const [modelVersionid, setModelVersionId] = React.useState<string | null>();
 | 
				
			||||||
 | 
					  const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const handleSearch = useDebouncedCallback((search) => {
 | 
				
			||||||
 | 
					    const validationResult =
 | 
				
			||||||
 | 
					      insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search);
 | 
				
			||||||
 | 
					    if (!validationResult.success) {
 | 
				
			||||||
 | 
					      console.error(validationResult.error);
 | 
				
			||||||
 | 
					      // Optionally set an error state here
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    setLoading(true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const controller = new AbortController();
 | 
				
			||||||
 | 
					    const { url, modelVersionId: versionId } = getUrl(search);
 | 
				
			||||||
 | 
					    setModelVersionId(versionId);
 | 
				
			||||||
 | 
					    fetch(url, {
 | 
				
			||||||
 | 
					      signal: controller.signal,
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					      .then((x) => x.json())
 | 
				
			||||||
 | 
					      .then((a) => {
 | 
				
			||||||
 | 
					        const res = CivitaiModelResponse.parse(a);
 | 
				
			||||||
 | 
					        console.log(a);
 | 
				
			||||||
 | 
					        console.log(res);
 | 
				
			||||||
 | 
					        setModelRes(res);
 | 
				
			||||||
 | 
					        setLoading(false);
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return () => {
 | 
				
			||||||
 | 
					      controller.abort();
 | 
				
			||||||
 | 
					      setLoading(false);
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }, 300);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const modifiedField = {
 | 
				
			||||||
 | 
					    ...fieldProps,
 | 
				
			||||||
 | 
					    // onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
 | 
				
			||||||
 | 
					    //   handleSearch(event.target.value);
 | 
				
			||||||
 | 
					    // },
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <FormItem>
 | 
				
			||||||
 | 
					      {fieldConfigItem.inputProps?.showLabel && (
 | 
				
			||||||
 | 
					        <FormLabel>
 | 
				
			||||||
 | 
					          {label}
 | 
				
			||||||
 | 
					          {isRequired && <span className="text-destructive">*</span>}
 | 
				
			||||||
 | 
					        </FormLabel>
 | 
				
			||||||
 | 
					      )}
 | 
				
			||||||
 | 
					      <FormControl>
 | 
				
			||||||
 | 
					        <AutoFormInput {...props} fieldProps={modifiedField} />
 | 
				
			||||||
 | 
					      </FormControl>
 | 
				
			||||||
 | 
					    </FormItem>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,3 +1,4 @@
 | 
				
			|||||||
 | 
					import { CivitaiModelResponse } from "@/types/civitai";
 | 
				
			||||||
import { type InferSelectModel, relations } from "drizzle-orm";
 | 
					import { type InferSelectModel, relations } from "drizzle-orm";
 | 
				
			||||||
import {
 | 
					import {
 | 
				
			||||||
  boolean,
 | 
					  boolean,
 | 
				
			||||||
@ -90,7 +91,7 @@ export const workflowVersionRelations = relations(
 | 
				
			|||||||
      fields: [workflowVersionTable.workflow_id],
 | 
					      fields: [workflowVersionTable.workflow_id],
 | 
				
			||||||
      references: [workflowTable.id],
 | 
					      references: [workflowTable.id],
 | 
				
			||||||
    }),
 | 
					    }),
 | 
				
			||||||
  }),
 | 
					  })
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export const workflowRunStatus = pgEnum("workflow_run_status", [
 | 
					export const workflowRunStatus = pgEnum("workflow_run_status", [
 | 
				
			||||||
@ -139,7 +140,7 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
 | 
				
			|||||||
    () => workflowVersionTable.id,
 | 
					    () => workflowVersionTable.id,
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
      onDelete: "set null",
 | 
					      onDelete: "set null",
 | 
				
			||||||
    },
 | 
					    }
 | 
				
			||||||
  ),
 | 
					  ),
 | 
				
			||||||
  workflow_inputs:
 | 
					  workflow_inputs:
 | 
				
			||||||
    jsonb("workflow_inputs").$type<Record<string, string | number>>(),
 | 
					    jsonb("workflow_inputs").$type<Record<string, string | number>>(),
 | 
				
			||||||
@ -175,7 +176,7 @@ export const workflowRunRelations = relations(
 | 
				
			|||||||
      fields: [workflowRunsTable.workflow_id],
 | 
					      fields: [workflowRunsTable.workflow_id],
 | 
				
			||||||
      references: [workflowTable.id],
 | 
					      references: [workflowTable.id],
 | 
				
			||||||
    }),
 | 
					    }),
 | 
				
			||||||
  }),
 | 
					  })
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// We still want to keep the workflow run record.
 | 
					// We still want to keep the workflow run record.
 | 
				
			||||||
@ -199,7 +200,7 @@ export const workflowOutputRelations = relations(
 | 
				
			|||||||
      fields: [workflowRunOutputs.run_id],
 | 
					      fields: [workflowRunOutputs.run_id],
 | 
				
			||||||
      references: [workflowRunsTable.id],
 | 
					      references: [workflowRunsTable.id],
 | 
				
			||||||
    }),
 | 
					    }),
 | 
				
			||||||
  }),
 | 
					  })
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// when user delete, also delete all the workflow versions
 | 
					// when user delete, also delete all the workflow versions
 | 
				
			||||||
@ -232,7 +233,7 @@ export const snapshotType = z.object({
 | 
				
			|||||||
    z.object({
 | 
					    z.object({
 | 
				
			||||||
      hash: z.string(),
 | 
					      hash: z.string(),
 | 
				
			||||||
      disabled: z.boolean(),
 | 
					      disabled: z.boolean(),
 | 
				
			||||||
    }),
 | 
					    })
 | 
				
			||||||
  ),
 | 
					  ),
 | 
				
			||||||
  file_custom_nodes: z.array(z.any()),
 | 
					  file_custom_nodes: z.array(z.any()),
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
@ -247,7 +248,7 @@ export const showcaseMedia = z.array(
 | 
				
			|||||||
  z.object({
 | 
					  z.object({
 | 
				
			||||||
    url: z.string(),
 | 
					    url: z.string(),
 | 
				
			||||||
    isCover: z.boolean().default(false),
 | 
					    isCover: z.boolean().default(false),
 | 
				
			||||||
  }),
 | 
					  })
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export const showcaseMediaNullable = z
 | 
					export const showcaseMediaNullable = z
 | 
				
			||||||
@ -255,7 +256,7 @@ export const showcaseMediaNullable = z
 | 
				
			|||||||
    z.object({
 | 
					    z.object({
 | 
				
			||||||
      url: z.string(),
 | 
					      url: z.string(),
 | 
				
			||||||
      isCover: z.boolean().default(false),
 | 
					      isCover: z.boolean().default(false),
 | 
				
			||||||
    }),
 | 
					    })
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  .nullable();
 | 
					  .nullable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -344,8 +345,109 @@ export const authRequestsTable = dbSchema.table("auth_requests", {
 | 
				
			|||||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
					  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const resourceUpload = pgEnum("resource_upload", [
 | 
				
			||||||
 | 
					  "started",
 | 
				
			||||||
 | 
					  "success",
 | 
				
			||||||
 | 
					  "failed",
 | 
				
			||||||
 | 
					]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const modelUploadType = pgEnum("model_upload_type", [
 | 
				
			||||||
 | 
					  "civitai",
 | 
				
			||||||
 | 
					  "huggingface",
 | 
				
			||||||
 | 
					  "other",
 | 
				
			||||||
 | 
					]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const checkpointTable = dbSchema.table("checkpoints", {
 | 
				
			||||||
 | 
					  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
				
			||||||
 | 
					  user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
 | 
				
			||||||
 | 
					  org_id: text("org_id"),
 | 
				
			||||||
 | 
					  description: text("description"),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  checkpoint_volume_id: uuid("checkpoint_volume_id")
 | 
				
			||||||
 | 
					    .notNull()
 | 
				
			||||||
 | 
					    .references(() => checkpointVolumeTable.id, {
 | 
				
			||||||
 | 
					      onDelete: "cascade",
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					    .notNull(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  model_name: text("model_name"),
 | 
				
			||||||
 | 
					  folder_path: text("folder_path"), // in volume
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  civitai_id: text("civitai_id"),
 | 
				
			||||||
 | 
					  civitai_version_id: text("civitai_version_id"),
 | 
				
			||||||
 | 
					  civitai_url: text("civitai_url"),
 | 
				
			||||||
 | 
					  civitai_download_url: text("civitai_download_url"),
 | 
				
			||||||
 | 
					  civitai_model_response: jsonb("civitai_model_response").$type<
 | 
				
			||||||
 | 
					    z.infer<typeof CivitaiModelResponse>
 | 
				
			||||||
 | 
					  >(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  hf_url: text("hf_url"),
 | 
				
			||||||
 | 
					  s3_url: text("s3_url"),
 | 
				
			||||||
 | 
					  user_url: text("client_url"),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  is_public: boolean("is_public").notNull().default(false),
 | 
				
			||||||
 | 
					  status: resourceUpload("status").notNull().default("started"),
 | 
				
			||||||
 | 
					  upload_machine_id: text("upload_machine_id"),
 | 
				
			||||||
 | 
					  upload_type: modelUploadType("upload_type").notNull(),
 | 
				
			||||||
 | 
					  error_log: text("error_log"),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
				
			||||||
 | 
					  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const insertCivitaiCheckpointSchema = createInsertSchema(
 | 
				
			||||||
 | 
					  checkpointTable,
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    civitai_url: (schema) =>
 | 
				
			||||||
 | 
					      schema.civitai_url
 | 
				
			||||||
 | 
					        .trim()
 | 
				
			||||||
 | 
					        .url({ message: "URL required" })
 | 
				
			||||||
 | 
					        .includes("civitai.com/models", {
 | 
				
			||||||
 | 
					          message: "civitai.com/models link required",
 | 
				
			||||||
 | 
					        }),
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
 | 
				
			||||||
 | 
					  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
				
			||||||
 | 
					  user_id: text("user_id").references(() => usersTable.id, {
 | 
				
			||||||
 | 
					    // onDelete: "cascade",
 | 
				
			||||||
 | 
					  }),
 | 
				
			||||||
 | 
					  org_id: text("org_id"),
 | 
				
			||||||
 | 
					  volume_name: text("volume_name").notNull(),
 | 
				
			||||||
 | 
					  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
				
			||||||
 | 
					  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
				
			||||||
 | 
					  disabled: boolean("disabled").default(false).notNull(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
 | 
				
			||||||
 | 
					  user: one(usersTable, {
 | 
				
			||||||
 | 
					    fields: [checkpointTable.user_id],
 | 
				
			||||||
 | 
					    references: [usersTable.id],
 | 
				
			||||||
 | 
					  }),
 | 
				
			||||||
 | 
					  volume: one(checkpointVolumeTable, {
 | 
				
			||||||
 | 
					    fields: [checkpointTable.checkpoint_volume_id],
 | 
				
			||||||
 | 
					    references: [checkpointVolumeTable.id],
 | 
				
			||||||
 | 
					  }),
 | 
				
			||||||
 | 
					}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const checkpointVolumeRelations = relations(
 | 
				
			||||||
 | 
					  checkpointVolumeTable,
 | 
				
			||||||
 | 
					  ({ many, one }) => ({
 | 
				
			||||||
 | 
					    checkpoint: many(checkpointTable),
 | 
				
			||||||
 | 
					    user: one(usersTable, {
 | 
				
			||||||
 | 
					      fields: [checkpointVolumeTable.user_id],
 | 
				
			||||||
 | 
					      references: [usersTable.id],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export type UserType = InferSelectModel<typeof usersTable>;
 | 
					export type UserType = InferSelectModel<typeof usersTable>;
 | 
				
			||||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
 | 
					export type WorkflowType = InferSelectModel<typeof workflowTable>;
 | 
				
			||||||
export type MachineType = InferSelectModel<typeof machinesTable>;
 | 
					export type MachineType = InferSelectModel<typeof machinesTable>;
 | 
				
			||||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
 | 
					export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
 | 
				
			||||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
 | 
					export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
 | 
				
			||||||
 | 
					export type CheckpointType = InferSelectModel<typeof checkpointTable>;
 | 
				
			||||||
 | 
					export type CheckpointVolumeType = InferSelectModel<
 | 
				
			||||||
 | 
					  typeof checkpointVolumeTable
 | 
				
			||||||
 | 
					>;
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										5
									
								
								web/src/server/addCheckpointSchema.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								web/src/server/addCheckpointSchema.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
 | 
				
			||||||
 | 
					  civitai_url: true,
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
							
								
								
									
										252
									
								
								web/src/server/curdCheckpoint.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										252
									
								
								web/src/server/curdCheckpoint.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,252 @@
 | 
				
			|||||||
 | 
					"use server";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import { auth } from "@clerk/nextjs";
 | 
				
			||||||
 | 
					import {
 | 
				
			||||||
 | 
					  checkpointTable,
 | 
				
			||||||
 | 
					  CheckpointType,
 | 
				
			||||||
 | 
					  checkpointVolumeTable,
 | 
				
			||||||
 | 
					  CheckpointVolumeType,
 | 
				
			||||||
 | 
					} from "@/db/schema";
 | 
				
			||||||
 | 
					import { withServerPromise } from "./withServerPromise";
 | 
				
			||||||
 | 
					import { redirect } from "next/navigation";
 | 
				
			||||||
 | 
					import { db } from "@/db/db";
 | 
				
			||||||
 | 
					import type { z } from "zod";
 | 
				
			||||||
 | 
					import { headers } from "next/headers";
 | 
				
			||||||
 | 
					import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
 | 
				
			||||||
 | 
					import { and, eq, isNull } from "drizzle-orm";
 | 
				
			||||||
 | 
					import { CivitaiModelResponse } from "@/types/civitai";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function getCheckpoints() {
 | 
				
			||||||
 | 
					  const { userId, orgId } = auth();
 | 
				
			||||||
 | 
					  if (!userId) throw new Error("No user id");
 | 
				
			||||||
 | 
					  const checkpoints = await db
 | 
				
			||||||
 | 
					    .select()
 | 
				
			||||||
 | 
					    .from(checkpointTable)
 | 
				
			||||||
 | 
					    .where(
 | 
				
			||||||
 | 
					      orgId
 | 
				
			||||||
 | 
					        ? eq(checkpointTable.org_id, orgId)
 | 
				
			||||||
 | 
					        // make sure org_id is null
 | 
				
			||||||
 | 
					        : and(
 | 
				
			||||||
 | 
					          eq(checkpointTable.user_id, userId),
 | 
				
			||||||
 | 
					          isNull(checkpointTable.org_id),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					  return checkpoints;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function getCheckpointById(id: string) {
 | 
				
			||||||
 | 
					  const { userId, orgId } = auth();
 | 
				
			||||||
 | 
					  if (!userId) throw new Error("No user id");
 | 
				
			||||||
 | 
					  const checkpoint = await db
 | 
				
			||||||
 | 
					    .select()
 | 
				
			||||||
 | 
					    .from(checkpointTable)
 | 
				
			||||||
 | 
					    .where(
 | 
				
			||||||
 | 
					      and(
 | 
				
			||||||
 | 
					        orgId ? eq(checkpointTable.org_id, orgId) : and(
 | 
				
			||||||
 | 
					          eq(checkpointTable.user_id, userId),
 | 
				
			||||||
 | 
					          isNull(checkpointTable.org_id),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        eq(checkpointTable.id, id),
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					  return checkpoint[0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function getCheckpointVolumes() {
 | 
				
			||||||
 | 
					  const { userId, orgId } = auth();
 | 
				
			||||||
 | 
					  if (!userId) throw new Error("No user id");
 | 
				
			||||||
 | 
					  const volume = await db
 | 
				
			||||||
 | 
					    .select()
 | 
				
			||||||
 | 
					    .from(checkpointVolumeTable)
 | 
				
			||||||
 | 
					    .where(
 | 
				
			||||||
 | 
					      and(
 | 
				
			||||||
 | 
					        orgId
 | 
				
			||||||
 | 
					          ? eq(checkpointVolumeTable.org_id, orgId)
 | 
				
			||||||
 | 
					          // make sure org_id is null
 | 
				
			||||||
 | 
					          : and(
 | 
				
			||||||
 | 
					            eq(checkpointVolumeTable.user_id, userId),
 | 
				
			||||||
 | 
					            isNull(checkpointVolumeTable.org_id),
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					        eq(checkpointVolumeTable.disabled, false),
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					  return volume;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function retrieveCheckpointVolumes() {
 | 
				
			||||||
 | 
					  let volumes = await getCheckpointVolumes()
 | 
				
			||||||
 | 
					  if (volumes.length === 0) {
 | 
				
			||||||
 | 
					    // create volume if not already created
 | 
				
			||||||
 | 
					    volumes = await addCheckpointVolume()
 | 
				
			||||||
 | 
					  } 
 | 
				
			||||||
 | 
					  return volumes
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function addCheckpointVolume() {
 | 
				
			||||||
 | 
					  const { userId, orgId } = auth();
 | 
				
			||||||
 | 
					  if (!userId) throw new Error("No user id");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Insert the new checkpointVolume into the checkpointVolumeTable
 | 
				
			||||||
 | 
					  const insertedVolume = await db
 | 
				
			||||||
 | 
					    .insert(checkpointVolumeTable)
 | 
				
			||||||
 | 
					    .values({
 | 
				
			||||||
 | 
					      user_id: userId,
 | 
				
			||||||
 | 
					      org_id: orgId,
 | 
				
			||||||
 | 
					      volume_name: `checkpoints_${userId}`,
 | 
				
			||||||
 | 
					      // created_at and updated_at will be set to current timestamp by default
 | 
				
			||||||
 | 
					      disabled: false, // Default value
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					    .returning(); // Returns the inserted row
 | 
				
			||||||
 | 
					return insertedVolume;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					function getUrl(civitai_url: string) {
 | 
				
			||||||
 | 
					  // expect to be a URL to be https://civitai.com/models/36520
 | 
				
			||||||
 | 
					  // possiblity with slugged name and query-param modelVersionId
 | 
				
			||||||
 | 
					  const baseUrl = "https://civitai.com/api/v1/models/";
 | 
				
			||||||
 | 
					  const url = new URL(civitai_url);
 | 
				
			||||||
 | 
					  const pathSegments = url.pathname.split("/");
 | 
				
			||||||
 | 
					  const modelId = pathSegments[pathSegments.indexOf("models") + 1];
 | 
				
			||||||
 | 
					  const modelVersionId = url.searchParams.get("modelVersionId");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return { url: baseUrl + modelId, modelVersionId };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const addCivitaiCheckpoint = withServerPromise(
 | 
				
			||||||
 | 
					  async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
 | 
				
			||||||
 | 
					    const { userId, orgId } = auth();
 | 
				
			||||||
 | 
					    console.log("1");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!data.civitai_url) return { error: "no civitai_url" };
 | 
				
			||||||
 | 
					    console.log("2");
 | 
				
			||||||
 | 
					    if (!userId) return { error: "No user id" };
 | 
				
			||||||
 | 
					    console.log("3");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const { url, modelVersionId } = getUrl(data?.civitai_url);
 | 
				
			||||||
 | 
					    console.log("4", url, modelVersionId);
 | 
				
			||||||
 | 
					    const civitaiModelRes = await fetch(url)
 | 
				
			||||||
 | 
					      .then((x) => x.json())
 | 
				
			||||||
 | 
					      .then((a) => {
 | 
				
			||||||
 | 
					        console.log(a)
 | 
				
			||||||
 | 
					        return CivitaiModelResponse.parse(a);
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					    console.log("5");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (civitaiModelRes?.modelVersions?.length === 0) {
 | 
				
			||||||
 | 
					      console.log("6");
 | 
				
			||||||
 | 
					      return; // no versions to download
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    console.log("7");
 | 
				
			||||||
 | 
					    let selectedModelVersion;
 | 
				
			||||||
 | 
					    let selectedModelVersionId: string | null = modelVersionId;
 | 
				
			||||||
 | 
					    if (!selectedModelVersionId) {
 | 
				
			||||||
 | 
					      selectedModelVersion = civitaiModelRes.modelVersions[0];
 | 
				
			||||||
 | 
					      selectedModelVersionId = civitaiModelRes.modelVersions[0].id.toString();
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      selectedModelVersion = civitaiModelRes.modelVersions.find((version) =>
 | 
				
			||||||
 | 
					        version.id.toString() === selectedModelVersionId
 | 
				
			||||||
 | 
					      );
 | 
				
			||||||
 | 
					      if (!selectedModelVersion) {
 | 
				
			||||||
 | 
					        return; // version id is wrong
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      selectedModelVersionId = selectedModelVersion?.id.toString();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    console.log("8");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const checkpointVolumes = await getCheckpointVolumes();
 | 
				
			||||||
 | 
					    console.log("9");
 | 
				
			||||||
 | 
					    let cVolume;
 | 
				
			||||||
 | 
					    if (checkpointVolumes.length === 0) {
 | 
				
			||||||
 | 
					      console.log("10");
 | 
				
			||||||
 | 
					      const volume = await addCheckpointVolume();
 | 
				
			||||||
 | 
					      console.log("11");
 | 
				
			||||||
 | 
					      cVolume = volume[0];
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      console.log("12");
 | 
				
			||||||
 | 
					      cVolume = checkpointVolumes[0];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    console.log("13");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const a = await db
 | 
				
			||||||
 | 
					      .insert(checkpointTable)
 | 
				
			||||||
 | 
					      .values({
 | 
				
			||||||
 | 
					        user_id: userId,
 | 
				
			||||||
 | 
					        org_id: orgId,
 | 
				
			||||||
 | 
					        upload_type: "civitai",
 | 
				
			||||||
 | 
					        model_name: selectedModelVersion.files[0].name,
 | 
				
			||||||
 | 
					        civitai_id: civitaiModelRes.id.toString(),
 | 
				
			||||||
 | 
					        civitai_version_id: selectedModelVersionId,
 | 
				
			||||||
 | 
					        civitai_url: data.civitai_url,
 | 
				
			||||||
 | 
					        civitai_download_url: selectedModelVersion.files[0].downloadUrl,
 | 
				
			||||||
 | 
					        civitai_model_response: civitaiModelRes,
 | 
				
			||||||
 | 
					        checkpoint_volume_id: cVolume.id,
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					      .returning();
 | 
				
			||||||
 | 
					    console.log("14");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const b = a[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await uploadCheckpoint(data, b, cVolume);
 | 
				
			||||||
 | 
					    console.log("15");
 | 
				
			||||||
 | 
					    // redirect(`/checkpoints/${b.id}`);
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async function uploadCheckpoint(
 | 
				
			||||||
 | 
					  data: z.infer<typeof addCivitaiCheckpointSchema>,
 | 
				
			||||||
 | 
					  c: CheckpointType,
 | 
				
			||||||
 | 
					  v: CheckpointVolumeType,
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					  const headersList = headers();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const domain = headersList.get("x-forwarded-host") || "";
 | 
				
			||||||
 | 
					  const protocol = headersList.get("x-forwarded-proto") || "";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (domain === "") {
 | 
				
			||||||
 | 
					    throw new Error("No domain");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Call remote builder
 | 
				
			||||||
 | 
					  const result = await fetch(
 | 
				
			||||||
 | 
					    `${process.env.MODAL_BUILDER_URL!}/upload-volume`,
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      method: "POST",
 | 
				
			||||||
 | 
					      headers: {
 | 
				
			||||||
 | 
					        "Content-Type": "application/json",
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      body: JSON.stringify({
 | 
				
			||||||
 | 
					        download_url: c.civitai_download_url,
 | 
				
			||||||
 | 
					        volume_name: v.volume_name,
 | 
				
			||||||
 | 
					        volume_id: v.id,
 | 
				
			||||||
 | 
					        checkpoint_id: c.id,
 | 
				
			||||||
 | 
					        callback_url: `${protocol}://${domain}/api/volume-updated`,
 | 
				
			||||||
 | 
					        upload_type: "checkpoint"
 | 
				
			||||||
 | 
					      }),
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!result.ok) {
 | 
				
			||||||
 | 
					    const error_log = await result.text();
 | 
				
			||||||
 | 
					    await db
 | 
				
			||||||
 | 
					      .update(checkpointTable)
 | 
				
			||||||
 | 
					      .set({
 | 
				
			||||||
 | 
					        ...data,
 | 
				
			||||||
 | 
					        status: "failed",
 | 
				
			||||||
 | 
					        error_log: error_log,
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					      .where(eq(checkpointTable.id, c.id));
 | 
				
			||||||
 | 
					    throw new Error(`Error: ${result.statusText} ${error_log}`);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // setting the build machine id
 | 
				
			||||||
 | 
					    const json = await result.json();
 | 
				
			||||||
 | 
					    await db
 | 
				
			||||||
 | 
					      .update(checkpointTable)
 | 
				
			||||||
 | 
					      .set({
 | 
				
			||||||
 | 
					        ...data,
 | 
				
			||||||
 | 
					        upload_machine_id: json.build_machine_instance_id,
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					      .where(eq(checkpointTable.id, c.id));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -15,6 +15,7 @@ import { headers } from "next/headers";
 | 
				
			|||||||
import { redirect } from "next/navigation";
 | 
					import { redirect } from "next/navigation";
 | 
				
			||||||
import "server-only";
 | 
					import "server-only";
 | 
				
			||||||
import type { z } from "zod";
 | 
					import type { z } from "zod";
 | 
				
			||||||
 | 
					import { retrieveCheckpointVolumes } from "./curdCheckpoint";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export async function getMachines() {
 | 
					export async function getMachines() {
 | 
				
			||||||
  const { userId, orgId } = auth();
 | 
					  const { userId, orgId } = auth();
 | 
				
			||||||
@ -189,6 +190,7 @@ async function _buildMachine(
 | 
				
			|||||||
    throw new Error("No domain");
 | 
					    throw new Error("No domain");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const volumes = await retrieveCheckpointVolumes();
 | 
				
			||||||
  // Call remote builder
 | 
					  // Call remote builder
 | 
				
			||||||
  const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
 | 
					  const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
 | 
				
			||||||
    method: "POST",
 | 
					    method: "POST",
 | 
				
			||||||
@ -202,6 +204,7 @@ async function _buildMachine(
 | 
				
			|||||||
      callback_url: `${protocol}://${domain}/api/machine-built`,
 | 
					      callback_url: `${protocol}://${domain}/api/machine-built`,
 | 
				
			||||||
      models: data.models, //JSON.parse(data.models as string),
 | 
					      models: data.models, //JSON.parse(data.models as string),
 | 
				
			||||||
      gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
 | 
					      gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
 | 
				
			||||||
 | 
					      checkpoint_volume_name: volumes[0].volume_name,
 | 
				
			||||||
    }),
 | 
					    }),
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										41
									
								
								web/src/server/getAllUserCheckpoints.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								web/src/server/getAllUserCheckpoints.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					import { db } from "@/db/db";
 | 
				
			||||||
 | 
					import {
 | 
				
			||||||
 | 
					  checkpointTable,
 | 
				
			||||||
 | 
					} from "@/db/schema";
 | 
				
			||||||
 | 
					import { auth } from "@clerk/nextjs";
 | 
				
			||||||
 | 
					import { and, desc, eq, isNull } from "drizzle-orm";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function getAllUserCheckpoints() {
 | 
				
			||||||
 | 
					  const { userId, orgId } = await auth();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!userId) {
 | 
				
			||||||
 | 
					    return null;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const checkpoints = await db.query.checkpointTable.findMany({
 | 
				
			||||||
 | 
					    with: {
 | 
				
			||||||
 | 
					      user: {
 | 
				
			||||||
 | 
					        columns: {
 | 
				
			||||||
 | 
					          name: true,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    columns: {
 | 
				
			||||||
 | 
					      id: true,
 | 
				
			||||||
 | 
					      updated_at: true,
 | 
				
			||||||
 | 
					      model_name: true,
 | 
				
			||||||
 | 
					      civitai_url: true,
 | 
				
			||||||
 | 
					      civitai_model_response: true,
 | 
				
			||||||
 | 
					      is_public: true,
 | 
				
			||||||
 | 
					      upload_type: true,
 | 
				
			||||||
 | 
					      status: true,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    orderBy: desc(checkpointTable.updated_at),
 | 
				
			||||||
 | 
					    where: 
 | 
				
			||||||
 | 
					      orgId != undefined
 | 
				
			||||||
 | 
					        ? eq(checkpointTable.org_id, orgId)
 | 
				
			||||||
 | 
					        : and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return checkpoints;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										129
									
								
								web/src/types/civitai.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								web/src/types/civitai.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,129 @@
 | 
				
			|||||||
 | 
					import { z } from "zod";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const creatorSchema = z.object({
 | 
				
			||||||
 | 
					  username: z.string().nullish(),
 | 
				
			||||||
 | 
					  image: z.string().url().nullish(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const fileMetadataSchema = z.object({
 | 
				
			||||||
 | 
					  fp: z.string().nullish(),
 | 
				
			||||||
 | 
					  size: z.string().nullish(),
 | 
				
			||||||
 | 
					  format: z.string().nullish(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const fileSchema = z.object({
 | 
				
			||||||
 | 
					  id: z.number(),
 | 
				
			||||||
 | 
					  sizeKB: z.number().nullish(),
 | 
				
			||||||
 | 
					  name: z.string(),
 | 
				
			||||||
 | 
					  type: z.string().nullish(),
 | 
				
			||||||
 | 
					  metadata: fileMetadataSchema.nullish(),
 | 
				
			||||||
 | 
					  pickleScanResult: z.string().nullish(),
 | 
				
			||||||
 | 
					  pickleScanMessage: z.string().nullable(),
 | 
				
			||||||
 | 
					  virusScanResult: z.string().nullish(),
 | 
				
			||||||
 | 
					  virusScanMessage: z.string().nullable(),
 | 
				
			||||||
 | 
					  scannedAt: z.string().nullish(),
 | 
				
			||||||
 | 
					  hashes: z.record(z.string()).nullish(),
 | 
				
			||||||
 | 
					  downloadUrl: z.string().url(),
 | 
				
			||||||
 | 
					  primary: z.boolean().nullish(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const imageMetadataSchema = z.object({
 | 
				
			||||||
 | 
					  hash: z.string(),
 | 
				
			||||||
 | 
					  width: z.number(),
 | 
				
			||||||
 | 
					  height: z.number(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const imageMetaSchema = z.object({
 | 
				
			||||||
 | 
					  ENSD: z.string().nullish(),
 | 
				
			||||||
 | 
					  Size: z.string().nullish(),
 | 
				
			||||||
 | 
					  seed: z.number().nullish(),
 | 
				
			||||||
 | 
					  Model: z.string().nullish(),
 | 
				
			||||||
 | 
					  steps: z.number().nullish(),
 | 
				
			||||||
 | 
					  hashes: z.record(z.string()).nullish(),
 | 
				
			||||||
 | 
					  prompt: z.string().nullish(),
 | 
				
			||||||
 | 
					  sampler: z.string().nullish(),
 | 
				
			||||||
 | 
					  cfgScale: z.number().nullish(),
 | 
				
			||||||
 | 
					  ClipSkip: z.number().nullish(),
 | 
				
			||||||
 | 
					  resources: z.array(
 | 
				
			||||||
 | 
					    z.object({
 | 
				
			||||||
 | 
					      hash: z.string().nullish(),
 | 
				
			||||||
 | 
					      name: z.string(),
 | 
				
			||||||
 | 
					      type: z.string(),
 | 
				
			||||||
 | 
					      weight: z.number().nullish(),
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					  ).nullish(),
 | 
				
			||||||
 | 
					  ModelHash: z.string().nullish(),
 | 
				
			||||||
 | 
					  HiresSteps: z.string().nullish(),
 | 
				
			||||||
 | 
					  HiresUpscale: z.string().nullish(),
 | 
				
			||||||
 | 
					  HiresUpscaler: z.string().nullish(),
 | 
				
			||||||
 | 
					  negativePrompt: z.string(),
 | 
				
			||||||
 | 
					  DenoisingStrength: z.number().nullish(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// all over the place
 | 
				
			||||||
 | 
					// export const imageSchema = z.object({
 | 
				
			||||||
 | 
					//   url: z.string().url().nullish(),
 | 
				
			||||||
 | 
					//   nsfw: z.enum(["None", "Soft", "Mature"]).nullish(),
 | 
				
			||||||
 | 
					//   width: z.number().nullish(),
 | 
				
			||||||
 | 
					//   height: z.number().nullish(),
 | 
				
			||||||
 | 
					//   hash: z.string().nullish(),
 | 
				
			||||||
 | 
					//   type: z.string().nullish(),
 | 
				
			||||||
 | 
					//   metadata: imageMetadataSchema.nullish(),
 | 
				
			||||||
 | 
					//   meta: imageMetaSchema.nullish(),
 | 
				
			||||||
 | 
					// });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const modelVersionSchema = z.object({
 | 
				
			||||||
 | 
					  id: z.number(),
 | 
				
			||||||
 | 
					  modelId: z.number(),
 | 
				
			||||||
 | 
					  name: z.string(),
 | 
				
			||||||
 | 
					  createdAt: z.string().nullish(),
 | 
				
			||||||
 | 
					  updatedAt: z.string().nullish(),
 | 
				
			||||||
 | 
					  // status: z.enum(["Published", "Unpublished"]).nullish(),
 | 
				
			||||||
 | 
					  status: z.string().nullish(),
 | 
				
			||||||
 | 
					  publishedAt: z.string().nullish(),
 | 
				
			||||||
 | 
					  trainedWords: z.array(z.string()).nullable(),
 | 
				
			||||||
 | 
					  trainingStatus: z.string().nullable(),
 | 
				
			||||||
 | 
					  trainingDetails: z.string().nullable(),
 | 
				
			||||||
 | 
					  baseModel: z.string().nullish(),
 | 
				
			||||||
 | 
					  baseModelType: z.string().nullish(),
 | 
				
			||||||
 | 
					  earlyAccessTimeFrame: z.number().nullish(),
 | 
				
			||||||
 | 
					  description: z.string().nullable(),
 | 
				
			||||||
 | 
					  vaeId: z.number().nullable(),
 | 
				
			||||||
 | 
					  stats: z.object({
 | 
				
			||||||
 | 
					    downloadCount: z.number(),
 | 
				
			||||||
 | 
					    ratingCount: z.number(),
 | 
				
			||||||
 | 
					    rating: z.number(),
 | 
				
			||||||
 | 
					  }).nullish(),
 | 
				
			||||||
 | 
					  files: z.array(fileSchema),
 | 
				
			||||||
 | 
					  images: z.array(z.any()).nullish(),
 | 
				
			||||||
 | 
					  downloadUrl: z.string().url(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const statsSchema = z.object({
 | 
				
			||||||
 | 
					  downloadCount: z.number(),
 | 
				
			||||||
 | 
					  favoriteCount: z.number(),
 | 
				
			||||||
 | 
					  commentCount: z.number(),
 | 
				
			||||||
 | 
					  ratingCount: z.number(),
 | 
				
			||||||
 | 
					  rating: z.number(),
 | 
				
			||||||
 | 
					  tippedAmountCount: z.number(),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export const CivitaiModelResponse = z.object({
 | 
				
			||||||
 | 
					  id: z.number(),
 | 
				
			||||||
 | 
					  name: z.string().nullish(),
 | 
				
			||||||
 | 
					  description: z.string().nullish(),
 | 
				
			||||||
 | 
					  // type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know
 | 
				
			||||||
 | 
					  type: z.string(),
 | 
				
			||||||
 | 
					  poi: z.boolean().nullish(),
 | 
				
			||||||
 | 
					  nsfw: z.boolean().nullish(),
 | 
				
			||||||
 | 
					  allowNoCredit: z.boolean().nullish(),
 | 
				
			||||||
 | 
					  allowCommercialUse: z.string().nullish(),
 | 
				
			||||||
 | 
					  allowDerivatives: z.boolean().nullish(),
 | 
				
			||||||
 | 
					  allowDifferentLicense: z.boolean().nullish(),
 | 
				
			||||||
 | 
					  stats: statsSchema.nullish(),
 | 
				
			||||||
 | 
					  creator: creatorSchema.nullish(),
 | 
				
			||||||
 | 
					  tags: z.array(z.string()).nullish(),
 | 
				
			||||||
 | 
					  modelVersions: z.array(modelVersionSchema),
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user