whitelist.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package job
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  7. "github.com/google/uuid"
  8. "github.com/rabbitmq/amqp091-go"
  9. "go.uber.org/zap"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. )
  14. // taskHandler 定义了处理单个消息的函数签名
  15. // 它负责业务逻辑的执行,并返回一个 error 来告知调用者处理是否成功。
  16. type taskHandler func(ctx context.Context, logger *zap.Logger, delivery amqp091.Delivery) error
  17. // WhitelistJob 定义了处理白名单相关任务的接口
  18. type WhitelistJob interface {
  19. // DomainConsumer 启动消费者,处理域名白名单任务
  20. DomainConsumer(ctx context.Context)
  21. // IpConsumer 启动消费者,处理 IP 白名单任务
  22. IpConsumer(ctx context.Context)
  23. }
  24. // NewWhitelistJob 创建一个新的 WhitelistJob
  25. func NewWhitelistJob(job *Job,
  26. aoDunService service.AoDunService,
  27. wafForMatter service.WafFormatterService,
  28. ) WhitelistJob {
  29. return &whitelistJob{
  30. Job: job,
  31. aoDunService: aoDunService,
  32. wafForMatter: wafForMatter,
  33. }
  34. }
  35. type whitelistJob struct {
  36. *Job
  37. aoDunService service.AoDunService
  38. wafForMatter service.WafFormatterService
  39. }
  40. // DomainConsumer 启动域名白名单消费者
  41. func (j *whitelistJob) DomainConsumer(ctx context.Context) {
  42. j.consume(ctx, "domain_whitelist", "domain_whitelist_consumer", j.handleDomainMessage)
  43. }
  44. // IpConsumer 启动IP白名单消费者
  45. func (j *whitelistJob) IpConsumer(ctx context.Context) {
  46. j.consume(ctx, "ip_white", "ip_white_consumer", j.handleIpMessage)
  47. }
  48. // consume 是一个通用的 RabbitMQ 消费者方法,封装了重复的逻辑
  49. func (j *whitelistJob) consume(ctx context.Context, taskName, consumerName string, handler taskHandler) {
  50. taskCfg, ok := j.Rabbitmq.GetTaskConfig(taskName)
  51. if !ok {
  52. j.logger.Error(fmt.Sprintf("未找到任务 '%s' 的配置", taskName))
  53. return
  54. }
  55. j.logger.Info("正在启动消费者...",
  56. zap.String("task", taskName),
  57. zap.String("queue", taskCfg.Queue),
  58. zap.String("consumer", consumerName),
  59. )
  60. msgs, err := j.Rabbitmq.Consume(taskCfg.Queue, consumerName, taskCfg.PrefetchCount)
  61. if err != nil {
  62. j.logger.Error("启动消费者失败", zap.String("task", taskName), zap.Error(err))
  63. return
  64. }
  65. for {
  66. select {
  67. case <-ctx.Done():
  68. j.logger.Info("消费者正在关闭...", zap.String("task", taskName))
  69. return
  70. case d, ok := <-msgs:
  71. if !ok {
  72. j.logger.Warn("消息通道已关闭,消费者退出。", zap.String("task", taskName))
  73. return
  74. }
  75. // 尝试从消息头获取 trace_id,如果不存在则生成一个新的
  76. traceID, ok := d.Headers["trace_id"].(string)
  77. if !ok || traceID == "" {
  78. traceID = uuid.New().String()
  79. }
  80. // 创建一个带有 trace_id 的 logger,用于本次任务的所有日志记录
  81. scopedLogger := j.logger.With(zap.String("trace_id", traceID))
  82. // 创建一个带有 trace_id 的 context,用于传递给下游服务
  83. ctxWithTrace := context.WithValue(ctx, "trace_id", traceID)
  84. // 调用具体的业务处理器
  85. processingErr := handler(ctxWithTrace, scopedLogger, d)
  86. // 根据处理结果统一进行 Ack/Nack
  87. if processingErr != nil {
  88. // 业务失败,拒绝消息并不重新入队
  89. if err := d.Nack(false, false); err != nil {
  90. scopedLogger.Error("消息 Nack 失败", zap.Error(err), zap.String("task", taskName))
  91. }
  92. } else {
  93. // 业务处理成功,手动发送确认
  94. if err := d.Ack(false); err != nil {
  95. scopedLogger.Error("消息 Ack 失败", zap.Error(err), zap.String("task", taskName))
  96. }
  97. }
  98. }
  99. }
  100. }
  101. // handleDomainMessage 是域名白名单任务的具体处理器
  102. func (j *whitelistJob) handleDomainMessage(ctx context.Context, logger *zap.Logger, d amqp091.Delivery) error {
  103. type domainTaskPayload struct {
  104. Domain string `json:"domain"`
  105. Ip string `json:"ip"`
  106. Action string `json:"action"` // "add" or "del"
  107. }
  108. var payload domainTaskPayload
  109. if err := json.Unmarshal(d.Body, &payload); err != nil {
  110. logger.Error("解析域名白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body))
  111. return nil // 返回 nil 以避免消息重入队列,因为这是一个格式错误
  112. }
  113. logger.Info("收到域名白名单任务",
  114. zap.String("action", payload.Action),
  115. zap.String("domain", payload.Domain),
  116. zap.String("ip", payload.Ip),
  117. zap.String("routing_key", d.RoutingKey),
  118. )
  119. var processingErr error
  120. switch payload.Action {
  121. case "add", "del":
  122. processingErr = j.aoDunService.DomainWhiteList(ctx, payload.Domain, payload.Ip, payload.Action)
  123. default:
  124. processingErr = fmt.Errorf("unknown action: %s", payload.Action)
  125. logger.Warn("在域名白名单任务中收到未知操作", zap.String("action", payload.Action), zap.String("domain", payload.Domain))
  126. }
  127. if processingErr != nil {
  128. logger.Error("处理域名白名单任务失败", zap.Error(processingErr), zap.String("domain", payload.Domain))
  129. } else {
  130. logger.Info("已成功处理域名白名单任务", zap.String("action", payload.Action), zap.String("domain", payload.Domain))
  131. }
  132. return processingErr
  133. }
  134. // handleIpMessage 是 IP 白名单任务的具体处理器
  135. func (j *whitelistJob) handleIpMessage(ctx context.Context, logger *zap.Logger, d amqp091.Delivery) error {
  136. type ipTaskPayload struct {
  137. Ips []string `json:"ips"`
  138. Action string `json:"action"`
  139. Color string `json:"color"`
  140. ReturnSourceIp string `json:"return_source_ip"`
  141. }
  142. var payload ipTaskPayload
  143. if err := json.Unmarshal(d.Body, &payload); err != nil {
  144. logger.Error("解析IP白名单消息失败", zap.Error(err), zap.ByteString("body", d.Body), zap.String("routing_key", d.RoutingKey))
  145. return nil // 消息格式错误,不应重试
  146. }
  147. logger.Info("收到IP白名单任务",
  148. zap.String("action", payload.Action),
  149. zap.Any("ips", payload.Ips),
  150. zap.String("color", payload.Color),
  151. zap.String("routing_key", d.RoutingKey),
  152. )
  153. var processingErr error
  154. switch payload.Action {
  155. case "add":
  156. ips, err := j.wafForMatter.AppendWafIp(ctx, payload.Ips, payload.ReturnSourceIp)
  157. if err != nil {
  158. // 如果附加IP失败,记录错误并终止
  159. processingErr = fmt.Errorf("为WAF准备IP列表失败: %w", err)
  160. } else {
  161. var wg sync.WaitGroup
  162. errChan := make(chan error, 2)
  163. wg.Add(2)
  164. go func() {
  165. defer wg.Done()
  166. if err := j.aoDunService.AddWhiteStaticList(ctx, false, ips, payload.Color); err != nil {
  167. errChan <- err
  168. }
  169. }()
  170. go func() {
  171. defer wg.Done()
  172. if err := j.aoDunService.AddWhiteStaticList(ctx, true, ips,payload.Color); err != nil {
  173. errChan <- err
  174. }
  175. }()
  176. wg.Wait()
  177. close(errChan)
  178. var errs []string
  179. for err := range errChan {
  180. errs = append(errs, err.Error())
  181. }
  182. if len(errs) > 0 {
  183. processingErr = fmt.Errorf("添加IP到白名单时发生错误: %s", strings.Join(errs, "; "))
  184. }
  185. }
  186. case "del":
  187. var wg sync.WaitGroup
  188. errChan := make(chan error, len(payload.Ips)*2)
  189. deleteFromWall := func(isSmall bool, ip string) {
  190. defer wg.Done()
  191. id, err := j.aoDunService.GetWhiteStaticList(ctx, isSmall, ip, payload.ReturnSourceIp,payload.Color)
  192. if err != nil {
  193. errChan <- fmt.Errorf("获取IP '%s' (isSmall: %t) ID失败: %w , color: %s", ip, isSmall, err, payload.Color)
  194. return
  195. }
  196. if err := j.aoDunService.DelWhiteStaticList(ctx, isSmall, strconv.Itoa(id), payload.Color); err != nil {
  197. errChan <- fmt.Errorf("删除IP '%s' (isSmall: %t, id: %d) 失败: %w , color: %s", ip, isSmall, id, err , payload.Color)
  198. }
  199. }
  200. for _, ip := range payload.Ips {
  201. wg.Add(2)
  202. go deleteFromWall(false, ip)
  203. go deleteFromWall(true, ip)
  204. }
  205. wg.Wait()
  206. close(errChan)
  207. var errs []string
  208. for err := range errChan {
  209. logger.Error("删除IP白名单过程中发生错误", zap.Error(err), zap.String("color", payload.Color))
  210. errs = append(errs, err.Error())
  211. }
  212. if len(errs) > 0 {
  213. processingErr = fmt.Errorf("删除IP任务中发生错误: %s", strings.Join(errs, "; ") + ", color: " + payload.Color)
  214. }
  215. default:
  216. processingErr = fmt.Errorf("unknown action: %s", payload.Action)
  217. logger.Warn("在IP白名单任务中收到未知操作", zap.String("action", payload.Action), zap.Any("ips", payload.Ips), zap.String("color", payload.Color))
  218. }
  219. if processingErr != nil {
  220. logger.Error("处理IP白名单任务失败", zap.Error(processingErr), zap.Any("ips", payload.Ips), zap.String("color", payload.Color))
  221. } else {
  222. logger.Info("已成功处理IP白名单任务", zap.String("action", payload.Action), zap.Any("ips", payload.Ips), zap.String("color", payload.Color))
  223. }
  224. return processingErr
  225. }