limiter.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package middleware
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/go-nunu/nunu-layout-advanced/pkg/limiter"
  5. "net"
  6. "net/http"
  7. "strings"
  8. "sync"
  9. )
  10. // IPPathRateLimitMiddleware 根据 IP + 路径创建限流中间件
  11. func IPPathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
  12. buckets := sync.Map{}
  13. return func(c *gin.Context) {
  14. ip := getClientIP(c.Request)
  15. path := c.FullPath()
  16. if path == "" {
  17. path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径
  18. }
  19. key := ip + ":" + path
  20. val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
  21. tb := val.(*limiter.TokenBucket)
  22. if tb.Allow() {
  23. c.Next()
  24. } else {
  25. c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
  26. "code": 429,
  27. "msg": "请求太频繁,请稍后再试",
  28. })
  29. }
  30. }
  31. }
  32. // IPRateLimitMiddleware 仅根据 IP 创建限流中间件
  33. func IPRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
  34. buckets := sync.Map{}
  35. return func(c *gin.Context) {
  36. ip := getClientIP(c.Request)
  37. key := ip
  38. val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
  39. tb := val.(*limiter.TokenBucket)
  40. if tb.Allow() {
  41. c.Next()
  42. } else {
  43. c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
  44. "code": 429,
  45. "msg": "请求太频繁,请稍后再试",
  46. })
  47. }
  48. }
  49. }
  50. // PathRateLimitMiddleware 仅根据路径创建限流中间件
  51. func PathRateLimitMiddleware(cfg *limiter.RateLimitConfig) gin.HandlerFunc {
  52. buckets := sync.Map{}
  53. return func(c *gin.Context) {
  54. path := c.FullPath()
  55. if path == "" {
  56. path = c.Request.URL.Path // 如果 FullPath 为空(未使用路由组),则使用原始路径
  57. }
  58. key := path
  59. val, _ := buckets.LoadOrStore(key, limiter.NewTokenBucket(cfg.Capacity, cfg.FillRate))
  60. tb := val.(*limiter.TokenBucket)
  61. if tb.Allow() {
  62. c.Next()
  63. } else {
  64. c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
  65. "code": 429,
  66. "msg": "请求太频繁,请稍后再试",
  67. })
  68. }
  69. }
  70. }
  71. // getClientIP 获取客户端 IP 地址
  72. func getClientIP(r *http.Request) string {
  73. ip := r.Header.Get("X-Forwarded-For")
  74. if ip != "" {
  75. parts := strings.Split(ip, ",")
  76. return strings.TrimSpace(parts[0])
  77. }
  78. ip, _, _ = net.SplitHostPort(r.RemoteAddr)
  79. return ip
  80. }
  81. // 提供全局默认限流中间件,用于Wire依赖注入
  82. func NewRateLimitMiddleware(limiterInstance *limiter.Limiter) gin.HandlerFunc {
  83. return IPPathRateLimitMiddleware(limiterInstance.Config)
  84. }