nezhahq-agent/pkg/monitor/gpu/gpu_linux.go
UUBulb af41e4d843
modularize monitor, reduce init usage (#81)
* kill process tree using syscall on windows & cleanup (#80)

* kill process tree using syscall on windows & cleanup

* use job api

* add error check for cmd.Start

* modularize monitor, reduce init usage

* replace slices with sort

* update gopsutil & other dependencies
2024-11-03 21:53:09 +08:00

127 lines
1.9 KiB
Go

//go:build linux
package gpu
import (
"context"
"errors"
"github.com/nezhahq/agent/pkg/monitor/gpu/vendor"
)
const (
vendorAMD = iota + 1
vendorNVIDIA
)
var vendorType = getVendor()
func getVendor() uint8 {
_, err := getNvidiaStat()
if err != nil {
return vendorAMD
} else {
return vendorNVIDIA
}
}
func getNvidiaStat() ([]float64, error) {
smi := &vendor.NvidiaSMI{
BinPath: "/usr/bin/nvidia-smi",
}
err1 := smi.Start()
if err1 != nil {
return nil, err1
}
data, err2 := smi.GatherUsage()
if err2 != nil {
return nil, err2
}
return data, nil
}
func getAMDStat() ([]float64, error) {
rsmi := &vendor.ROCmSMI{
BinPath: "/opt/rocm/bin/rocm-smi",
}
err := rsmi.Start()
if err != nil {
return nil, err
}
data, err := rsmi.GatherUsage()
if err != nil {
return nil, err
}
return data, nil
}
func getNvidiaHost() ([]string, error) {
smi := &vendor.NvidiaSMI{
BinPath: "/usr/bin/nvidia-smi",
}
err := smi.Start()
if err != nil {
return nil, err
}
data, err := smi.GatherModel()
if err != nil {
return nil, err
}
return data, nil
}
func getAMDHost() ([]string, error) {
rsmi := &vendor.ROCmSMI{
BinPath: "/opt/rocm/bin/rocm-smi",
}
err := rsmi.Start()
if err != nil {
return nil, err
}
data, err := rsmi.GatherModel()
if err != nil {
return nil, err
}
return data, nil
}
func GetHost(_ context.Context) ([]string, error) {
var gi []string
var err error
switch vendorType {
case vendorAMD:
gi, err = getAMDHost()
case vendorNVIDIA:
gi, err = getNvidiaHost()
default:
return nil, errors.New("invalid vendor")
}
if err != nil {
return nil, err
}
return gi, nil
}
func GetState(_ context.Context) ([]float64, error) {
var gs []float64
var err error
switch vendorType {
case vendorAMD:
gs, err = getAMDStat()
case vendorNVIDIA:
gs, err = getNvidiaStat()
default:
return nil, errors.New("invalid vendor")
}
if err != nil {
return nil, err
}
return gs, nil
}