ip_ratelimit.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package middleware
  2. import (
  3. "net/http"
  4. "sync"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/spf13/viper"
  8. "go.uber.org/zap"
  9. "golang.org/x/time/rate"
  10. "projectName/pkg/log"
  11. )
  12. // visitor 包含一个速率限制器和最后一次被看到的时间
  13. type visitor struct {
  14. limiter *rate.Limiter
  15. lastSeen time.Time
  16. }
  17. // IPRateLimiter 包含速率限制器的配置和访问者列表
  18. type IPRateLimiter struct {
  19. visitors map[string]*visitor
  20. enabled bool
  21. rate rate.Limit
  22. burst int
  23. mu sync.Mutex
  24. logger *log.Logger
  25. }
  26. // NewIPRateLimiter 创建一个新的IP速率限制器实例
  27. func NewIPRateLimiter(conf *viper.Viper, logger *log.Logger) *IPRateLimiter {
  28. enabled := conf.GetBool("rate_limit.enabled")
  29. r := conf.GetFloat64("rate_limit.rate")
  30. b := conf.GetInt("rate_limit.burst")
  31. limiter := &IPRateLimiter{
  32. visitors: make(map[string]*visitor),
  33. enabled: enabled,
  34. rate: rate.Limit(r),
  35. burst: b,
  36. logger: logger,
  37. }
  38. // 启动一个后台goroutine来清理旧的条目
  39. go limiter.cleanupVisitors()
  40. return limiter
  41. }
  42. // getVisitor 返回给定IP的速率限制器,如果不存在则创建一个新的
  43. func (l *IPRateLimiter) getVisitor(ip string) *rate.Limiter {
  44. l.mu.Lock()
  45. defer l.mu.Unlock()
  46. v, exists := l.visitors[ip]
  47. if !exists {
  48. limiter := rate.NewLimiter(l.rate, l.burst)
  49. v = &visitor{limiter, time.Now()}
  50. l.visitors[ip] = v
  51. }
  52. // 更新最后一次看到的时间
  53. v.lastSeen = time.Now()
  54. return v.limiter
  55. }
  56. // cleanupVisitors 定期清理长时间未活动的访问者条目
  57. func (l *IPRateLimiter) cleanupVisitors() {
  58. for {
  59. time.Sleep(1 * time.Minute)
  60. l.mu.Lock()
  61. for ip, v := range l.visitors {
  62. if time.Since(v.lastSeen) > 3*time.Minute {
  63. delete(l.visitors, ip)
  64. }
  65. }
  66. l.mu.Unlock()
  67. }
  68. }
  69. // RateLimitMiddleware 创建一个IP速率限制中间件
  70. func (l *IPRateLimiter) RateLimitMiddleware() gin.HandlerFunc {
  71. return func(c *gin.Context) {
  72. if !l.enabled {
  73. c.Next()
  74. return
  75. }
  76. ip := getClientIP(c.Request)
  77. limiter := l.getVisitor(ip)
  78. if !limiter.Allow() {
  79. l.logger.WithContext(c).Warn("IP请求过于频繁,已限流", zap.String("ip", ip))
  80. c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
  81. "code": 429,
  82. "msg": "Too Many Requests",
  83. })
  84. return
  85. }
  86. c.Next()
  87. }
  88. }