ip_whitelist.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package middleware
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/spf13/viper"
  5. "go.uber.org/zap"
  6. "net"
  7. "net/http"
  8. "projectName/pkg/log"
  9. "strings"
  10. "sync"
  11. )
  12. // IPAllowlist 保存允许访问的IP列表
  13. type IPAllowlist struct {
  14. allowedIPs map[string]bool
  15. enabled bool
  16. mu sync.RWMutex
  17. logger *log.Logger
  18. }
  19. func getClientIP(r *http.Request) string {
  20. ip := r.Header.Get("X-Forwarded-For")
  21. if ip != "" {
  22. parts := strings.Split(ip, ",")
  23. return strings.TrimSpace(parts[0])
  24. }
  25. ip, _, _ = net.SplitHostPort(r.RemoteAddr)
  26. return ip
  27. }
  28. // NewIPAllowlist 创建一个新的IP白名单实例
  29. func NewIPAllowlist(conf *viper.Viper, logger *log.Logger) *IPAllowlist {
  30. allowedIPs := conf.GetStringSlice("ip_allowlist.ips")
  31. enabled := conf.GetBool("ip_allowlist.enabled")
  32. allowlist := &IPAllowlist{
  33. allowedIPs: make(map[string]bool),
  34. enabled: enabled,
  35. logger: logger,
  36. }
  37. // 将配置文件中的IP添加到白名单
  38. for _, ip := range allowedIPs {
  39. allowlist.allowedIPs[ip] = true
  40. }
  41. return allowlist
  42. }
  43. // IPAllowlistMiddleware 创建一个IP白名单中间件
  44. func (a *IPAllowlist) IPAllowlistMiddleware() gin.HandlerFunc {
  45. return func(c *gin.Context) {
  46. // 如果白名单未启用,直接放行
  47. if !a.enabled {
  48. c.Next()
  49. return
  50. }
  51. clientIP := getClientIP(c.Request)
  52. a.mu.RLock()
  53. _, allowed := a.allowedIPs[clientIP]
  54. a.mu.RUnlock()
  55. if allowed {
  56. c.Next()
  57. } else {
  58. a.logger.WithContext(c).Warn("拒绝未授权的IP访问: %s", zap.String("ip", clientIP))
  59. c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
  60. "code": 403,
  61. "msg": "IP访问受限,您的IP没有权限访问此资源",
  62. })
  63. }
  64. }
  65. }
  66. // AddIP 添加IP到白名单
  67. func (a *IPAllowlist) AddIP(ip string) {
  68. a.mu.Lock()
  69. defer a.mu.Unlock()
  70. a.allowedIPs[ip] = true
  71. }
  72. // RemoveIP 从白名单中移除IP
  73. func (a *IPAllowlist) RemoveIP(ip string) {
  74. a.mu.Lock()
  75. defer a.mu.Unlock()
  76. delete(a.allowedIPs, ip)
  77. }
  78. // IsIPAllowed 检查IP是否在白名单中
  79. func (a *IPAllowlist) IsIPAllowed(ip string) bool {
  80. a.mu.RLock()
  81. defer a.mu.RUnlock()
  82. return a.allowedIPs[ip]
  83. }
  84. // EnableAllowlist 启用IP白名单
  85. func (a *IPAllowlist) EnableAllowlist() {
  86. a.mu.Lock()
  87. defer a.mu.Unlock()
  88. a.enabled = true
  89. }
  90. // DisableAllowlist 禁用IP白名单
  91. func (a *IPAllowlist) DisableAllowlist() {
  92. a.mu.Lock()
  93. defer a.mu.Unlock()
  94. a.enabled = false
  95. }
  96. // IsEnabled 检查白名单是否启用
  97. func (a *IPAllowlist) IsEnabled() bool {
  98. a.mu.RLock()
  99. defer a.mu.RUnlock()
  100. return a.enabled
  101. }