Sfoglia il codice sorgente

fix(webforwarding): 删除转发规则时一并删除 SSL 证书- 新增 SSLPolicy 和 SslCertsJSON 结构体用于处理 SSL 策略
- 添加 DelSSLCert 方法到 CDN服务接口
- 修改 GetSSLPolicy 方法返回值类型
-移除未使用的 GetWebForwardingWafWebIdById 方法
- 在删除转发规则时,如果存在 SSL 证书,则先删除关联的 SSL 证书

huangjl 1 mese fa
parent
commit
ffe818b75b

+ 16 - 0
api/v1/cdn.go

@@ -139,3 +139,19 @@ type AddSSLPolicy struct {
 	CipherSuitesIsOn bool     `json:"cipherSuitesIsOn" form:"cipherSuitesIsOn"` //可选项,是否启用自定义加密套件
 	OcspIsOn         bool     `json:"ocspIsOn" form:"ocspIsOn"`                 //可选项,是否开启OCSP
 }
+type SSLPolicy struct {
+	Http2Enabled     bool           `json:"http2Enabled" form:"http2Enabled"`         //是否支持HTTP/2
+	Http3Enabled     bool           `json:"http3Enabled" form:"http3Enabled"`         //是否支持Http3Enabled
+	MinVersion       string         `json:"minVersion" form:"minVersion"`             //最小TLS版本
+	SslCertsJSON     []SslCertsJSON `json:"sslCertsJSON" form:"sslCertsJSON"`         //SslCertsJSON
+	HstsJSON         []byte         `json:"hstsJSON" form:"hstsJSON"`                 //HstsJSON
+	ClientAuthType   int32          `json:"clientAuthType" form:"clientAuthType"`     //可选项,客户端校验类型:0 无需证书,1 需要客户端证书,2 需要任一客户端证书,3 如果客户端上传了证书才校验,4 需要客户端证书而且需要校验
+	CipherSuites     []string       `json:"cipherSuites" form:"cipherSuites"`         //可选项,支持的TLS加密套件
+	CipherSuitesIsOn bool           `json:"cipherSuitesIsOn" form:"cipherSuitesIsOn"` //可选项,是否启用自定义加密套件
+	OcspIsOn         bool           `json:"ocspIsOn" form:"ocspIsOn"`                 //可选项,是否开启OCSP
+}
+
+type SslCertsJSON struct {
+	IsOn   bool  `json:"isOn" form:"isOn"`
+	CertId int64 `json:"certId" form:"certId"`
+}

+ 9 - 22
internal/repository/webforwarding.go

@@ -18,7 +18,6 @@ type WebForwardingRepository interface {
 	AddWebForwarding(ctx context.Context, req *model.WebForwarding) (int, error)
 	EditWebForwarding(ctx context.Context, req *model.WebForwarding) error
 	DeleteWebForwarding(ctx context.Context, id int64) error
-	GetWebForwardingWafWebIdById(ctx context.Context, id int) (int, error)
 	GetWebForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error)
 	GetWebForwardingDomainCountByHostId(ctx context.Context, hostId int) (int64, []string, error)
 	GetWebForwardingWafWebAllIds(ctx context.Context, hostId int) ([]int, error)
@@ -27,7 +26,7 @@ type WebForwardingRepository interface {
 	GetWebForwardingIpsByID(ctx context.Context, webId int) (*model.WebForwardingRule, error)
 	DeleteWebForwardingIpsById(ctx context.Context, webId int) error
 	// 获取域名数量
-	GetDomainCount(ctx context.Context, hostId int,domain string) (int, error)
+	GetDomainCount(ctx context.Context, hostId int, domain string) (int, error)
 	// 获取IP数量等于1的IP
 	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
 }
@@ -92,15 +91,6 @@ func (r *webForwardingRepository) DeleteWebForwarding(ctx context.Context, id in
 	return nil
 }
 
-func (r *webForwardingRepository) GetWebForwardingWafWebIdById(ctx context.Context, id int) (int, error) {
-	var WafWebId int
-
-	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("id = ?", id).Select("waf_web_id").Find(&WafWebId).Error; err != nil {
-		return 0, err
-	}
-	return WafWebId, nil
-}
-
 func (r *webForwardingRepository) GetWebForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error) {
 	var count int64
 	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("host_id = ?", hostId).Count(&count).Error; err != nil {
@@ -112,7 +102,7 @@ func (r *webForwardingRepository) GetWebForwardingPortCountByHostId(ctx context.
 func (r *webForwardingRepository) GetWebForwardingDomainCountByHostId(ctx context.Context, hostId int) (int64, []string, error) {
 	var distinctDomains []string
 	err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).
-		Distinct(). // 确保我们只获取唯一的 domain 值
+		Distinct().                                                           // 确保我们只获取唯一的 domain 值
 		Where("host_id = ? AND domain IS NOT NULL AND domain != ''", hostId). // 额外添加 domain != '' 以排除空字符串
 		Pluck("domain", &distinctDomains).Error
 
@@ -133,7 +123,6 @@ func (r *webForwardingRepository) GetWebForwardingWafWebAllIds(ctx context.Conte
 	return ids, nil
 }
 
-
 // mongodb 插入
 func (r *webForwardingRepository) AddWebForwardingIps(ctx context.Context, req model.WebForwardingRule) (primitive.ObjectID, error) {
 	collection := r.mongoDB.Collection("web_forwarding_rules")
@@ -163,12 +152,10 @@ func (r *webForwardingRepository) EditWebForwardingIps(ctx context.Context, req
 		updateData["web_id"] = req.WebId
 	}
 
-
 	if len(req.BackendList) > 0 {
 		updateData["backend_list"] = req.BackendList
 	}
 
-
 	updateData["cdn_origin_ids"] = req.CdnOriginIds
 
 	// 始终更新更新时间
@@ -181,7 +168,7 @@ func (r *webForwardingRepository) EditWebForwardingIps(ctx context.Context, req
 
 	// 执行更新
 	update := bson.M{"$set": updateData}
-	 err := collection.UpdateOne(ctx, bson.M{"web_id": req.WebId}, update)
+	err := collection.UpdateOne(ctx, bson.M{"web_id": req.WebId}, update)
 	if err != nil {
 		return fmt.Errorf("更新MongoDB文档失败: %w", err)
 	}
@@ -217,7 +204,7 @@ func (r *webForwardingRepository) DeleteWebForwardingIpsById(ctx context.Context
 
 	collection := r.mongoDB.Collection("web_forwarding_rules")
 
-	 err := collection.Remove(ctx, bson.M{"web_id": webId})
+	err := collection.Remove(ctx, bson.M{"web_id": webId})
 
 	if err != nil {
 		if errors.Is(err, mongo.ErrNoDocuments) {
@@ -229,9 +216,9 @@ func (r *webForwardingRepository) DeleteWebForwardingIpsById(ctx context.Context
 }
 
 // 获取域名数量
-func (r *webForwardingRepository) GetDomainCount(ctx context.Context, hostId int,domain string) (int, error) {
+func (r *webForwardingRepository) GetDomainCount(ctx context.Context, hostId int, domain string) (int, error) {
 	var count int64
-	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("host_id = ? AND domain = ?", hostId,domain).Count(&count).Error; err != nil {
+	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("host_id = ? AND domain = ?", hostId, domain).Count(&count).Error; err != nil {
 		return 0, err
 	}
 	return int(count), nil
@@ -256,9 +243,9 @@ func (r *webForwardingRepository) GetIpCountByIp(ctx context.Context, ips []stri
 		},
 		{
 			"$project": bson.M{
-				"_id":   0,       // 不输出默认的_id
-				"ip":    "$_id",  // 将分组的_id字段重命名为ip
-				"count": 1,       // 保留count字段
+				"_id":   0,      // 不输出默认的_id
+				"ip":    "$_id", // 将分组的_id字段重命名为ip
+				"count": 1,      // 保留count字段
 			},
 		},
 	}

+ 7 - 5
internal/service/cdn.go

@@ -37,6 +37,8 @@ type CdnService interface {
 	EditServerName(ctx context.Context, req v1.EditServerNames) error
 	// 添加ssl策略
 	AddSSLPolicy(ctx context.Context, req v1.AddSSLPolicy) (int64, error)
+	DelSSLCert(ctx context.Context, sslCertId int64) error
+	GetSSLPolicy(ctx context.Context, sslPolicyId int64) (v1.SSLPolicy, error)
 }
 
 func NewCdnService(
@@ -681,7 +683,7 @@ func (s *cdnService) DelSSLCert(ctx context.Context, sslCertId int64) error {
 
 }
 
-func (s *cdnService) GetSSLPolicy(ctx context.Context, sslPolicyId int64) (v1.AddSSLPolicy, error) {
+func (s *cdnService) GetSSLPolicy(ctx context.Context, sslPolicyId int64) (v1.SSLPolicy, error) {
 	formData := map[string]interface{}{
 		"sslPolicyId": sslPolicyId,
 		"ignoreData":  true,
@@ -689,14 +691,14 @@ func (s *cdnService) GetSSLPolicy(ctx context.Context, sslPolicyId int64) (v1.Ad
 	apiUrl := s.Url + "SSLPolicyService/findEnabledSSLPolicyConfig"
 	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl)
 	if err != nil {
-		return v1.AddSSLPolicy{}, err
+		return v1.SSLPolicy{}, err
 	}
-	var res v1.GeneralResponse[v1.AddSSLPolicy]
+	var res v1.GeneralResponse[v1.SSLPolicy]
 	if err := json.Unmarshal(resBody, &res); err != nil {
-		return v1.AddSSLPolicy{}, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+		return v1.SSLPolicy{}, fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
 	}
 	if res.Code != 200 {
-		return v1.AddSSLPolicy{}, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
+		return v1.SSLPolicy{}, fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 	}
 	return res.Data, nil
 }

+ 22 - 2
internal/service/webforwarding.go

@@ -238,7 +238,7 @@ func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebFo
 // 辅助函数:buildProxyJSONConfig
 // 职责:专门负责处理 HTTP/HTTPS 的差异,并生成对应的 JSON 配置。
 // =================================================================
-func (s *webForwardingService) buildProxyJSONConfig(ctx context.Context, req *v1.WebForwardingRequest, require RequireResponse) ([]byte,int64, error) {
+func (s *webForwardingService) buildProxyJSONConfig(ctx context.Context, req *v1.WebForwardingRequest, require RequireResponse) ([]byte, int64, error) {
 	var (
 		jsonData v1.TypeJSON
 		apiType  string
@@ -322,7 +322,7 @@ func (s *webForwardingService) buildProxyJSONConfig(ctx context.Context, req *v1
 		return nil, 0, fmt.Errorf("序列化WAF配置失败: %w", err)
 	}
 
-	return byteData,sslPolicyId, nil
+	return byteData, sslPolicyId, nil
 }
 
 // 查找两个列表的差异
@@ -707,6 +707,26 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []in
 			}
 		}
 
+		// 删除ssl
+		data, err := s.webForwardingRepository.GetWebForwarding(ctx, int64(Id))
+		if err != nil {
+			return err
+		}
+		if data.SslCertId != 0 {
+			sslPolicyData, err := s.cdn.GetSSLPolicy(ctx, int64(data.SslCertId))
+			if err != nil {
+				return err
+			}
+			if sslPolicyData.SslCertsJSON != nil {
+				for _, v := range sslPolicyData.SslCertsJSON {
+					err := s.cdn.DelSSLCert(ctx, v.CertId)
+					if err != nil {
+						return err
+					}
+				}
+			}
+		}
+
 		if err = s.webForwardingRepository.DeleteWebForwarding(ctx, int64(Id)); err != nil {
 			return err
 		}