diff --git a/web/src/app/(app)/machines/page.tsx b/web/src/app/(app)/machines/page.tsx
index c4c3de2..bd5f5c7 100644
--- a/web/src/app/(app)/machines/page.tsx
+++ b/web/src/app/(app)/machines/page.tsx
@@ -1,5 +1,6 @@
import { AccessType } from "../../../lib/AccessType";
import { MachineList } from "@/components/MachineList";
+import { SubscriptionProvider } from "@/components/useCurrentPlan";
import { db } from "@/db/db";
import { machinesTable } from "@/db/schema";
import { getCurrentPlanWithAuth } from "@/server/getCurrentPlan";
@@ -32,12 +33,12 @@ async function MachineListServer() {
return (
- {/*
Machines
*/}
-
+
+
+
);
}
diff --git a/web/src/components/CurrentPlanContextType.tsx b/web/src/components/CurrentPlanContextType.tsx
new file mode 100644
index 0000000..1e513a2
--- /dev/null
+++ b/web/src/components/CurrentPlanContextType.tsx
@@ -0,0 +1,24 @@
+"use client";
+import { getCurrentPlanWithAuth } from "@/server/getCurrentPlan";
+import * as React from "react";
+import { createContext } from "react";
+
+export type CurrentPlanContextType = Awaited<
+ ReturnType
+>;
+export const CurrentPlanContext = createContext<
+ CurrentPlanContextType | undefined
+>(undefined);
+export function SubscriptionProvider({
+ sub,
+ children,
+}: {
+ sub: CurrentPlanContextType;
+ children: React.ReactNode;
+}) {
+ return (
+
+ {children}
+
+ );
+}
diff --git a/web/src/components/MachineList.tsx b/web/src/components/MachineList.tsx
index e6dde79..4e8ca1d 100644
--- a/web/src/components/MachineList.tsx
+++ b/web/src/components/MachineList.tsx
@@ -183,6 +183,7 @@ export const columns: ColumnDef[] = [
cell: ({ row }) => {
const machine = row.original;
const [open, setOpen] = useState(false);
+ const sub = useCurrentPlan();
return (
@@ -291,7 +292,10 @@ export const columns: ColumnDef[] = [
fieldType: "models",
},
gpu: {
- inputProps: {},
+ fieldType: "gpuPicker",
+ inputProps: {
+ sub: sub,
+ },
},
}}
/>
@@ -319,14 +323,14 @@ export const columns: ColumnDef[] = [
},
];
+import { useCurrentPlan } from "./useCurrentPlan";
+
export function MachineList({
data,
userMetadata,
- sub,
}: {
data: Machine[];
userMetadata: z.infer;
- sub: Awaited>;
}) {
const [sorting, setSorting] = React.useState([]);
const [columnFilters, setColumnFilters] = React.useState(
@@ -336,6 +340,8 @@ export function MachineList({
React.useState({});
const [rowSelection, setRowSelection] = React.useState({});
+ const sub = useCurrentPlan();
+
const table = useReactTable({
data,
columns,
@@ -420,11 +426,9 @@ export function MachineList({
},
},
gpu: {
- fieldType: !userMetadata.betaFeaturesAccess
- ? "fallback"
- : "select",
+ fieldType: "gpuPicker",
inputProps: {
- disabled: !userMetadata.betaFeaturesAccess,
+ sub: sub,
},
},
}}
diff --git a/web/src/components/custom-form/gpu-picker.tsx b/web/src/components/custom-form/gpu-picker.tsx
new file mode 100644
index 0000000..9218b3e
--- /dev/null
+++ b/web/src/components/custom-form/gpu-picker.tsx
@@ -0,0 +1,90 @@
+import { AutoFormInputComponentProps } from "@/components/ui/auto-form/types";
+import { getBaseSchema } from "@/components/ui/auto-form/utils";
+import {
+ FormItem,
+ FormLabel,
+ FormControl,
+ FormDescription,
+ FormMessage,
+} from "@/components/ui/form";
+import {
+ Select,
+ SelectContent,
+ SelectItem,
+ SelectTrigger,
+ SelectValue,
+} from "@/components/ui/select";
+import { Lock } from "lucide-react";
+import * as z from "zod";
+
+export default function AutoFormGPUPicker({
+ label,
+ isRequired,
+ field,
+ fieldConfigItem,
+ zodItem,
+}: AutoFormInputComponentProps) {
+ const baseValues = (getBaseSchema(zodItem) as unknown as z.ZodEnum)._def
+ .values;
+
+ let values: [string, string][] = [];
+ if (!Array.isArray(baseValues)) {
+ values = Object.entries(baseValues);
+ } else {
+ values = baseValues.map((value) => [value, value]);
+ }
+
+ function findItem(value: any) {
+ return values.find((item) => item[0] === value);
+ }
+
+ const plan = fieldConfigItem.inputProps?.sub?.plan;
+ const enabledGPU = ["T4"];
+
+ if (plan == "pro") {
+ enabledGPU.push("A10G");
+ } else if (plan == "enterprise") {
+ enabledGPU.push("A10G");
+ enabledGPU.push("A100");
+ }
+
+ return (
+
+
+ {label}
+ {isRequired && *}
+
+
+
+
+ {fieldConfigItem.description && (
+ {fieldConfigItem.description}
+ )}
+
+
+ );
+}
diff --git a/web/src/components/ui/auto-form/config.ts b/web/src/components/ui/auto-form/config.ts
index 4fdb9ea..05afc22 100644
--- a/web/src/components/ui/auto-form/config.ts
+++ b/web/src/components/ui/auto-form/config.ts
@@ -1,3 +1,4 @@
+import AutoFormGPUPicker from "@/components/custom-form/gpu-picker";
import AutoFormCheckbox from "./fields/checkbox";
import AutoFormDate from "./fields/date";
import AutoFormEnum from "./fields/enum";
@@ -22,6 +23,7 @@ export const INPUT_COMPONENTS = {
// Customs
snapshot: AutoFormSnapshotPicker,
models: AutoFormModelsPicker,
+ gpuPicker: AutoFormGPUPicker,
};
/**
diff --git a/web/src/components/ui/auto-form/types.ts b/web/src/components/ui/auto-form/types.ts
index 50d55bc..3514fda 100644
--- a/web/src/components/ui/auto-form/types.ts
+++ b/web/src/components/ui/auto-form/types.ts
@@ -1,3 +1,4 @@
+import type { getCurrentPlanWithAuth } from "@/server/getCurrentPlan";
import type { INPUT_COMPONENTS } from "./config";
import type { ControllerRenderProps, FieldValues } from "react-hook-form";
import type * as z from "zod";
@@ -6,6 +7,7 @@ export type FieldConfigItem = {
description?: React.ReactNode;
inputProps?: React.InputHTMLAttributes & {
showLabel?: boolean;
+ sub?: Awaited>;
};
fieldType?:
| keyof typeof INPUT_COMPONENTS
diff --git a/web/src/components/useCurrentPlan.tsx b/web/src/components/useCurrentPlan.tsx
new file mode 100644
index 0000000..8cda7d0
--- /dev/null
+++ b/web/src/components/useCurrentPlan.tsx
@@ -0,0 +1,36 @@
+"use client";
+
+import { getCurrentPlanWithAuth } from "@/server/getCurrentPlan";
+import * as React from "react";
+import { createContext, useContext } from "react";
+
+type CurrentPlanContextType = Awaited<
+ ReturnType
+>;
+const CurrentPlanContext = createContext(
+ undefined,
+);
+
+export function SubscriptionProvider({
+ sub,
+ children,
+}: {
+ sub: CurrentPlanContextType;
+ children: React.ReactNode;
+}) {
+ return (
+
+ {children}
+
+ );
+}
+
+export const useCurrentPlan = (): CurrentPlanContextType => {
+ const context = useContext(CurrentPlanContext);
+
+ // if (context === undefined) {
+ // throw new Error("useCurrentPlan must be used within a CurrentPlanProvider");
+ // }
+
+ return context;
+};
diff --git a/web/src/server/getCurrentPlan.tsx b/web/src/server/getCurrentPlan.tsx
index 738c91f..06a309c 100644
--- a/web/src/server/getCurrentPlan.tsx
+++ b/web/src/server/getCurrentPlan.tsx
@@ -3,6 +3,7 @@ import { and, desc, eq, isNull, or } from "drizzle-orm";
import { subscriptionStatusTable } from "@/db/schema";
import { APIKeyUserType } from "@/server/APIKeyBodyRequest";
import { auth } from "@clerk/nextjs";
+import "server-only";
export async function getCurrentPlanWithAuth() {
const { userId, orgId } = auth();
@@ -23,7 +24,10 @@ export async function getCurrentPlan({ user_id, org_id }: APIKeyUserType) {
eq(subscriptionStatusTable.user_id, user_id),
org_id
? eq(subscriptionStatusTable.org_id, org_id)
- : or(isNull(subscriptionStatusTable.org_id), eq(subscriptionStatusTable.org_id, "")),
+ : or(
+ isNull(subscriptionStatusTable.org_id),
+ eq(subscriptionStatusTable.org_id, ""),
+ ),
),
orderBy: desc(subscriptionStatusTable.created_at),
});