package service import ( "context" "crypto/tls" "crypto/x509" "encoding/json" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq" amqp "github.com/rabbitmq/amqp091-go" "go.uber.org/zap" "golang.org/x/net/idna" "golang.org/x/net/publicsuffix" "golang.org/x/sync/errgroup" "net" "slices" "strconv" "strings" ) type WafFormatterService interface { Require(ctx context.Context, req v1.GlobalRequire) (RequireResponse, error) validateWafPortCount(ctx context.Context, hostId int) error validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error ConvertToWildcardDomain(ctx context.Context, domain string) (string, error) AppendWafIp(ctx context.Context, req []string, returnSourceIp string) ([]v1.IpInfo, error) WashIps(ctx context.Context, req []string) ([]string, error) PublishIpWhitelistTask(ips []string, action string, returnSourceIp string, color string) PublishDomainWhitelistTask(domain, ip, action string) findIpDifferences(oldIps, newIps []string) ([]string, []string) WashDeleteWafIp(ctx context.Context, backendList []string) ([]string, error) WashEditWafIp(ctx context.Context, newBackendList []string, oldBackendList []string) ([]string, []string, error) //cdn添加网站 AddOrigin(ctx context.Context, req v1.WebJson) (int64, error) // 获取ip数量等于1的源站过白ip WashDelIps(ctx context.Context, ips []string) ([]string, error) // 判断域名是否是IDN,如果是,转换为 Punycode ConvertToPunycodeIfIDN(ctx context.Context, domain string) (isIDN bool, punycodeDomain string, err error) // 解析证书 ParseCert(ctx context.Context, httpsCert string, httpKey string) (serverName string, commonName []string, DNSNames []string, before int64, after int64, isSelfSigned bool, err error) AddSSLPolicy(ctx context.Context, req v1.SSL) (sslPolicyId int64, sslCertId int64, err error) EditSSL(ctx context.Context, req v1.SSL) error } func NewWafFormatterService( service *Service, globalRep repository.GlobalLimitRepository, hostRep repository.HostRepository, required RequiredService, parser ParserService, tcpforwardingRep repository.TcpforwardingRepository, udpForWardingRep repository.UdpForWardingRepository, webForwardingRep repository.WebForwardingRepository, mq *rabbitmq.RabbitMQ, host HostService, gatewayGroupRep repository.GatewayGroupRepository, gatewayGroupIpRep repository.GateWayGroupIpRepository, cdn CdnService, ) WafFormatterService { return &wafFormatterService{ Service: service, globalRep: globalRep, hostRep: hostRep, required: required, parser: parser, tcpforwardingRep: tcpforwardingRep, udpForWardingRep: udpForWardingRep, webForwardingRep: webForwardingRep, host: host, mq: mq, gatewayGroupRep: gatewayGroupRep, gatewayGroupIpRep: gatewayGroupIpRep, cdn: cdn, } } type wafFormatterService struct { *Service globalRep repository.GlobalLimitRepository hostRep repository.HostRepository required RequiredService parser ParserService tcpforwardingRep repository.TcpforwardingRepository udpForWardingRep repository.UdpForWardingRepository webForwardingRep repository.WebForwardingRepository host HostService mq *rabbitmq.RabbitMQ gatewayGroupRep repository.GatewayGroupRepository gatewayGroupIpRep repository.GateWayGroupIpRepository cdn CdnService } type RequireResponse struct { model.GlobalLimit `json:"globalLimit" form:"globalLimit"` GatewayIps []string `json:"ips" form:"ips"` Tag string `json:"tag" form:"tag"` SslPolicyId int64 `json:"sslPolicyId" form:"sslPolicyId"` } func (s *wafFormatterService) Require(ctx context.Context, req v1.GlobalRequire) (RequireResponse, error) { var res RequireResponse // 获取全局配置信息 globalLimit, err := s.globalRep.GetGlobalLimitByHostId(ctx, int64(req.HostId)) if err != nil { return RequireResponse{}, err } if globalLimit != nil { res.GlobalLimit = *globalLimit } // 获取主机名 domain, err := s.hostRep.GetDomainById(ctx, req.HostId) if err != nil { return RequireResponse{}, err } res.Tag = strconv.Itoa(req.Uid) + "_" + strconv.Itoa(req.HostId) + "_" + domain + "_" + req.Comment res.GatewayIps, err = s.gatewayGroupIpRep.GetGateWayGroupAllIpByGatewayGroupId(ctx, res.GatewayGroupId) if err != nil { return RequireResponse{}, err } return res, nil } func (s *wafFormatterService) validateWafPortCount(ctx context.Context, hostId int) error { congfig, err := s.host.GetGlobalLimitConfig(ctx, hostId) if err != nil { return err } tcpCount, err := s.tcpforwardingRep.GetTcpForwardingPortCountByHostId(ctx, hostId) if err != nil { return err } udpCount, err := s.udpForWardingRep.GetUdpForwardingPortCountByHostId(ctx, hostId) if err != nil { return err } webCount, err := s.webForwardingRep.GetWebForwardingPortCountByHostId(ctx, hostId) if err != nil { return err } if int64(congfig.PortCount) > tcpCount+udpCount+webCount { return nil } return fmt.Errorf("端口数量超出套餐限制,已配置%d个端口,套餐限制为%d个端口", tcpCount+udpCount+webCount, congfig.PortCount) } func (s *wafFormatterService) validateWafDomainCount(ctx context.Context, req v1.GlobalRequire) error { congfig, err := s.host.GetGlobalLimitConfig(ctx, req.HostId) if err != nil { return err } domainCount, domainSlice, err := s.webForwardingRep.GetWebForwardingDomainCountByHostId(ctx, req.HostId) if err != nil { return err } if req.Domain != "" { if !slices.Contains(domainSlice, req.Domain) { domainCount += 1 if domainCount > int64(congfig.DomainCount) { return fmt.Errorf("域名数量已达到上限,已配置%d个域名,套餐限制为%d个域名", domainCount, congfig.DomainCount) } } } return nil } func (s *wafFormatterService) ConvertToWildcardDomain(ctx context.Context, domain string) (string, error) { // 1. 使用 EffectiveTLDPlusOne 获取可注册域名部分。 // 例如,对于 "www.google.com",这将返回 "google.com"。 // 对于 "a.b.c.tokyo.jp",这将返回 "c.tokyo.jp"。 if domain == "" { return "", nil } registrableDomain, err := publicsuffix.EffectiveTLDPlusOne(domain) if err != nil { s.logger.Error("无效的域名", zap.String("domain", domain), zap.Error(err)) // 如果域名无效(如 IP 地址、localhost),则返回错误。 return "", nil } // 2. 比较原始域名和可注册域名。 // 如果它们不相等,说明原始域名包含子域名。 if domain != registrableDomain { // 3. 如果存在子域名,则用 "*." 加上可注册域名来构造通配符域名。 return registrableDomain, nil } // 4. 如果原始域名和可注册域名相同(例如,输入就是 "google.com"), // 则说明没有子域名可替换,直接返回原始域名。 return domain, nil } func (s *wafFormatterService) AppendWafIp(ctx context.Context, req []string, returnSourceIp string) ([]v1.IpInfo, error) { var ips []v1.IpInfo for _, v := range req { ips = append(ips, v1.IpInfo{ FType: "0", FStartIp: v, FEndIp: v, FRemark: "宁波高防IP过白", FServerIp: returnSourceIp, }) } return ips, nil } func (s *wafFormatterService) AppendWafIpByRemovePort(ctx context.Context, req []string) ([]v1.IpInfo, error) { var ips []v1.IpInfo for _, v := range req { ip, _, err := net.SplitHostPort(v) if err != nil { return nil, err } ips = append(ips, v1.IpInfo{ FType: "0", FStartIp: ip, FEndIp: ip, FRemark: "宁波高防IP过白", FServerIp: "", }) } return ips, nil } func (s *wafFormatterService) WashIps(ctx context.Context, req []string) ([]string, error) { var res []string for _, v := range req { res = append(res, v) } return res, nil } // publishDomainWhitelistTask is a helper function to publish domain whitelist tasks to RabbitMQ. // It can handle different actions like "add" or "del". func (s *wafFormatterService) PublishDomainWhitelistTask(domain, ip, action string) { // Define message payload, including the action type domainTaskPayload struct { Domain string `json:"domain"` Ip string `json:"ip"` Action string `json:"action"` } payload := domainTaskPayload{ Domain: domain, Ip: ip, Action: action, } // Serialize the message msgBody, err := json.Marshal(payload) if err != nil { s.logger.Error("Failed to serialize domain whitelist task message", zap.Error(err), zap.String("domain", domain), zap.String("ip", ip), zap.String("action", action)) return } // Get task configuration taskCfg, ok := s.mq.GetTaskConfig("domain_whitelist") if !ok { s.logger.Error("Failed to get 'domain_whitelist' task configuration") return } // Construct the routing key dynamically based on the action routingKey := fmt.Sprintf("whitelist.domain.%s", action) // Construct the amqp.Publishing message publishingMsg := amqp.Publishing{ ContentType: "application/json", Body: msgBody, DeliveryMode: amqp.Persistent, // Persistent message } // Publish the message err = s.mq.PublishWithCh(taskCfg.Exchange, routingKey, publishingMsg) if err != nil { s.logger.Error("发布 域名 白名单任务到 MQ 失败", zap.Error(err), zap.String("domain", domain), zap.String("action", action)) } else { s.logger.Info("成功将 域名 白名单任务发布到 MQ", zap.String("domain", domain), zap.String("action", action)) } } func (s *wafFormatterService) PublishIpWhitelistTask(ips []string, action string, returnSourceIp string, color string) { // Define message payload, including the action type ipTaskPayload struct { Ips []string `json:"ips"` Action string `json:"action"` ReturnSourceIp string `json:"return_source_ip"` Color string `json:"color"` } payload := ipTaskPayload{ Ips: ips, Action: action, ReturnSourceIp: returnSourceIp, Color: color, } // Serialize the message msgBody, err := json.Marshal(payload) if err != nil { s.logger.Error("序列化 IP 白名单任务消息失败", zap.Error(err), zap.Any("IPs", ips), zap.String("action", action), zap.String("color", color)) return } // Get task configuration taskCfg, ok := s.mq.GetTaskConfig("ip_white") if !ok { s.logger.Error("无法获取“ip_white”任务配置") return } // Construct the routing key dynamically based on the action routingKey := fmt.Sprintf("task.ip_white.%s", action) // Construct the amqp.Publishing message publishingMsg := amqp.Publishing{ ContentType: "application/json", Body: msgBody, DeliveryMode: amqp.Persistent, // Persistent message } // Publish the message err = s.mq.PublishWithCh(taskCfg.Exchange, routingKey, publishingMsg) if err != nil { s.logger.Error("发布 IP 白名单任务到 MQ 失败", zap.Error(err), zap.String("action", action), zap.String("color", color)) } else { s.logger.Info("成功将 IP 白名单任务发布到 MQ", zap.String("action", action), zap.String("color", color)) } } func (s *wafFormatterService) findIpDifferences(oldIps, newIps []string) ([]string, []string) { // 使用 map 实现 set,用于快速查找 oldIpsSet := make(map[string]struct{}, len(oldIps)) for _, ip := range oldIps { oldIpsSet[ip] = struct{}{} } newIpsSet := make(map[string]struct{}, len(newIps)) for _, ip := range newIps { newIpsSet[ip] = struct{}{} } var addedIps []string // 查找新增的 IP:存在于 newIpsSet 但不存在于 oldIpsSet for ip := range newIpsSet { if _, found := oldIpsSet[ip]; !found { addedIps = append(addedIps, ip) } } var removedIps []string // 查找移除的 IP:存在于 oldIpsSet 但不存在于 newIpsSet for ip := range oldIpsSet { if _, found := newIpsSet[ip]; !found { removedIps = append(removedIps, ip) } } return addedIps, removedIps } func (s *wafFormatterService) WashDeleteWafIp(ctx context.Context, backendList []string) ([]string, error) { var res []string for _, v := range backendList { ip, _, err := net.SplitHostPort(v) if err != nil { return nil, err } res = append(res, ip) } return res, nil } func (s *wafFormatterService) WashEditWafIp(ctx context.Context, newBackendList []string, oldBackendList []string) ([]string, []string, error) { var oldIps []string var newIps []string for _, v := range oldBackendList { ip, _, err := net.SplitHostPort(v) if err != nil { return nil, nil, err } oldIps = append(oldIps, ip) } if newBackendList != nil { for _, v := range newBackendList { ip, _, err := net.SplitHostPort(v) if err != nil { return nil, nil, err } newIps = append(newIps, ip) } } addedIps, removedIps := s.findIpDifferences(oldIps, newIps) return addedIps, removedIps, nil } func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (int64, error) { ip, port, err := net.SplitHostPort(req.BackendList) if err != nil { return 0, fmt.Errorf("无效的后端地址: %s", err) } addr := v1.Addr{ Protocol: req.ApiType, Host: ip, Port: port, } id, err := s.cdn.CreateOrigin(ctx, v1.Origin{ Addr: addr, Weight: 10, Description: req.Comment, Host: req.Host, IsOn: true, TlsSecurityVerifyMode: "auto", }) if err != nil { return 0, err } return id, nil } // 获取ip数量等于1的源站过白ip func (s *wafFormatterService) WashDelIps(ctx context.Context, ips []string) ([]string, error) { var udpIpCounts, tcpIpCounts, webIpCounts []v1.IpCountResult g, gCtx := errgroup.WithContext(ctx) // 1. 查询 IP 的数量 g.Go(func() error { var err error udpIpCounts, err = s.udpForWardingRep.GetIpCountByIp(gCtx, ips) if err != nil { return fmt.Errorf("in udp repository: %w", err) } return nil }) g.Go(func() error { var err error tcpIpCounts, err = s.tcpforwardingRep.GetIpCountByIp(gCtx, ips) if err != nil { return fmt.Errorf("in tcp repository: %w", err) } return nil }) g.Go(func() error { var err error webIpCounts, err = s.webForwardingRep.GetIpCountByIp(gCtx, ips) if err != nil { return fmt.Errorf("in web repository: %w", err) } return nil }) if err := g.Wait(); err != nil { return nil, err } // 2. 汇总所有计数结果 totalCountMap := make(map[string]int) // 将多个 for 循环合并到一个函数中,可以显得更整洁(可选) accumulateCounts := func(counts []v1.IpCountResult) { for _, result := range counts { totalCountMap[result.Ip] += result.Count } } accumulateCounts(udpIpCounts) accumulateCounts(tcpIpCounts) accumulateCounts(webIpCounts) // 3. 筛选出总引用数小于 2 的 IP var ipsToDelist []string for _, ip := range ips { if totalCountMap[ip] < 2 { ipsToDelist = append(ipsToDelist, ip) } } return ipsToDelist, nil } // 判断域名是否为 中文域名,如果是,转换为 Punycode func (s *wafFormatterService) ConvertToPunycodeIfIDN(ctx context.Context, domain string) (isIDN bool, punycodeDomain string, err error) { // 使用 idna.ToASCII 将域名转换为 Punycode。 // 这个函数同时会根据 IDNA 规范验证域名的合法性。 punycodeDomain, err = idna.ToASCII(domain) if err != nil { // 如果转换出错,说明域名格式不符合 IDNA 标准。 return false, "", fmt.Errorf("域名 '%s' 格式无效: %v", domain, err) } // 判断是否为 IDN 的关键: // 比较转换后的 Punycode 域名和原始域名(忽略大小写)。 // 如果不相等,说明原始域名包含非 ASCII 字符,即为 IDN。 isIDN = !strings.EqualFold(domain, punycodeDomain) return isIDN, punycodeDomain, nil } func (s *wafFormatterService) ParseCert(ctx context.Context, httpsCert string, httpKey string) (serverName string, commonName []string, DNSNames []string, before int64, after int64, isSelfSigned bool, err error) { cert, err := tls.X509KeyPair([]byte(httpsCert), []byte(httpKey)) if err != nil { return "", nil, nil, 0, 0, false, fmt.Errorf("无法从字符串加载密钥对: %v", err) } if len(cert.Certificate) == 0 { return "", nil, nil, 0, 0, false, fmt.Errorf("提供的证书数据中没有找到证书。") } // 解析第一个证书(通常是叶子证书) x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { return "", nil, nil, 0, 0, false, fmt.Errorf("无法解析证书: %v", err) } // 1. 获取 Common Name (通用名称) // Common Name 位于 Subject 字段内. [1] serverName = x509Cert.Subject.CommonName // 2. 获取 DNS Names (备用主题名称中的DNS条目) // DNS Names 直接是证书结构体的一个字段. [1] DNSNames = x509Cert.DNSNames // 检查证书是否为自签名 // 判断条件:颁发者(Issuer)和主题(Subject)相同,并且证书的签名可以由其自身的公钥验证 if err := x509Cert.CheckSignatureFrom(x509Cert); err == nil { isSelfSigned = true } // 将CommonName放入一个切片,以匹配[]string的类型要求 var commonNames []string if x509Cert.Subject.CommonName != "" { commonNames = []string{x509Cert.Subject.CommonName} } return serverName, commonNames, DNSNames, x509Cert.NotBefore.Unix(), x509Cert.NotAfter.Unix(), isSelfSigned, nil } // HandleSSLPolicy 负责处理SSL证书的完整生命周期:解析、上传到CDN并创建或更新SSL策略。 // 它封装了与CDN服务交互的复杂性,并返回一个可用的SSL策略ID。 func (s *wafFormatterService) AddSSLPolicy(ctx context.Context, req v1.SSL) (sslPolicyId int64, sslCertId int64, err error) { // 1. 解析证书文件,提取元数据 serverName, commonNames, DNSNames, before, after, isSelfSigned, err := s.ParseCert(ctx, req.CertData, req.KeyData) if err != nil { return 0, 0, fmt.Errorf("解析证书失败: %w", err) } // 2. 将证书添加到CDN提供商 // 这是获取可以在策略中引用的 `sslCertId` 的前提 newSslCertId, err := s.cdn.AddSSLCert(ctx, v1.SSlCert{ IsOn: true, UserId: int64(req.CdnUserId), Name: req.Domain, // 使用域名作为证书名称 ServerName: serverName, Description: req.Description, CertData: []byte(req.CertData), KeyData: []byte(req.KeyData), TimeBeginAt: before, TimeEndAt: after, DnsNames: DNSNames, CommonNames: commonNames, IsSelfSigned: isSelfSigned, }) if err != nil { return 0, 0, fmt.Errorf("添加SSL证书到CDN失败: %w", err) } // 3. 基于获取到的证书ID,创建SSL策略 if newSslCertId != 0 { // 构造策略中引用的证书列表 type sslCerts struct { IsOn bool `json:"isOn" form:"isOn"` CertId int64 `json:"certId" form:"certId"` } var sslCertsSlice []sslCerts sslCertsSlice = append(sslCertsSlice, sslCerts{ IsOn: true, CertId: newSslCertId, }) sslCertsJson, err := json.Marshal(sslCertsSlice) if err != nil { return 0, 0, fmt.Errorf("序列化SSL证书引用失败: %w", err) } // 调用CDN服务创建策略 newSslPolicyId, err := s.cdn.AddSSLPolicy(ctx, v1.AddSSLPolicy{ Http2Enabled: true, SslCertsJSON: sslCertsJson, MinVersion: "TLS 1.1", // 可根据安全要求调整 }) if err != nil { // 如果策略创建失败,需要考虑回滚或记录错误,这里直接返回错误 return 0, 0, fmt.Errorf("通过CDN添加SSL策略失败: %w", err) } return newSslPolicyId, newSslCertId, nil } return 0, 0, fmt.Errorf("未能创建有效的SSL证书ID,无法继续创建策略") } func (s *wafFormatterService) EditSSL(ctx context.Context, req v1.SSL) error { oldData, err := s.webForwardingRep.GetWebForwarding(ctx, req.WebId) if err != nil { return err } if oldData.HttpsKey != req.KeyData || oldData.HttpsCert != req.CertData { serverName, commonNames, DNSNames, before, after, isSelfSigned, err := s.ParseCert(ctx, req.CertData, req.KeyData) if err != nil { return fmt.Errorf("解析证书失败: %w", err) } sslCert, err := s.webForwardingRep.GetSslCertId(ctx, oldData.SslCertId) if err != nil { return fmt.Errorf("获取SSL证书失败: %w", err) } for _, v := range sslCert { err = s.cdn.EditSSLCert(ctx, v1.SSlCert{ SslCertId: v.CertId, IsOn: v.IsOn, UserId: int64(req.CdnUserId), Name: req.Domain, // 使用域名作为证书名称 ServerName: serverName, Description: req.Description, CertData: []byte(req.CertData), KeyData: []byte(req.KeyData), TimeBeginAt: before, TimeEndAt: after, DnsNames: DNSNames, CommonNames: commonNames, IsSelfSigned: isSelfSigned, }) if err != nil { return fmt.Errorf("更新SSL证书失败: %w", err) } } return nil } return nil }