123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- package middleware
- import (
- "net/http"
- "sync"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/spf13/viper"
- "go.uber.org/zap"
- "golang.org/x/time/rate"
- "projectName/pkg/log"
- )
- // visitor 包含一个速率限制器和最后一次被看到的时间
- type visitor struct {
- limiter *rate.Limiter
- lastSeen time.Time
- }
- // IPRateLimiter 包含速率限制器的配置和访问者列表
- type IPRateLimiter struct {
- visitors map[string]*visitor
- enabled bool
- rate rate.Limit
- burst int
- mu sync.Mutex
- logger *log.Logger
- }
- // NewIPRateLimiter 创建一个新的IP速率限制器实例
- func NewIPRateLimiter(conf *viper.Viper, logger *log.Logger) *IPRateLimiter {
- enabled := conf.GetBool("rate_limit.enabled")
- r := conf.GetFloat64("rate_limit.rate")
- b := conf.GetInt("rate_limit.burst")
- limiter := &IPRateLimiter{
- visitors: make(map[string]*visitor),
- enabled: enabled,
- rate: rate.Limit(r),
- burst: b,
- logger: logger,
- }
- // 启动一个后台goroutine来清理旧的条目
- go limiter.cleanupVisitors()
- return limiter
- }
- // getVisitor 返回给定IP的速率限制器,如果不存在则创建一个新的
- func (l *IPRateLimiter) getVisitor(ip string) *rate.Limiter {
- l.mu.Lock()
- defer l.mu.Unlock()
- v, exists := l.visitors[ip]
- if !exists {
- limiter := rate.NewLimiter(l.rate, l.burst)
- v = &visitor{limiter, time.Now()}
- l.visitors[ip] = v
- }
- // 更新最后一次看到的时间
- v.lastSeen = time.Now()
- return v.limiter
- }
- // cleanupVisitors 定期清理长时间未活动的访问者条目
- func (l *IPRateLimiter) cleanupVisitors() {
- for {
- time.Sleep(1 * time.Minute)
- l.mu.Lock()
- for ip, v := range l.visitors {
- if time.Since(v.lastSeen) > 3*time.Minute {
- delete(l.visitors, ip)
- }
- }
- l.mu.Unlock()
- }
- }
- // RateLimitMiddleware 创建一个IP速率限制中间件
- func (l *IPRateLimiter) RateLimitMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- if !l.enabled {
- c.Next()
- return
- }
- ip := getClientIP(c.Request)
- limiter := l.getVisitor(ip)
- if !limiter.Allow() {
- l.logger.WithContext(c).Warn("IP请求过于频繁,已限流", zap.String("ip", ip))
- c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
- "code": 429,
- "msg": "Too Many Requests",
- })
- return
- }
- c.Next()
- }
- }
|