diff --git a/cmd/agent/main.go b/cmd/agent/main.go index f6d310c..5a0c841 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "crypto/tls" "errors" @@ -11,7 +12,6 @@ import ( "net/http" "net/url" "os" - "os/exec" "path/filepath" "runtime" "strings" @@ -622,8 +622,7 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) { return } startedAt := time.Now() - var cmd *exec.Cmd - var endCh = make(chan struct{}) + endCh := make(chan struct{}) pg, err := processgroup.NewProcessExitGroup() if err != nil { // 进程组创建失败,直接退出 @@ -631,12 +630,14 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) { return } timeout := time.NewTimer(time.Hour * 2) - if util.IsWindows() { - cmd = exec.Command("cmd", "/c", task.GetData()) // #nosec - } else { - cmd = exec.Command("sh", "-c", task.GetData()) // #nosec - } + cmd := processgroup.NewCommand(task.GetData()) + var b bytes.Buffer + cmd.Stdout = &b cmd.Env = os.Environ() + if err = cmd.Start(); err != nil { + result.Data = err.Error() + return + } pg.AddProcess(cmd) go func() { select { @@ -648,12 +649,11 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) { timeout.Stop() } }() - output, err := cmd.Output() - if err != nil { - result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error()) + if err = cmd.Wait(); err != nil { + result.Data += fmt.Sprintf("%s\n%s", b.String(), err.Error()) } else { close(endCh) - result.Data = string(output) + result.Data = b.String() result.Successful = true } pg.Dispose() diff --git a/pkg/processgroup/process_group.go b/pkg/processgroup/process_group.go index 2bda9bd..49ad65f 100644 --- a/pkg/processgroup/process_group.go +++ b/pkg/processgroup/process_group.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package processgroup @@ -17,7 +16,33 @@ func NewProcessExitGroup() (ProcessExitGroup, error) { return ProcessExitGroup{}, nil } -func (g *ProcessExitGroup) killChildProcess(c *exec.Cmd) error { +func NewCommand(arg string) *exec.Cmd { + cmd := exec.Command("sh", "-c", arg) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + return cmd +} + +func (g *ProcessExitGroup) Dispose() error { + var wg sync.WaitGroup + wg.Add(len(g.cmds)) + + for _, c := range g.cmds { + go func(c *exec.Cmd) { + defer wg.Done() + killChildProcess(c) + }(c) + } + + wg.Wait() + return nil +} + +func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error { + g.cmds = append(g.cmds, cmd) + return nil +} + +func killChildProcess(c *exec.Cmd) { pgid, err := syscall.Getpgid(c.Process.Pid) if err != nil { // Fall-back on error. Kill the main process only. @@ -25,30 +50,5 @@ func (g *ProcessExitGroup) killChildProcess(c *exec.Cmd) error { } // Kill the whole process group. syscall.Kill(-pgid, syscall.SIGTERM) - return c.Wait() -} - -func (g *ProcessExitGroup) Dispose() []error { - var errors []error - mutex := new(sync.Mutex) - wg := new(sync.WaitGroup) - wg.Add(len(g.cmds)) - for _, c := range g.cmds { - go func(c *exec.Cmd) { - defer wg.Done() - if err := g.killChildProcess(c); err != nil { - mutex.Lock() - defer mutex.Unlock() - errors = append(errors, err) - } - }(c) - } - wg.Wait() - return errors -} - -func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error { - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - g.cmds = append(g.cmds, cmd) - return nil + c.Wait() } diff --git a/pkg/processgroup/process_group_windows.go b/pkg/processgroup/process_group_windows.go index 74f57b1..fb3c4c9 100644 --- a/pkg/processgroup/process_group_windows.go +++ b/pkg/processgroup/process_group_windows.go @@ -5,26 +5,80 @@ package processgroup import ( "fmt" "os/exec" + "unsafe" + + "golang.org/x/sys/windows" ) type ProcessExitGroup struct { - cmds []*exec.Cmd + cmds []*exec.Cmd + jobHandle windows.Handle + procs []windows.Handle } -func NewProcessExitGroup() (ProcessExitGroup, error) { - return ProcessExitGroup{}, nil -} - -func (g *ProcessExitGroup) Dispose() error { - for _, c := range g.cmds { - if err := exec.Command("taskkill", "/F", "/T", "/PID", fmt.Sprint(c.Process.Pid)).Run(); err != nil { - return err - } +func NewProcessExitGroup() (*ProcessExitGroup, error) { + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return nil, err } - return nil + + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + + _, err = windows.SetInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info))) + + return &ProcessExitGroup{jobHandle: job}, nil +} + +func NewCommand(args string) *exec.Cmd { + cmd := exec.Command("cmd") + cmd.SysProcAttr = &windows.SysProcAttr{ + CmdLine: fmt.Sprintf("/c %s", args), + CreationFlags: windows.CREATE_NEW_PROCESS_GROUP, + } + return cmd } func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error { + proc, err := windows.OpenProcess(windows.PROCESS_TERMINATE|windows.PROCESS_SET_QUOTA|windows.PROCESS_SET_INFORMATION, false, uint32(cmd.Process.Pid)) + if err != nil { + return err + } + + g.procs = append(g.procs, proc) g.cmds = append(g.cmds, cmd) + + return windows.AssignProcessToJobObject(g.jobHandle, proc) +} + +func (g *ProcessExitGroup) Dispose() error { + defer func() { + windows.CloseHandle(g.jobHandle) + for _, proc := range g.procs { + windows.CloseHandle(proc) + } + }() + + if err := windows.TerminateJobObject(g.jobHandle, 1); err != nil { + // Fall-back on error. Kill the main process only. + for _, cmd := range g.cmds { + cmd.Process.Kill() + } + return err + } + + // wait for job to be terminated + status, err := windows.WaitForSingleObject(g.jobHandle, windows.INFINITE) + if status != windows.WAIT_OBJECT_0 { + return err + } + return nil } diff --git a/pkg/pty/pty.go b/pkg/pty/pty.go index 949a532..853b8a4 100644 --- a/pkg/pty/pty.go +++ b/pkg/pty/pty.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package pty