Parcourir la source

feat(limiter): 实现令牌桶限流功能并应用于登录和注册接口

- 新增 limiter包,实现令牌桶算法
- 新增中间件包,提供 IP、路径等不同维度的限流中间件
- 在 http 服务器中集成限流功能,应用于登录和注册接口
- 更新配置文件,添加全局和特定 API 的限流配置
- 通过 Wire依赖注入,将限流器和限流中间件集成到应用中
fusu il y a 3 mois
Parent
commit
cd216c5fbb

+ 10 - 0
cmd/server/wire/wire.go

@@ -6,11 +6,13 @@ package wire
 import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/handler"
 	"github.com/go-nunu/nunu-layout-advanced/internal/job"
+	"github.com/go-nunu/nunu-layout-advanced/internal/middleware"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/server"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/app"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/jwt"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/limiter"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/server/http"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/sid"
@@ -72,6 +74,13 @@ var jobSet = wire.NewSet(
 	job.NewJob,
 	job.NewUserJob,
 )
+
+// 限流器依赖集
+var limiterSet = wire.NewSet(
+	limiter.NewLimiter,
+	middleware.NewRateLimitMiddleware,
+)
+
 var serverSet = wire.NewSet(
 	server.NewHTTPServer,
 	server.NewJobServer,
@@ -96,6 +105,7 @@ func NewWire(*viper.Viper, *log.Logger) (*app.App, func(), error) {
 		handlerSet,
 		jobSet,
 		serverSet,
+		limiterSet,
 		sid.NewSid,
 		jwt.NewJwt,
 		newApp,

+ 8 - 1
cmd/server/wire/wire_gen.go

@@ -9,11 +9,13 @@ package wire
 import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/handler"
 	"github.com/go-nunu/nunu-layout-advanced/internal/job"
+	"github.com/go-nunu/nunu-layout-advanced/internal/middleware"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/server"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/app"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/jwt"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/limiter"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/server/http"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/sid"
@@ -25,6 +27,8 @@ import (
 
 func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), error) {
 	jwtJWT := jwt.NewJwt(viperViper)
+	limiterLimiter := limiter.NewLimiter(viperViper)
+	handlerFunc := middleware.NewRateLimitMiddleware(limiterLimiter)
 	handlerHandler := handler.NewHandler(logger)
 	db := repository.NewDB(viperViper, logger)
 	repositoryRepository := repository.NewRepository(logger, db)
@@ -63,7 +67,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	udpLimitRepository := repository.NewUdpLimitRepository(repositoryRepository)
 	udpLimitService := service.NewUdpLimitService(serviceService, udpLimitRepository, requiredService, crawlerService, parserService)
 	udpLimitHandler := handler.NewUdpLimitHandler(handlerHandler, udpLimitService)
-	httpServer := server.NewHTTPServer(logger, viperViper, jwtJWT, userHandler, gameShieldHandler, webForwardingHandler, webLimitHandler, tcpforwardingHandler, udpForWardingHandler, tcpLimitHandler, udpLimitHandler)
+	httpServer := server.NewHTTPServer(logger, viperViper, jwtJWT, limiterLimiter, handlerFunc, userHandler, gameShieldHandler, webForwardingHandler, webLimitHandler, tcpforwardingHandler, udpForWardingHandler, tcpLimitHandler, udpLimitHandler)
 	jobJob := job.NewJob(transaction, logger, sidSid)
 	userJob := job.NewUserJob(jobJob, userRepository)
 	jobServer := server.NewJobServer(logger, userJob)
@@ -82,6 +86,9 @@ var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler
 
 var jobSet = wire.NewSet(job.NewJob, job.NewUserJob)
 
+// 限流器依赖集
+var limiterSet = wire.NewSet(limiter.NewLimiter, middleware.NewRateLimitMiddleware)
+
 var serverSet = wire.NewSet(server.NewHTTPServer, server.NewJobServer)
 
 // build App

+ 14 - 0
config/local.yml

@@ -41,3 +41,17 @@ crawler:
   password: "mr7c6r61jIRLGhcnT5j9"
   Url: "http://api.hongxingdun.net:8700/"
   keyUrl: "http://api.hongxingdun.net:13350/sdk/key?app_name="
+
+# 令牌桶限流配置
+limiter:
+  # 全局限流配置
+  capacity: 20    # 令牌桶容量(允许的突发请求数)
+  fillRate: 5    # 每秒填充速率(QPS)
+  # 特定API限流配置
+  api:
+    login:         # 登录接口限流
+      capacity: 20
+      fillRate: 2
+    register:      # 注册接口限流
+      capacity: 50
+      fillRate: 5

+ 99 - 0
internal/middleware/limiter.go

@@ -0,0 +1,99 @@
+package middleware
+
+import (
+	"github.com/gin-gonic/gin"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/limiter"
+	"net"
+	"net/http"
+	"strings"
+	"sync"
+)
+
+// IPPathRateLimitMiddleware 根据 IP + 路径创建限流中间件
+func IPPathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
+	buckets := sync.Map{}
+
+	return func(c *gin.Context) {
+		ip := getClientIP(c.Request)
+		path := c.FullPath()
+		if path == "" {
+			path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径
+		}
+		key := ip + ":" + path
+
+		val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
+		tb := val.(*limiter.TokenBucket)
+
+		if tb.Allow() {
+			c.Next()
+		} else {
+			c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
+				"code": 429,
+				"msg":  "请求太频繁,请稍后再试",
+			})
+		}
+	}
+}
+
+// IPRateLimitMiddleware 仅根据 IP 创建限流中间件
+func IPRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
+	buckets := sync.Map{}
+
+	return func(c *gin.Context) {
+		ip := getClientIP(c.Request)
+		key := ip
+
+		val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
+		tb := val.(*limiter.TokenBucket)
+
+		if tb.Allow() {
+			c.Next()
+		} else {
+			c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
+				"code": 429,
+				"msg":  "请求太频繁,请稍后再试",
+			})
+		}
+	}
+}
+
+// PathRateLimitMiddleware 仅根据路径创建限流中间件
+func PathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
+	buckets := sync.Map{}
+
+	return func(c *gin.Context) {
+		path := c.FullPath()
+		if path == "" {
+			path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径
+		}
+		key := path
+
+		val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
+		tb := val.(*limiter.TokenBucket)
+
+		if tb.Allow() {
+			c.Next()
+		} else {
+			c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
+				"code": 429,
+				"msg":  "请求太频繁,请稍后再试",
+			})
+		}
+	}
+}
+
+// getClientIP 获取客户端 IP 地址
+func getClientIP(r *http.Request) string {
+	ip := r.Header.Get("X-Forwarded-For")
+	if ip != "" {
+		parts := strings.Split(ip, ",")
+		return strings.TrimSpace(parts[0])
+	}
+	ip, _, _ = net.SplitHostPort(r.RemoteAddr)
+	return ip
+}
+
+// 提供全局默认限流中间件,用于Wire依赖注入
+func NewRateLimitMiddleware(limiterInstance *limiter.Limiter) gin.HandlerFunc {
+	return IPPathRateLimitMiddleware(limiterInstance.Config)
+}

+ 13 - 2
internal/server/http.go

@@ -7,6 +7,7 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/handler"
 	"github.com/go-nunu/nunu-layout-advanced/internal/middleware"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/jwt"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/limiter"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/server/http"
 	"github.com/spf13/viper"
@@ -18,6 +19,8 @@ func NewHTTPServer(
 	logger *log.Logger,
 	conf *viper.Viper,
 	jwt *jwt.JWT,
+	limiterInstance *limiter.Limiter,
+	rateLimitMiddleware gin.HandlerFunc,
 	userHandler *handler.UserHandler,
 	gameShieldHandler *handler.GameShieldHandler,
 	webForwardingHandler *handler.WebForwardingHandler,
@@ -49,6 +52,7 @@ func NewHTTPServer(
 		middleware.ResponseLogMiddleware(logger),
 		middleware.RequestLogMiddleware(logger),
 		//middleware.SignMiddleware(log),
+		rateLimitMiddleware,
 	)
 	s.GET("/", func(ctx *gin.Context) {
 		logger.WithContext(ctx).Info("hello")
@@ -62,8 +66,15 @@ func NewHTTPServer(
 		// No route group has permission
 		noAuthRouter := v1.Group("/")
 		{
-			noAuthRouter.POST("/register", userHandler.Register)
-			noAuthRouter.POST("/login", userHandler.Login)
+			// 使用增强的Limiter.GetAPIConfig方法获取特定API的限流配置
+
+			// 登录API限流
+			loginConfig := limiterInstance.GetAPIConfig("login")
+			noAuthRouter.POST("/login", middleware.IPRateLimitMiddleware(loginConfig), userHandler.Login)
+
+			// 注册API限流
+			registerConfig := limiterInstance.GetAPIConfig("register")
+			noAuthRouter.POST("/register", middleware.IPRateLimitMiddleware(registerConfig), userHandler.Register)
 			noAuthRouter.POST("/gameShield/add", gameShieldHandler.SubmitGameShield)
 			noAuthRouter.POST("/gameShield/getField", gameShieldHandler.GetGameShieldField)
 			noAuthRouter.POST("/gameShield/getKey", gameShieldHandler.GetGameShieldKey)

+ 1 - 1
internal/service/gameshield.go

@@ -142,7 +142,7 @@ func (service *gameShieldService) EditGameShield(ctx context.Context, req *v1.Ga
 	}
 	sendUrl := service.Url + "admin/edit/rule"
 
-	dunName := strconv.Itoa(req.Uid) + "_" + strconv.FormatInt(time.Now().Unix(), 10)
+	dunName := strconv.Itoa(req.Uid) + "_" + strconv.FormatInt(time.Now().Unix(), 10) + "_" + req.AppName
 	formData := map[string]interface{}{
 		"app_name":             dunName,
 		"gateway_group_id":     4,

+ 103 - 0
pkg/limiter/limiter.go

@@ -0,0 +1,103 @@
+package limiter
+
+import (
+	"github.com/spf13/viper"
+	"sync"
+	"time"
+)
+
+// RateLimitConfig 限流配置
+type RateLimitConfig struct {
+	Capacity int // 桶容量
+	FillRate int // 每秒填充速率
+}
+
+// NewRateLimitConfig 创建新的限流配置
+func NewRateLimitConfig(capacity, fillRate int) *RateLimitConfig {
+	return &RateLimitConfig{
+		Capacity: capacity,
+		FillRate: fillRate,
+	}
+}
+
+// TokenBucket 令牌桶实现
+type TokenBucket struct {
+	capacity     int
+	tokens       int
+	fillInterval time.Duration
+	mutex        sync.Mutex
+}
+
+// NewTokenBucket 创建新的令牌桶
+func NewTokenBucket(capacity int, fillRate int) *TokenBucket {
+	tb := &TokenBucket{
+		capacity:     capacity,
+		tokens:       capacity,
+		fillInterval: time.Second / time.Duration(fillRate),
+	}
+
+	go func() {
+		ticker := time.NewTicker(tb.fillInterval)
+		for range ticker.C {
+			tb.mutex.Lock()
+			if tb.tokens < tb.capacity {
+				tb.tokens++
+			}
+			tb.mutex.Unlock()
+		}
+	}()
+	return tb
+}
+
+// Allow 消耗令牌,返回是否允许请求通过
+func (tb *TokenBucket) Allow() bool {
+	tb.mutex.Lock()
+	defer tb.mutex.Unlock()
+
+	if tb.tokens > 0 {
+		tb.tokens--
+		return true
+	}
+	return false
+}
+
+// Limiter 限流器实例
+type Limiter struct {
+	Config *RateLimitConfig
+	conf   *viper.Viper // 保存配置对象便于查询特定API的限流设置
+}
+
+// NewLimiter 创建限流器实例,用于Wire依赖注入
+func NewLimiter(conf *viper.Viper) *Limiter {
+	// 从配置文件读取限流设置,如果没有则使用默认值
+	capacity := conf.GetInt("limiter.capacity")
+	fillRate := conf.GetInt("limiter.fillRate")
+
+	if capacity <= 0 {
+		capacity = 100 // 默认容量
+	}
+	if fillRate <= 0 {
+		fillRate = 10 // 默认每秒填充速率
+	}
+
+	return &Limiter{
+		Config: NewRateLimitConfig(capacity, fillRate),
+		conf:   conf,
+	}
+}
+
+// GetAPIConfig 获取特定API的限流配置
+// 如果配置文件中不存在特定API的配置,则返回默认的全局配置
+func (l *Limiter) GetAPIConfig(apiName string) *RateLimitConfig {
+	// 尝试从配置文件中读取特定API的配置
+	capacity := l.conf.GetInt("limiter.api." + apiName + ".capacity")
+	fillRate := l.conf.GetInt("limiter.api." + apiName + ".fillRate")
+
+	// 如果配置有效,返回特定API的配置
+	if capacity > 0 && fillRate > 0 {
+		return NewRateLimitConfig(capacity, fillRate)
+	}
+
+	// 如果没有特定配置,返回全局配置
+	return l.Config
+}