feat: add file transfer support (#55)

* feat: add file transfer support

* 1MB buffer
This commit is contained in:
UUBulb 2024-08-20 22:24:03 +08:00 committed by GitHub
parent 093275bc80
commit 73a727d435
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 342 additions and 54 deletions

View File

@ -31,6 +31,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/nezhahq/agent/model"
fm "github.com/nezhahq/agent/pkg/fm"
"github.com/nezhahq/agent/pkg/monitor"
"github.com/nezhahq/agent/pkg/processgroup"
"github.com/nezhahq/agent/pkg/pty"
@ -161,7 +162,7 @@ func init() {
func main() {
if err := agentCmd.Execute(); err != nil {
fmt.Println(err)
println(err)
os.Exit(1)
}
}
@ -215,7 +216,11 @@ func run() {
// 下载远程命令执行需要的终端
if !agentCliParam.DisableCommandExecute {
go pty.DownloadDependency()
go func() {
if err := pty.DownloadDependency(); err != nil {
printf("pty 下载依赖失败: %v", err)
}
}()
}
// 上报服务器信息
go reportStateDaemon()
@ -259,7 +264,7 @@ func run() {
}
conn, err = grpc.DialContext(timeOutCtx, agentCliParam.Server, securityOption, grpc.WithPerRPCCredentials(&auth))
if err != nil {
println("与面板建立连接失败:", err)
printf("与面板建立连接失败: %v", err)
cancel()
retry()
continue
@ -270,7 +275,7 @@ func run() {
timeOutCtx, cancel = context.WithTimeout(context.Background(), networkTimeOut)
_, err = client.ReportSystemInfo(timeOutCtx, monitor.GetHost().PB())
if err != nil {
println("上报系统信息失败:", err)
printf("上报系统信息失败: %v", err)
cancel()
retry()
continue
@ -280,12 +285,12 @@ func run() {
// 执行 Task
tasks, err := client.RequestTask(context.Background(), monitor.GetHost().PB())
if err != nil {
println("请求任务失败:", err)
printf("请求任务失败: %v", err)
retry()
continue
}
err = receiveTasks(tasks)
println("receiveTasks exit to main", err)
printf("receiveTasks exit to main: %v", err)
retry()
}
}
@ -293,7 +298,7 @@ func run() {
func runService(action string, flags []string) {
dir, err := os.Getwd()
if err != nil {
println("获取当前工作目录时出错: ", err)
printf("获取当前工作目录时出错: ", err)
return
}
@ -315,7 +320,7 @@ func runService(action string, flags []string) {
}
s, err := service.New(prg, svcConfig)
if err != nil {
log.Printf("创建服务时出错,以普通模式运行: %v", err)
printf("创建服务时出错,以普通模式运行: %v", err)
run()
return
}
@ -324,7 +329,7 @@ func runService(action string, flags []string) {
if agentConfig.Debug {
serviceLogger, err := s.Logger(nil)
if err != nil {
log.Printf("获取 service logger 时出错: %+v", err)
printf("获取 service logger 时出错: %+v", err)
} else {
util.Logger = serviceLogger
}
@ -332,7 +337,7 @@ func runService(action string, flags []string) {
if action == "install" {
initName := s.Platform()
log.Println("Init system is:", initName)
println("Init system is:", initName)
}
if len(action) != 0 {
@ -351,7 +356,7 @@ func runService(action string, flags []string) {
func receiveTasks(tasks pb.NezhaService_RequestTaskClient) error {
var err error
defer println("receiveTasks exit", time.Now(), "=>", err)
defer printf("receiveTasks exit %v => %v", time.Now(), err)
for {
var task *pb.Task
task, err = tasks.Recv()
@ -393,10 +398,13 @@ func doTask(task *pb.Task) {
case model.TaskTypeReportHostInfo:
reportState(time.Time{})
return
case model.TaskTypeFM:
handleFMTask(task)
return
case model.TaskTypeKeepalive:
return
default:
println("不支持的任务:", task)
printf("不支持的任务: %v", task)
return
}
client.ReportTask(context.Background(), &result)
@ -406,7 +414,7 @@ func doTask(task *pb.Task) {
func reportStateDaemon() {
var lastReportHostInfo time.Time
var err error
defer println("reportState exit", time.Now(), "=>", err)
defer printf("reportState exit %v => %v", time.Now(), err)
for {
// 为了更准确的记录时段流量inited 后再上传状态信息
lastReportHostInfo = reportState(lastReportHostInfo)
@ -421,7 +429,7 @@ func reportState(lastReportHostInfo time.Time) time.Time {
_, err := client.ReportSystemState(timeOutCtx, monitor.GetState(agentCliParam.SkipConnectionCount, agentCliParam.SkipProcsCount).PB())
cancel()
if err != nil {
println("reportState error", err)
printf("reportState error: %v", err)
time.Sleep(delayWhenError)
}
// 每10分钟重新获取一次硬件信息
@ -445,7 +453,7 @@ func doSelfUpdate(useLocalVersion bool) {
if useLocalVersion {
v = semver.MustParse(version)
}
println("检查更新:", v)
printf("检查更新: %v", v)
var latest *selfupdate.Release
var err error
if monitor.CachedCountryCode != "cn" && !agentCliParam.UseGiteeToUpgrade {
@ -454,11 +462,11 @@ func doSelfUpdate(useLocalVersion bool) {
latest, err = selfupdate.UpdateSelfGitee(v, "naibahq/agent")
}
if err != nil {
println("更新失败:", err)
printf("更新失败: %v", err)
return
}
if !latest.Version.Equals(v) {
println("已经更新至:", latest.Version, " 正在结束进程")
printf("已经更新至: %v, 正在结束进程", latest.Version)
os.Exit(1)
}
}
@ -668,13 +676,13 @@ func handleTerminalTask(task *pb.Task) {
var terminal model.TerminalTask
err := util.Json.Unmarshal([]byte(task.GetData()), &terminal)
if err != nil {
println("Terminal 任务解析错误:", err)
printf("Terminal 任务解析错误: %v", err)
return
}
remoteIO, err := client.IOStream(context.Background())
if err != nil {
println("Terminal IOStream失败", err)
printf("Terminal IOStream失败: %v", err)
return
}
@ -682,13 +690,13 @@ func handleTerminalTask(task *pb.Task) {
if err := remoteIO.Send(&pb.IOStreamData{Data: append([]byte{
0xff, 0x05, 0xff, 0x05,
}, []byte(terminal.StreamID)...)}); err != nil {
println("Terminal 发送StreamID失败", err)
printf("Terminal 发送StreamID失败: %v", err)
return
}
tty, err := pty.Start()
if err != nil {
println("Terminal pty.Start失败", err)
printf("Terminal pty.Start失败 %v", err)
return
}
@ -739,13 +747,13 @@ func handleNATTask(task *pb.Task) {
var nat model.TaskNAT
err := util.Json.Unmarshal([]byte(task.GetData()), &nat)
if err != nil {
println("NAT 任务解析错误:", err)
printf("NAT 任务解析错误: %v", err)
return
}
remoteIO, err := client.IOStream(context.Background())
if err != nil {
println("NAT IOStream失败", err)
printf("NAT IOStream失败: %v", err)
return
}
@ -753,13 +761,13 @@ func handleNATTask(task *pb.Task) {
if err := remoteIO.Send(&pb.IOStreamData{Data: append([]byte{
0xff, 0x05, 0xff, 0x05,
}, []byte(nat.StreamID)...)}); err != nil {
println("NAT 发送StreamID失败", err)
printf("NAT 发送StreamID失败: %v", err)
return
}
conn, err := net.Dial("tcp", nat.Host)
if err != nil {
println(fmt.Sprintf("NAT Dial %s 失败:%s", nat.Host, err))
printf("NAT Dial %s 失败:%s", nat.Host, err)
return
}
@ -792,10 +800,59 @@ func handleNATTask(task *pb.Task) {
}
}
func handleFMTask(task *pb.Task) {
if agentCliParam.DisableCommandExecute {
println("此 Agent 已禁止命令执行")
return
}
var fmTask model.TaskFM
err := util.Json.Unmarshal([]byte(task.GetData()), &fmTask)
if err != nil {
printf("FM 任务解析错误: %v", err)
return
}
remoteIO, err := client.IOStream(context.Background())
if err != nil {
printf("FM IOStream失败: %v", err)
return
}
// 发送 StreamID
if err := remoteIO.Send(&pb.IOStreamData{Data: append([]byte{
0xff, 0x05, 0xff, 0x05,
}, []byte(fmTask.StreamID)...)}); err != nil {
printf("FM 发送StreamID失败: %v", err)
return
}
defer func() {
errCloseSend := remoteIO.CloseSend()
println("FM exit", fmTask.StreamID, nil, errCloseSend)
}()
println("FM init", fmTask.StreamID)
fmc := fm.NewFMClient(remoteIO, printf)
for {
var remoteData *pb.IOStreamData
if remoteData, err = remoteIO.Recv(); err != nil {
return
}
if remoteData.Data == nil || len(remoteData.Data) == 0 {
return
}
fmc.DoTask(remoteData)
}
}
func println(v ...interface{}) {
util.Println(agentConfig.Debug, v...)
}
func printf(format string, v ...interface{}) {
util.Printf(agentConfig.Debug, format, v...)
}
func generateQueue(start int, size int) []int {
var result []int
for i := start; i < start+size; i++ {

View File

@ -12,6 +12,7 @@ const (
TaskTypeTerminalGRPC
TaskTypeNAT
TaskTypeReportHostInfo
TaskTypeFM
)
type TerminalTask struct {
@ -22,3 +23,7 @@ type TaskNAT struct {
StreamID string
Host string
}
type TaskFM struct {
StreamID string
}

65
pkg/fm/binary.go Normal file
View File

@ -0,0 +1,65 @@
package fm
import (
"bytes"
"encoding/binary"
)
var (
fileIdentifier = []byte{0x4E, 0x5A, 0x54, 0x44} // NZTD
fileNameIdentifier = []byte{0x4E, 0x5A, 0x46, 0x4E} // NZFN
errorIdentifier = []byte{0x4E, 0x45, 0x52, 0x52} // NERR
completeIdentifier = []byte{0x4E, 0x5A, 0x55, 0x50} // NZUP
)
func AppendFileName(bin []byte, data string, isDir bool) []byte {
buffer := bytes.NewBuffer(bin)
appendFileName(buffer, isDir, []byte(data))
return buffer.Bytes()
}
func Create(buffer *bytes.Buffer, path string) []byte {
// Write identifier for TypeFileName (4 bytes)
binary.Write(buffer, binary.BigEndian, fileNameIdentifier)
// Write length of path (4 byte)
binary.Write(buffer, binary.BigEndian, uint32(len(path)))
// Write path string
binary.Write(buffer, binary.BigEndian, []byte(path))
return buffer.Bytes()
}
func CreateFile(buffer *bytes.Buffer, size uint64) []byte {
// Write identifier for TypeFile (4 bytes)
binary.Write(buffer, binary.BigEndian, fileIdentifier)
// Write file size (8 bytes)
binary.Write(buffer, binary.BigEndian, size)
return buffer.Bytes()
}
func CreateErr(err error) []byte {
buffer := new(bytes.Buffer)
binary.Write(buffer, binary.BigEndian, errorIdentifier)
binary.Write(buffer, binary.BigEndian, []byte(err.Error()))
return buffer.Bytes()
}
func appendFileName(buffer *bytes.Buffer, isDir bool, data []byte) {
// Write file type (1 byte)
if isDir {
binary.Write(buffer, binary.BigEndian, byte(1))
} else {
binary.Write(buffer, binary.BigEndian, byte(0))
}
// Write the length of file name (1 byte)
length := byte(len(data))
binary.Write(buffer, binary.BigEndian, length)
// Write file name
buffer.Write(data)
}

158
pkg/fm/tasks.go Normal file
View File

@ -0,0 +1,158 @@
package fm
import (
"bytes"
"encoding/binary"
"errors"
"io"
"io/fs"
"os"
"os/user"
"path/filepath"
pb "github.com/nezhahq/agent/proto"
)
type Task struct {
taskClient pb.NezhaService_IOStreamClient
printf func(string, ...interface{})
remoteData *pb.IOStreamData
}
func NewFMClient(client pb.NezhaService_IOStreamClient, printFunc func(string, ...interface{})) *Task {
return &Task{
taskClient: client,
printf: printFunc,
}
}
func (t *Task) DoTask(data *pb.IOStreamData) {
t.remoteData = data
switch t.remoteData.Data[0] {
case 0:
t.listDir()
case 1:
go t.download()
case 2:
t.upload()
}
}
func (t *Task) listDir() {
dir := string(t.remoteData.Data[1:])
var entries []fs.DirEntry
var err error
for {
entries, err = os.ReadDir(dir)
if err != nil {
usr, err := user.Current()
if err != nil {
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
dir = usr.HomeDir + string(filepath.Separator)
continue
}
break
}
var buffer bytes.Buffer
td := Create(&buffer, dir)
for _, e := range entries {
newBin := AppendFileName(td, e.Name(), e.IsDir())
td = newBin
}
t.taskClient.Send(&pb.IOStreamData{Data: td})
}
func (t *Task) download() {
path := string(t.remoteData.Data[1:])
file, err := os.Open(path)
if err != nil {
println("Error opening file: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
defer file.Close()
fileInfo, err := file.Stat()
if err != nil {
println("Error getting file info: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
fileSize := fileInfo.Size()
if fileSize <= 0 {
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(errors.New("requested file is empty"))})
return
}
// Send header (12 bytes)
var header bytes.Buffer
headerData := CreateFile(&header, uint64(fileSize))
if err := t.taskClient.Send(&pb.IOStreamData{Data: headerData}); err != nil {
println("Error sending file header: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
buffer := make([]byte, 1048576)
for {
n, err := file.Read(buffer)
if err != nil {
if err == io.EOF {
return
}
println("Error reading file: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
if err := t.taskClient.Send(&pb.IOStreamData{Data: buffer[:n]}); err != nil {
println("Error sending file chunk: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
}
}
func (t *Task) upload() {
if len(t.remoteData.Data) < 9 {
println("data is invalid")
return
}
fileSize := binary.BigEndian.Uint64(t.remoteData.Data[1:9])
path := string(t.remoteData.Data[9:])
file, err := os.Create(path)
if err != nil {
println("Error creating file: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
defer file.Close()
totalReceived := uint64(0)
t.printf("receiving file: %s, size: %d", file.Name(), fileSize)
for totalReceived < fileSize {
if t.remoteData, err = t.taskClient.Recv(); err != nil {
println("Error receiving data: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
bytesWritten, err := file.Write(t.remoteData.Data)
if err != nil {
println("Error writing to file: ", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
totalReceived += uint64(bytesWritten)
}
t.printf("received file %s.", file.Name())
t.taskClient.Send(&pb.IOStreamData{Data: completeIdentifier}) // NZUP
}

View File

@ -79,7 +79,7 @@ func GetHost() *model.Host {
var cpuType string
hi, err := host.Info()
if err != nil {
println("host.Info error: ", err)
printf("host.Info error: %v", err)
} else {
if hi.VirtualizationRole == "guest" {
cpuType = "Virtual"
@ -99,7 +99,7 @@ func GetHost() *model.Host {
ci, err := cpu.Info()
if err != nil {
hostDataFetchAttempts["CPU"]++
println("cpu.Info error: ", err, ", attempt: ", hostDataFetchAttempts["CPU"])
printf("cpu.Info error: %v, attempt: %d", err, hostDataFetchAttempts["CPU"])
} else {
hostDataFetchAttempts["CPU"] = 0
for i := 0; i < len(ci); i++ {
@ -120,7 +120,7 @@ func GetHost() *model.Host {
ret.GPU, err = gpu.GetGPUModel()
if err != nil {
hostDataFetchAttempts["GPU"]++
println("gpu.GetGPUModel error: ", err, ", attempt: ", hostDataFetchAttempts["GPU"])
printf("gpu.GetGPUModel error: %v, attempt: %d", err, hostDataFetchAttempts["GPU"])
} else {
hostDataFetchAttempts["GPU"] = 0
}
@ -131,7 +131,7 @@ func GetHost() *model.Host {
mv, err := mem.VirtualMemory()
if err != nil {
println("mem.VirtualMemory error: ", err)
printf("mem.VirtualMemory error: %v", err)
} else {
ret.MemTotal = mv.Total
if runtime.GOOS != "windows" {
@ -142,7 +142,7 @@ func GetHost() *model.Host {
if runtime.GOOS == "windows" {
ms, err := mem.SwapMemory()
if err != nil {
println("mem.SwapMemory error: ", err)
printf("mem.SwapMemory error: %v", err)
} else {
ret.SwapTotal = ms.Total
}
@ -163,7 +163,7 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState {
cp, err := cpu.Percent(0, false)
if err != nil || len(cp) == 0 {
statDataFetchAttempts["CPU"]++
println("cpu.Percent error: ", err, ", attempt: ", statDataFetchAttempts["CPU"])
printf("cpu.Percent error: %v, attempt: %d", err, statDataFetchAttempts["CPU"])
} else {
statDataFetchAttempts["CPU"] = 0
ret.CPU = cp[0]
@ -172,7 +172,7 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState {
vm, err := mem.VirtualMemory()
if err != nil {
println("mem.VirtualMemory error: ", err)
printf("mem.VirtualMemory error: %v", err)
} else {
ret.MemUsed = vm.Total - vm.Available
if runtime.GOOS != "windows" {
@ -183,7 +183,7 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState {
// gopsutil 在 Windows 下不能正确取 swap
ms, err := mem.SwapMemory()
if err != nil {
println("mem.SwapMemory error: ", err)
printf("mem.SwapMemory error: %v", err)
} else {
ret.SwapUsed = ms.Used
}
@ -195,7 +195,7 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState {
loadStat, err := load.Avg()
if err != nil {
statDataFetchAttempts["Load"]++
println("load.Avg error: ", err, ", attempt: ", statDataFetchAttempts["Load"])
printf("load.Avg error: %v, attempt: %d", err, statDataFetchAttempts["Load"])
} else {
statDataFetchAttempts["Load"] = 0
ret.Load1 = loadStat.Load1
@ -208,7 +208,7 @@ func GetState(skipConnectionCount bool, skipProcsCount bool) *model.HostState {
if !skipProcsCount {
procs, err = process.Pids()
if err != nil {
println("process.Pids error: ", err)
printf("process.Pids error: %v", err)
} else {
ret.ProcessCount = uint64(len(procs))
}
@ -360,7 +360,7 @@ func updateGPUStat(gpuStat *uint64) {
gs, err := gpustat.GetGPUStat()
if err != nil {
statDataFetchAttempts["GPU"]++
println("gpustat.GetGPUStat error: ", err, ", attempt: ", statDataFetchAttempts["GPU"])
printf("gpustat.GetGPUStat error: %v, attempt: %d", err, statDataFetchAttempts["GPU"])
atomicStoreFloat64(gpuStat, gs)
} else {
statDataFetchAttempts["GPU"] = 0
@ -379,7 +379,7 @@ func updateTemperatureStat() {
temperatures, err := sensors.SensorsTemperatures()
if err != nil {
statDataFetchAttempts["Temperatures"]++
println("host.SensorsTemperatures error: ", err, ", attempt: ", statDataFetchAttempts["Temperatures"])
printf("host.SensorsTemperatures error: %v, attempt: %d", err, statDataFetchAttempts["Temperatures"])
} else {
statDataFetchAttempts["Temperatures"] = 0
tempStat := []model.SensorTemperature{}
@ -410,6 +410,6 @@ func atomicStoreFloat64(x *uint64, v float64) {
atomic.StoreUint64(x, math.Float64bits(v))
}
func println(v ...interface{}) {
util.Println(agentConfig.Debug, v...)
func printf(format string, v ...interface{}) {
util.Printf(agentConfig.Debug, format, v...)
}

View File

@ -19,7 +19,8 @@ type Pty struct {
cmd *exec.Cmd
}
func DownloadDependency() {
func DownloadDependency() error {
return nil
}
func Start() (IPty, error) {

View File

@ -5,7 +5,6 @@ package pty
import (
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
@ -55,12 +54,11 @@ func VersionCheck() bool {
return false
}
func DownloadDependency() {
func DownloadDependency() error {
if !isWin10 {
executablePath, err := getExecutableFilePath()
if err != nil {
fmt.Println("NEZHA>> wintty 获取文件路径失败", err)
return
return fmt.Errorf("winpty 获取文件路径失败: %v", err)
}
winptyAgentExe := filepath.Join(executablePath, "winpty-agent.exe")
@ -69,27 +67,23 @@ func DownloadDependency() {
fe, errFe := os.Stat(winptyAgentExe)
fd, errFd := os.Stat(winptyAgentDll)
if errFe == nil && fe.Size() > 300000 && errFd == nil && fd.Size() > 300000 {
return
return fmt.Errorf("winpty 文件完整性检查失败")
}
resp, err := http.Get("https://github.com/rprichard/winpty/releases/download/0.4.3/winpty-0.4.3-msvc2015.zip")
if err != nil {
log.Println("NEZHA>> wintty 下载失败", err)
return
return fmt.Errorf("winpty 下载失败: %v", err)
}
defer resp.Body.Close()
content, err := io.ReadAll(resp.Body)
if err != nil {
log.Println("NEZHA>> wintty 下载失败", err)
return
return fmt.Errorf("winpty 下载失败: %v", err)
}
if err := os.WriteFile("./wintty.zip", content, os.FileMode(0777)); err != nil {
log.Println("NEZHA>> wintty 写入失败", err)
return
return fmt.Errorf("winpty 写入失败: %v", err)
}
if err := unzip.New("./wintty.zip", "./wintty").Extract(); err != nil {
fmt.Println("NEZHA>> wintty 解压失败", err)
return
return fmt.Errorf("winpty 解压失败: %v", err)
}
arch := "x64"
if runtime.GOARCH != "amd64" {
@ -101,6 +95,7 @@ func DownloadDependency() {
os.RemoveAll("./wintty")
os.RemoveAll("./wintty.zip")
}
return nil
}
func getExecutableFilePath() (string, error) {

View File

@ -14,7 +14,8 @@ type Pty struct {
tty *conpty.ConPty
}
func DownloadDependency() {
func DownloadDependency() error {
return nil
}
func getExecutableFilePath() (string, error) {

View File

@ -23,3 +23,9 @@ func Println(enabled bool, v ...interface{}) {
Logger.Infof("NEZHA@%s>> %v", time.Now().Format("2006-01-02 15:04:05"), fmt.Sprint(v...))
}
}
func Printf(enabled bool, format string, v ...interface{}) {
if enabled {
Logger.Infof("NEZHA@%s>> "+format, append([]interface{}{time.Now().Format("2006-01-02 15:04:05")}, v...)...)
}
}