feat: waf 🤡
This commit is contained in:
		
							parent
							
								
									d699d0ee87
								
							
						
					
					
						commit
						17b02640a9
					
				@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
@ -16,6 +15,7 @@ import (
 | 
			
		||||
	swaggerfiles "github.com/swaggo/files"
 | 
			
		||||
	ginSwagger "github.com/swaggo/gin-swagger"
 | 
			
		||||
 | 
			
		||||
	"github.com/naiba/nezha/cmd/dashboard/controller/waf"
 | 
			
		||||
	docs "github.com/naiba/nezha/cmd/dashboard/docs"
 | 
			
		||||
	"github.com/naiba/nezha/model"
 | 
			
		||||
	"github.com/naiba/nezha/service/singleton"
 | 
			
		||||
@ -34,39 +34,14 @@ func ServeWeb() http.Handler {
 | 
			
		||||
		r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.Use(realIp)
 | 
			
		||||
	r.Use(waf.RealIp)
 | 
			
		||||
	r.Use(waf.Waf)
 | 
			
		||||
	r.Use(recordPath)
 | 
			
		||||
	routers(r)
 | 
			
		||||
 | 
			
		||||
	return r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func realIp(c *gin.Context) {
 | 
			
		||||
	if singleton.Conf.RealIPHeader == "" {
 | 
			
		||||
		c.Next()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
 | 
			
		||||
		c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
 | 
			
		||||
		c.Next()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
 | 
			
		||||
	if vals == "" {
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	ip, err := netip.ParseAddr(vals)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.Set(model.CtxKeyRealIPStr, ip.String())
 | 
			
		||||
	c.Next()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func routers(r *gin.Engine) {
 | 
			
		||||
	authMiddleware, err := jwt.New(initParams())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@ -154,7 +129,6 @@ func routers(r *gin.Engine) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func recordPath(c *gin.Context) {
 | 
			
		||||
	log.Printf("bingo web real ip: %s", c.GetString(model.CtxKeyRealIPStr))
 | 
			
		||||
	url := c.Request.URL.String()
 | 
			
		||||
	for _, p := range c.Params {
 | 
			
		||||
		url = strings.Replace(url, p.Value, ":"+p.Key, 1)
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
 | 
			
		||||
	"github.com/naiba/nezha/cmd/dashboard/controller/waf"
 | 
			
		||||
	"github.com/naiba/nezha/model"
 | 
			
		||||
	"github.com/naiba/nezha/pkg/utils"
 | 
			
		||||
	"github.com/naiba/nezha/service/singleton"
 | 
			
		||||
@ -87,10 +88,12 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
 | 
			
		||||
 | 
			
		||||
		var user model.User
 | 
			
		||||
		if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
 | 
			
		||||
			model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
 | 
			
		||||
			return nil, jwt.ErrFailedAuthentication
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
 | 
			
		||||
			model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
 | 
			
		||||
			return nil, jwt.ErrFailedAuthentication
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -163,6 +166,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
 | 
			
		||||
		identity := mw.IdentityHandler(c)
 | 
			
		||||
 | 
			
		||||
		if identity != nil {
 | 
			
		||||
			if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
 | 
			
		||||
				waf.ShowBlockPage(c, err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			c.Set(mw.IdentityKey, identity)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										90
									
								
								cmd/dashboard/controller/waf/waf.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								cmd/dashboard/controller/waf/waf.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
			
		||||
package waf
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	_ "embed"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/naiba/nezha/model"
 | 
			
		||||
	"github.com/naiba/nezha/service/singleton"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//go:embed waf.html
 | 
			
		||||
var errorPageTemplate string
 | 
			
		||||
 | 
			
		||||
func RealIp(c *gin.Context) {
 | 
			
		||||
	if singleton.Conf.RealIPHeader == "" {
 | 
			
		||||
		c.Next()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
 | 
			
		||||
		c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
 | 
			
		||||
		c.Next()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
 | 
			
		||||
	if vals == "" {
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	ip, err := netip.ParseAddr(vals)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.Set(model.CtxKeyRealIPStr, ip.String())
 | 
			
		||||
	c.Next()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Waf(c *gin.Context) {
 | 
			
		||||
	if singleton.Conf.RealIPHeader == "" {
 | 
			
		||||
		c.Next()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	realipAddr := c.GetString(model.CtxKeyRealIPStr)
 | 
			
		||||
	if realipAddr == "" {
 | 
			
		||||
		c.Next()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var w model.WAF
 | 
			
		||||
	if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil {
 | 
			
		||||
		if err != gorm.ErrRecordNotFound {
 | 
			
		||||
			ShowBlockPage(c, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now().Unix()
 | 
			
		||||
	if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
 | 
			
		||||
		log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
 | 
			
		||||
		ShowBlockPage(c, errors.New("you are blocked by nezha WAF"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.Next()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func pow(x, y uint64) uint64 {
 | 
			
		||||
	base := big.NewInt(0).SetUint64(x)
 | 
			
		||||
	exp := big.NewInt(0).SetUint64(y)
 | 
			
		||||
	result := big.NewInt(1)
 | 
			
		||||
	result.Exp(base, exp, nil)
 | 
			
		||||
	if !result.IsUint64() {
 | 
			
		||||
		return ^uint64(0) // return max uint64 value on overflow
 | 
			
		||||
	}
 | 
			
		||||
	return result.Uint64()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ShowBlockPage(c *gin.Context, err error) {
 | 
			
		||||
	c.Writer.WriteHeader(http.StatusForbidden)
 | 
			
		||||
	c.Header("Content-Type", "text/html; charset=utf-8")
 | 
			
		||||
	c.Writer.WriteString(strings.Replace(errorPageTemplate, "{error}", err.Error(), 1))
 | 
			
		||||
	c.Abort()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										39
									
								
								cmd/dashboard/controller/waf/waf.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								cmd/dashboard/controller/waf/waf.html
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
<!DOCTYPE html>
 | 
			
		||||
<html lang="en">
 | 
			
		||||
 | 
			
		||||
<head>
 | 
			
		||||
    <meta charset="UTF-8">
 | 
			
		||||
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 | 
			
		||||
    <title>Blocked</title>
 | 
			
		||||
    <style>
 | 
			
		||||
        body {
 | 
			
		||||
            display: flex;
 | 
			
		||||
            justify-content: center;
 | 
			
		||||
            align-items: center;
 | 
			
		||||
            height: 90vh;
 | 
			
		||||
            font-weight: bolder;
 | 
			
		||||
            font-family: 'Courier New', Courier, monospace;
 | 
			
		||||
        }
 | 
			
		||||
        main {
 | 
			
		||||
            text-align: center;
 | 
			
		||||
        }
 | 
			
		||||
        .emoji {
 | 
			
		||||
            font-size: 200px;
 | 
			
		||||
        }
 | 
			
		||||
        p.secondary {
 | 
			
		||||
            font-size: 12px;
 | 
			
		||||
            color: #888;
 | 
			
		||||
        }
 | 
			
		||||
    </style>
 | 
			
		||||
</head>
 | 
			
		||||
 | 
			
		||||
<body>
 | 
			
		||||
    <main>
 | 
			
		||||
        <div class="emoji">🤡</div>
 | 
			
		||||
        <h1>Blocked</h1>
 | 
			
		||||
        <p>{error}</p>
 | 
			
		||||
        <p class="secondary">nezha WAF</p>
 | 
			
		||||
    </main>
 | 
			
		||||
</body>
 | 
			
		||||
 | 
			
		||||
</html>
 | 
			
		||||
							
								
								
									
										29
									
								
								cmd/dashboard/controller/waf/waf_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								cmd/dashboard/controller/waf/waf_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,29 @@
 | 
			
		||||
package waf
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"math"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestPow(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		x,
 | 
			
		||||
		y,
 | 
			
		||||
		expect uint64
 | 
			
		||||
	}{
 | 
			
		||||
		{2, 64, math.MaxUint64},                 // 2 的 64 次方,超过 uint64 最大值
 | 
			
		||||
		{uint64(1 << 63), 2, math.MaxUint64},    // 大数平方,可能溢出
 | 
			
		||||
		{uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出
 | 
			
		||||
		{2, 3, 8},
 | 
			
		||||
		{5, 0, 1},
 | 
			
		||||
		{3, 1, 3},
 | 
			
		||||
		{0, 5, 0},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		result := pow(tt.x, tt.y)
 | 
			
		||||
		if result != tt.expect {
 | 
			
		||||
			t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"google.golang.org/grpc"
 | 
			
		||||
@ -48,7 +49,9 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
 | 
			
		||||
	if len(vals) == 0 {
 | 
			
		||||
		return nil, fmt.Errorf("real ip header not found")
 | 
			
		||||
	}
 | 
			
		||||
	ip, err := netip.ParseAddr(vals[0])
 | 
			
		||||
	a := strings.Split(vals[0], ",")
 | 
			
		||||
	h := strings.TrimSpace(a[len(a)-1])
 | 
			
		||||
	ip, err := netip.ParseAddr(h)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										40
									
								
								model/waf.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								model/waf.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	_ uint8 = iota
 | 
			
		||||
	WAFBlockReasonTypeLoginFail
 | 
			
		||||
	WAFBlockReasonTypeBruteForceToken
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type WAF struct {
 | 
			
		||||
	IP                 string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
 | 
			
		||||
	Count              uint64 `json:"count,omitempty"`
 | 
			
		||||
	LastBlockReason    uint8  `json:"last_block_reason,omitempty"`
 | 
			
		||||
	LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *WAF) TableName() string {
 | 
			
		||||
	return "waf"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
 | 
			
		||||
	if ip == "" {
 | 
			
		||||
		return errors.New("empty ip")
 | 
			
		||||
	}
 | 
			
		||||
	var w WAF
 | 
			
		||||
	w.LastBlockReason = reason
 | 
			
		||||
	w.LastBlockTimestamp = uint64(time.Now().Unix())
 | 
			
		||||
	return db.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@ -2,7 +2,6 @@ package rpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"log"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"google.golang.org/grpc/codes"
 | 
			
		||||
@ -25,9 +24,6 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
 | 
			
		||||
		return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	realIp := ctx.Value(model.CtxKeyRealIP{})
 | 
			
		||||
	log.Printf("bingo rpc realIp: %s, metadata: %v", realIp, md)
 | 
			
		||||
 | 
			
		||||
	var clientSecret string
 | 
			
		||||
	if value, ok := md["client_secret"]; ok {
 | 
			
		||||
		clientSecret = value[0]
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,8 @@ func InitDBFromPath(path string) {
 | 
			
		||||
	err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
 | 
			
		||||
		model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
 | 
			
		||||
		model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
 | 
			
		||||
		model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{})
 | 
			
		||||
		model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
 | 
			
		||||
		model.WAF{})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user