package limiter import ( "github.com/spf13/viper" "sync" "time" ) // RateLimitConfig 限流配置 type RateLimitConfig struct { Capacity int // 桶容量 FillRate int // 每秒填充速率 } // NewRateLimitConfig 创建新的限流配置 func NewRateLimitConfig(capacity, fillRate int) *RateLimitConfig { return &RateLimitConfig{ Capacity: capacity, FillRate: fillRate, } } // TokenBucket 令牌桶实现 type TokenBucket struct { capacity int tokens int fillInterval time.Duration mutex sync.Mutex } // NewTokenBucket 创建新的令牌桶 func NewTokenBucket(capacity int, fillRate int) *TokenBucket { tb := &TokenBucket{ capacity: capacity, tokens: capacity, fillInterval: time.Second / time.Duration(fillRate), } go func() { ticker := time.NewTicker(tb.fillInterval) for range ticker.C { tb.mutex.Lock() if tb.tokens < tb.capacity { tb.tokens++ } tb.mutex.Unlock() } }() return tb } // Allow 消耗令牌,返回是否允许请求通过 func (tb *TokenBucket) Allow() bool { tb.mutex.Lock() defer tb.mutex.Unlock() if tb.tokens > 0 { tb.tokens-- return true } return false } // Limiter 限流器实例 type Limiter struct { Config *RateLimitConfig conf *viper.Viper // 保存配置对象便于查询特定API的限流设置 } // NewLimiter 创建限流器实例,用于Wire依赖注入 func NewLimiter(conf *viper.Viper) *Limiter { // 从配置文件读取限流设置,如果没有则使用默认值 capacity := conf.GetInt("limiter.capacity") fillRate := conf.GetInt("limiter.fillRate") if capacity <= 0 { capacity = 100 // 默认容量 } if fillRate <= 0 { fillRate = 10 // 默认每秒填充速率 } return &Limiter{ Config: NewRateLimitConfig(capacity, fillRate), conf: conf, } } // GetAPIConfig 获取特定API的限流配置 // 如果配置文件中不存在特定API的配置,则返回默认的全局配置 func (l *Limiter) GetAPIConfig(apiName string) *RateLimitConfig { // 尝试从配置文件中读取特定API的配置 capacity := l.conf.GetInt("limiter.api." + apiName + ".capacity") fillRate := l.conf.GetInt("limiter.api." + apiName + ".fillRate") // 如果配置有效,返回特定API的配置 if capacity > 0 && fillRate > 0 { return NewRateLimitConfig(capacity, fillRate) } // 如果没有特定配置,返回全局配置 return l.Config }