Squashed commit of the following:
commit c36b0ec0b374dd8ccbee3a6044ee7e3f1fefe368
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 17:54:54 2024 -0800
    nits on wording and removing link to broken storage/:id page
commit 0777fdcf7b0002244bc713199d3d64eea6b6061e
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 17:23:55 2024 -0800
    builder update config and such
commit 958b795bb2b6ac27ce33c5729ef265b068420e1a
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 17:23:43 2024 -0800
    rename all from checkponit to model
commit 7a9c5636e73bd005499b141a4dd382db5672c962
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 16:51:59 2024 -0800
    rename for consistency
commit 48bebbafab9a95388817df97c15f8ea97e0fea75
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 16:18:36 2024 -0800
    bulider
commit 81dacd9af457886f2f027994d225a7748c738abb
Author: Nicholas Koben Kao <kobenkao@gmail.com>
Date:   Thu Jan 25 16:17:56 2024 -0800
    different types of models
			
			
This commit is contained in:
		
							parent
							
								
									62a69dba06
								
							
						
					
					
						commit
						85477aba9d
					
				@ -3,4 +3,5 @@ MODAL_TOKEN_SECRET=
 | 
			
		||||
CIVITAI_API_KEY=
 | 
			
		||||
 | 
			
		||||
# On production set to False
 | 
			
		||||
DEPLOY_TEST_FLAG=True
 | 
			
		||||
DEPLOY_TEST_FLAG=True
 | 
			
		||||
CIVITAI_API_KEY=
 | 
			
		||||
 | 
			
		||||
@ -177,7 +177,7 @@ class Item(BaseModel):
 | 
			
		||||
    snapshot: Snapshot
 | 
			
		||||
    models: List[Model]
 | 
			
		||||
    callback_url: str
 | 
			
		||||
    checkpoint_volume_name: str
 | 
			
		||||
    model_volume_name: str
 | 
			
		||||
    gpu: GPUType = Field(default=GPUType.T4)
 | 
			
		||||
 | 
			
		||||
    @field_validator('gpu')
 | 
			
		||||
@ -227,24 +227,31 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
			
		||||
 | 
			
		||||
#     return {"Hello": "World"}
 | 
			
		||||
 | 
			
		||||
# definition based on web schema
 | 
			
		||||
class UploadType(str, Enum):
 | 
			
		||||
    checkpoint = "checkpoint"
 | 
			
		||||
    lora = "lora"
 | 
			
		||||
    embedding = "embedding"
 | 
			
		||||
 | 
			
		||||
class UploadBody(BaseModel):
 | 
			
		||||
    download_url: str
 | 
			
		||||
    volume_name: str
 | 
			
		||||
    volume_id: str
 | 
			
		||||
    checkpoint_id: str
 | 
			
		||||
    model_id: str
 | 
			
		||||
    upload_type: UploadType
 | 
			
		||||
    callback_url: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# based on ComfyUI's model dir, and our mappings in ./src/template/data/extra_model_paths.yaml
 | 
			
		||||
UPLOAD_TYPE_DIR_MAP = {
 | 
			
		||||
    UploadType.checkpoint: "checkpoints"
 | 
			
		||||
    UploadType.checkpoint: "checkpoints",
 | 
			
		||||
    UploadType.lora: "loras",
 | 
			
		||||
    UploadType.embedding: "embeddings",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/upload-volume")
 | 
			
		||||
async def upload_checkpoint(body: UploadBody):
 | 
			
		||||
async def upload_model(body: UploadBody):
 | 
			
		||||
    global last_activity_time
 | 
			
		||||
    last_activity_time = time.time()
 | 
			
		||||
    logger.info(f"Extended inactivity time to {global_timeout}")
 | 
			
		||||
@ -254,6 +261,7 @@ async def upload_checkpoint(body: UploadBody):
 | 
			
		||||
    # check that this
 | 
			
		||||
    return JSONResponse(status_code=200, content={"message": "Volume uploading", "build_machine_instance_id": fly_instance_id})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def upload_logic(body: UploadBody):
 | 
			
		||||
    folder_path = f"/app/builds/{body.volume_id}"
 | 
			
		||||
 | 
			
		||||
@ -270,7 +278,7 @@ async def upload_logic(body: UploadBody):
 | 
			
		||||
        }, 
 | 
			
		||||
        "callback_url": body.callback_url,
 | 
			
		||||
        "callback_body": {
 | 
			
		||||
            "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
            "model_id": body.model_id,
 | 
			
		||||
            "volume_id": body.volume_id,
 | 
			
		||||
            "folder_path": upload_path,
 | 
			
		||||
        },
 | 
			
		||||
@ -279,51 +287,11 @@ async def upload_logic(body: UploadBody):
 | 
			
		||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
			
		||||
        f.write("config = " + json.dumps(config))
 | 
			
		||||
 | 
			
		||||
    process = await asyncio.subprocess.create_subprocess_shell(
 | 
			
		||||
    await asyncio.subprocess.create_subprocess_shell(
 | 
			
		||||
        f"modal run app.py",
 | 
			
		||||
        # stdout=asyncio.subprocess.PIPE,
 | 
			
		||||
        # stderr=asyncio.subprocess.PIPE,
 | 
			
		||||
        cwd=folder_path,
 | 
			
		||||
        env={**os.environ, "COLUMNS": "10000"}
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    # error_logs = []
 | 
			
		||||
    # async def read_stream(stream):
 | 
			
		||||
    #     while True:
 | 
			
		||||
    #         line = await stream.readline()
 | 
			
		||||
    #         if line:
 | 
			
		||||
    #             l = line.decode('utf-8').strip()
 | 
			
		||||
    #             error_logs.append(l)
 | 
			
		||||
    #             logger.error(l)
 | 
			
		||||
    #             error_logs.append({
 | 
			
		||||
    #                 "logs": l,
 | 
			
		||||
    #                 "timestamp": time.time()
 | 
			
		||||
    #             })
 | 
			
		||||
    #         else:
 | 
			
		||||
    #             break
 | 
			
		||||
 | 
			
		||||
    # stderr_read_task = asyncio.create_task(read_stream(process.stderr))
 | 
			
		||||
    #
 | 
			
		||||
    # await asyncio.wait([stderr_read_task])
 | 
			
		||||
    # await process.wait()
 | 
			
		||||
 | 
			
		||||
    # if process.returncode != 0:
 | 
			
		||||
    #     error_logs.append({"logs": "Unable to upload volume.", "timestamp": time.time()})
 | 
			
		||||
    #     # Error handling: send POST request to callback URL with error details
 | 
			
		||||
    #     requests.post(body.callback_url, json={
 | 
			
		||||
    #         "volume_id": body.volume_id, 
 | 
			
		||||
    #         "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
    #         "folder_path": upload_path,
 | 
			
		||||
    #         "error_logs": json.dumps(error_logs),
 | 
			
		||||
    #         "status": "failed"
 | 
			
		||||
    #     })
 | 
			
		||||
    #
 | 
			
		||||
    # requests.post(body.callback_url, json={
 | 
			
		||||
    #     "checkpoint_id": body.checkpoint_id,
 | 
			
		||||
    #     "volume_id": body.volume_id,
 | 
			
		||||
    #     "folder_path": upload_path,
 | 
			
		||||
    #     "status": "success"
 | 
			
		||||
    # })
 | 
			
		||||
 | 
			
		||||
@app.post("/create")
 | 
			
		||||
async def create_machine(item: Item):
 | 
			
		||||
@ -414,8 +382,8 @@ async def build_logic(item: Item):
 | 
			
		||||
        "name": item.name,
 | 
			
		||||
        "deploy_test": os.environ.get("DEPLOY_TEST_FLAG", "False"),
 | 
			
		||||
        "gpu": item.gpu,
 | 
			
		||||
        "public_checkpoint_volume": "model-store",
 | 
			
		||||
        "private_checkpoint_volume": item.checkpoint_volume_name
 | 
			
		||||
        "public_model_volume": "model-store",
 | 
			
		||||
        "private_model_volume": item.model_volume_name
 | 
			
		||||
    }
 | 
			
		||||
    with open(f"{folder_path}/config.py", "w") as f:
 | 
			
		||||
        f.write("config = " + json.dumps(config))
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,6 @@ config = {
 | 
			
		||||
    "name": "my-app",
 | 
			
		||||
    "deploy_test": "True",
 | 
			
		||||
    "gpu": "T4", 
 | 
			
		||||
    "public_checkpoint_volume": "model-store",
 | 
			
		||||
    "private_checkpoint_volume": "private-model-store"
 | 
			
		||||
    "public_model_volume": "model-store",
 | 
			
		||||
    "private_model_volume": "private-model-store"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -13,3 +13,11 @@ public:
 | 
			
		||||
private:
 | 
			
		||||
  base_path: /private_models/
 | 
			
		||||
  checkpoints: checkpoints
 | 
			
		||||
  clip: clip
 | 
			
		||||
  clip_vision: clip_vision
 | 
			
		||||
  configs: configs
 | 
			
		||||
  controlnet: controlnet
 | 
			
		||||
  embeddings: embeddings
 | 
			
		||||
  loras: loras
 | 
			
		||||
  upscale_models: upscale_models
 | 
			
		||||
  vae: vae
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,8 @@
 | 
			
		||||
import modal
 | 
			
		||||
from config import config
 | 
			
		||||
 | 
			
		||||
public_model_volume = modal.Volume.persisted(config["public_checkpoint_volume"])
 | 
			
		||||
private_volume = modal.Volume.persisted(config["private_checkpoint_volume"])
 | 
			
		||||
public_model_volume = modal.Volume.persisted(config["public_model_volume"])
 | 
			
		||||
private_volume = modal.Volume.persisted(config["private_model_volume"])
 | 
			
		||||
 | 
			
		||||
PUBLIC_BASEMODEL_DIR = "/public_models"
 | 
			
		||||
PRIVATE_BASEMODEL_DIR = "/private_models"
 | 
			
		||||
 | 
			
		||||
@ -1,18 +1,18 @@
 | 
			
		||||
config = {
 | 
			
		||||
    "volume_names": {
 | 
			
		||||
        "test": {
 | 
			
		||||
            "download_url": "https://pub-6230db03dc3a4861a9c3e55145ceda44.r2.dev/openpose-pose (1).png",
 | 
			
		||||
            "folder_path": "images"
 | 
			
		||||
        "user4": {
 | 
			
		||||
            "download_url": "https://civitai.com/api/download/models/11745",
 | 
			
		||||
            "folder_path": "checkpoints"
 | 
			
		||||
        }
 | 
			
		||||
    }, 
 | 
			
		||||
    "volume_paths": {
 | 
			
		||||
        "test": "/volumes/something"
 | 
			
		||||
        "user4": "/volumes/something",
 | 
			
		||||
    },
 | 
			
		||||
    "callback_url": "",
 | 
			
		||||
    "callback_body": {
 | 
			
		||||
        "checkpoint_id": "",
 | 
			
		||||
        "model_id": "",
 | 
			
		||||
        "volume_id": "",
 | 
			
		||||
        "folder_path": "images",
 | 
			
		||||
        "folder_path": "checkpoints",
 | 
			
		||||
    }, 
 | 
			
		||||
    "civitai_api_key": "",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								web/drizzle/0043_dapper_santa_claus.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								web/drizzle/0043_dapper_santa_claus.sql
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 CREATE TYPE "model_type" AS ENUM('checkpoint', 'lora', 'embedding', 'vae');
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."checkpoints" ADD COLUMN "model_type" "model_type" NOT NULL;
 | 
			
		||||
							
								
								
									
										26
									
								
								web/drizzle/0044_married_malcolm_colcord.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								web/drizzle/0044_married_malcolm_colcord.sql
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."checkpoints" RENAME TO "models";--> statement-breakpoint
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."checkpoint_volume" RENAME TO "user_volume";--> statement-breakpoint
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."models" RENAME COLUMN "checkpoint_volume_id" TO "user_volume_id";--> statement-breakpoint
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."models" DROP CONSTRAINT "checkpoints_user_id_users_id_fk";
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."models" DROP CONSTRAINT "checkpoints_checkpoint_volume_id_checkpoint_volume_id_fk";
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
ALTER TABLE "comfyui_deploy"."user_volume" DROP CONSTRAINT "checkpoint_volume_user_id_users_id_fk";
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."models" ADD CONSTRAINT "models_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."models" ADD CONSTRAINT "models_user_volume_id_user_volume_id_fk" FOREIGN KEY ("user_volume_id") REFERENCES "comfyui_deploy"."user_volume"("id") ON DELETE cascade ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
--> statement-breakpoint
 | 
			
		||||
DO $$ BEGIN
 | 
			
		||||
 ALTER TABLE "comfyui_deploy"."user_volume" ADD CONSTRAINT "user_volume_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "comfyui_deploy"."users"("id") ON DELETE no action ON UPDATE no action;
 | 
			
		||||
EXCEPTION
 | 
			
		||||
 WHEN duplicate_object THEN null;
 | 
			
		||||
END $$;
 | 
			
		||||
							
								
								
									
										1288
									
								
								web/drizzle/meta/0043_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1288
									
								
								web/drizzle/meta/0043_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										1293
									
								
								web/drizzle/meta/0044_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1293
									
								
								web/drizzle/meta/0044_snapshot.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -302,6 +302,20 @@
 | 
			
		||||
      "when": 1706164614659,
 | 
			
		||||
      "tag": "0042_windy_madelyne_pryor",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "idx": 43,
 | 
			
		||||
      "version": "5",
 | 
			
		||||
      "when": 1706225960550,
 | 
			
		||||
      "tag": "0043_dapper_santa_claus",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "idx": 44,
 | 
			
		||||
      "version": "5",
 | 
			
		||||
      "when": 1706230304140,
 | 
			
		||||
      "tag": "0044_married_malcolm_colcord",
 | 
			
		||||
      "breakpoints": true
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
@ -1,15 +1,15 @@
 | 
			
		||||
import { parseDataSafe } from "../../../../lib/parseDataSafe";
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import { checkpointTable, machinesTable } from "@/db/schema";
 | 
			
		||||
import { modelTable } from "@/db/schema";
 | 
			
		||||
import { eq } from "drizzle-orm";
 | 
			
		||||
import { NextResponse } from "next/server";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
 | 
			
		||||
const Request = z.object({
 | 
			
		||||
  volume_id: z.string(),
 | 
			
		||||
  checkpoint_id: z.string(),
 | 
			
		||||
  model_id: z.string(),
 | 
			
		||||
  folder_path: z.string().optional(),
 | 
			
		||||
  status: z.enum(['success', 'failed']),
 | 
			
		||||
  status: z.enum(["success", "failed"]),
 | 
			
		||||
  error_log: z.string().optional(),
 | 
			
		||||
  timeout: z.number().optional(),
 | 
			
		||||
});
 | 
			
		||||
@ -18,30 +18,30 @@ export async function POST(request: Request) {
 | 
			
		||||
  const [data, error] = await parseDataSafe(Request, request);
 | 
			
		||||
  if (!data || error) return error;
 | 
			
		||||
 | 
			
		||||
  const { checkpoint_id, error_log, status, folder_path } = data;
 | 
			
		||||
  console.log( checkpoint_id, error_log, status, folder_path )
 | 
			
		||||
  const { model_id, error_log, status, folder_path } = data;
 | 
			
		||||
  console.log(model_id, error_log, status, folder_path);
 | 
			
		||||
 | 
			
		||||
  if (status === "success") {
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .update(modelTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        status: "success",
 | 
			
		||||
        folder_path,
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, checkpoint_id));
 | 
			
		||||
      .where(eq(modelTable.id, model_id));
 | 
			
		||||
  } else {
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .update(modelTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        status: "failed",
 | 
			
		||||
        error_log, 
 | 
			
		||||
        error_log,
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
        // status: "error",
 | 
			
		||||
        // build_log: build_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, checkpoint_id));
 | 
			
		||||
      .where(eq(modelTable.id, model_id));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return NextResponse.json(
 | 
			
		||||
@ -50,6 +50,6 @@ export async function POST(request: Request) {
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      status: 200,
 | 
			
		||||
    }
 | 
			
		||||
    },
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,14 +1,14 @@
 | 
			
		||||
import { setInitialUserData } from "../../../lib/setInitialUserData";
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import { clerkClient } from "@clerk/nextjs/server";
 | 
			
		||||
import { CheckpointList } from "@/components/CheckpointList";
 | 
			
		||||
import { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
			
		||||
import { ModelList } from "@/components/ModelList";
 | 
			
		||||
import { getAllUserModels } from "@/server/getAllUserModel";
 | 
			
		||||
 | 
			
		||||
export default function Page() {
 | 
			
		||||
  return <CheckpointListServer />;
 | 
			
		||||
  return <ModelListServer />;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async function CheckpointListServer() {
 | 
			
		||||
async function ModelListServer() {
 | 
			
		||||
  const { userId } = auth();
 | 
			
		||||
 | 
			
		||||
  if (!userId) {
 | 
			
		||||
@ -21,15 +21,15 @@ async function CheckpointListServer() {
 | 
			
		||||
    await setInitialUserData(userId);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const checkpoints = await getAllUserCheckpoints();
 | 
			
		||||
  const models = await getAllUserModels();
 | 
			
		||||
 | 
			
		||||
  if (!checkpoints) {
 | 
			
		||||
    return <div>No checkpoints found</div>;
 | 
			
		||||
  if (!models) {
 | 
			
		||||
    return <div>No models found</div>;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <div className="w-full">
 | 
			
		||||
      <CheckpointList data={checkpoints} />
 | 
			
		||||
      <ModelList data={models} />
 | 
			
		||||
    </div>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ import { getRelativeTime } from "../lib/getRelativeTime";
 | 
			
		||||
import { Badge } from "@/components/ui/badge";
 | 
			
		||||
import { Button } from "@/components/ui/button";
 | 
			
		||||
import { Checkbox } from "@/components/ui/checkbox";
 | 
			
		||||
import { InsertModal, UpdateModal } from "./InsertModal";
 | 
			
		||||
import { InsertModal } from "./InsertModal";
 | 
			
		||||
import { Input } from "@/components/ui/input";
 | 
			
		||||
import { ScrollArea } from "@/components/ui/scroll-area";
 | 
			
		||||
import {
 | 
			
		||||
@ -15,7 +15,7 @@ import {
 | 
			
		||||
  TableHeader,
 | 
			
		||||
  TableRow,
 | 
			
		||||
} from "@/components/ui/table";
 | 
			
		||||
import type { getAllUserCheckpoints } from "@/server/getAllUserCheckpoints";
 | 
			
		||||
import type { getAllUserModels as getAllUserModels } from "@/server/getAllUserModel";
 | 
			
		||||
import type {
 | 
			
		||||
  ColumnDef,
 | 
			
		||||
  ColumnFiltersState,
 | 
			
		||||
@ -32,23 +32,22 @@ import {
 | 
			
		||||
} from "@tanstack/react-table";
 | 
			
		||||
import { ArrowUpDown } from "lucide-react";
 | 
			
		||||
import * as React from "react";
 | 
			
		||||
import { addCivitaiCheckpoint } from "@/server/curdCheckpoint";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "@/server/addCheckpointSchema";
 | 
			
		||||
import { addCivitaiModel } from "@/server/curdModel";
 | 
			
		||||
import { addCivitaiModelSchema } from "@/server/addCivitaiModelSchema";
 | 
			
		||||
import { modelEnumType } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
export type CheckpointItemList = NonNullable<
 | 
			
		||||
  Awaited<ReturnType<typeof getAllUserCheckpoints>>
 | 
			
		||||
export type ModelItemList = NonNullable<
 | 
			
		||||
  Awaited<ReturnType<typeof getAllUserModels>>
 | 
			
		||||
>[0];
 | 
			
		||||
 | 
			
		||||
export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
export const columns: ColumnDef<ModelItemList>[] = [
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "id",
 | 
			
		||||
    id: "select",
 | 
			
		||||
    header: ({ table }) => (
 | 
			
		||||
      <Checkbox
 | 
			
		||||
        checked={
 | 
			
		||||
          table.getIsAllPageRowsSelected() ||
 | 
			
		||||
          (table.getIsSomePageRowsSelected() && "indeterminate")
 | 
			
		||||
        }
 | 
			
		||||
        checked={table.getIsAllPageRowsSelected() ||
 | 
			
		||||
          (table.getIsSomePageRowsSelected() && "indeterminate")}
 | 
			
		||||
        onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
 | 
			
		||||
        aria-label="Select all"
 | 
			
		||||
      />
 | 
			
		||||
@ -77,22 +76,23 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      const checkpoint = row.original;
 | 
			
		||||
      const model = row.original;
 | 
			
		||||
      return (
 | 
			
		||||
        <a
 | 
			
		||||
        <>
 | 
			
		||||
          {
 | 
			
		||||
            /*<a
 | 
			
		||||
          className="hover:underline flex gap-2"
 | 
			
		||||
          href={`/storage/${checkpoint.id}`} // TODO
 | 
			
		||||
        >
 | 
			
		||||
          href={`/storage/${model.id}`} // TODO
 | 
			
		||||
        >*/
 | 
			
		||||
          }
 | 
			
		||||
          <span className="truncate max-w-[200px]">
 | 
			
		||||
            {row.original.model_name}
 | 
			
		||||
          </span>
 | 
			
		||||
 | 
			
		||||
          {checkpoint.is_public ? (
 | 
			
		||||
            <Badge variant="green">Public</Badge>
 | 
			
		||||
          ) : (
 | 
			
		||||
            <Badge variant="orange">Private</Badge>
 | 
			
		||||
          )}
 | 
			
		||||
        </a>
 | 
			
		||||
          {model.is_public
 | 
			
		||||
            ? <Badge variant="green">Public</Badge>
 | 
			
		||||
            : <Badge variant="orange">Private</Badge>}
 | 
			
		||||
        </>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
@ -111,7 +111,11 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <Badge variant={row.original.status === "failed" ? "red" : (row.original.status === "started" ? "yellow" : "green")}>
 | 
			
		||||
        <Badge
 | 
			
		||||
          variant={row.original.status === "failed"
 | 
			
		||||
            ? "red"
 | 
			
		||||
            : (row.original.status === "started" ? "yellow" : "green")}
 | 
			
		||||
        >
 | 
			
		||||
          {row.original.status}
 | 
			
		||||
        </Badge>
 | 
			
		||||
      );
 | 
			
		||||
@ -167,6 +171,35 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
      return <Badge variant="cyan">{row.original.upload_type}</Badge>;
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "model_type",
 | 
			
		||||
    header: ({ column }) => {
 | 
			
		||||
      return (
 | 
			
		||||
        <button
 | 
			
		||||
          className="flex items-center hover:underline"
 | 
			
		||||
          onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
 | 
			
		||||
        >
 | 
			
		||||
          Model Type
 | 
			
		||||
          <ArrowUpDown className="ml-2 h-4 w-4" />
 | 
			
		||||
        </button>
 | 
			
		||||
      );
 | 
			
		||||
    },
 | 
			
		||||
    cell: ({ row }) => {
 | 
			
		||||
      const model_type_map: Record<modelEnumType, any> = {
 | 
			
		||||
        "checkpoint": "amber",
 | 
			
		||||
        "lora": "green",
 | 
			
		||||
        "embedding": "violet",
 | 
			
		||||
        "vae": "teal",
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      function getBadgeColor(modelType: modelEnumType) {
 | 
			
		||||
        return model_type_map[modelType] || "default";
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      const color = getBadgeColor(row.original.model_type);
 | 
			
		||||
      return <Badge variant={color}>{row.original.model_type}</Badge>;
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    accessorKey: "date",
 | 
			
		||||
    sortingFn: "datetime",
 | 
			
		||||
@ -221,13 +254,14 @@ export const columns: ColumnDef<CheckpointItemList>[] = [
 | 
			
		||||
  // },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
export function ModelList({ data }: { data: ModelItemList[] }) {
 | 
			
		||||
  const [sorting, setSorting] = React.useState<SortingState>([]);
 | 
			
		||||
  const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
 | 
			
		||||
    []
 | 
			
		||||
    [],
 | 
			
		||||
  );
 | 
			
		||||
  const [columnVisibility, setColumnVisibility] =
 | 
			
		||||
    React.useState<VisibilityState>({});
 | 
			
		||||
  const [columnVisibility, setColumnVisibility] = React.useState<
 | 
			
		||||
    VisibilityState
 | 
			
		||||
  >({});
 | 
			
		||||
  const [rowSelection, setRowSelection] = React.useState({});
 | 
			
		||||
 | 
			
		||||
  const table = useReactTable({
 | 
			
		||||
@ -254,10 +288,10 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
      <div className="flex flex-row w-full items-center py-4">
 | 
			
		||||
        <Input
 | 
			
		||||
          placeholder="Filter workflows..."
 | 
			
		||||
          value={(table.getColumn("name")?.getFilterValue() as string) ?? ""}
 | 
			
		||||
          value={(table.getColumn("model_name")?.getFilterValue() as string) ??
 | 
			
		||||
            ""}
 | 
			
		||||
          onChange={(event) =>
 | 
			
		||||
            table.getColumn("name")?.setFilterValue(event.target.value)
 | 
			
		||||
          }
 | 
			
		||||
            table.getColumn("model_name")?.setFilterValue(event.target.value)}
 | 
			
		||||
          className="max-w-sm"
 | 
			
		||||
        />
 | 
			
		||||
        <div className="ml-auto flex gap-2">
 | 
			
		||||
@ -268,17 +302,17 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
              // TODO: limitations based on plan
 | 
			
		||||
            }
 | 
			
		||||
            tooltip={"Add models using their civitai url!"}
 | 
			
		||||
            title="Civitai Checkpoint"
 | 
			
		||||
            title="Add a Civitai Model"
 | 
			
		||||
            description="Pick a model from civitai"
 | 
			
		||||
            serverAction={addCivitaiCheckpoint}
 | 
			
		||||
            formSchema={addCivitaiCheckpointSchema}
 | 
			
		||||
            serverAction={addCivitaiModel}
 | 
			
		||||
            formSchema={addCivitaiModelSchema}
 | 
			
		||||
            fieldConfig={{
 | 
			
		||||
              civitai_url: {
 | 
			
		||||
                fieldType: "fallback",
 | 
			
		||||
                inputProps: { required: true },
 | 
			
		||||
                description: (
 | 
			
		||||
                  <>
 | 
			
		||||
                    Pick a checkpoint from{" "}
 | 
			
		||||
                    Pick a model from{" "}
 | 
			
		||||
                    <a
 | 
			
		||||
                      href="https://www.civitai.com/models"
 | 
			
		||||
                      target="_blank"
 | 
			
		||||
@ -302,12 +336,10 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
                {headerGroup.headers.map((header) => {
 | 
			
		||||
                  return (
 | 
			
		||||
                    <TableHead key={header.id}>
 | 
			
		||||
                      {header.isPlaceholder
 | 
			
		||||
                        ? null
 | 
			
		||||
                        : flexRender(
 | 
			
		||||
                            header.column.columnDef.header,
 | 
			
		||||
                            header.getContext()
 | 
			
		||||
                          )}
 | 
			
		||||
                      {header.isPlaceholder ? null : flexRender(
 | 
			
		||||
                        header.column.columnDef.header,
 | 
			
		||||
                        header.getContext(),
 | 
			
		||||
                      )}
 | 
			
		||||
                    </TableHead>
 | 
			
		||||
                  );
 | 
			
		||||
                })}
 | 
			
		||||
@ -315,32 +347,34 @@ export function CheckpointList({ data }: { data: CheckpointItemList[] }) {
 | 
			
		||||
            ))}
 | 
			
		||||
          </TableHeader>
 | 
			
		||||
          <TableBody>
 | 
			
		||||
            {table.getRowModel().rows?.length ? (
 | 
			
		||||
              table.getRowModel().rows.map((row) => (
 | 
			
		||||
                <TableRow
 | 
			
		||||
                  key={row.id}
 | 
			
		||||
                  data-state={row.getIsSelected() && "selected"}
 | 
			
		||||
                >
 | 
			
		||||
                  {row.getVisibleCells().map((cell) => (
 | 
			
		||||
                    <TableCell key={cell.id}>
 | 
			
		||||
                      {flexRender(
 | 
			
		||||
                        cell.column.columnDef.cell,
 | 
			
		||||
                        cell.getContext()
 | 
			
		||||
                      )}
 | 
			
		||||
                    </TableCell>
 | 
			
		||||
                  ))}
 | 
			
		||||
            {table.getRowModel().rows?.length
 | 
			
		||||
              ? (
 | 
			
		||||
                table.getRowModel().rows.map((row) => (
 | 
			
		||||
                  <TableRow
 | 
			
		||||
                    key={row.id}
 | 
			
		||||
                    data-state={row.getIsSelected() && "selected"}
 | 
			
		||||
                  >
 | 
			
		||||
                    {row.getVisibleCells().map((cell) => (
 | 
			
		||||
                      <TableCell key={cell.id}>
 | 
			
		||||
                        {flexRender(
 | 
			
		||||
                          cell.column.columnDef.cell,
 | 
			
		||||
                          cell.getContext(),
 | 
			
		||||
                        )}
 | 
			
		||||
                      </TableCell>
 | 
			
		||||
                    ))}
 | 
			
		||||
                  </TableRow>
 | 
			
		||||
                ))
 | 
			
		||||
              )
 | 
			
		||||
              : (
 | 
			
		||||
                <TableRow>
 | 
			
		||||
                  <TableCell
 | 
			
		||||
                    colSpan={columns.length}
 | 
			
		||||
                    className="h-24 text-center"
 | 
			
		||||
                  >
 | 
			
		||||
                    No results.
 | 
			
		||||
                  </TableCell>
 | 
			
		||||
                </TableRow>
 | 
			
		||||
              ))
 | 
			
		||||
            ) : (
 | 
			
		||||
              <TableRow>
 | 
			
		||||
                <TableCell
 | 
			
		||||
                  colSpan={columns.length}
 | 
			
		||||
                  className="h-24 text-center"
 | 
			
		||||
                >
 | 
			
		||||
                  No results.
 | 
			
		||||
                </TableCell>
 | 
			
		||||
              </TableRow>
 | 
			
		||||
            )}
 | 
			
		||||
              )}
 | 
			
		||||
          </TableBody>
 | 
			
		||||
        </Table>
 | 
			
		||||
      </ScrollArea>
 | 
			
		||||
@ -7,7 +7,7 @@ import AutoFormInput from "../ui/auto-form/fields/input";
 | 
			
		||||
import { useDebouncedCallback } from "use-debounce";
 | 
			
		||||
import { CivitaiModelResponse } from "@/types/civitai";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
import { insertCivitaiModelSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
function getUrl(civitai_url: string) {
 | 
			
		||||
  // expect to be a URL to be https://civitai.com/models/36520
 | 
			
		||||
@ -33,7 +33,7 @@ export default function AutoFormCheckpointInput(
 | 
			
		||||
 | 
			
		||||
  const handleSearch = useDebouncedCallback((search) => {
 | 
			
		||||
    const validationResult =
 | 
			
		||||
      insertCivitaiCheckpointSchema.shape.civitai_url.safeParse(search);
 | 
			
		||||
      insertCivitaiModelSchema.shape.civitai_url.safeParse(search);
 | 
			
		||||
    if (!validationResult.success) {
 | 
			
		||||
      console.error(validationResult.error);
 | 
			
		||||
      // Optionally set an error state here
 | 
			
		||||
@ -12,7 +12,7 @@ import {
 | 
			
		||||
  real,
 | 
			
		||||
} from "drizzle-orm/pg-core";
 | 
			
		||||
import { createInsertSchema, createSelectSchema } from "drizzle-zod";
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
import { TypeOf, z } from "zod";
 | 
			
		||||
 | 
			
		||||
export const dbSchema = pgSchema("comfyui_deploy");
 | 
			
		||||
 | 
			
		||||
@ -376,15 +376,25 @@ export const modelUploadType = pgEnum("model_upload_type", [
 | 
			
		||||
  "other",
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
export const checkpointTable = dbSchema.table("checkpoints", {
 | 
			
		||||
// https://www.answeroverflow.com/m/1125106227387584552 
 | 
			
		||||
const modelTypes  = [
 | 
			
		||||
  "checkpoint",
 | 
			
		||||
  "lora",
 | 
			
		||||
  "embedding",
 | 
			
		||||
  "vae",
 | 
			
		||||
] as const
 | 
			
		||||
export const modelType = pgEnum("model_type", modelTypes);
 | 
			
		||||
export type modelEnumType = typeof modelTypes[number]
 | 
			
		||||
 | 
			
		||||
export const modelTable = dbSchema.table("models", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global checkpoints
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {}), // perhaps a "special" user_id for global models
 | 
			
		||||
  org_id: text("org_id"),
 | 
			
		||||
  description: text("description"),
 | 
			
		||||
 | 
			
		||||
  checkpoint_volume_id: uuid("checkpoint_volume_id")
 | 
			
		||||
  user_volume_id: uuid("user_volume_id")
 | 
			
		||||
    .notNull()
 | 
			
		||||
    .references(() => checkpointVolumeTable.id, {
 | 
			
		||||
    .references(() => userVolume.id, {
 | 
			
		||||
      onDelete: "cascade",
 | 
			
		||||
    })
 | 
			
		||||
    .notNull(),
 | 
			
		||||
@ -408,12 +418,13 @@ export const checkpointTable = dbSchema.table("checkpoints", {
 | 
			
		||||
  status: resourceUpload("status").notNull().default("started"),
 | 
			
		||||
  upload_machine_id: text("upload_machine_id"), // TODO: review if actually needed
 | 
			
		||||
  upload_type: modelUploadType("upload_type").notNull(),
 | 
			
		||||
  model_type: modelType("model_type").notNull(),
 | 
			
		||||
  error_log: text("error_log"),
 | 
			
		||||
  created_at: timestamp("created_at").defaultNow().notNull(),
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
 | 
			
		||||
export const userVolume = dbSchema.table("user_volume", {
 | 
			
		||||
  id: uuid("id").primaryKey().defaultRandom().notNull(),
 | 
			
		||||
  user_id: text("user_id").references(() => usersTable.id, {
 | 
			
		||||
    // onDelete: "cascade",
 | 
			
		||||
@ -425,23 +436,23 @@ export const checkpointVolumeTable = dbSchema.table("checkpoint_volume", {
 | 
			
		||||
  disabled: boolean("disabled").default(false).notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const checkpointRelations = relations(checkpointTable, ({ one }) => ({
 | 
			
		||||
export const modelRelations = relations(modelTable, ({ one }) => ({
 | 
			
		||||
  user: one(usersTable, {
 | 
			
		||||
    fields: [checkpointTable.user_id],
 | 
			
		||||
    fields: [modelTable.user_id],
 | 
			
		||||
    references: [usersTable.id],
 | 
			
		||||
  }),
 | 
			
		||||
  volume: one(checkpointVolumeTable, {
 | 
			
		||||
    fields: [checkpointTable.checkpoint_volume_id],
 | 
			
		||||
    references: [checkpointVolumeTable.id],
 | 
			
		||||
  volume: one(userVolume, {
 | 
			
		||||
    fields: [modelTable.user_volume_id],
 | 
			
		||||
    references: [userVolume.id],
 | 
			
		||||
  }),
 | 
			
		||||
}));
 | 
			
		||||
 | 
			
		||||
export const checkpointVolumeRelations = relations(
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
export const modalVolumeRelations = relations(
 | 
			
		||||
  userVolume,
 | 
			
		||||
  ({ many, one }) => ({
 | 
			
		||||
    checkpoint: many(checkpointTable),
 | 
			
		||||
    model: many(modelTable),
 | 
			
		||||
    user: one(usersTable, {
 | 
			
		||||
      fields: [checkpointVolumeTable.user_id],
 | 
			
		||||
      fields: [userVolume.user_id],
 | 
			
		||||
      references: [usersTable.id],
 | 
			
		||||
    }),
 | 
			
		||||
  })
 | 
			
		||||
@ -473,8 +484,8 @@ export const subscriptionStatusTable = dbSchema.table("subscription_status", {
 | 
			
		||||
  updated_at: timestamp("updated_at").defaultNow().notNull(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export const insertCivitaiCheckpointSchema = createInsertSchema(
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
export const insertCivitaiModelSchema = createInsertSchema(
 | 
			
		||||
  modelTable,
 | 
			
		||||
  {
 | 
			
		||||
    civitai_url: (schema) =>
 | 
			
		||||
      schema.civitai_url
 | 
			
		||||
@ -491,8 +502,8 @@ export type WorkflowType = InferSelectModel<typeof workflowTable>;
 | 
			
		||||
export type MachineType = InferSelectModel<typeof machinesTable>;
 | 
			
		||||
export type WorkflowVersionType = InferSelectModel<typeof workflowVersionTable>;
 | 
			
		||||
export type DeploymentType = InferSelectModel<typeof deploymentsTable>;
 | 
			
		||||
export type CheckpointType = InferSelectModel<typeof checkpointTable>;
 | 
			
		||||
export type CheckpointVolumeType = InferSelectModel<
 | 
			
		||||
  typeof checkpointVolumeTable
 | 
			
		||||
export type ModelType = InferSelectModel<typeof modelTable>;
 | 
			
		||||
export type UserVolumeType = InferSelectModel<
 | 
			
		||||
  typeof userVolume
 | 
			
		||||
>;
 | 
			
		||||
export type UserUsageType = InferSelectModel<typeof userUsageTable>;
 | 
			
		||||
 | 
			
		||||
@ -1,5 +0,0 @@
 | 
			
		||||
import { insertCivitaiCheckpointSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
export const addCivitaiCheckpointSchema = insertCivitaiCheckpointSchema.pick({
 | 
			
		||||
  civitai_url: true,
 | 
			
		||||
});
 | 
			
		||||
							
								
								
									
										5
									
								
								web/src/server/addCivitaiModelSchema.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								web/src/server/addCivitaiModelSchema.tsx
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
import { insertCivitaiModelSchema } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
export const addCivitaiModelSchema = insertCivitaiModelSchema.pick({
 | 
			
		||||
  civitai_url: true,
 | 
			
		||||
});
 | 
			
		||||
@ -15,7 +15,7 @@ import { headers } from "next/headers";
 | 
			
		||||
import { redirect } from "next/navigation";
 | 
			
		||||
import "server-only";
 | 
			
		||||
import type { z } from "zod";
 | 
			
		||||
import { retrieveCheckpointVolumes } from "./curdCheckpoint";
 | 
			
		||||
import { retrieveModelVolumes } from "./curdModel";
 | 
			
		||||
 | 
			
		||||
export async function getMachines() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
@ -190,7 +190,7 @@ async function _buildMachine(
 | 
			
		||||
    throw new Error("No domain");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const volumes = await retrieveCheckpointVolumes();
 | 
			
		||||
  const volumes = await retrieveModelVolumes();
 | 
			
		||||
  // Call remote builder
 | 
			
		||||
  const result = await fetch(`${process.env.MODAL_BUILDER_URL!}/create`, {
 | 
			
		||||
    method: "POST",
 | 
			
		||||
@ -204,7 +204,7 @@ async function _buildMachine(
 | 
			
		||||
      callback_url: `${protocol}://${domain}/api/machine-built`,
 | 
			
		||||
      models: data.models, //JSON.parse(data.models as string),
 | 
			
		||||
      gpu: data.gpu && data.gpu.length > 0 ? data.gpu : "T4",
 | 
			
		||||
      checkpoint_volume_name: volumes[0].volume_name,
 | 
			
		||||
      model_volume_name: volumes[0].volume_name,
 | 
			
		||||
    }),
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,92 +2,91 @@
 | 
			
		||||
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import {
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
  CheckpointType,
 | 
			
		||||
  checkpointVolumeTable,
 | 
			
		||||
  CheckpointVolumeType,
 | 
			
		||||
  modelTable,
 | 
			
		||||
  ModelType,
 | 
			
		||||
  userVolume,
 | 
			
		||||
  UserVolumeType,
 | 
			
		||||
} from "@/db/schema";
 | 
			
		||||
import { withServerPromise } from "./withServerPromise";
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import type { z } from "zod";
 | 
			
		||||
import { headers } from "next/headers";
 | 
			
		||||
import { addCivitaiCheckpointSchema } from "./addCheckpointSchema";
 | 
			
		||||
import { addCivitaiModelSchema } from "./addCivitaiModelSchema";
 | 
			
		||||
import { and, eq, isNull } from "drizzle-orm";
 | 
			
		||||
import { CivitaiModelResponse } from "@/types/civitai";
 | 
			
		||||
import { CivitaiModelResponse, getModelTypeDetails } from "@/types/civitai";
 | 
			
		||||
 | 
			
		||||
export async function getCheckpoints() {
 | 
			
		||||
export async function getModel() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpoints = await db
 | 
			
		||||
  const models = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointTable)
 | 
			
		||||
    .from(modelTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      orgId
 | 
			
		||||
        ? eq(checkpointTable.org_id, orgId)
 | 
			
		||||
        ? eq(modelTable.org_id, orgId)
 | 
			
		||||
        // make sure org_id is null
 | 
			
		||||
        : and(
 | 
			
		||||
          eq(checkpointTable.user_id, userId),
 | 
			
		||||
          isNull(checkpointTable.org_id),
 | 
			
		||||
          eq(modelTable.user_id, userId),
 | 
			
		||||
          isNull(modelTable.org_id),
 | 
			
		||||
        ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpoints;
 | 
			
		||||
  return models;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function getCheckpointById(id: string) {
 | 
			
		||||
export async function getModelById(id: string) {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const checkpoint = await db
 | 
			
		||||
  const model = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointTable)
 | 
			
		||||
    .from(modelTable)
 | 
			
		||||
    .where(
 | 
			
		||||
      and(
 | 
			
		||||
        orgId ? eq(checkpointTable.org_id, orgId) : and(
 | 
			
		||||
          eq(checkpointTable.user_id, userId),
 | 
			
		||||
          isNull(checkpointTable.org_id),
 | 
			
		||||
        orgId ? eq(modelTable.org_id, orgId) : and(
 | 
			
		||||
          eq(modelTable.user_id, userId),
 | 
			
		||||
          isNull(modelTable.org_id),
 | 
			
		||||
        ),
 | 
			
		||||
        eq(checkpointTable.id, id),
 | 
			
		||||
        eq(modelTable.id, id),
 | 
			
		||||
      ),
 | 
			
		||||
    );
 | 
			
		||||
  return checkpoint[0];
 | 
			
		||||
  return model[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function getCheckpointVolumes() {
 | 
			
		||||
export async function getModelVolumes() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
  const volume = await db
 | 
			
		||||
    .select()
 | 
			
		||||
    .from(checkpointVolumeTable)
 | 
			
		||||
    .from(userVolume)
 | 
			
		||||
    .where(
 | 
			
		||||
      and(
 | 
			
		||||
        orgId
 | 
			
		||||
          ? eq(checkpointVolumeTable.org_id, orgId)
 | 
			
		||||
          ? eq(userVolume.org_id, orgId)
 | 
			
		||||
          // make sure org_id is null
 | 
			
		||||
          : and(
 | 
			
		||||
            eq(checkpointVolumeTable.user_id, userId),
 | 
			
		||||
            isNull(checkpointVolumeTable.org_id),
 | 
			
		||||
            eq(userVolume.user_id, userId),
 | 
			
		||||
            isNull(userVolume.org_id),
 | 
			
		||||
          ),
 | 
			
		||||
        eq(checkpointVolumeTable.disabled, false),
 | 
			
		||||
        eq(userVolume.disabled, false),
 | 
			
		||||
      ),
 | 
			
		||||
    );
 | 
			
		||||
  return volume;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function retrieveCheckpointVolumes() {
 | 
			
		||||
  let volumes = await getCheckpointVolumes();
 | 
			
		||||
export async function retrieveModelVolumes() {
 | 
			
		||||
  let volumes = await getModelVolumes();
 | 
			
		||||
  if (volumes.length === 0) {
 | 
			
		||||
    // create volume if not already created
 | 
			
		||||
    volumes = await addCheckpointVolume();
 | 
			
		||||
    volumes = await addModelVolume();
 | 
			
		||||
  }
 | 
			
		||||
  return volumes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function addCheckpointVolume() {
 | 
			
		||||
export async function addModelVolume() {
 | 
			
		||||
  const { userId, orgId } = auth();
 | 
			
		||||
  if (!userId) throw new Error("No user id");
 | 
			
		||||
 | 
			
		||||
  // Insert the new checkpointVolume into the checkpointVolumeTable
 | 
			
		||||
  const insertedVolume = await db
 | 
			
		||||
    .insert(checkpointVolumeTable)
 | 
			
		||||
    .insert(userVolume)
 | 
			
		||||
    .values({
 | 
			
		||||
      user_id: userId,
 | 
			
		||||
      org_id: orgId,
 | 
			
		||||
@ -111,8 +110,8 @@ function getUrl(civitai_url: string) {
 | 
			
		||||
  return { url: baseUrl + modelId, modelVersionId };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
  async (data: z.infer<typeof addCivitaiCheckpointSchema>) => {
 | 
			
		||||
export const addCivitaiModel = withServerPromise(
 | 
			
		||||
  async (data: z.infer<typeof addCivitaiModelSchema>) => {
 | 
			
		||||
    const { userId, orgId } = auth();
 | 
			
		||||
 | 
			
		||||
    if (!data.civitai_url) return { error: "no civitai_url" };
 | 
			
		||||
@ -145,17 +144,22 @@ export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
      selectedModelVersionId = selectedModelVersion?.id.toString();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const checkpointVolumes = await getCheckpointVolumes();
 | 
			
		||||
    const userVolume = await getModelVolumes();
 | 
			
		||||
    let cVolume;
 | 
			
		||||
    if (checkpointVolumes.length === 0) {
 | 
			
		||||
      const volume = await addCheckpointVolume();
 | 
			
		||||
    if (userVolume.length === 0) {
 | 
			
		||||
      const volume = await addModelVolume();
 | 
			
		||||
      cVolume = volume[0];
 | 
			
		||||
    } else {
 | 
			
		||||
      cVolume = checkpointVolumes[0];
 | 
			
		||||
      cVolume = userVolume[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const model_type = getModelTypeDetails(civitaiModelRes.type);
 | 
			
		||||
    if (!model_type) {
 | 
			
		||||
      return 
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const a = await db
 | 
			
		||||
      .insert(checkpointTable)
 | 
			
		||||
      .insert(modelTable)
 | 
			
		||||
      .values({
 | 
			
		||||
        user_id: userId,
 | 
			
		||||
        org_id: orgId,
 | 
			
		||||
@ -166,15 +170,15 @@ export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
        civitai_url: data.civitai_url,
 | 
			
		||||
        civitai_download_url: selectedModelVersion.files[0].downloadUrl,
 | 
			
		||||
        civitai_model_response: civitaiModelRes,
 | 
			
		||||
        checkpoint_volume_id: cVolume.id,
 | 
			
		||||
        user_volume_id: cVolume.id,
 | 
			
		||||
        model_type, 
 | 
			
		||||
        updated_at: new Date(),
 | 
			
		||||
      })
 | 
			
		||||
      .returning();
 | 
			
		||||
 | 
			
		||||
    const b = a[0];
 | 
			
		||||
 | 
			
		||||
    await uploadCheckpoint(data, b, cVolume);
 | 
			
		||||
    // redirect(`/checkpoints/${b.id}`);
 | 
			
		||||
    await uploadModel(data, b, cVolume);
 | 
			
		||||
  },
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
@ -213,10 +217,10 @@ export const addCivitaiCheckpoint = withServerPromise(
 | 
			
		||||
//   },
 | 
			
		||||
// );
 | 
			
		||||
 | 
			
		||||
async function uploadCheckpoint(
 | 
			
		||||
  data: z.infer<typeof addCivitaiCheckpointSchema>,
 | 
			
		||||
  c: CheckpointType,
 | 
			
		||||
  v: CheckpointVolumeType,
 | 
			
		||||
async function uploadModel(
 | 
			
		||||
  data: z.infer<typeof addCivitaiModelSchema>,
 | 
			
		||||
  c: ModelType,
 | 
			
		||||
  v: UserVolumeType,
 | 
			
		||||
) {
 | 
			
		||||
  const headersList = headers();
 | 
			
		||||
 | 
			
		||||
@ -239,9 +243,9 @@ async function uploadCheckpoint(
 | 
			
		||||
        download_url: c.civitai_download_url,
 | 
			
		||||
        volume_name: v.volume_name,
 | 
			
		||||
        volume_id: v.id,
 | 
			
		||||
        checkpoint_id: c.id,
 | 
			
		||||
        model_id: c.id,
 | 
			
		||||
        callback_url: `${protocol}://${domain}/api/volume-upload`,
 | 
			
		||||
        upload_type: "checkpoint"
 | 
			
		||||
        upload_type: c.model_type,
 | 
			
		||||
      }),
 | 
			
		||||
    },
 | 
			
		||||
  );
 | 
			
		||||
@ -249,23 +253,23 @@ async function uploadCheckpoint(
 | 
			
		||||
  if (!result.ok) {
 | 
			
		||||
    const error_log = await result.text();
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .update(modelTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        ...data,
 | 
			
		||||
        status: "failed",
 | 
			
		||||
        error_log: error_log,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, c.id));
 | 
			
		||||
      .where(eq(modelTable.id, c.id));
 | 
			
		||||
    throw new Error(`Error: ${result.statusText} ${error_log}`);
 | 
			
		||||
  } else {
 | 
			
		||||
    // setting the build machine id
 | 
			
		||||
    const json = await result.json();
 | 
			
		||||
    await db
 | 
			
		||||
      .update(checkpointTable)
 | 
			
		||||
      .update(modelTable)
 | 
			
		||||
      .set({
 | 
			
		||||
        ...data,
 | 
			
		||||
        upload_machine_id: json.build_machine_instance_id,
 | 
			
		||||
      })
 | 
			
		||||
      .where(eq(checkpointTable.id, c.id));
 | 
			
		||||
      .where(eq(modelTable.id, c.id));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -1,18 +1,18 @@
 | 
			
		||||
import { db } from "@/db/db";
 | 
			
		||||
import {
 | 
			
		||||
  checkpointTable,
 | 
			
		||||
  modelTable,
 | 
			
		||||
} from "@/db/schema";
 | 
			
		||||
import { auth } from "@clerk/nextjs";
 | 
			
		||||
import { and, desc, eq, isNull } from "drizzle-orm";
 | 
			
		||||
 | 
			
		||||
export async function getAllUserCheckpoints() {
 | 
			
		||||
export async function getAllUserModels() {
 | 
			
		||||
  const { userId, orgId } = await auth();
 | 
			
		||||
 | 
			
		||||
  if (!userId) {
 | 
			
		||||
    return null;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const checkpoints = await db.query.checkpointTable.findMany({
 | 
			
		||||
  const models = await db.query.modelTable.findMany({
 | 
			
		||||
    with: {
 | 
			
		||||
      user: {
 | 
			
		||||
        columns: {
 | 
			
		||||
@ -28,14 +28,15 @@ export async function getAllUserCheckpoints() {
 | 
			
		||||
      civitai_model_response: true,
 | 
			
		||||
      is_public: true,
 | 
			
		||||
      upload_type: true,
 | 
			
		||||
      model_type: true,
 | 
			
		||||
      status: true,
 | 
			
		||||
    },
 | 
			
		||||
    orderBy: desc(checkpointTable.updated_at),
 | 
			
		||||
    orderBy: desc(modelTable.updated_at),
 | 
			
		||||
    where: 
 | 
			
		||||
      orgId != undefined
 | 
			
		||||
        ? eq(checkpointTable.org_id, orgId)
 | 
			
		||||
        : and(eq(checkpointTable.user_id, userId), isNull(checkpointTable.org_id)),
 | 
			
		||||
        ? eq(modelTable.org_id, orgId)
 | 
			
		||||
        : and(eq(modelTable.user_id, userId), isNull(modelTable.org_id)),
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return checkpoints;
 | 
			
		||||
  return models;
 | 
			
		||||
}
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
import { z } from "zod";
 | 
			
		||||
import { TypeOf, z } from "zod";
 | 
			
		||||
import { modelEnumType } from "@/db/schema";
 | 
			
		||||
 | 
			
		||||
// from chatgpt https://chat.openai.com/share/4985d20b-30b1-4a28-87f6-6ebf84a1040e
 | 
			
		||||
 | 
			
		||||
@ -110,12 +111,21 @@ export const statsSchema = z.object({
 | 
			
		||||
  tippedAmountCount: z.number(),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
const civitaiModelType = z.enum([
 | 
			
		||||
  "Checkpoint",
 | 
			
		||||
  "TextualInversion",
 | 
			
		||||
  "Hypernetwork",
 | 
			
		||||
  "AestheticGradient",
 | 
			
		||||
  "LORA",
 | 
			
		||||
  "Controlnet",
 | 
			
		||||
  "Poses",
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
export const CivitaiModelResponse = z.object({
 | 
			
		||||
  id: z.number(),
 | 
			
		||||
  name: z.string().nullish(),
 | 
			
		||||
  description: z.string().nullish(),
 | 
			
		||||
  // type: z.enum(["Checkpoint", "Lora"]), // TODO: this will be important to know
 | 
			
		||||
  type: z.string(),
 | 
			
		||||
  type: civitaiModelType,
 | 
			
		||||
  poi: z.boolean().nullish(),
 | 
			
		||||
  nsfw: z.boolean().nullish(),
 | 
			
		||||
  allowNoCredit: z.boolean().nullish(),
 | 
			
		||||
@ -127,3 +137,22 @@ export const CivitaiModelResponse = z.object({
 | 
			
		||||
  tags: z.array(z.string()).nullish(),
 | 
			
		||||
  modelVersions: z.array(modelVersionSchema),
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
export function getModelTypeDetails(
 | 
			
		||||
  modelType: typeof civitaiModelType["_type"],
 | 
			
		||||
): modelEnumType | undefined {
 | 
			
		||||
  switch (modelType) {
 | 
			
		||||
    case "Checkpoint":
 | 
			
		||||
      return "checkpoint"
 | 
			
		||||
    case "TextualInversion":
 | 
			
		||||
      return "embedding"
 | 
			
		||||
    case "LORA":
 | 
			
		||||
      return "lora"
 | 
			
		||||
    case "AestheticGradient":
 | 
			
		||||
    case "Hypernetwork":
 | 
			
		||||
    case "Controlnet":
 | 
			
		||||
    case "Poses":
 | 
			
		||||
    default:
 | 
			
		||||
      return undefined;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user