diff --git a/web/src/app/(app)/machines/page.tsx b/web/src/app/(app)/machines/page.tsx index c4c3de2..bd5f5c7 100644 --- a/web/src/app/(app)/machines/page.tsx +++ b/web/src/app/(app)/machines/page.tsx @@ -1,5 +1,6 @@ import { AccessType } from "../../../lib/AccessType"; import { MachineList } from "@/components/MachineList"; +import { SubscriptionProvider } from "@/components/useCurrentPlan"; import { db } from "@/db/db"; import { machinesTable } from "@/db/schema"; import { getCurrentPlanWithAuth } from "@/server/getCurrentPlan"; @@ -32,12 +33,12 @@ async function MachineListServer() { return (
- {/*
Machines
*/} - + + +
); } diff --git a/web/src/components/CurrentPlanContextType.tsx b/web/src/components/CurrentPlanContextType.tsx new file mode 100644 index 0000000..1e513a2 --- /dev/null +++ b/web/src/components/CurrentPlanContextType.tsx @@ -0,0 +1,24 @@ +"use client"; +import { getCurrentPlanWithAuth } from "@/server/getCurrentPlan"; +import * as React from "react"; +import { createContext } from "react"; + +export type CurrentPlanContextType = Awaited< + ReturnType +>; +export const CurrentPlanContext = createContext< + CurrentPlanContextType | undefined +>(undefined); +export function SubscriptionProvider({ + sub, + children, +}: { + sub: CurrentPlanContextType; + children: React.ReactNode; +}) { + return ( + + {children} + + ); +} diff --git a/web/src/components/MachineList.tsx b/web/src/components/MachineList.tsx index e6dde79..4e8ca1d 100644 --- a/web/src/components/MachineList.tsx +++ b/web/src/components/MachineList.tsx @@ -183,6 +183,7 @@ export const columns: ColumnDef[] = [ cell: ({ row }) => { const machine = row.original; const [open, setOpen] = useState(false); + const sub = useCurrentPlan(); return ( @@ -291,7 +292,10 @@ export const columns: ColumnDef[] = [ fieldType: "models", }, gpu: { - inputProps: {}, + fieldType: "gpuPicker", + inputProps: { + sub: sub, + }, }, }} /> @@ -319,14 +323,14 @@ export const columns: ColumnDef[] = [ }, ]; +import { useCurrentPlan } from "./useCurrentPlan"; + export function MachineList({ data, userMetadata, - sub, }: { data: Machine[]; userMetadata: z.infer; - sub: Awaited>; }) { const [sorting, setSorting] = React.useState([]); const [columnFilters, setColumnFilters] = React.useState( @@ -336,6 +340,8 @@ export function MachineList({ React.useState({}); const [rowSelection, setRowSelection] = React.useState({}); + const sub = useCurrentPlan(); + const table = useReactTable({ data, columns, @@ -420,11 +426,9 @@ export function MachineList({ }, }, gpu: { - fieldType: !userMetadata.betaFeaturesAccess - ? "fallback" - : "select", + fieldType: "gpuPicker", inputProps: { - disabled: !userMetadata.betaFeaturesAccess, + sub: sub, }, }, }} diff --git a/web/src/components/custom-form/gpu-picker.tsx b/web/src/components/custom-form/gpu-picker.tsx new file mode 100644 index 0000000..9218b3e --- /dev/null +++ b/web/src/components/custom-form/gpu-picker.tsx @@ -0,0 +1,90 @@ +import { AutoFormInputComponentProps } from "@/components/ui/auto-form/types"; +import { getBaseSchema } from "@/components/ui/auto-form/utils"; +import { + FormItem, + FormLabel, + FormControl, + FormDescription, + FormMessage, +} from "@/components/ui/form"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Lock } from "lucide-react"; +import * as z from "zod"; + +export default function AutoFormGPUPicker({ + label, + isRequired, + field, + fieldConfigItem, + zodItem, +}: AutoFormInputComponentProps) { + const baseValues = (getBaseSchema(zodItem) as unknown as z.ZodEnum)._def + .values; + + let values: [string, string][] = []; + if (!Array.isArray(baseValues)) { + values = Object.entries(baseValues); + } else { + values = baseValues.map((value) => [value, value]); + } + + function findItem(value: any) { + return values.find((item) => item[0] === value); + } + + const plan = fieldConfigItem.inputProps?.sub?.plan; + const enabledGPU = ["T4"]; + + if (plan == "pro") { + enabledGPU.push("A10G"); + } else if (plan == "enterprise") { + enabledGPU.push("A10G"); + enabledGPU.push("A100"); + } + + return ( + + + {label} + {isRequired && *} + + + + + {fieldConfigItem.description && ( + {fieldConfigItem.description} + )} + + + ); +} diff --git a/web/src/components/ui/auto-form/config.ts b/web/src/components/ui/auto-form/config.ts index 4fdb9ea..05afc22 100644 --- a/web/src/components/ui/auto-form/config.ts +++ b/web/src/components/ui/auto-form/config.ts @@ -1,3 +1,4 @@ +import AutoFormGPUPicker from "@/components/custom-form/gpu-picker"; import AutoFormCheckbox from "./fields/checkbox"; import AutoFormDate from "./fields/date"; import AutoFormEnum from "./fields/enum"; @@ -22,6 +23,7 @@ export const INPUT_COMPONENTS = { // Customs snapshot: AutoFormSnapshotPicker, models: AutoFormModelsPicker, + gpuPicker: AutoFormGPUPicker, }; /** diff --git a/web/src/components/ui/auto-form/types.ts b/web/src/components/ui/auto-form/types.ts index 50d55bc..3514fda 100644 --- a/web/src/components/ui/auto-form/types.ts +++ b/web/src/components/ui/auto-form/types.ts @@ -1,3 +1,4 @@ +import type { getCurrentPlanWithAuth } from "@/server/getCurrentPlan"; import type { INPUT_COMPONENTS } from "./config"; import type { ControllerRenderProps, FieldValues } from "react-hook-form"; import type * as z from "zod"; @@ -6,6 +7,7 @@ export type FieldConfigItem = { description?: React.ReactNode; inputProps?: React.InputHTMLAttributes & { showLabel?: boolean; + sub?: Awaited>; }; fieldType?: | keyof typeof INPUT_COMPONENTS diff --git a/web/src/components/useCurrentPlan.tsx b/web/src/components/useCurrentPlan.tsx new file mode 100644 index 0000000..8cda7d0 --- /dev/null +++ b/web/src/components/useCurrentPlan.tsx @@ -0,0 +1,36 @@ +"use client"; + +import { getCurrentPlanWithAuth } from "@/server/getCurrentPlan"; +import * as React from "react"; +import { createContext, useContext } from "react"; + +type CurrentPlanContextType = Awaited< + ReturnType +>; +const CurrentPlanContext = createContext( + undefined, +); + +export function SubscriptionProvider({ + sub, + children, +}: { + sub: CurrentPlanContextType; + children: React.ReactNode; +}) { + return ( + + {children} + + ); +} + +export const useCurrentPlan = (): CurrentPlanContextType => { + const context = useContext(CurrentPlanContext); + + // if (context === undefined) { + // throw new Error("useCurrentPlan must be used within a CurrentPlanProvider"); + // } + + return context; +}; diff --git a/web/src/server/getCurrentPlan.tsx b/web/src/server/getCurrentPlan.tsx index 738c91f..06a309c 100644 --- a/web/src/server/getCurrentPlan.tsx +++ b/web/src/server/getCurrentPlan.tsx @@ -3,6 +3,7 @@ import { and, desc, eq, isNull, or } from "drizzle-orm"; import { subscriptionStatusTable } from "@/db/schema"; import { APIKeyUserType } from "@/server/APIKeyBodyRequest"; import { auth } from "@clerk/nextjs"; +import "server-only"; export async function getCurrentPlanWithAuth() { const { userId, orgId } = auth(); @@ -23,7 +24,10 @@ export async function getCurrentPlan({ user_id, org_id }: APIKeyUserType) { eq(subscriptionStatusTable.user_id, user_id), org_id ? eq(subscriptionStatusTable.org_id, org_id) - : or(isNull(subscriptionStatusTable.org_id), eq(subscriptionStatusTable.org_id, "")), + : or( + isNull(subscriptionStatusTable.org_id), + eq(subscriptionStatusTable.org_id, ""), + ), ), orderBy: desc(subscriptionStatusTable.created_at), });