2024-08-21 22:22:06 +08:00

161 lines
3.6 KiB
Go

package fm
import (
"bytes"
"encoding/binary"
"errors"
"io"
"io/fs"
"os"
"os/user"
"path/filepath"
pb "github.com/nezhahq/agent/proto"
)
type Task struct {
taskClient pb.NezhaService_IOStreamClient
printf func(string, ...interface{})
remoteData *pb.IOStreamData
}
func NewFMClient(client pb.NezhaService_IOStreamClient, printFunc func(string, ...interface{})) *Task {
return &Task{
taskClient: client,
printf: printFunc,
}
}
func (t *Task) DoTask(data *pb.IOStreamData) {
t.remoteData = data
switch t.remoteData.Data[0] {
case 0:
t.listDir()
case 1:
go t.download()
case 2:
t.upload()
}
}
func (t *Task) listDir() {
dir := string(t.remoteData.Data[1:])
var entries []fs.DirEntry
var err error
for {
entries, err = os.ReadDir(dir)
if err != nil {
usr, err := user.Current()
if err != nil {
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
dir = usr.HomeDir + string(filepath.Separator)
continue
}
break
}
var buffer bytes.Buffer
td := Create(&buffer, dir)
for _, e := range entries {
newBin := AppendFileName(td, e.Name(), e.IsDir())
td = newBin
}
t.taskClient.Send(&pb.IOStreamData{Data: td})
}
func (t *Task) download() {
path := string(t.remoteData.Data[1:])
file, err := os.Open(path)
if err != nil {
t.printf("Error opening file: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
defer file.Close()
fileInfo, err := file.Stat()
if err != nil {
t.printf("Error getting file info: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
fileSize := fileInfo.Size()
if fileSize <= 0 {
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(errors.New("requested file is empty"))})
return
}
// Send header (12 bytes)
var header bytes.Buffer
headerData := CreateFile(&header, uint64(fileSize))
if err := t.taskClient.Send(&pb.IOStreamData{Data: headerData}); err != nil {
t.printf("Error sending file header: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
buffer := make([]byte, 1048576)
for {
n, err := file.Read(buffer)
if err != nil {
if err == io.EOF {
return
}
t.printf("Error reading file: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
if err := t.taskClient.Send(&pb.IOStreamData{Data: buffer[:n]}); err != nil {
t.printf("Error sending file chunk: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
}
}
func (t *Task) upload() {
if len(t.remoteData.Data) < 9 {
const err string = "data is invalid"
t.printf(err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(errors.New(err))})
return
}
fileSize := binary.BigEndian.Uint64(t.remoteData.Data[1:9])
path := string(t.remoteData.Data[9:])
file, err := os.Create(path)
if err != nil {
t.printf("Error creating file: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
defer file.Close()
totalReceived := uint64(0)
t.printf("receiving file: %s, size: %d", file.Name(), fileSize)
for totalReceived < fileSize {
if t.remoteData, err = t.taskClient.Recv(); err != nil {
t.printf("Error receiving data: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
bytesWritten, err := file.Write(t.remoteData.Data)
if err != nil {
t.printf("Error writing to file: %s", err)
t.taskClient.Send(&pb.IOStreamData{Data: CreateErr(err)})
return
}
totalReceived += uint64(bytesWritten)
}
t.printf("received file %s.", file.Name())
t.taskClient.Send(&pb.IOStreamData{Data: completeIdentifier}) // NZUP
}