work
This commit is contained in:
		
							parent
							
								
									90cec6b778
								
							
						
					
					
						commit
						fed7b380b6
					
				@ -224,6 +224,20 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
			
		||||
#     return {"Hello": "World"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UploadBody(BaseModel):
 | 
			
		||||
    download_url: str
 | 
			
		||||
    volume_name: str
 | 
			
		||||
    # callback_url: str
 | 
			
		||||
 | 
			
		||||
@app.post("/upload_volume")
 | 
			
		||||
async def upload_checkpoint(body: UploadBody):
 | 
			
		||||
    download_url = body.download_url
 | 
			
		||||
    volume_name = body.download_url
 | 
			
		||||
    # callback_url = body.callback_url
 | 
			
		||||
    # check that thi
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/create")
 | 
			
		||||
async def create_machine(item: Item):
 | 
			
		||||
    global last_activity_time
 | 
			
		||||
 | 
			
		||||
@ -3,9 +3,9 @@ This is a standalone script to download models into a modal Volume using civitai
 | 
			
		||||
 | 
			
		||||
Example Usage
 | 
			
		||||
`modal run insert_models::insert_model --civitai-url https://civitai.com/models/36520/ghostmix`
 | 
			
		||||
This inserts an individual model from a civitai url (public not API url)
 | 
			
		||||
This inserts an individual model from a civitai url 
 | 
			
		||||
 | 
			
		||||
`modal run insert_models::insert_models` 
 | 
			
		||||
`modal run insert_models::insert_models_civitai_api` 
 | 
			
		||||
This inserts a bunch of models based on the models retrieved by civitai
 | 
			
		||||
 | 
			
		||||
civitai's API reference https://github.com/civitai/civitai/wiki/REST-API-Reference
 | 
			
		||||
@ -13,24 +13,21 @@ civitai's API reference https://github.com/civitai/civitai/wiki/REST-API-Referen
 | 
			
		||||
import modal
 | 
			
		||||
import subprocess
 | 
			
		||||
import requests
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
stub = modal.Stub()
 | 
			
		||||
 | 
			
		||||
# NOTE: volume name can be variable
 | 
			
		||||
volume = modal.Volume.persisted("private-model-store")
 | 
			
		||||
volume = modal.Volume.persisted("rah")
 | 
			
		||||
model_store_path = "/vol/models"
 | 
			
		||||
MODEL_ROUTE = "models"
 | 
			
		||||
image = (
 | 
			
		||||
    modal.Image.debian_slim().apt_install("wget").pip_install("requests")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@stub.function(volumes={model_store_path: volume}, gpu="any", image=image, timeout=600)
 | 
			
		||||
def download_model(model):
 | 
			
		||||
    # wget https://civitai.com/api/download/models/{modelVersionId} --content-disposition
 | 
			
		||||
    # model_id = model['modelVersions'][0]['id']
 | 
			
		||||
    # download_url = f"https://civitai.com/api/download/models/{model_id}"
 | 
			
		||||
 | 
			
		||||
    download_url = model['modelVersions'][0]['downloadUrl'] 
 | 
			
		||||
@stub.function(volumes={model_store_path: volume}, image=image, timeout=50000, gpu=None)
 | 
			
		||||
def download_model(download_url):
 | 
			
		||||
    print(download_url)
 | 
			
		||||
    subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
 | 
			
		||||
    subprocess.run(["ls", "-la", model_store_path])
 | 
			
		||||
    volume.commit()
 | 
			
		||||
@ -52,40 +49,53 @@ def get_civitai_models(model_type: str, sort: str = "Highest Rated", page: int =
 | 
			
		||||
@stub.function()
 | 
			
		||||
def get_civitai_model_url(civitai_url: str):
 | 
			
		||||
    # Validate the URL
 | 
			
		||||
    if not civitai_url.startswith("https://civitai.com/models/"):
 | 
			
		||||
        return "Error: URL must be from civitai.com and contain /models/"
 | 
			
		||||
 | 
			
		||||
    # Extract the model ID
 | 
			
		||||
    if civitai_url.startswith("https://civitai.com/api/"):
 | 
			
		||||
        api_url = civitai_url
 | 
			
		||||
    elif civitai_url.startswith("https://civitai.com/models/"):  
 | 
			
		||||
        try:
 | 
			
		||||
            model_id = civitai_url.split("/")[4]
 | 
			
		||||
        int(model_id)  # Check if the ID is an integer
 | 
			
		||||
            int(model_id) 
 | 
			
		||||
        except (IndexError, ValueError):
 | 
			
		||||
        return None #Error: Invalid model ID in URL
 | 
			
		||||
 | 
			
		||||
    # Make the API request
 | 
			
		||||
            return None 
 | 
			
		||||
        api_url = f"https://civitai.com/api/v1/models/{model_id}"
 | 
			
		||||
    response = requests.get(api_url)
 | 
			
		||||
    else:
 | 
			
		||||
        return "Error: URL must be from civitai.com and contain /models/"
 | 
			
		||||
 | 
			
		||||
    response = requests.get(api_url)
 | 
			
		||||
    # Check for successful response
 | 
			
		||||
    if response.status_code != 200:
 | 
			
		||||
        return f"Error: Unable to fetch data from {api_url}"
 | 
			
		||||
 | 
			
		||||
    # Return the response data
 | 
			
		||||
    return response.json()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
def insert_models(type: str = "Checkpoint", sort = "Highest Rated", page: int = 1):
 | 
			
		||||
def insert_models_civitai_api(type: str = "Checkpoint", sort = "Highest Rated", page: int = 1):
 | 
			
		||||
    civitai_models = get_civitai_models.local(type, sort, page)
 | 
			
		||||
    if civitai_models:
 | 
			
		||||
        for _ in download_model.map(civitai_models['items'][1:]):
 | 
			
		||||
        for _ in download_model.map(map(lambda model: model['modelVersions'][0]['downloadUrl'], civitai_models['items'])):
 | 
			
		||||
            pass
 | 
			
		||||
    else:
 | 
			
		||||
        print("Failed to retrieve models.")
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
def insert_model(civitai_url: str):
 | 
			
		||||
    if civitai_url.startswith("'https://civitai.com/api/download/models/"):
 | 
			
		||||
        download_url = civitai_url
 | 
			
		||||
    else:
 | 
			
		||||
        civitai_model = get_civitai_model_url.local(civitai_url)
 | 
			
		||||
        if civitai_model:
 | 
			
		||||
        download_model.remote(civitai_model)
 | 
			
		||||
            download_url = civitai_model['modelVersions'][0]['downloadUrl']
 | 
			
		||||
        else:
 | 
			
		||||
            return "invalid URL"
 | 
			
		||||
 | 
			
		||||
    download_model.remote(download_url)
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
def simple_download():
 | 
			
		||||
    download_urls = ['https://civitai.com/api/download/models/119057', 'https://civitai.com/api/download/models/130090', 'https://civitai.com/api/download/models/31859', 'https://civitai.com/api/download/models/128713', 'https://civitai.com/api/download/models/179657', 'https://civitai.com/api/download/models/143906', 'https://civitai.com/api/download/models/9208', 'https://civitai.com/api/download/models/136078', 'https://civitai.com/api/download/models/134065', 'https://civitai.com/api/download/models/288775', 'https://civitai.com/api/download/models/95263', 'https://civitai.com/api/download/models/288982', 'https://civitai.com/api/download/models/87153', 'https://civitai.com/api/download/models/10638', 'https://civitai.com/api/download/models/263809', 'https://civitai.com/api/download/models/130072', 'https://civitai.com/api/download/models/117019', 'https://civitai.com/api/download/models/95256', 'https://civitai.com/api/download/models/197181', 'https://civitai.com/api/download/models/256915', 'https://civitai.com/api/download/models/118945', 'https://civitai.com/api/download/models/125843', 'https://civitai.com/api/download/models/179015', 'https://civitai.com/api/download/models/245598', 'https://civitai.com/api/download/models/223670', 'https://civitai.com/api/download/models/90072', 'https://civitai.com/api/download/models/290817', 'https://civitai.com/api/download/models/154097', 'https://civitai.com/api/download/models/143497', 'https://civitai.com/api/download/models/5637']
 | 
			
		||||
 | 
			
		||||
    for _ in download_model.map(download_urls):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										54
									
								
								builder/modal-builder/src/volume-builder/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								builder/modal-builder/src/volume-builder/app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
			
		||||
import modal
 | 
			
		||||
from config import config
 | 
			
		||||
import os
 | 
			
		||||
import uuid
 | 
			
		||||
import subprocess
 | 
			
		||||
 | 
			
		||||
stub = modal.Stub()
 | 
			
		||||
 | 
			
		||||
base_path = "/volumes"
 | 
			
		||||
 | 
			
		||||
# Volume names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length.
 | 
			
		||||
def is_valid_name(name: str) -> bool:
 | 
			
		||||
    allowed_characters = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._")
 | 
			
		||||
    return 0 < len(name) <= 64 and all(char in allowed_characters for char in name)
 | 
			
		||||
 | 
			
		||||
def create_volumes(volume_names):
 | 
			
		||||
    path_to_vol = {}
 | 
			
		||||
    vol_to_path = {}
 | 
			
		||||
    for volume_name in volume_names.keys():
 | 
			
		||||
        if not is_valid_name(volume_name):
 | 
			
		||||
            pass
 | 
			
		||||
        modal_volume = modal.Volume.persisted(volume_name)
 | 
			
		||||
        volume_path = create_volume_path(base_path)
 | 
			
		||||
        path_to_vol[volume_path] = modal_volume
 | 
			
		||||
        vol_to_path[volume_name] = volume_path
 | 
			
		||||
 
 | 
			
		||||
    return (path_to_vol, vol_to_path)
 | 
			
		||||
 | 
			
		||||
def create_volume_path(base_path: str):
 | 
			
		||||
    random_path = str(uuid.uuid4())
 | 
			
		||||
    return os.path.join(base_path, random_path)
 | 
			
		||||
 | 
			
		||||
vol_name_to_links = config["volume_names"]
 | 
			
		||||
(path_to_vol, vol_name_to_path) = create_volumes(vol_name_to_links)
 | 
			
		||||
image = ( 
 | 
			
		||||
   modal.Image.debian_slim().apt_install("wget").pip_install("requests")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
print(vol_name_to_links)
 | 
			
		||||
print(path_to_vol)
 | 
			
		||||
print(vol_name_to_path)
 | 
			
		||||
 | 
			
		||||
@stub.function(volumes=path_to_vol, image=image, timeout=5000, gpu=None)
 | 
			
		||||
def download_model(volume_name, download_url):
 | 
			
		||||
    model_store_path = vol_name_to_path[volume_name]
 | 
			
		||||
    subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path])
 | 
			
		||||
    subprocess.run(["ls", "-la", model_store_path])
 | 
			
		||||
    path_to_vol[model_store_path].commit()
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
def simple_download():
 | 
			
		||||
    print(vol_name_to_links)
 | 
			
		||||
    print([(vol_name, link) for vol_name,link in vol_name_to_links.items()])
 | 
			
		||||
    list(download_model.starmap([(vol_name, link) for vol_name,link in vol_name_to_links.items()]))
 | 
			
		||||
							
								
								
									
										5
									
								
								builder/modal-builder/src/volume-builder/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								builder/modal-builder/src/volume-builder/config.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
config = {
 | 
			
		||||
    "volume_names": {
 | 
			
		||||
        "eg1": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png"
 | 
			
		||||
    }, 
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										62
									
								
								web/drizzle/0031_common_deathbird.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								web/drizzle/0031_common_deathbird.sql
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 CREATE TYPE "model_upload_type" AS ENUM('civitai', 'huggingface', 'other');
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 CREATE TYPE "resource_upload" AS ENUM('started', 'failed', 'succeded');
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpoints" (
 | 
			
		||||
	"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
 | 
			
		||||
	"user_id" text,
 | 
			
		||||
	"org_id" text,
 | 
			
		||||
	"description" text,
 | 
			
		||||
	"checkpoint_volume_id" uuid NOT NULL,
 | 
			
		||||
	"model_name" text,
 | 
			
		||||
	"civitai_id" text,
 | 
			
		||||
	"civitai_version_id" text,
 | 
			
		||||
	"civitai_url" text,
 | 
			
		||||
	"civitai_download_url" text,
 | 
			
		||||
	"civitai_model_response" jsonb,
 | 
			
		||||
	"hf_url" text,
 | 
			
		||||
	"s3_url" text,
 | 
			
		||||
	"client_url" text,
 | 
			
		||||
	"is_public" boolean DEFAULT false NOT NULL,
 | 
			
		||||
	"status" "resource_upload" DEFAULT 'started' NOT NULL,
 | 
			
		||||
	"upload_machine_id" text,
 | 
			
		||||
	"upload_type" "model_upload_type" NOT NULL,
 | 
			
		||||
	"created_at" timestamp DEFAULT now() NOT NULL,
 | 
			
		||||
	"updated_at" timestamp DEFAULT now() NOT NULL
 | 
			
		||||
);
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
CREATE TABLE IF NOT EXISTS "comfyui_deploy"."checkpointVolumeTable" (
 | 
			
		||||
	"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
 | 
			
		||||
	"user_id" text,
 | 
			
		||||
	"org_id" text,
 | 
			
		||||
	"volume_name" text NOT NULL,
 | 
			
		||||
	"created_at" timestamp DEFAULT now() NOT NULL,
 | 
			
		||||
	"updated_at" timestamp DEFAULT now() NOT NULL,
 | 
			
		||||
	"disabled" boolean DEFAULT false NOT NULL
 | 
			
		||||
);
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."checkpoints" ADD CONSTRAINT "checkpoints_checkpoint_volume_id_workflow_runs_id_fk" FOREIGN KEY ("checkpoint_volume_id") REFERENCES "comfyui_deploy"."workflow_runs"("id") ON DELETE cascade ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."checkpointVolumeTable" ADD CONSTRAINT "checkpointVolumeTable_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
							
								
								
									
										1004
									
								
								web/drizzle/meta/0031_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1004
									
								
								web/drizzle/meta/0031_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -218,6 +218,13 @@
 | 
			
		||||
      "when": 1705716303820,
 | 
			
		||||
      "tag": "0030_kind_doorman",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "idx": 31,
 | 
			
		||||
      "version": "5",
 | 
			
		||||
      "when": 1705963548821,
 | 
			
		||||
      "tag": "0031_common_deathbird",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								web/src/app/(app)/api/volume-updated/route.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								web/src/app/(app)/api/volume-updated/route.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,50 @@
 | 
			
		||||
import { parseDataSafe } from "../../../../lib/parseDataSafe";
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import { checkpointTable, machinesTable } from "@/db/schema";
 | 
			
		||||
import { eq } from "drizzle-orm";
 | 
			
		||||
import { NextResponse } from "next/server";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
 | 
			
		||||
const Request = z.object({
 | 
			
		||||
  machine_id: z.string(),
 | 
			
		||||
  endpoint: z.string().optional(),
 | 
			
		||||
  build_log: z.string().optional(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export async function POST(request: Request) {
 | 
			
		||||
  const [data, error] = await parseDataSafe(Request, request);
 | 
			
		||||
  if (!data || error) return error;
 | 
			
		||||
 | 
			
		||||
  // console.log(data);
 | 
			
		||||
 | 
			
		||||
  const { machine_id, endpoint, build_log } = data;
 | 
			
		||||
 | 
			
		||||
  if (endpoint) {
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        // status: "ready",
 | 
			
		||||
        // endpoint: endpoint,
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(machinesTable.id, machine_id));
 | 
			
		||||
  } else {
 | 
			
		||||
    // console.log(data);
 | 
			
		||||
    await db
 | 
			
		||||
      .update(machinesTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        // status: "error",
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(machinesTable.id, machine_id));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return NextResponse.json(
 | 
			
		||||
    {
 | 
			
		||||
      message: "success",
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      status: 200,
 | 
			
		||||
    }
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								web/src/app/(app)/storage/loading.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								web/src/app/(app)/storage/loading.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
"use client";
 | 
			
		||||
 | 
			
		||||
import { LoadingPageWrapper } from "@/components/LoadingWrapper";
 | 
			
		||||
import { usePathname } from "next/navigation";
 | 
			
		||||
 | 
			
		||||
export default function Loading() {
 | 
			
		||||
  const pathName = usePathname();
 | 
			
		||||
  return <LoadingPageWrapper className="h-full" tag={pathName.toLowerCase()} />;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										35
									
								
								web/src/app/(app)/storage/page.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								web/src/app/(app)/storage/page.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
import { setInitialUserData } from "../../../lib/setInitialUserData";
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import { clerkClient } from "@clerk/nextjs/server";
 | 
			
		||||
import { CheckpointList } from "@/components/CheckpointList"
 | 
			
		||||
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
			
		||||
 | 
			
		||||
export default function Page() {
 | 
			
		||||
  return <CheckpointListServer />;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async function CheckpointListServer() {
 | 
			
		||||
  const { userId } = auth();
 | 
			
		||||
 | 
			
		||||
  if (!userId) {
 | 
			
		||||
    return <div>No auth</div>;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const user = await clerkClient.users.getUser(userId);
 | 
			
		||||
 | 
			
		||||
  if (!user) {
 | 
			
		||||
    await setInitialUserData(userId);
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  const checkpoints  = await getAllUserCheckpoints()
 | 
			
		||||
 | 
			
		||||
  if (!checkpoints) {
 | 
			
		||||
    return <div>No checkpoints found</div>;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <div className="w-full">
 | 
			
		||||
      <CheckpointList data={checkpoints}/>
 | 
			
		||||
    </div>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										315
									
								
								web/src/components/CheckpointList.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										315
									
								
								web/src/components/CheckpointList.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,315 @@
 | 
			
		||||
"use client";
 | 
			
		||||
 | 
			
		||||
import { getRelativeTime } from "../lib/getRelativeTime";
 | 
			
		||||
import { Badge } from "@/components/ui/badge";
 | 
			
		||||
import { Button } from "@/components/ui/button";
 | 
			
		||||
import { Checkbox } from "@/components/ui/checkbox";
 | 
			
		||||
import { InsertModal, UpdateModal } from "./InsertModal";
 | 
			
		||||
import { Input } from "@/components/ui/input";
 | 
			
		||||
import { ScrollArea } from "@/components/ui/scroll-area";
 | 
			
		||||
import {
 | 
			
		||||
  Table,
 | 
			
		||||
  TableBody,
 | 
			
		||||
  TableCell,
 | 
			
		||||
  TableHead,
 | 
			
		||||
  TableHeader,
 | 
			
		||||
  TableRow,
 | 
			
		||||
} from "@/components/ui/table";
 | 
			
		||||
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
			
		||||
import type {
 | 
			
		||||
  ColumnDef,
 | 
			
		||||
  ColumnFiltersState,
 | 
			
		||||
  SortingState,
 | 
			
		||||
  VisibilityState,
 | 
			
		||||
} from "@tanstack/react-table";
 | 
			
		||||
import {
 | 
			
		||||
  flexRender,
 | 
			
		||||
  getCoreRowModel,
 | 
			
		||||
  getFilteredRowModel,
 | 
			
		||||
  getPaginationRowModel,
 | 
			
		||||
  getSortedRowModel,
 | 
			
		||||
  useReactTable,
 | 
			
		||||
} from "@tanstack/react-table";
 | 
			
		||||
import { ArrowUpDown, MoreHorizontal } from "lucide-react";
 | 
			
		||||
import * as React from "react";
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
 | 
			
		||||
 | 
			
		||||
export type CheckpointItemList = NonNullable<
 | 
			
		||||
  Awaited<ReturnType<typeof getAllUserCheckpoints>>
 | 
			
		||||
>[0];
 | 
			
		||||
 | 
			
		||||
export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "id",
 | 
			
		||||
    id: "select",
 | 
			
		||||
    header: ({ table }) => (
 | 
			
		||||
      <Checkbox
 | 
			
		||||
        checked={table.getIsAllPageRowsSelected() ||
 | 
			
		||||
          (table.getIsSomePageRowsSelected() && "indeterminate")}
 | 
			
		||||
        onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
 | 
			
		||||
        aria-label="Select all"
 | 
			
		||||
      />
 | 
			
		||||
    ),
 | 
			
		||||
    cell: ({ row }) => (
 | 
			
		||||
      <Checkbox
 | 
			
		||||
        checked={row.getIsSelected()}
 | 
			
		||||
        onCheckedChange={(value) => row.toggleSelected(!!value)}
 | 
			
		||||
        aria-label="Select row"
 | 
			
		||||
      />
 | 
			
		||||
    ),
 | 
			
		||||
    enableSorting: false,
 | 
			
		||||
    enableHiding: false,
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "name",
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="flex items-center hover:underline"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Name
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      const checkpoint = row.original;
 | 
			
		||||
      return (
 | 
			
		||||
        <a
 | 
			
		||||
          className="hover:underline flex gap-2"
 | 
			
		||||
          href={`/storage/${checkpoint.id}`} // TODO
 | 
			
		||||
        >
 | 
			
		||||
          <span className="truncate max-w-[200px]">{row.original.model_name}</span>
 | 
			
		||||
 | 
			
		||||
          <Badge variant="default">{}</Badge>
 | 
			
		||||
          {checkpoint.is_public
 | 
			
		||||
            ? <Badge variant="success">Public</Badge>
 | 
			
		||||
            : <Badge variant="teal">Private</Badge>}
 | 
			
		||||
        </a>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "creator",
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="flex items-center hover:underline"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Creator
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      // return <Badge variant="cyan">{row?.original?.user?.name ? row.original.user.name : "Public"}</Badge>;
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "date",
 | 
			
		||||
    sortingFn: "datetime",
 | 
			
		||||
    enableSorting: true,
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="w-full flex items-center justify-end hover:underline truncate"
 | 
			
		||||
          // variant="ghost"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Update Date
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => (
 | 
			
		||||
      <div className="w-full capitalize text-right truncate">
 | 
			
		||||
        {getRelativeTime(row.original.updated_at)}
 | 
			
		||||
      </div>
 | 
			
		||||
    ),
 | 
			
		||||
  },
 | 
			
		||||
  // {
 | 
			
		||||
  //   id: "actions",
 | 
			
		||||
  //   enableHiding: false,
 | 
			
		||||
  //   cell: ({ row }) => {
 | 
			
		||||
  //     const checkpoint = row.original;
 | 
			
		||||
  //
 | 
			
		||||
  //     return (
 | 
			
		||||
  //       <DropdownMenu>
 | 
			
		||||
  //         <DropdownMenuTrigger asChild>
 | 
			
		||||
  //           <Button variant="ghost" className="h-8 w-8 p-0">
 | 
			
		||||
  //             <span className="sr-only">Open menu</span>
 | 
			
		||||
  //             <MoreHorizontal className="h-4 w-4" />
 | 
			
		||||
  //           </Button>
 | 
			
		||||
  //         </DropdownMenuTrigger>
 | 
			
		||||
  //         <DropdownMenuContent align="end">
 | 
			
		||||
  //           <DropdownMenuLabel>Actions</DropdownMenuLabel>
 | 
			
		||||
  //           <DropdownMenuItem
 | 
			
		||||
  //             className="text-destructive"
 | 
			
		||||
  //             onClick={() => {
 | 
			
		||||
  //               deleteWorkflow(checkpoint.id);
 | 
			
		||||
  //             }}
 | 
			
		||||
  //           >
 | 
			
		||||
  //             Delete Workflow
 | 
			
		||||
  //           </DropdownMenuItem>
 | 
			
		||||
  //         </DropdownMenuContent>
 | 
			
		||||
  //       </DropdownMenu>
 | 
			
		||||
  //     );
 | 
			
		||||
  //   },
 | 
			
		||||
  // },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
  const [sorting, setSorting] = React.useState<SortingState>([]);
 | 
			
		||||
  const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
 | 
			
		||||
    [],
 | 
			
		||||
  );
 | 
			
		||||
  const [columnVisibility, setColumnVisibility] = React.useState<
 | 
			
		||||
    VisibilityState
 | 
			
		||||
  >({});
 | 
			
		||||
  const [rowSelection, setRowSelection] = React.useState({});
 | 
			
		||||
 | 
			
		||||
  const table = useReactTable({
 | 
			
		||||
    data,
 | 
			
		||||
    columns,
 | 
			
		||||
    onSortingChange: setSorting,
 | 
			
		||||
    onColumnFiltersChange: setColumnFilters,
 | 
			
		||||
    getCoreRowModel: getCoreRowModel(),
 | 
			
		||||
    getPaginationRowModel: getPaginationRowModel(),
 | 
			
		||||
    getSortedRowModel: getSortedRowModel(),
 | 
			
		||||
    getFilteredRowModel: getFilteredRowModel(),
 | 
			
		||||
    onColumnVisibilityChange: setColumnVisibility,
 | 
			
		||||
    onRowSelectionChange: setRowSelection,
 | 
			
		||||
    state: {
 | 
			
		||||
      sorting,
 | 
			
		||||
      columnFilters,
 | 
			
		||||
      columnVisibility,
 | 
			
		||||
      rowSelection,
 | 
			
		||||
    },
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <div className="grid grid-rows-[auto,1fr,auto] h-full">
 | 
			
		||||
      <div className="flex flex-row w-full items-center py-4">
 | 
			
		||||
        <Input
 | 
			
		||||
          placeholder="Filter workflows..."
 | 
			
		||||
          value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
 | 
			
		||||
          onChange={(event) =>
 | 
			
		||||
            table.getColumn("name")?.setFilterValue(event.target.value)}
 | 
			
		||||
          className="max-w-sm"
 | 
			
		||||
        />
 | 
			
		||||
        <div className="ml-auto flex gap-2">
 | 
			
		||||
          <InsertModal
 | 
			
		||||
            dialogClassName="sm:max-w-[600px]"
 | 
			
		||||
            disabled={
 | 
			
		||||
              false
 | 
			
		||||
              // TODO: limitations based on plan
 | 
			
		||||
            }
 | 
			
		||||
            tooltip={"Add models using their civitai url!"}
 | 
			
		||||
            title="Civitai Checkpoint"
 | 
			
		||||
            description="Pick a model from civitai"
 | 
			
		||||
            serverAction={addCivitaiCheckpoint}
 | 
			
		||||
            formSchema={addCivitaiCheckpointSchema}
 | 
			
		||||
            fieldConfig={{
 | 
			
		||||
              civitai_url: {
 | 
			
		||||
                fieldType: "fallback",
 | 
			
		||||
                // fieldType: "fallback",
 | 
			
		||||
                inputProps: { required: true },
 | 
			
		||||
                description: (
 | 
			
		||||
                  <>
 | 
			
		||||
                    Pick a checkpoint from{" "}
 | 
			
		||||
                    <a
 | 
			
		||||
                      href="https://www.civitai.com/models"
 | 
			
		||||
                      target="_blank"
 | 
			
		||||
                      className="underline text-blue-600 hover:text-blue-800 visited:text-purple-600"
 | 
			
		||||
                    >
 | 
			
		||||
                      civitai.com
 | 
			
		||||
                    </a>{" "}
 | 
			
		||||
                    and place it's url here
 | 
			
		||||
                  </>
 | 
			
		||||
                ),
 | 
			
		||||
              },
 | 
			
		||||
            }}
 | 
			
		||||
          />
 | 
			
		||||
        </div>
 | 
			
		||||
      </div>
 | 
			
		||||
      <ScrollArea className="h-full w-full rounded-md border">
 | 
			
		||||
        <Table>
 | 
			
		||||
          <TableHeader className="bg-background top-0 sticky">
 | 
			
		||||
            {table.getHeaderGroups().map((headerGroup) => (
 | 
			
		||||
              <TableRow key={headerGroup.id}>
 | 
			
		||||
                {headerGroup.headers.map((header) => {
 | 
			
		||||
                  return (
 | 
			
		||||
                    <TableHead key={header.id}>
 | 
			
		||||
                      {header.isPlaceholder ? null : flexRender(
 | 
			
		||||
                        header.column.columnDef.header,
 | 
			
		||||
                        header.getContext(),
 | 
			
		||||
                      )}
 | 
			
		||||
                    </TableHead>
 | 
			
		||||
                  );
 | 
			
		||||
                })}
 | 
			
		||||
              </TableRow>
 | 
			
		||||
            ))}
 | 
			
		||||
          </TableHeader>
 | 
			
		||||
          <TableBody>
 | 
			
		||||
            {table.getRowModel().rows?.length
 | 
			
		||||
              ? (
 | 
			
		||||
                table.getRowModel().rows.map((row) => (
 | 
			
		||||
                  <TableRow
 | 
			
		||||
                    key={row.id}
 | 
			
		||||
                    data-state={row.getIsSelected() && "selected"}
 | 
			
		||||
                  >
 | 
			
		||||
                    {row.getVisibleCells().map((cell) => (
 | 
			
		||||
                      <TableCell key={cell.id}>
 | 
			
		||||
                        {flexRender(
 | 
			
		||||
                          cell.column.columnDef.cell,
 | 
			
		||||
                          cell.getContext(),
 | 
			
		||||
                        )}
 | 
			
		||||
                      </TableCell>
 | 
			
		||||
                    ))}
 | 
			
		||||
                  </TableRow>
 | 
			
		||||
                ))
 | 
			
		||||
              )
 | 
			
		||||
              : (
 | 
			
		||||
                <TableRow>
 | 
			
		||||
                  <TableCell
 | 
			
		||||
                    colSpan={columns.length}
 | 
			
		||||
                    className="h-24 text-center"
 | 
			
		||||
                  >
 | 
			
		||||
                    No results.
 | 
			
		||||
                  </TableCell>
 | 
			
		||||
                </TableRow>
 | 
			
		||||
              )}
 | 
			
		||||
          </TableBody>
 | 
			
		||||
        </Table>
 | 
			
		||||
      </ScrollArea>
 | 
			
		||||
      <div className="flex flex-row items-center justify-end space-x-2 py-4">
 | 
			
		||||
        <div className="flex-1 text-sm text-muted-foreground">
 | 
			
		||||
          {table.getFilteredSelectedRowModel().rows.length} of{" "}
 | 
			
		||||
          {table.getFilteredRowModel().rows.length} row(s) selected.
 | 
			
		||||
        </div>
 | 
			
		||||
        <div className="space-x-2">
 | 
			
		||||
          <Button
 | 
			
		||||
            variant="outline"
 | 
			
		||||
            size="sm"
 | 
			
		||||
            onClick={() => table.previousPage()}
 | 
			
		||||
            disabled={!table.getCanPreviousPage()}
 | 
			
		||||
          >
 | 
			
		||||
            Previous
 | 
			
		||||
          </Button>
 | 
			
		||||
          <Button
 | 
			
		||||
            variant="outline"
 | 
			
		||||
            size="sm"
 | 
			
		||||
            onClick={() => table.nextPage()}
 | 
			
		||||
            disabled={!table.getCanNextPage()}
 | 
			
		||||
          >
 | 
			
		||||
            Next
 | 
			
		||||
          </Button>
 | 
			
		||||
        </div>
 | 
			
		||||
      </div>
 | 
			
		||||
    </div>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
@ -34,6 +34,10 @@ export function NavbarMenu({ className }: { className?: string }) {
 | 
			
		||||
      name: "API Keys",
 | 
			
		||||
      path: "/api-keys",
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      name: "Storage",
 | 
			
		||||
      path: "/storage",
 | 
			
		||||
    },
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
@ -42,9 +46,9 @@ export function NavbarMenu({ className }: { className?: string }) {
 | 
			
		||||
      {isDesktop && (
 | 
			
		||||
        <Tabs
 | 
			
		||||
          defaultValue={pathname}
 | 
			
		||||
          className="w-[300px] flex pointer-events-auto"
 | 
			
		||||
          className="w-[400px] flex pointer-events-auto"
 | 
			
		||||
        >
 | 
			
		||||
          <TabsList className="grid w-full grid-cols-3">
 | 
			
		||||
          <TabsList className="grid w-full grid-cols-4">
 | 
			
		||||
            {pages.map((page) => (
 | 
			
		||||
              <TabsTrigger
 | 
			
		||||
                key={page.name}
 | 
			
		||||
 | 
			
		||||
@ -42,9 +42,7 @@ const Model = z.object({
 | 
			
		||||
  url: z.string(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const CivitalModelSchema = z.object({
 | 
			
		||||
  items: z.array(
 | 
			
		||||
    z.object({
 | 
			
		||||
export const CivitaiModel = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  description: z.string(),
 | 
			
		||||
@ -117,7 +115,7 @@ export const CivitalModelSchema = z.object({
 | 
			
		||||
          // }),
 | 
			
		||||
          downloadUrl: z.string(),
 | 
			
		||||
          // primary: z.boolean().default(false),
 | 
			
		||||
            })
 | 
			
		||||
        }),
 | 
			
		||||
      ),
 | 
			
		||||
      images: z.array(
 | 
			
		||||
        z.object({
 | 
			
		||||
@ -134,13 +132,15 @@ export const CivitalModelSchema = z.object({
 | 
			
		||||
            height: z.number(),
 | 
			
		||||
          }),
 | 
			
		||||
          meta: z.any(),
 | 
			
		||||
            })
 | 
			
		||||
        }),
 | 
			
		||||
      ),
 | 
			
		||||
      downloadUrl: z.string(),
 | 
			
		||||
        })
 | 
			
		||||
      ),
 | 
			
		||||
    })
 | 
			
		||||
    }),
 | 
			
		||||
  ),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const CivitalModelSchema = z.object({
 | 
			
		||||
  items: z.array(CivitaiModel),
 | 
			
		||||
  metadata: z.object({
 | 
			
		||||
    totalItems: z.number(),
 | 
			
		||||
    currentPage: z.number(),
 | 
			
		||||
@ -197,7 +197,7 @@ function mapType(type: string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function mapModelsList(
 | 
			
		||||
  models: z.infer<typeof CivitalModelSchema>
 | 
			
		||||
  models: z.infer<typeof CivitalModelSchema>,
 | 
			
		||||
): z.infer<typeof ModelListWrapper> {
 | 
			
		||||
  return {
 | 
			
		||||
    models: models.items.flatMap((item) => {
 | 
			
		||||
@ -241,8 +241,9 @@ function getUrl(search?: string) {
 | 
			
		||||
export function CivitaiModelRegistry({
 | 
			
		||||
  field,
 | 
			
		||||
}: Pick<AutoFormInputComponentProps, "field">) {
 | 
			
		||||
  const [modelList, setModelList] =
 | 
			
		||||
    React.useState<z.infer<typeof ModelListWrapper>>();
 | 
			
		||||
  const [modelList, setModelList] = React.useState<
 | 
			
		||||
    z.infer<typeof ModelListWrapper>
 | 
			
		||||
  >();
 | 
			
		||||
 | 
			
		||||
  const [loading, setLoading] = React.useState(false);
 | 
			
		||||
 | 
			
		||||
@ -301,8 +302,9 @@ export function CivitaiModelRegistry({
 | 
			
		||||
export function ComfyUIManagerModelRegistry({
 | 
			
		||||
  field,
 | 
			
		||||
}: Pick<AutoFormInputComponentProps, "field">) {
 | 
			
		||||
  const [modelList, setModelList] =
 | 
			
		||||
    React.useState<z.infer<typeof ModelListWrapper>>();
 | 
			
		||||
  const [modelList, setModelList] = React.useState<
 | 
			
		||||
    z.infer<typeof ModelListWrapper>
 | 
			
		||||
  >();
 | 
			
		||||
 | 
			
		||||
  React.useEffect(() => {
 | 
			
		||||
    const controller = new AbortController();
 | 
			
		||||
@ -310,7 +312,7 @@ export function ComfyUIManagerModelRegistry({
 | 
			
		||||
      "https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/model-list.json",
 | 
			
		||||
      {
 | 
			
		||||
        signal: controller.signal,
 | 
			
		||||
      }
 | 
			
		||||
      },
 | 
			
		||||
    )
 | 
			
		||||
      .then((x) => x.json())
 | 
			
		||||
      .then((a) => {
 | 
			
		||||
@ -353,14 +355,14 @@ export function ModelSelector({
 | 
			
		||||
    if (
 | 
			
		||||
      prevSelectedModels.some(
 | 
			
		||||
        (selectedModel) =>
 | 
			
		||||
          selectedModel.url + selectedModel.name === model.url + model.name
 | 
			
		||||
          selectedModel.url + selectedModel.name === model.url + model.name,
 | 
			
		||||
      )
 | 
			
		||||
    ) {
 | 
			
		||||
      field.onChange(
 | 
			
		||||
        prevSelectedModels.filter(
 | 
			
		||||
          (selectedModel) =>
 | 
			
		||||
            selectedModel.url + selectedModel.name !== model.url + model.name
 | 
			
		||||
        )
 | 
			
		||||
            selectedModel.url + selectedModel.name !== model.url + model.name,
 | 
			
		||||
        ),
 | 
			
		||||
      );
 | 
			
		||||
    } else {
 | 
			
		||||
      field.onChange([...prevSelectedModels, model]);
 | 
			
		||||
@ -408,10 +410,10 @@ export function ModelSelector({
 | 
			
		||||
                      className={cn(
 | 
			
		||||
                        "ml-auto h-4 w-4",
 | 
			
		||||
                        value.some(
 | 
			
		||||
                          (selectedModel) => selectedModel.url === model.url
 | 
			
		||||
                            (selectedModel) => selectedModel.url === model.url,
 | 
			
		||||
                          )
 | 
			
		||||
                          ? "opacity-100"
 | 
			
		||||
                          : "opacity-0"
 | 
			
		||||
                          : "opacity-0",
 | 
			
		||||
                      )}
 | 
			
		||||
                    />
 | 
			
		||||
                  </CommandItem>
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										89
									
								
								web/src/components/custom-form/checkpoint-input.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								web/src/components/custom-form/checkpoint-input.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,89 @@
 | 
			
		||||
import type { AutoFormInputComponentProps } from "../ui/auto-form/types";
 | 
			
		||||
import { FormControl, FormItem, FormLabel } from "../ui/form";
 | 
			
		||||
import { LoadingIcon } from "@/components/LoadingIcon";
 | 
			
		||||
import * as React from "react";
 | 
			
		||||
import AutoFormInput from "../ui/auto-form/fields/input";
 | 
			
		||||
import { useDebouncedCallback } from "use-debounce";
 | 
			
		||||
import { CivitaiModel } from "./ModelPickerView";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
function getUrl(civitai_url: string) {
 | 
			
		||||
  // expect to be a URL to be https://civitai.com/models/36520
 | 
			
		||||
  // possiblity with slugged name and query-param modelVersionId
 | 
			
		||||
 | 
			
		||||
  const baseUrl = "https://civitai.com/api/v1/models/";
 | 
			
		||||
  const url = new URL(civitai_url);
 | 
			
		||||
  const pathSegments = url.pathname.split("/");
 | 
			
		||||
  const modelId = pathSegments[pathSegments.indexOf("models") + 1];
 | 
			
		||||
  const modelVersionId = url.searchParams.get("modelVersionId");
 | 
			
		||||
 | 
			
		||||
  return { url: baseUrl + modelId, modelVersionId };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export default function AutoFormCheckpointInput(
 | 
			
		||||
  props: AutoFormInputComponentProps,
 | 
			
		||||
) {
 | 
			
		||||
  const [loading, setLoading] = React.useState(false);
 | 
			
		||||
  const [modelRes, setModelRes] = React.useState<
 | 
			
		||||
    z.infer<typeof CivitaiModel>
 | 
			
		||||
  >();
 | 
			
		||||
  const [modelVersionid, setModelVersionId] = React.useState<string | null>();
 | 
			
		||||
  const { label, isRequired, fieldProps, zodItem, fieldConfigItem } = props;
 | 
			
		||||
 | 
			
		||||
  const handleSearch = useDebouncedCallback((search) => {
 | 
			
		||||
    const validationResult = insertCivitaiCheckpointSchema.shape.civitai_url
 | 
			
		||||
      .safeParse(search);
 | 
			
		||||
    if (!validationResult.success) {
 | 
			
		||||
      console.error(validationResult.error);
 | 
			
		||||
      // Optionally set an error state here
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
 | 
			
		||||
    const controller = new AbortController();
 | 
			
		||||
    const { url, modelVersionId: versionId } = getUrl(search);
 | 
			
		||||
    setModelVersionId(versionId);
 | 
			
		||||
    fetch(url, {
 | 
			
		||||
      signal: controller.signal,
 | 
			
		||||
    })
 | 
			
		||||
      .then((x) => x.json())
 | 
			
		||||
      .then((a) => {
 | 
			
		||||
        const res = CivitaiModel.parse(a);
 | 
			
		||||
        console.log(a);
 | 
			
		||||
        console.log(res);
 | 
			
		||||
        setModelRes(res);
 | 
			
		||||
        setLoading(false);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
    return () => {
 | 
			
		||||
      controller.abort();
 | 
			
		||||
      setLoading(false);
 | 
			
		||||
    };
 | 
			
		||||
  }, 300);
 | 
			
		||||
 | 
			
		||||
  const modifiedField = {
 | 
			
		||||
    ...fieldProps,
 | 
			
		||||
    // onChange: (event: React.ChangeEvent<HTMLInputElement>) => {
 | 
			
		||||
    //   handleSearch(event.target.value);
 | 
			
		||||
    // },
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <FormItem>
 | 
			
		||||
      {fieldConfigItem.inputProps?.showLabel && (
 | 
			
		||||
        <FormLabel>
 | 
			
		||||
          {label}
 | 
			
		||||
          {isRequired && <span className="text-destructive">*</span>}
 | 
			
		||||
        </FormLabel>
 | 
			
		||||
      )}
 | 
			
		||||
      <FormControl>
 | 
			
		||||
        <AutoFormInput
 | 
			
		||||
          {...props}
 | 
			
		||||
          fieldProps={modifiedField}
 | 
			
		||||
        />
 | 
			
		||||
      </FormControl>
 | 
			
		||||
    </FormItem>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
@ -8,6 +8,7 @@ import AutoFormSwitch from "./fields/switch";
 | 
			
		||||
import AutoFormTextarea from "./fields/textarea";
 | 
			
		||||
import AutoFormModelsPicker from "@/components/custom-form/model-picker";
 | 
			
		||||
import AutoFormSnapshotPicker from "@/components/custom-form/snapshot-picker";
 | 
			
		||||
import AutoFormCheckpointInput from "@/components/custom-form/checkpoint-input";
 | 
			
		||||
 | 
			
		||||
export const INPUT_COMPONENTS = {
 | 
			
		||||
  checkbox: AutoFormCheckbox,
 | 
			
		||||
@ -22,6 +23,7 @@ export const INPUT_COMPONENTS = {
 | 
			
		||||
  // Customs
 | 
			
		||||
  snapshot: AutoFormSnapshotPicker,
 | 
			
		||||
  models: AutoFormModelsPicker,
 | 
			
		||||
  checkpoints: AutoFormCheckpointInput,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 | 
			
		||||
@ -1,13 +1,13 @@
 | 
			
		||||
import { relations, type InferSelectModel } from "drizzle-orm";
 | 
			
		||||
import { type InferSelectModel, relations } from "drizzle-orm";
 | 
			
		||||
import {
 | 
			
		||||
  text,
 | 
			
		||||
  pgSchema,
 | 
			
		||||
  uuid,
 | 
			
		||||
  boolean,
 | 
			
		||||
  integer,
 | 
			
		||||
  timestamp,
 | 
			
		||||
  jsonb,
 | 
			
		||||
  pgEnum,
 | 
			
		||||
  boolean,
 | 
			
		||||
  pgSchema,
 | 
			
		||||
  text,
 | 
			
		||||
  timestamp,
 | 
			
		||||
  uuid,
 | 
			
		||||
} from "drizzle-orm/pg-core";
 | 
			
		||||
import { createInsertSchema } from "drizzle-zod";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
@ -87,7 +87,7 @@ export const workflowVersionRelations = relations(
 | 
			
		||||
      fields: [workflowVersionTable.workflow_id],
 | 
			
		||||
      references: [workflowTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
  }),
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const workflowRunStatus = pgEnum("workflow_run_status", [
 | 
			
		||||
@ -136,10 +136,11 @@ export const workflowRunsTable = dbSchema.table("workflow_runs", {
 | 
			
		||||
    () => workflowVersionTable.id,
 | 
			
		||||
    {
 | 
			
		||||
      onDelete: "set null",
 | 
			
		||||
    }
 | 
			
		||||
    },
 | 
			
		||||
  ),
 | 
			
		||||
  workflow_inputs:
 | 
			
		||||
    jsonb("workflow_inputs").$type<Record<string, string | number>>(),
 | 
			
		||||
  workflow_inputs: jsonb("workflow_inputs").$type<
 | 
			
		||||
    Record<string, string | number>
 | 
			
		||||
  >(),
 | 
			
		||||
  workflow_id: uuid("workflow_id")
 | 
			
		||||
    .notNull()
 | 
			
		||||
    .references(() => workflowTable.id, {
 | 
			
		||||
@ -171,7 +172,7 @@ export const workflowRunRelations = relations(
 | 
			
		||||
      fields: [workflowRunsTable.workflow_id],
 | 
			
		||||
      references: [workflowTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
  }),
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// We still want to keep the workflow run record.
 | 
			
		||||
@ -195,7 +196,7 @@ export const workflowOutputRelations = relations(
 | 
			
		||||
      fields: [workflowRunOutputs.run_id],
 | 
			
		||||
      references: [workflowRunsTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
  }),
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// when user delete, also delete all the workflow versions
 | 
			
		||||
@ -228,7 +229,7 @@ export const snapshotType = z.object({
 | 
			
		||||
    z.object({
 | 
			
		||||
      hash: z.string(),
 | 
			
		||||
      disabled: z.boolean(),
 | 
			
		||||
    })
 | 
			
		||||
    }),
 | 
			
		||||
  ),
 | 
			
		||||
  file_custom_nodes: z.array(z.any()),
 | 
			
		||||
});
 | 
			
		||||
@ -243,7 +244,7 @@ export const showcaseMedia = z.array(
 | 
			
		||||
  z.object({
 | 
			
		||||
    url: z.string(),
 | 
			
		||||
    isCover: z.boolean().default(false),
 | 
			
		||||
  })
 | 
			
		||||
  }),
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const showcaseMediaNullable = z
 | 
			
		||||
@ -251,7 +252,7 @@ export const showcaseMediaNullable = z
 | 
			
		||||
    z.object({
 | 
			
		||||
      url: z.string(),
 | 
			
		||||
      isCover: z.boolean().default(false),
 | 
			
		||||
    })
 | 
			
		||||
    }),
 | 
			
		||||
  )
 | 
			
		||||
  .nullable();
 | 
			
		||||
 | 
			
		||||
@ -275,8 +276,9 @@ export const deploymentsTable = dbSchema.table("deployments", {
 | 
			
		||||
    .notNull()
 | 
			
		||||
    .references(() => machinesTable.id),
 | 
			
		||||
  description: text("description"),
 | 
			
		||||
  showcase_media:
 | 
			
		||||
    jsonb("showcase_media").$type<z.infer<typeof showcaseMedia>>(),
 | 
			
		||||
  showcase_media: jsonb("showcase_media").$type<
 | 
			
		||||
    z.infer<typeof showcaseMedia>
 | 
			
		||||
  >(),
 | 
			
		||||
  environment: deploymentEnvironment("environment").notNull(),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
@ -329,7 +331,126 @@ export const apiKeyTable = dbSchema.table("api_keys", {
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
const civitaiModelVersion = z.object({
 | 
			
		||||
// const civitaiModelVersion = z.object({
 | 
			
		||||
//   id: z.number(),
 | 
			
		||||
//   modelId: z.number(),
 | 
			
		||||
//   name: z.string(),
 | 
			
		||||
//   createdAt: z.string(),
 | 
			
		||||
//   updatedAt: z.string(),
 | 
			
		||||
//   status: z.string(),
 | 
			
		||||
//   publishedAt: z.string(),
 | 
			
		||||
//   trainedWords: z.array(z.string()).optional(),
 | 
			
		||||
//   trainingStatus: z.string().optional(),
 | 
			
		||||
//   trainingDetails: z.string().optional(),
 | 
			
		||||
//   baseModel: z.string(),
 | 
			
		||||
//   baseModelType: z.string(),
 | 
			
		||||
//   earlyAccessTimeFrame: z.number(),
 | 
			
		||||
//   description: z.string().optional(),
 | 
			
		||||
//   vaeId: z.string().optional(),
 | 
			
		||||
//   stats: z.object({
 | 
			
		||||
//     downloadCount: z.number(),
 | 
			
		||||
//     ratingCount: z.number(),
 | 
			
		||||
//     rating: z.number(),
 | 
			
		||||
//   }),
 | 
			
		||||
//   files: z.array(z.object({
 | 
			
		||||
//     id: z.number(),
 | 
			
		||||
//     sizeKB: z.number(),
 | 
			
		||||
//     name: z.string(),
 | 
			
		||||
//     type: z.string(),
 | 
			
		||||
//     metadata: z.object({
 | 
			
		||||
//       fp: z.string(),
 | 
			
		||||
//       size: z.string(),
 | 
			
		||||
//       format: z.string(),
 | 
			
		||||
//     }),
 | 
			
		||||
//     pickleScanResult: z.string(),
 | 
			
		||||
//     pickleScanMessage: z.string().optional(),
 | 
			
		||||
//     virusScanResult: z.string(),
 | 
			
		||||
//     virusScanMessage: z.string().optional(),
 | 
			
		||||
//     scannedAt: z.string(),
 | 
			
		||||
//     hashes: z.object({
 | 
			
		||||
//       AutoV1: z.string(),
 | 
			
		||||
//       AutoV2: z.string(),
 | 
			
		||||
//       SHA256: z.string(),
 | 
			
		||||
//       CRC32: z.string(),
 | 
			
		||||
//       BLAKE3: z.string(),
 | 
			
		||||
//       AutoV3: z.string(),
 | 
			
		||||
//     }),
 | 
			
		||||
//     downloadUrl: z.string(),
 | 
			
		||||
//     primary: z.boolean(),
 | 
			
		||||
//   })),
 | 
			
		||||
//   images: z.array(z.object({
 | 
			
		||||
//     url: z.string(),
 | 
			
		||||
//     nsfw: z.string(),
 | 
			
		||||
//     width: z.number(),
 | 
			
		||||
//     height: z.number(),
 | 
			
		||||
//     hash: z.string(),
 | 
			
		||||
//     type: z.string(),
 | 
			
		||||
//     metadata: z.object({
 | 
			
		||||
//       hash: z.string(),
 | 
			
		||||
//       size: z.number(),
 | 
			
		||||
//       width: z.number(),
 | 
			
		||||
//       height: z.number(),
 | 
			
		||||
//     }),
 | 
			
		||||
//     meta: z.any(),
 | 
			
		||||
//   })),
 | 
			
		||||
//   downloadUrl: z.string(),
 | 
			
		||||
// });
 | 
			
		||||
//
 | 
			
		||||
// const civitaiModelResponseType = z.object({
 | 
			
		||||
//   id: z.number(),
 | 
			
		||||
//   name: z.string(),
 | 
			
		||||
//   description: z.string().optional(),
 | 
			
		||||
//   type: z.string(),
 | 
			
		||||
//   poi: z.boolean(),
 | 
			
		||||
//   nsfw: z.boolean(),
 | 
			
		||||
//   allowNoCredit: z.boolean(),
 | 
			
		||||
//   allowCommercialUse: z.string(),
 | 
			
		||||
//   allowDerivatives: z.boolean(),
 | 
			
		||||
//   allowDifferentLicense: z.boolean(),
 | 
			
		||||
//   stats: z.object({
 | 
			
		||||
//     downloadCount: z.number(),
 | 
			
		||||
//     favoriteCount: z.number(),
 | 
			
		||||
//     commentCount: z.number(),
 | 
			
		||||
//     ratingCount: z.number(),
 | 
			
		||||
//     rating: z.number(),
 | 
			
		||||
//     tippedAmountCount: z.number(),
 | 
			
		||||
//   }),
 | 
			
		||||
//   creator: z.object({
 | 
			
		||||
//     username: z.string(),
 | 
			
		||||
//     image: z.string(),
 | 
			
		||||
//   }),
 | 
			
		||||
//   tags: z.array(z.string()),
 | 
			
		||||
//   modelVersions: z.array(civitaiModelVersion),
 | 
			
		||||
// });
 | 
			
		||||
 | 
			
		||||
export const CivitaiModel = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  description: z.string(),
 | 
			
		||||
  type: z.string(),
 | 
			
		||||
  // poi: z.boolean(),
 | 
			
		||||
  // nsfw: z.boolean(),
 | 
			
		||||
  // allowNoCredit: z.boolean(),
 | 
			
		||||
  // allowCommercialUse: z.string(),
 | 
			
		||||
  // allowDerivatives: z.boolean(),
 | 
			
		||||
  // allowDifferentLicense: z.boolean(),
 | 
			
		||||
  // stats: z.object({
 | 
			
		||||
  //   downloadCount: z.number(),
 | 
			
		||||
  //   favoriteCount: z.number(),
 | 
			
		||||
  //   commentCount: z.number(),
 | 
			
		||||
  //   ratingCount: z.number(),
 | 
			
		||||
  //   rating: z.number(),
 | 
			
		||||
  //   tippedAmountCount: z.number(),
 | 
			
		||||
  // }),
 | 
			
		||||
  creator: z
 | 
			
		||||
    .object({
 | 
			
		||||
      username: z.string().nullable(),
 | 
			
		||||
      image: z.string().nullable().default(null),
 | 
			
		||||
    })
 | 
			
		||||
    .nullable(),
 | 
			
		||||
  tags: z.array(z.string()),
 | 
			
		||||
  modelVersions: z.array(
 | 
			
		||||
    z.object({
 | 
			
		||||
      id: z.number(),
 | 
			
		||||
      modelId: z.number(),
 | 
			
		||||
      name: z.string(),
 | 
			
		||||
@ -337,46 +458,49 @@ const civitaiModelVersion = z.object({
 | 
			
		||||
      updatedAt: z.string(),
 | 
			
		||||
      status: z.string(),
 | 
			
		||||
      publishedAt: z.string(),
 | 
			
		||||
  trainedWords: z.array(z.string()).optional(),
 | 
			
		||||
  trainingStatus: z.string().optional(),
 | 
			
		||||
  trainingDetails: z.string().optional(),
 | 
			
		||||
      trainedWords: z.array(z.unknown()),
 | 
			
		||||
      trainingStatus: z.string().nullable(),
 | 
			
		||||
      trainingDetails: z.string().nullable(),
 | 
			
		||||
      baseModel: z.string(),
 | 
			
		||||
  baseModelType: z.string(),
 | 
			
		||||
      baseModelType: z.string().nullable(),
 | 
			
		||||
      earlyAccessTimeFrame: z.number(),
 | 
			
		||||
  description: z.string().optional(),
 | 
			
		||||
  vaeId: z.string().optional(),
 | 
			
		||||
      description: z.string().nullable(),
 | 
			
		||||
      vaeId: z.number().nullable(),
 | 
			
		||||
      stats: z.object({
 | 
			
		||||
        downloadCount: z.number(),
 | 
			
		||||
        ratingCount: z.number(),
 | 
			
		||||
    rating: z.number()
 | 
			
		||||
        rating: z.number(),
 | 
			
		||||
      }),
 | 
			
		||||
  files: z.array(z.object({
 | 
			
		||||
      files: z.array(
 | 
			
		||||
        z.object({
 | 
			
		||||
          id: z.number(),
 | 
			
		||||
          sizeKB: z.number(),
 | 
			
		||||
          name: z.string(),
 | 
			
		||||
          type: z.string(),
 | 
			
		||||
    metadata: z.object({
 | 
			
		||||
      fp: z.string(),
 | 
			
		||||
      size: z.string(),
 | 
			
		||||
      format: z.string()
 | 
			
		||||
    }),
 | 
			
		||||
    pickleScanResult: z.string(),
 | 
			
		||||
    pickleScanMessage: z.string().optional(),
 | 
			
		||||
    virusScanResult: z.string(),
 | 
			
		||||
    virusScanMessage: z.string().optional(),
 | 
			
		||||
    scannedAt: z.string(),
 | 
			
		||||
    hashes: z.object({
 | 
			
		||||
      AutoV1: z.string(),
 | 
			
		||||
      AutoV2: z.string(),
 | 
			
		||||
      SHA256: z.string(),
 | 
			
		||||
      CRC32: z.string(),
 | 
			
		||||
      BLAKE3: z.string(),
 | 
			
		||||
      AutoV3: z.string()
 | 
			
		||||
    }),
 | 
			
		||||
          // metadata: z.object({
 | 
			
		||||
          //   fp: z.string().nullable().optional(),
 | 
			
		||||
          //   size: z.string().nullable().optional(),
 | 
			
		||||
          //   format: z.string().nullable().optional(),
 | 
			
		||||
          // }),
 | 
			
		||||
          // pickleScanResult: z.string(),
 | 
			
		||||
          // pickleScanMessage: z.string(),
 | 
			
		||||
          // virusScanResult: z.string(),
 | 
			
		||||
          // virusScanMessage: z.string().nullable(),
 | 
			
		||||
          // scannedAt: z.string(),
 | 
			
		||||
          // hashes: z.object({
 | 
			
		||||
          //   AutoV1: z.string().nullable().optional(),
 | 
			
		||||
          //   AutoV2: z.string().nullable().optional(),
 | 
			
		||||
          //   SHA256: z.string().nullable().optional(),
 | 
			
		||||
          //   CRC32: z.string().nullable().optional(),
 | 
			
		||||
          //   BLAKE3: z.string().nullable().optional(),
 | 
			
		||||
          // }),
 | 
			
		||||
          downloadUrl: z.string(),
 | 
			
		||||
    primary: z.boolean()
 | 
			
		||||
  })),
 | 
			
		||||
  images: z.array(z.object({
 | 
			
		||||
          // primary: z.boolean().default(false),
 | 
			
		||||
        }),
 | 
			
		||||
      ),
 | 
			
		||||
      images: z.array(
 | 
			
		||||
        z.object({
 | 
			
		||||
          id: z.number(),
 | 
			
		||||
          url: z.string(),
 | 
			
		||||
          nsfw: z.string(),
 | 
			
		||||
          width: z.number(),
 | 
			
		||||
@ -385,65 +509,113 @@ const civitaiModelVersion = z.object({
 | 
			
		||||
          type: z.string(),
 | 
			
		||||
          metadata: z.object({
 | 
			
		||||
            hash: z.string(),
 | 
			
		||||
      size: z.number(),
 | 
			
		||||
            width: z.number(),
 | 
			
		||||
      height: z.number()
 | 
			
		||||
            height: z.number(),
 | 
			
		||||
          }),
 | 
			
		||||
    meta: z.any()
 | 
			
		||||
  })),
 | 
			
		||||
  downloadUrl: z.string()
 | 
			
		||||
          meta: z.any(),
 | 
			
		||||
        }),
 | 
			
		||||
      ),
 | 
			
		||||
      downloadUrl: z.string(),
 | 
			
		||||
    }),
 | 
			
		||||
  ),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
const civitaiModelResponseType = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  description: z.string().optional(),
 | 
			
		||||
  type: z.string(),
 | 
			
		||||
  poi: z.boolean(),
 | 
			
		||||
  nsfw: z.boolean(),
 | 
			
		||||
  allowNoCredit: z.boolean(),
 | 
			
		||||
  allowCommercialUse: z.string(),
 | 
			
		||||
  allowDerivatives: z.boolean(),
 | 
			
		||||
  allowDifferentLicense: z.boolean(),
 | 
			
		||||
  stats: z.object({
 | 
			
		||||
    downloadCount: z.number(),
 | 
			
		||||
    favoriteCount: z.number(),
 | 
			
		||||
    commentCount: z.number(),
 | 
			
		||||
    ratingCount: z.number(),
 | 
			
		||||
    rating: z.number(),
 | 
			
		||||
    tippedAmountCount: z.number()
 | 
			
		||||
  }),
 | 
			
		||||
  creator: z.object({
 | 
			
		||||
    username: z.string(),
 | 
			
		||||
    image: z.string()
 | 
			
		||||
  }),
 | 
			
		||||
  tags: z.array(z.string()),
 | 
			
		||||
  modelVersions: z.array(civitaiModelVersion)
 | 
			
		||||
});
 | 
			
		||||
export const resourceUpload = pgEnum("resource_upload", [
 | 
			
		||||
  "started",
 | 
			
		||||
  "failed",
 | 
			
		||||
  "succeded",
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
export const modelUploadType = pgEnum("model_upload_type", [
 | 
			
		||||
  "civitai",
 | 
			
		||||
  "huggingface",
 | 
			
		||||
  "other",
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
export const checkpoints = dbSchema.table("checkpoints", {
 | 
			
		||||
export const checkpointTable = dbSchema.table("checkpoints", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id")
 | 
			
		||||
    .references(() => usersTable.id, {
 | 
			
		||||
      // onDelete: "cascade",
 | 
			
		||||
    }), // if null it's global?
 | 
			
		||||
    .references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  description: text("description"),
 | 
			
		||||
 | 
			
		||||
  civitai_id : text('civitai_id'),
 | 
			
		||||
  civitai_url : text('civitai_url'),
 | 
			
		||||
  civitai_details: jsonb("civitai_model_response").$type<z.infer<typeof civitaiModelResponseType >>(),
 | 
			
		||||
  checkpoint_volume_id: uuid("checkpoint_volume_id")
 | 
			
		||||
    .notNull()
 | 
			
		||||
    .references(() => workflowRunsTable.id, {
 | 
			
		||||
      onDelete: "cascade",
 | 
			
		||||
    }).notNull(),
 | 
			
		||||
 | 
			
		||||
  hf_url: text('hf_url'),
 | 
			
		||||
  s3_url: text('s3_url'),
 | 
			
		||||
  model_name: text("model_name"),
 | 
			
		||||
 | 
			
		||||
  civitai_id: text("civitai_id"),
 | 
			
		||||
  civitai_version_id: text("civitai_version_id"),
 | 
			
		||||
  civitai_url: text("civitai_url"),
 | 
			
		||||
  civitai_download_url: text("civitai_download_url"),
 | 
			
		||||
  civitai_model_response: jsonb("civitai_model_response").$type<
 | 
			
		||||
    z.infer<typeof CivitaiModel>
 | 
			
		||||
  >(),
 | 
			
		||||
 | 
			
		||||
  hf_url: text("hf_url"),
 | 
			
		||||
  s3_url: text("s3_url"),
 | 
			
		||||
  user_url: text("client_url"),
 | 
			
		||||
 | 
			
		||||
  is_public: boolean("is_public").notNull().default(false),
 | 
			
		||||
  status: resourceUpload("status").notNull().default("started"),
 | 
			
		||||
  upload_machine_id: text("upload_machine_id"),
 | 
			
		||||
  upload_type: modelUploadType("upload_type").notNull(),
 | 
			
		||||
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const insertCivitaiCheckpointSchema = createInsertSchema(
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
  {
 | 
			
		||||
    civitai_url: (schema) =>
 | 
			
		||||
      schema.civitai_url.trim().url({ message: "URL required" }).includes(
 | 
			
		||||
        "civitai.com/models",
 | 
			
		||||
        { message: "civitai.com/models link required" },
 | 
			
		||||
      ),
 | 
			
		||||
  },
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeTable = dbSchema.table("checkpointVolumeTable", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id")
 | 
			
		||||
    .references(() => usersTable.id, {
 | 
			
		||||
      // onDelete: "cascade",
 | 
			
		||||
    }),
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  volume_name: text("volume_name").notNull(),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
  disabled: boolean("disabled").default(false).notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
 | 
			
		||||
  user: one(usersTable, {
 | 
			
		||||
    fields: [checkpointTable.user_id],
 | 
			
		||||
    references: [usersTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
  volume: one(checkpointVolumeTable, {
 | 
			
		||||
    fields: [checkpointTable.checkpoint_volume_id],
 | 
			
		||||
    references: [checkpointVolumeTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
}));
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeRelations = relations(
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  ({ many }) => ({
 | 
			
		||||
    checkpoint: many(checkpointTable),
 | 
			
		||||
  }),
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export type UserType = InferSelectModel<typeof usersTable>;
 | 
			
		||||
export type WorkflowType = InferSelectModel<typeof workflowTable>;
 | 
			
		||||
export type MachineType = InferSelectModel<typeof machinesTable>;
 | 
			
		||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
 | 
			
		||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
 | 
			
		||||
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
 | 
			
		||||
export type CheckpointVolumeType = InferSelectModel<
 | 
			
		||||
  typeof checkpointVolumeTable
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								web/src/server/addCheckpointSchema.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								web/src/server/addCheckpointSchema.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
 | 
			
		||||
  civitai_url: true,
 | 
			
		||||
});
 | 
			
		||||
							
								
								
									
										222
									
								
								web/src/server/curdCheckpoint.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								web/src/server/curdCheckpoint.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,222 @@
 | 
			
		||||
"use server";
 | 
			
		||||
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import {
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
  CheckpointType,
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  CheckpointVolumeType,
 | 
			
		||||
  CivitaiModel,
 | 
			
		||||
} from "@/db/schema";
 | 
			
		||||
import { withServerPromise } from "./withServerPromise";
 | 
			
		||||
import { redirect } from "next/navigation";
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import type { z } from "zod";
 | 
			
		||||
import { headers } from "next/headers";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
 | 
			
		||||
import { and, eq, isNull } from "drizzle-orm";
 | 
			
		||||
 | 
			
		||||
export async function getCheckpoints() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpoints = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      orgId
 | 
			
		||||
        ? eq(checkpointTable.org_id, orgId)
 | 
			
		||||
        // make sure org_id is null
 | 
			
		||||
        : and(
 | 
			
		||||
          eq(checkpointTable.user_id, userId),
 | 
			
		||||
          isNull(checkpointTable.org_id),
 | 
			
		||||
        ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpoints;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function getCheckpointById(id: string) {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpoint = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      and(
 | 
			
		||||
        orgId ? eq(checkpointTable.org_id, orgId) : and(
 | 
			
		||||
          eq(checkpointTable.user_id, userId),
 | 
			
		||||
          isNull(checkpointTable.org_id),
 | 
			
		||||
        ),
 | 
			
		||||
        eq(checkpointTable.id, id),
 | 
			
		||||
      ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpoint[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function getCheckpointVolumes() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpointVolume = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointVolumeTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      and(
 | 
			
		||||
        orgId
 | 
			
		||||
          ? eq(checkpointVolumeTable.org_id, orgId)
 | 
			
		||||
          // make sure org_id is null
 | 
			
		||||
          : and(
 | 
			
		||||
            eq(checkpointVolumeTable.user_id, userId),
 | 
			
		||||
            isNull(checkpointVolumeTable.org_id),
 | 
			
		||||
          ),
 | 
			
		||||
        eq(checkpointVolumeTable.disabled, false),
 | 
			
		||||
      ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpointVolume;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function addCheckpointVolume() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
 | 
			
		||||
  // Insert the new volume into the checkpointVolumeTable
 | 
			
		||||
  const insertedVolume = await db
 | 
			
		||||
    .insert(checkpointVolumeTable)
 | 
			
		||||
    .values({
 | 
			
		||||
      user_id: userId,
 | 
			
		||||
      org_id: orgId,
 | 
			
		||||
      volume_name: `checkpoints_${userId}`,
 | 
			
		||||
      // created_at and updated_at will be set to current timestamp by default
 | 
			
		||||
      disabled: false, // Default value
 | 
			
		||||
    })
 | 
			
		||||
    .returning(); // Returns the inserted row
 | 
			
		||||
 | 
			
		||||
  return insertedVolume;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function getUrl(civitai_url: string) {
 | 
			
		||||
  // expect to be a URL to be https://civitai.com/models/36520
 | 
			
		||||
  // possiblity with slugged name and query-param modelVersionId
 | 
			
		||||
  const baseUrl = "https://civitai.com/api/v1/models/";
 | 
			
		||||
  const url = new URL(civitai_url);
 | 
			
		||||
  const pathSegments = url.pathname.split("/");
 | 
			
		||||
  const modelId = pathSegments[pathSegments.indexOf("models") + 1];
 | 
			
		||||
  const modelVersionId = url.searchParams.get("modelVersionId");
 | 
			
		||||
 | 
			
		||||
  return { url: baseUrl + modelId, modelVersionId };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
  async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
 | 
			
		||||
    const { userId, orgId } = auth();
 | 
			
		||||
 | 
			
		||||
    if (!data.civitai_url) return { error: "no civitai_url" };
 | 
			
		||||
    if (!userId) return { error: "No user id" };
 | 
			
		||||
 | 
			
		||||
    const { url, modelVersionId } = getUrl(data?.civitai_url);
 | 
			
		||||
    const civitaiModelRes = await fetch(url)
 | 
			
		||||
      .then((x) => x.json())
 | 
			
		||||
      .then((a) => {
 | 
			
		||||
        console.log(a)
 | 
			
		||||
        return CivitaiModel.parse(a);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
    if (civitaiModelRes.modelVersions?.length === 0) {
 | 
			
		||||
      return; // no versions to download
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let selectedModelVersion;
 | 
			
		||||
    let selectedModelVersionId: string | null = modelVersionId;
 | 
			
		||||
    if (!selectedModelVersionId) {
 | 
			
		||||
      selectedModelVersion = civitaiModelRes.modelVersions[0];
 | 
			
		||||
      selectedModelVersionId = civitaiModelRes.modelVersions[0].id.toString();
 | 
			
		||||
    } else {
 | 
			
		||||
      selectedModelVersion = civitaiModelRes.modelVersions.find((version) =>
 | 
			
		||||
        version.id.toString() === selectedModelVersionId
 | 
			
		||||
      );
 | 
			
		||||
      if (!selectedModelVersion) {
 | 
			
		||||
        return; // version id is wrong
 | 
			
		||||
      }
 | 
			
		||||
      selectedModelVersionId = selectedModelVersion?.id.toString();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const checkpointVolumes = await getCheckpointVolumes();
 | 
			
		||||
    let cVolume;
 | 
			
		||||
    if (checkpointVolumes.length === 0) {
 | 
			
		||||
      const volume = await addCheckpointVolume();
 | 
			
		||||
      cVolume = volume[0];
 | 
			
		||||
    } else {
 | 
			
		||||
      cVolume = checkpointVolumes[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const a = await db
 | 
			
		||||
      .insert(checkpointTable)
 | 
			
		||||
      .values({
 | 
			
		||||
        user_id: userId,
 | 
			
		||||
        org_id: orgId,
 | 
			
		||||
        upload_type: "civitai",
 | 
			
		||||
        civitai_id: civitaiModelRes.id.toString(),
 | 
			
		||||
        civitai_version_id: selectedModelVersionId,
 | 
			
		||||
        civitai_model_response: civitaiModelRes,
 | 
			
		||||
        checkpoint_volume_id: cVolume.id,
 | 
			
		||||
      })
 | 
			
		||||
      .returning();
 | 
			
		||||
 | 
			
		||||
    const b = a[0];
 | 
			
		||||
 | 
			
		||||
    await uploadCheckpoint(data, b, cVolume);
 | 
			
		||||
    redirect(`/checkpoints/${b.id}`);
 | 
			
		||||
  },
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
async function uploadCheckpoint(
 | 
			
		||||
  data: z.infer<typeof addCivitaiCheckpointSchema>,
 | 
			
		||||
  b: CheckpointType,
 | 
			
		||||
  v: CheckpointVolumeType,
 | 
			
		||||
) {
 | 
			
		||||
  const headersList = headers();
 | 
			
		||||
 | 
			
		||||
  const domain = headersList.get("x-forwarded-host") || "";
 | 
			
		||||
  const protocol = headersList.get("x-forwarded-proto") || "";
 | 
			
		||||
 | 
			
		||||
  if (domain === "") {
 | 
			
		||||
    throw new Error("No domain");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Call remote builder
 | 
			
		||||
  const result = await fetch(
 | 
			
		||||
    `${process.env.MODAL_BUILDER_URL!}/upload_volume`,
 | 
			
		||||
    {
 | 
			
		||||
      method: "POST",
 | 
			
		||||
      headers: {
 | 
			
		||||
        "Content-Type": "application/json",
 | 
			
		||||
      },
 | 
			
		||||
      body: JSON.stringify({
 | 
			
		||||
        download_url: data.civitai_url,
 | 
			
		||||
        volume_name: v.volume_name,
 | 
			
		||||
        callback_url: `${protocol}://${domain}/api/volume-updated`,
 | 
			
		||||
      }),
 | 
			
		||||
    },
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  if (!result.ok) {
 | 
			
		||||
    // const error_log = await result.text();
 | 
			
		||||
    // await db
 | 
			
		||||
    //   .update(checkpointTable)
 | 
			
		||||
    //   .set({
 | 
			
		||||
    //     ...data,
 | 
			
		||||
    //     status: "error",
 | 
			
		||||
    //     build_log: error_log,
 | 
			
		||||
    //   })
 | 
			
		||||
    //   .where(eq(machinesTable.id, b.id));
 | 
			
		||||
    // throw new Error(`Error: ${result.statusText} ${error_log}`);
 | 
			
		||||
  } else {
 | 
			
		||||
    // setting the build machine id
 | 
			
		||||
    const json = await result.json();
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        ...data,
 | 
			
		||||
        // build_machine_instance_id: json.build_machine_instance_id,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(machinesTable.id, b.id));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										39
									
								
								web/src/server/getAllUserCheckpoints.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								web/src/server/getAllUserCheckpoints.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import {
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
} from "@/db/schema";
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import { and, desc, eq, isNull } from "drizzle-orm";
 | 
			
		||||
 | 
			
		||||
export async function getAllUserCheckpoints() {
 | 
			
		||||
  const { userId, orgId } = await auth();
 | 
			
		||||
 | 
			
		||||
  if (!userId) {
 | 
			
		||||
    return null;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const checkpoints = await db.query.checkpointTable.findMany({
 | 
			
		||||
    with: {
 | 
			
		||||
      user: {
 | 
			
		||||
        columns: {
 | 
			
		||||
          name: true,
 | 
			
		||||
        },
 | 
			
		||||
      },
 | 
			
		||||
    },
 | 
			
		||||
    columns: {
 | 
			
		||||
      id: true,
 | 
			
		||||
      updated_at: true,
 | 
			
		||||
      name: true,
 | 
			
		||||
      civitai_url: true,
 | 
			
		||||
      civitai_model_response: true,
 | 
			
		||||
      is_public: true,
 | 
			
		||||
    },
 | 
			
		||||
    orderBy: desc(checkpointTable.updated_at),
 | 
			
		||||
    where: 
 | 
			
		||||
      orgId != undefined
 | 
			
		||||
        ? eq(checkpointTable.org_id, orgId)
 | 
			
		||||
        : and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return checkpoints;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										55
									
								
								web/src/types/civitai.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								web/src/types/civitai.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,55 @@
 | 
			
		||||
import { z } from 'zod';
 | 
			
		||||
 | 
			
		||||
export const creatorSchema = z.object({
 | 
			
		||||
  username: z.string(),
 | 
			
		||||
  image: z.string(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const statsSchema = z.object({
 | 
			
		||||
  downloadCount: z.number(),
 | 
			
		||||
  favoriteCount: z.number(),
 | 
			
		||||
  commentCount: z.number(),
 | 
			
		||||
  ratingCount: z.number(),
 | 
			
		||||
  rating: z.number(),
 | 
			
		||||
  tippedAmountCount: z.number(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const modelVersionSchema = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  modelId: z.number(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  createdAt: z.string(),
 | 
			
		||||
  updatedAt: z.string(),
 | 
			
		||||
  status: z.string(),
 | 
			
		||||
  publishedAt: z.string(),
 | 
			
		||||
  trainedWords: z.array(z.any()), // Replace with more specific type if known
 | 
			
		||||
  trainingStatus: z.any().optional(),
 | 
			
		||||
  trainingDetails: z.any().optional(),
 | 
			
		||||
  baseModel: z.string(),
 | 
			
		||||
  baseModelType: z.string(),
 | 
			
		||||
  earlyAccessTimeFrame: z.number(),
 | 
			
		||||
  description: z.string().optional(),
 | 
			
		||||
  vaeId: z.any().optional(), // Replace with more specific type if known
 | 
			
		||||
  stats: statsSchema.optional(), // If stats structure is known, replace with specific type
 | 
			
		||||
  files: z.array(z.any()), // Replace with more specific type if known
 | 
			
		||||
  images: z.array(z.any()), // Replace with more specific type if known
 | 
			
		||||
  downloadUrl: z.string(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const CivitaiModel = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  name: z.string(),
 | 
			
		||||
  description: z.string(),
 | 
			
		||||
  type: z.string(),
 | 
			
		||||
  poi: z.boolean(),
 | 
			
		||||
  nsfw: z.boolean(),
 | 
			
		||||
  allowNoCredit: z.boolean(),
 | 
			
		||||
  allowCommercialUse: z.string(),
 | 
			
		||||
  allowDerivatives: z.boolean(),
 | 
			
		||||
  allowDifferentLicense: z.boolean(),
 | 
			
		||||
  stats: statsSchema,
 | 
			
		||||
  creator: creatorSchema,
 | 
			
		||||
  tags: z.array(z.string()),
 | 
			
		||||
  modelVersions: z.array(modelVersionSchema),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user