feat: tcp/icmp ping 提前取 IP

This commit is contained in:
naiba 2024-07-28 15:05:30 +08:00
parent 7ad3094cee
commit 80ef923a38
6 changed files with 61 additions and 22 deletions

View File

@ -57,11 +57,11 @@ type AgentCliParam struct {
}
var (
version string
arch string
client pb.NezhaServiceClient
inited bool
geoip *pb.GeoIP
version string
arch string
client pb.NezhaServiceClient
inited bool
resolver = &net.Resolver{PreferGo: true}
)
var agentCmd = &cobra.Command{
@ -427,7 +427,13 @@ func reportState() {
if lastReportHostInfo.Before(time.Now().Add(-10 * time.Minute)) {
lastReportHostInfo = time.Now()
client.ReportSystemInfo(context.Background(), monitor.GetHost().PB())
geoip, _ = client.LookupGeoIP(context.Background(), &pb.GeoIP{Ip: monitor.QueryIP})
if monitor.GeoQueryIPChanged {
geoip, err := client.LookupGeoIP(context.Background(), &pb.GeoIP{Ip: monitor.GeoQueryIP})
if err == nil {
monitor.GeoQueryIPChanged = false
monitor.CachedCountryCode = geoip.GetCountryCode()
}
}
}
}
time.Sleep(time.Second * time.Duration(agentCliParam.ReportDelay))
@ -436,8 +442,7 @@ func reportState() {
// doSelfUpdate 执行更新检查 如果更新成功则会结束进程
func doSelfUpdate(useLocalVersion bool) {
code := geoip.GetCountryCode()
if code == "" {
if monitor.CachedCountryCode == "" {
return
}
v := semver.MustParse("0.1.0")
@ -447,7 +452,7 @@ func doSelfUpdate(useLocalVersion bool) {
println("检查更新:", v)
var latest *selfupdate.Release
var err error
if code != "cn" && !agentCliParam.UseGiteeToUpgrade {
if monitor.CachedCountryCode != "cn" && !agentCliParam.UseGiteeToUpgrade {
latest, err = selfupdate.UpdateSelf(v, "nezhahq/agent")
} else {
latest, err = selfupdate.UpdateSelfGitee(v, "naibahq/agent")
@ -470,8 +475,13 @@ func handleUpgradeTask(*pb.Task, *pb.TaskResult) {
}
func handleTcpPingTask(task *pb.Task, result *pb.TaskResult) {
ipAddr, err := lookupIP(task.GetData())
if err != nil {
result.Data = err.Error()
return
}
start := time.Now()
conn, err := net.DialTimeout("tcp", task.GetData(), time.Second*10)
conn, err := net.DialTimeout("tcp", ipAddr, time.Second*10)
if err == nil {
conn.Write([]byte("ping\n"))
conn.Close()
@ -483,7 +493,12 @@ func handleTcpPingTask(task *pb.Task, result *pb.TaskResult) {
}
func handleIcmpPingTask(task *pb.Task, result *pb.TaskResult) {
pinger, err := ping.NewPinger(task.GetData())
ipAddr, err := lookupIP(task.GetData())
if err != nil {
result.Data = err.Error()
return
}
pinger, err := ping.NewPinger(ipAddr)
if err == nil {
pinger.SetPrivileged(true)
pinger.Count = 5
@ -788,3 +803,17 @@ func generateQueue(start int, size int) []int {
}
return result
}
func lookupIP(hostOrIp string) (string, error) {
if net.ParseIP(hostOrIp) == nil {
ips, err := resolver.LookupIPAddr(context.Background(), hostOrIp)
if err != nil {
return "", err
}
if len(ips) == 0 {
return "", fmt.Errorf("无法解析 %s", hostOrIp)
}
return ips[0].IP.String(), nil
}
return hostOrIp, nil
}

View File

@ -1,6 +1,7 @@
package main
import (
"fmt"
"reflect"
"testing"
)
@ -23,3 +24,11 @@ func Test(t *testing.T) {
}
}
}
func TestLookupIP(t *testing.T) {
ip, err := lookupIP("www.google.com")
fmt.Printf("ip: %v, err: %v\n", ip, err)
if err != nil {
t.Errorf("lookupIP failed: %v", err)
}
}

2
go.mod
View File

@ -21,6 +21,7 @@ require (
github.com/shirou/gopsutil/v4 v4.24.6
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
golang.org/x/sys v0.22.0
google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.2
sigs.k8s.io/yaml v1.4.0
@ -83,7 +84,6 @@ require (
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.20.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/term v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect

4
go.sum
View File

@ -228,8 +228,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View File

@ -153,9 +153,6 @@ func GetHost() *model.Host {
ret.IP = CachedIP
ret.Version = Version
// stub
ret.CountryCode = ""
return &ret
}

View File

@ -18,9 +18,10 @@ var (
"https://dash.cloudflare.com/cdn-cgi/trace",
"https://cf-ns.com/cdn-cgi/trace", // 有国内节点
}
CachedIP, QueryIP string
httpClientV4 = util.NewSingleStackHTTPClient(time.Second*20, time.Second*5, time.Second*10, false)
httpClientV6 = util.NewSingleStackHTTPClient(time.Second*20, time.Second*5, time.Second*10, true)
CachedIP, GeoQueryIP, CachedCountryCode string
GeoQueryIPChanged bool
httpClientV4 = util.NewSingleStackHTTPClient(time.Second*20, time.Second*5, time.Second*10, false)
httpClientV6 = util.NewSingleStackHTTPClient(time.Second*20, time.Second*5, time.Second*10, true)
)
// UpdateIP 按设置时间间隔更新IP地址的缓存
@ -44,11 +45,14 @@ func UpdateIP(useIPv6CountryCode bool, period uint32) {
CachedIP = fmt.Sprintf("%s/%s", ipv4, ipv6)
}
var newIP string
if !useIPv6CountryCode {
QueryIP = ipv4
newIP = ipv4
} else {
QueryIP = ipv6
newIP = ipv6
}
GeoQueryIPChanged = newIP != GeoQueryIP
GeoQueryIP = newIP
time.Sleep(time.Second * time.Duration(period))
}