package web import ( "context" "encoding/json" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn" waf2 "github.com/go-nunu/nunu-layout-advanced/internal/service/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/waf/common" ) // AidedWebService Web转发辅助服务接口 type AidedWebServiceInterface interface { // 数据准备辅助函数 PrepareWafData(ctx context.Context, req *v1.WebForwardingRequest) (common.RequireResponse, v1.Website, error) BuildProxyConfig(ctx context.Context, req *v1.WebForwardingRequest, gatewayIps []string) (v1.TypeJSON, error) BulidFormData(ctx context.Context, formData v1.Website) (v1.WebsiteSend, error) // 协议判断辅助函数 GetProtocolType(isHttps int) string IsHttpsProtocol(isHttps int) bool // 模型构建辅助函数 BuildWebForwardingModel(req *v1.WebForwardingDataRequest, ruleId int, require common.RequireResponse) *model.WebForwarding BuildWebRuleModel(reqData *v1.WebForwardingDataRequest, require common.RequireResponse, localDbId int, cdnOriginIds map[string]int64) *model.WebForwardingRule } func NewAidedWebService( service *service.Service, webForwardingRepository waf.WebForwardingRepository, wafformatter common.WafFormatterService, sslCert flexCdn.SslCertService, cdn flexCdn.CdnService, proxy flexCdn.ProxyService, websocket flexCdn.WebsocketService, cc waf2.CcService, ccIpList waf2.CcIpListService, gatewayIp common.GatewayipService, globalLimitRep waf.GlobalLimitRepository, ) *AidedWebService { return &AidedWebService{ Service: service, webForwardingRepository: webForwardingRepository, wafformatter: wafformatter, sslCert: sslCert, cdn: cdn, proxy: proxy, websocket: websocket, cc: cc, ccIpList: ccIpList, gatewayIp: gatewayIp, globalLimitRep: globalLimitRep, } } type AidedWebService struct { *service.Service webForwardingRepository waf.WebForwardingRepository wafformatter common.WafFormatterService sslCert flexCdn.SslCertService cdn flexCdn.CdnService proxy flexCdn.ProxyService websocket flexCdn.WebsocketService cc waf2.CcService ccIpList waf2.CcIpListService gatewayIp common.GatewayipService globalLimitRep waf.GlobalLimitRepository } const ( // 协议类型常量 isHttps = 1 isHttp = 0 protocolHttps = "https" protocolHttp = "http" // 默认配置常量 defaultNodeClusterId = 2 proxyProtocolVersion = 1 ) // BuildWebForwardingModel 辅助函数,用于构建通用的 WebForwarding 模型 // ruleId 是从 WAF 系统获取的 ID func (s *AidedWebService) BuildWebForwardingModel(req *v1.WebForwardingDataRequest, ruleId int, require common.RequireResponse) *model.WebForwarding { return &model.WebForwarding{ HostId: require.HostId, CdnWebId: ruleId, Port: req.Port, Domain: req.Domain, IsHttps: req.IsHttps, Comment: req.Comment, HttpsCert: req.HttpsCert, HttpsKey: req.HttpsKey, SslCertId: int(req.SslCertId), SslPolicyId: int(req.SslPolicyId), Cc: req.CcConfig.IsOn, ThresholdMethod: req.CcConfig.ThresholdMethod, Level: req.CcConfig.Level, Limit5s: req.CcConfig.Limit5s, Limit60s: req.CcConfig.Limit60s, Limit300s: req.CcConfig.Limit300s, Proxy: req.Proxy, } } // BuildWebRuleModel 构建WebForwardingRule模型 func (s *AidedWebService) BuildWebRuleModel(reqData *v1.WebForwardingDataRequest, require common.RequireResponse, localDbId int, cdnOriginIds map[string]int64) *model.WebForwardingRule { return &model.WebForwardingRule{ Uid: require.Uid, HostId: require.HostId, WebId: localDbId, CdnOriginIds: cdnOriginIds, BackendList: reqData.BackendList, } } // getRequire 获取前置配置 func (s *AidedWebService) getRequire (ctx context.Context, req *v1.WebForwardingRequest) (common.RequireResponse, error) { // 1. 获取基础配置 require, err := s.wafformatter.Require(ctx, v1.GlobalRequire{ HostId: req.HostId, Uid: req.Uid, Comment: req.WebForwardingData.Comment, }) if err != nil { return common.RequireResponse{}, fmt.Errorf("获取WAF前置配置失败: %w", err) } if require.Uid == 0 { return common.RequireResponse{}, fmt.Errorf("请先配置实例") } return require, nil } // PrepareWafData 准备WAF数据 // 职责:协调整个流程,负责获取前置配置和组装最终的 formData。 func (s *AidedWebService) PrepareWafData(ctx context.Context, req *v1.WebForwardingRequest) (common.RequireResponse, v1.Website, error) { // 1. 获取前置配置 require, err := s.getRequire(ctx, req) if err != nil { return common.RequireResponse{}, v1.Website{}, err } // 2. 调用辅助函数,构建核心的代理配置 (将复杂逻辑封装起来) byteData, err := s.BuildProxyConfig(ctx, req, require.GatewayIps) if err != nil { return common.RequireResponse{}, v1.Website{}, err // 错误信息在辅助函数中已经包装好了 } type serverNames struct { ServerNames string `json:"name" form:"name"` Type string `json:"type" form:"type"` } var serverName []serverNames var serverJson []byte if req.WebForwardingData.Domain != "" { serverName = append(serverName, serverNames{ ServerNames: req.WebForwardingData.Domain, Type: "full", }) serverJson, err = json.Marshal(serverName) if err != nil { return common.RequireResponse{}, v1.Website{}, err } } // 3. 组装最终的 WAF 表单数据 formData := v1.Website{ UserId: int64(require.CdnUid), Type: "httpProxy", Name: require.Tag, ServerNamesJSON: serverJson, Description: req.WebForwardingData.Comment, ServerGroupIds: []int64{int64(require.GroupId)}, NodeClusterId: defaultNodeClusterId, } // 4. 根据协议类型,填充 HttpJSON 和 HttpsJSON 字段 if req.WebForwardingData.IsHttps == isHttps { formData.HttpJSON = v1.TypeJSON{IsOn: false} formData.HttpsJSON = byteData } else { formData.HttpJSON = byteData formData.HttpsJSON = v1.TypeJSON{IsOn: false} } return require, formData, nil } func (s *AidedWebService) buildSslPolicy(ctx context.Context, data *v1.WebForwardingDataRequest) (v1.SslPolicyRef, error) { // 如果不是 HTTPS,直接返回关闭状态的 SSL 策略 if data.IsHttps != isHttps { return v1.SslPolicyRef{ IsOn: false, SslPolicyId: data.SslPolicyId, }, nil } // --- 以下是 HTTPS 的逻辑 --- sslPolicyID := data.SslPolicyId // 如果请求中没有提供 SSL 策略 ID,则为其创建一个新的 if sslPolicyID == 0 { var err error sslPolicyID, err = s.sslCert.AddSslPolicy(ctx, nil) if err != nil { // 如果创建失败,返回零值和错误 return v1.SslPolicyRef{}, err } } // 返回开启状态的 HTTPS 策略 return v1.SslPolicyRef{ IsOn: true, SslPolicyId: sslPolicyID, }, nil } // BuildProxyConfig 构建代理配置 // 职责:专门负责处理 HTTP/HTTPS 的差异,并生成对应的 JSON 配置。 func (s *AidedWebService) BuildProxyConfig(ctx context.Context, req *v1.WebForwardingRequest, gatewayIps []string) (v1.TypeJSON, error) { // 第一步:构建 SSL 策略。所有复杂的 if/else 都被封装在辅助函数中 sslPolicy, err := s.buildSslPolicy(ctx, &req.WebForwardingData) if err != nil { return v1.TypeJSON{}, err } // 更新请求中的 SSL 策略 req.WebForwardingData.SslPolicyId = sslPolicy.SslPolicyId // 第二步:根据协议类型确定 apiType apiType := protocolHttp if req.WebForwardingData.IsHttps == isHttps { apiType = protocolHttps } // 第三步:构建通用的 Listen 配置 listenConfigs := make([]v1.Listen, 0, len(gatewayIps)) for _, ip := range gatewayIps { listenConfigs = append(listenConfigs, v1.Listen{ Protocol: apiType, Host: ip, Port: req.WebForwardingData.Port, }) } // 第四步:组装并返回最终结果 jsonData := v1.TypeJSON{ IsOn: true, SslPolicyRef: sslPolicy, Listen: listenConfigs, } return jsonData, nil } // BulidFormData 构建表单数据 func (s *AidedWebService) BulidFormData(ctx context.Context, formData v1.Website) (v1.WebsiteSend, error) { httpJSON, err := json.Marshal(formData.HttpJSON) if err != nil { return v1.WebsiteSend{}, err } httpsJSON, err := json.Marshal(formData.HttpsJSON) if err != nil { return v1.WebsiteSend{}, err } formDataSend := v1.WebsiteSend{ UserId: formData.UserId, AdminId: formData.AdminId, Type: formData.Type, Name: formData.Name, Description: formData.Description, ServerNamesJSON: formData.ServerNamesJSON, HttpJSON: httpJSON, HttpsJSON: httpsJSON, TcpJSON: formData.TcpJSON, TlsJSON: formData.TlsJSON, UdpJSON: formData.UdpJSON, WebId: formData.WebId, ReverseProxyJSON: formData.ReverseProxyJSON, ServerGroupIds: formData.ServerGroupIds, UserPlanId: formData.UserPlanId, NodeClusterId: formData.NodeClusterId, IncludeNodesJSON: formData.IncludeNodesJSON, ExcludeNodesJSON: formData.ExcludeNodesJSON, } return formDataSend, nil } // GetProtocolType 获取协议类型字符串 func (s *AidedWebService) GetProtocolType(isHttps int) string { if s.IsHttpsProtocol(isHttps) { return protocolHttps } return protocolHttp } // IsHttpsProtocol 判断是否为HTTPS协议 func (s *AidedWebService) IsHttpsProtocol(httpsFlag int) bool { return httpsFlag == isHttps } // updateWebsiteProtocolAndCert 更新网站协议和证书 func (s *AidedWebService) updateWebsiteProtocolAndCert(ctx context.Context, isHttps int, cdnWebId int64, formData v1.Website) error { // 切换协议 var typeConfig, closeConfig v1.TypeJSON var apiType, closeType string if s.IsHttpsProtocol(isHttps) { typeConfig = formData.HttpsJSON closeConfig = formData.HttpJSON apiType = s.GetProtocolType(isHttps) closeType = s.GetProtocolType(0) // HTTP } else { typeConfig = formData.HttpJSON closeConfig = formData.HttpsJSON apiType = s.GetProtocolType(isHttps) closeType = s.GetProtocolType(1) // HTTPS } typeJson, err := json.Marshal(typeConfig) if err != nil { return fmt.Errorf("序列化协议配置失败: %w", err) } closeJson, err := json.Marshal(closeConfig) if err != nil { return fmt.Errorf("序列化关闭协议配置失败: %w", err) } // 切换协议 if err := s.cdn.EditServerType(ctx, v1.EditWebsite{ Id: cdnWebId, TypeJSON: typeJson, }, apiType); err != nil { return fmt.Errorf("切换到%s协议失败: %w", apiType, err) } if err := s.cdn.EditServerType(ctx, v1.EditWebsite{ Id: cdnWebId, TypeJSON: closeJson, }, closeType); err != nil { return fmt.Errorf("关闭%s协议失败: %w", closeType, err) } return nil } // updateWebsiteDomain 更新网站域名 func (s *AidedWebService) updateWebsiteDomain(ctx context.Context, domain string, cdnWebId int64) error { type serverName struct { Name string `json:"name" form:"name"` Type string `json:"type" form:"type"` } var serverData []serverName serverData = append(serverData, serverName{ Name: domain, Type: "full", }) serverJson, err := json.Marshal(serverData) if err != nil { return fmt.Errorf("序列化服务器名称失败: %w", err) } if err := s.cdn.EditServerName(ctx, v1.EditServerNames{ ServerId: cdnWebId, ServerNamesJSON: serverJson, }); err != nil { return fmt.Errorf("更新服务器名称失败: %w", err) } return nil } // updateWebsiteBasicInfo 更新网站基本信息 func (s *AidedWebService) updateWebsiteBasicInfo(ctx context.Context, cdnWebId int64, tag string) error { // 通过globalLimitRep获取节点ID,这是项目中现有的方法 nodeId, err := s.globalLimitRep.GetNodeId(ctx, int(cdnWebId)) if err != nil { return fmt.Errorf("获取节点ID失败: %w", err) } if err := s.cdn.EditServerBasic(ctx, cdnWebId, tag, nodeId); err != nil { return fmt.Errorf("更新服务器基本信息失败: %w", err) } return nil }