refactor nat
This commit is contained in:
		
							parent
							
								
									4635bcf44f
								
							
						
					
					
						commit
						c9ec634857
					
				@ -6,20 +6,15 @@ import (
 | 
				
			|||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	jwt "github.com/appleboy/gin-jwt/v2"
 | 
						jwt "github.com/appleboy/gin-jwt/v2"
 | 
				
			||||||
	"github.com/gin-contrib/pprof"
 | 
						"github.com/gin-contrib/pprof"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/hashicorp/go-uuid"
 | 
					 | 
				
			||||||
	swaggerfiles "github.com/swaggo/files"
 | 
						swaggerfiles "github.com/swaggo/files"
 | 
				
			||||||
	ginSwagger "github.com/swaggo/gin-swagger"
 | 
						ginSwagger "github.com/swaggo/gin-swagger"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	docs "github.com/naiba/nezha/cmd/dashboard/docs"
 | 
						docs "github.com/naiba/nezha/cmd/dashboard/docs"
 | 
				
			||||||
	"github.com/naiba/nezha/model"
 | 
						"github.com/naiba/nezha/model"
 | 
				
			||||||
	"github.com/naiba/nezha/pkg/utils"
 | 
					 | 
				
			||||||
	"github.com/naiba/nezha/proto"
 | 
					 | 
				
			||||||
	"github.com/naiba/nezha/service/rpc"
 | 
					 | 
				
			||||||
	"github.com/naiba/nezha/service/singleton"
 | 
						"github.com/naiba/nezha/service/singleton"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -31,7 +26,6 @@ func ServeWeb() http.Handler {
 | 
				
			|||||||
		gin.SetMode(gin.DebugMode)
 | 
							gin.SetMode(gin.DebugMode)
 | 
				
			||||||
		pprof.Register(r)
 | 
							pprof.Register(r)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	r.Use(natGateway)
 | 
					 | 
				
			||||||
	if singleton.Conf.Debug {
 | 
						if singleton.Conf.Debug {
 | 
				
			||||||
		r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
 | 
							r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -83,67 +77,6 @@ func routers(r *gin.Engine) {
 | 
				
			|||||||
	auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS))
 | 
						auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func natGateway(c *gin.Context) {
 | 
					 | 
				
			||||||
	natConfig := singleton.GetNATConfigByDomain(c.Request.Host)
 | 
					 | 
				
			||||||
	if natConfig == nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	singleton.ServerLock.RLock()
 | 
					 | 
				
			||||||
	server := singleton.ServerList[natConfig.ServerID]
 | 
					 | 
				
			||||||
	singleton.ServerLock.RUnlock()
 | 
					 | 
				
			||||||
	if server == nil || server.TaskStream == nil {
 | 
					 | 
				
			||||||
		c.Writer.WriteString("server not found or not connected")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	streamId, err := uuid.GenerateUUID()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		c.Writer.WriteString(fmt.Sprintf("stream id error: %v", err))
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rpc.NezhaHandlerSingleton.CreateStream(streamId)
 | 
					 | 
				
			||||||
	defer rpc.NezhaHandlerSingleton.CloseStream(streamId)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	taskData, err := utils.Json.Marshal(model.TaskNAT{
 | 
					 | 
				
			||||||
		StreamID: streamId,
 | 
					 | 
				
			||||||
		Host:     natConfig.Host,
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		c.Writer.WriteString(fmt.Sprintf("task data error: %v", err))
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := server.TaskStream.Send(&proto.Task{
 | 
					 | 
				
			||||||
		Type: model.TaskTypeNAT,
 | 
					 | 
				
			||||||
		Data: string(taskData),
 | 
					 | 
				
			||||||
	}); err != nil {
 | 
					 | 
				
			||||||
		c.Writer.WriteString(fmt.Sprintf("send task error: %v", err))
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	w, err := utils.NewRequestWrapper(c.Request, c.Writer)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		c.Writer.WriteString(fmt.Sprintf("request wrapper error: %v", err))
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := rpc.NezhaHandlerSingleton.UserConnected(streamId, w); err != nil {
 | 
					 | 
				
			||||||
		c.Writer.WriteString(fmt.Sprintf("user connected error: %v", err))
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
 | 
					 | 
				
			||||||
	c.Abort()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func recordPath(c *gin.Context) {
 | 
					func recordPath(c *gin.Context) {
 | 
				
			||||||
	url := c.Request.URL.String()
 | 
						url := c.Request.URL.String()
 | 
				
			||||||
	for _, p := range c.Params {
 | 
						for _, p := range c.Params {
 | 
				
			||||||
 | 
				
			|||||||
@ -159,6 +159,11 @@ func dispatchReportInfoTask() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler {
 | 
					func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler {
 | 
				
			||||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							natConfig := singleton.GetNATConfigByDomain(r.Host)
 | 
				
			||||||
 | 
							if natConfig != nil {
 | 
				
			||||||
 | 
								rpc.ServeNAT(w, r, natConfig)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" &&
 | 
							if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" &&
 | 
				
			||||||
			strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) {
 | 
								strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) {
 | 
				
			||||||
			grpcHandler.ServeHTTP(w, r)
 | 
								grpcHandler.ServeHTTP(w, r)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,10 +1,16 @@
 | 
				
			|||||||
package rpc
 | 
					package rpc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"google.golang.org/grpc"
 | 
						"google.golang.org/grpc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/hashicorp/go-uuid"
 | 
				
			||||||
	"github.com/naiba/nezha/model"
 | 
						"github.com/naiba/nezha/model"
 | 
				
			||||||
	pb "github.com/naiba/nezha/proto"
 | 
						"github.com/naiba/nezha/pkg/utils"
 | 
				
			||||||
 | 
						"github.com/naiba/nezha/proto"
 | 
				
			||||||
	rpcService "github.com/naiba/nezha/service/rpc"
 | 
						rpcService "github.com/naiba/nezha/service/rpc"
 | 
				
			||||||
	"github.com/naiba/nezha/service/singleton"
 | 
						"github.com/naiba/nezha/service/singleton"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -12,7 +18,7 @@ import (
 | 
				
			|||||||
func ServeRPC() *grpc.Server {
 | 
					func ServeRPC() *grpc.Server {
 | 
				
			||||||
	server := grpc.NewServer()
 | 
						server := grpc.NewServer()
 | 
				
			||||||
	rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
 | 
						rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
 | 
				
			||||||
	pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
 | 
						proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
 | 
				
			||||||
	return server
 | 
						return server
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -69,7 +75,62 @@ func DispatchKeepalive() {
 | 
				
			|||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			singleton.SortedServerList[i].TaskStream.Send(&pb.Task{Type: model.TaskTypeKeepalive})
 | 
								singleton.SortedServerList[i].TaskStream.Send(&proto.Task{Type: model.TaskTypeKeepalive})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) {
 | 
				
			||||||
 | 
						singleton.ServerLock.RLock()
 | 
				
			||||||
 | 
						server := singleton.ServerList[natConfig.ServerID]
 | 
				
			||||||
 | 
						singleton.ServerLock.RUnlock()
 | 
				
			||||||
 | 
						if server == nil || server.TaskStream == nil {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
 | 
							w.Write([]byte("server not found or not connected"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						streamId, err := uuid.GenerateUUID()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
 | 
							w.Write([]byte(fmt.Sprintf("stream id error: %v", err)))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rpcService.NezhaHandlerSingleton.CreateStream(streamId)
 | 
				
			||||||
 | 
						defer rpcService.NezhaHandlerSingleton.CloseStream(streamId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						taskData, err := utils.Json.Marshal(model.TaskNAT{
 | 
				
			||||||
 | 
							StreamID: streamId,
 | 
				
			||||||
 | 
							Host:     natConfig.Host,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
 | 
							w.Write([]byte(fmt.Sprintf("task data error: %v", err)))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := server.TaskStream.Send(&proto.Task{
 | 
				
			||||||
 | 
							Type: model.TaskTypeNAT,
 | 
				
			||||||
 | 
							Data: string(taskData),
 | 
				
			||||||
 | 
						}); err != nil {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
 | 
							w.Write([]byte(fmt.Sprintf("send task error: %v", err)))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wWrapped, err := utils.NewRequestWrapper(r, w)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
 | 
							w.Write([]byte(fmt.Sprintf("request wrapper error: %v", err)))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := rpcService.NezhaHandlerSingleton.UserConnected(streamId, wWrapped); err != nil {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
 | 
							w.Write([]byte(fmt.Sprintf("user connected error: %v", err)))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rpcService.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -2,11 +2,10 @@ package utils
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var _ io.ReadWriteCloser = &RequestWrapper{}
 | 
					var _ io.ReadWriteCloser = &RequestWrapper{}
 | 
				
			||||||
@ -17,8 +16,12 @@ type RequestWrapper struct {
 | 
				
			|||||||
	writer net.Conn
 | 
						writer net.Conn
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewRequestWrapper(req *http.Request, writer gin.ResponseWriter) (*RequestWrapper, error) {
 | 
					func NewRequestWrapper(req *http.Request, writer http.ResponseWriter) (*RequestWrapper, error) {
 | 
				
			||||||
	conn, _, err := writer.Hijack()
 | 
						hj, ok := writer.(http.Hijacker)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil, errors.New("http server does not support hijacking")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						conn, _, err := hj.Hijack()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user