nezhahq-agent/pkg/gpu/gpu_linux.go

126 lines
1.9 KiB
Go

//go:build linux
package gpu
import (
"errors"
"github.com/nezhahq/agent/pkg/gpu/vendor"
)
const (
vendorAMD = iota + 1
vendorNVIDIA
)
var vendorType uint8
func init() {
_, err := getNvidiaStat()
if err != nil {
vendorType = vendorAMD
} else {
vendorType = vendorNVIDIA
}
}
func getNvidiaStat() ([]float64, error) {
smi := &vendor.NvidiaSMI{
BinPath: "/usr/bin/nvidia-smi",
}
err1 := smi.Start()
if err1 != nil {
return nil, err1
}
data, err2 := smi.GatherUsage()
if err2 != nil {
return nil, err2
}
return data, nil
}
func getAMDStat() ([]float64, error) {
rsmi := &vendor.ROCmSMI{
BinPath: "/opt/rocm/bin/rocm-smi",
}
err := rsmi.Start()
if err != nil {
return nil, err
}
data, err := rsmi.GatherUsage()
if err != nil {
return nil, err
}
return data, nil
}
func getNvidiaHost() ([]string, error) {
smi := &vendor.NvidiaSMI{
BinPath: "/usr/bin/nvidia-smi",
}
err := smi.Start()
if err != nil {
return nil, err
}
data, err := smi.GatherModel()
if err != nil {
return nil, err
}
return data, nil
}
func getAMDHost() ([]string, error) {
rsmi := &vendor.ROCmSMI{
BinPath: "/opt/rocm/bin/rocm-smi",
}
err := rsmi.Start()
if err != nil {
return nil, err
}
data, err := rsmi.GatherModel()
if err != nil {
return nil, err
}
return data, nil
}
func GetGPUModel() ([]string, error) {
var gi []string
var err error
switch vendorType {
case vendorAMD:
gi, err = getAMDHost()
case vendorNVIDIA:
gi, err = getNvidiaHost()
default:
return nil, errors.New("invalid vendor")
}
if err != nil {
return nil, err
}
return gi, nil
}
func GetGPUStat() ([]float64, error) {
var gs []float64
var err error
switch vendorType {
case vendorAMD:
gs, err = getAMDStat()
case vendorNVIDIA:
gs, err = getNvidiaStat()
default:
return nil, errors.New("invalid vendor")
}
if err != nil {
return nil, err
}
return gs, nil
}