package middleware import ( "github.com/gin-gonic/gin" "github.com/go-nunu/nunu-layout-advanced/pkg/limiter" "net" "net/http" "strings" "sync" ) // IPPathRateLimitMiddleware 根据 IP + 路径创建限流中间件 func IPPathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc { buckets := sync.Map{} return func(c *gin.Context) { ip := getClientIP(c.Request) path := c.FullPath() if path == "" { path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径 } key := ip + ":" + path val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate)) tb := val.(*limiter.TokenBucket) if tb.Allow() { c.Next() } else { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "code": 429, "msg": "请求太频繁,请稍后再试", }) } } } // IPRateLimitMiddleware 仅根据 IP 创建限流中间件 func IPRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc { buckets := sync.Map{} return func(c *gin.Context) { ip := getClientIP(c.Request) key := ip val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate)) tb := val.(*limiter.TokenBucket) if tb.Allow() { c.Next() } else { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "code": 429, "msg": "请求太频繁,请稍后再试", }) } } } // PathRateLimitMiddleware 仅根据路径创建限流中间件 func PathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc { buckets := sync.Map{} return func(c *gin.Context) { path := c.FullPath() if path == "" { path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径 } key := path val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate)) tb := val.(*limiter.TokenBucket) if tb.Allow() { c.Next() } else { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "code": 429, "msg": "请求太频繁,请稍后再试", }) } } } // getClientIP 获取客户端 IP 地址 func getClientIP(r *http.Request) string { ip := r.Header.Get("X-Forwarded-For") if ip != "" { parts := strings.Split(ip, ",") return strings.TrimSpace(parts[0]) } ip, _, _ = net.SplitHostPort(r.RemoteAddr) return ip } // 提供全局默认限流中间件,用于Wire依赖注入 func NewRateLimitMiddleware(limiterInstance *limiter.Limiter) gin.HandlerFunc { return IPPathRateLimitMiddleware(limiterInstance.Config) }