limiter.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package limiter
  2. import (
  3. "github.com/spf13/viper"
  4. "sync"
  5. "time"
  6. )
  7. // RateLimitConfig 限流配置
  8. type RateLimitConfig struct {
  9. Capacity int // 桶容量
  10. FillRate int // 每秒填充速率
  11. }
  12. // NewRateLimitConfig 创建新的限流配置
  13. func NewRateLimitConfig(capacity, fillRate int) *RateLimitConfig {
  14. return &RateLimitConfig{
  15. Capacity: capacity,
  16. FillRate: fillRate,
  17. }
  18. }
  19. // TokenBucket 令牌桶实现
  20. type TokenBucket struct {
  21. capacity int
  22. tokens int
  23. fillInterval time.Duration
  24. mutex sync.Mutex
  25. }
  26. // NewTokenBucket 创建新的令牌桶
  27. func NewTokenBucket(capacity int, fillRate int) *TokenBucket {
  28. tb := &TokenBucket{
  29. capacity: capacity,
  30. tokens: capacity,
  31. fillInterval: time.Second / time.Duration(fillRate),
  32. }
  33. go func() {
  34. ticker := time.NewTicker(tb.fillInterval)
  35. for range ticker.C {
  36. tb.mutex.Lock()
  37. if tb.tokens < tb.capacity {
  38. tb.tokens++
  39. }
  40. tb.mutex.Unlock()
  41. }
  42. }()
  43. return tb
  44. }
  45. // Allow 消耗令牌,返回是否允许请求通过
  46. func (tb *TokenBucket) Allow() bool {
  47. tb.mutex.Lock()
  48. defer tb.mutex.Unlock()
  49. if tb.tokens > 0 {
  50. tb.tokens--
  51. return true
  52. }
  53. return false
  54. }
  55. // Limiter 限流器实例
  56. type Limiter struct {
  57. Config *RateLimitConfig
  58. conf *viper.Viper // 保存配置对象便于查询特定API的限流设置
  59. }
  60. // NewLimiter 创建限流器实例,用于Wire依赖注入
  61. func NewLimiter(conf *viper.Viper) *Limiter {
  62. // 从配置文件读取限流设置,如果没有则使用默认值
  63. capacity := conf.GetInt("limiter.capacity")
  64. fillRate := conf.GetInt("limiter.fillRate")
  65. if capacity <= 0 {
  66. capacity = 100 // 默认容量
  67. }
  68. if fillRate <= 0 {
  69. fillRate = 10 // 默认每秒填充速率
  70. }
  71. return &Limiter{
  72. Config: NewRateLimitConfig(capacity, fillRate),
  73. conf: conf,
  74. }
  75. }
  76. // GetAPIConfig 获取特定API的限流配置
  77. // 如果配置文件中不存在特定API的配置,则返回默认的全局配置
  78. func (l *Limiter) GetAPIConfig(apiName string) *RateLimitConfig {
  79. // 尝试从配置文件中读取特定API的配置
  80. capacity := l.conf.GetInt("limiter.api." + apiName + ".capacity")
  81. fillRate := l.conf.GetInt("limiter.api." + apiName + ".fillRate")
  82. // 如果配置有效,返回特定API的配置
  83. if capacity > 0 && fillRate > 0 {
  84. return NewRateLimitConfig(capacity, fillRate)
  85. }
  86. // 如果没有特定配置,返回全局配置
  87. return l.Config
  88. }