123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- package middleware
- import (
- "github.com/gin-gonic/gin"
- "github.com/go-nunu/nunu-layout-advanced/pkg/limiter"
- "net"
- "net/http"
- "strings"
- "sync"
- )
- // IPPathRateLimitMiddleware 根据 IP + 路径创建限流中间件
- func IPPathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
- buckets := sync.Map{}
- return func(c *gin.Context) {
- ip := getClientIP(c.Request)
- path := c.FullPath()
- if path == "" {
- path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径
- }
- key := ip + ":" + path
- val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
- tb := val.(*limiter.TokenBucket)
- if tb.Allow() {
- c.Next()
- } else {
- c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
- "code": 429,
- "msg": "请求太频繁,请稍后再试",
- })
- }
- }
- }
- // IPRateLimitMiddleware 仅根据 IP 创建限流中间件
- func IPRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
- buckets := sync.Map{}
- return func(c *gin.Context) {
- ip := getClientIP(c.Request)
- key := ip
- val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
- tb := val.(*limiter.TokenBucket)
- if tb.Allow() {
- c.Next()
- } else {
- c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
- "code": 429,
- "msg": "请求太频繁,请稍后再试",
- })
- }
- }
- }
- // PathRateLimitMiddleware 仅根据路径创建限流中间件
- func PathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
- buckets := sync.Map{}
- return func(c *gin.Context) {
- path := c.FullPath()
- if path == "" {
- path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径
- }
- key := path
- val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
- tb := val.(*limiter.TokenBucket)
- if tb.Allow() {
- c.Next()
- } else {
- c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
- "code": 429,
- "msg": "请求太频繁,请稍后再试",
- })
- }
- }
- }
- // getClientIP 获取客户端 IP 地址
- func getClientIP(r *http.Request) string {
- ip := r.Header.Get("X-Forwarded-For")
- if ip != "" {
- parts := strings.Split(ip, ",")
- return strings.TrimSpace(parts[0])
- }
- ip, _, _ = net.SplitHostPort(r.RemoteAddr)
- return ip
- }
- // 提供全局默认限流中间件,用于Wire依赖注入
- func NewRateLimitMiddleware(limiterInstance *limiter.Limiter) gin.HandlerFunc {
- return IPPathRateLimitMiddleware(limiterInstance.Config)
- }
|