ip_whitelist.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package middleware
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  5. "github.com/spf13/viper"
  6. "go.uber.org/zap"
  7. "net/http"
  8. "sync"
  9. )
  10. // IPAllowlist 保存允许访问的IP列表
  11. type IPAllowlist struct {
  12. allowedIPs map[string]bool
  13. enabled bool
  14. mu sync.RWMutex
  15. logger *log.Logger
  16. }
  17. // NewIPAllowlist 创建一个新的IP白名单实例
  18. func NewIPAllowlist(conf *viper.Viper, logger *log.Logger) *IPAllowlist {
  19. allowedIPs := conf.GetStringSlice("ip_allowlist.ips")
  20. enabled := conf.GetBool("ip_allowlist.enabled")
  21. allowlist := &IPAllowlist{
  22. allowedIPs: make(map[string]bool),
  23. enabled: enabled,
  24. logger: logger,
  25. }
  26. // 将配置文件中的IP添加到白名单
  27. for _, ip := range allowedIPs {
  28. allowlist.allowedIPs[ip] = true
  29. }
  30. return allowlist
  31. }
  32. // IPAllowlistMiddleware 创建一个IP白名单中间件
  33. func (a *IPAllowlist) IPAllowlistMiddleware() gin.HandlerFunc {
  34. return func(c *gin.Context) {
  35. // 如果白名单未启用,直接放行
  36. if !a.enabled {
  37. c.Next()
  38. return
  39. }
  40. clientIP := getClientIP(c.Request)
  41. a.mu.RLock()
  42. _, allowed := a.allowedIPs[clientIP]
  43. a.mu.RUnlock()
  44. if allowed {
  45. c.Next()
  46. } else {
  47. a.logger.WithContext(c).Warn("拒绝未授权的IP访问: %s", zap.String("ip", clientIP))
  48. c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
  49. "code": 403,
  50. "msg": "IP访问受限,您的IP没有权限访问此资源",
  51. })
  52. }
  53. }
  54. }
  55. // AddIP 添加IP到白名单
  56. func (a *IPAllowlist) AddIP(ip string) {
  57. a.mu.Lock()
  58. defer a.mu.Unlock()
  59. a.allowedIPs[ip] = true
  60. }
  61. // RemoveIP 从白名单中移除IP
  62. func (a *IPAllowlist) RemoveIP(ip string) {
  63. a.mu.Lock()
  64. defer a.mu.Unlock()
  65. delete(a.allowedIPs, ip)
  66. }
  67. // IsIPAllowed 检查IP是否在白名单中
  68. func (a *IPAllowlist) IsIPAllowed(ip string) bool {
  69. a.mu.RLock()
  70. defer a.mu.RUnlock()
  71. return a.allowedIPs[ip]
  72. }
  73. // EnableAllowlist 启用IP白名单
  74. func (a *IPAllowlist) EnableAllowlist() {
  75. a.mu.Lock()
  76. defer a.mu.Unlock()
  77. a.enabled = true
  78. }
  79. // DisableAllowlist 禁用IP白名单
  80. func (a *IPAllowlist) DisableAllowlist() {
  81. a.mu.Lock()
  82. defer a.mu.Unlock()
  83. a.enabled = false
  84. }
  85. // IsEnabled 检查白名单是否启用
  86. func (a *IPAllowlist) IsEnabled() bool {
  87. a.mu.RLock()
  88. defer a.mu.RUnlock()
  89. return a.enabled
  90. }