Explorar el Código

fix(service): 优化端口重复验证逻辑

- 在 TCP、UDP 和 Web 转发服务中更新了端口验证逻辑
- 增加了对已有域名规则的检查,确保新规则不会冲突
- 修复了修改现有规则时可能出现的端口重复错误
- 优化了数据库查询,返回更多必要的域信息
fusu hace 3 semanas
padre
commit
13dcbced2a

+ 5 - 0
api/v1/webForwarding.go

@@ -45,3 +45,8 @@ type CcConfigRequest struct {
 	Level string `form:"level" json:"level" default:"low"` //拦截强度
 }
 
+type Domain struct {
+	Id     int    `gorm:"column:id"`
+	Domain string `gorm:"column:domain"`
+}
+

+ 4 - 4
internal/repository/webforwarding.go

@@ -35,7 +35,7 @@ type WebForwardingRepository interface {
 	// 获取CDN的web配置的id
 	GetWebConfigId(ctx context.Context, id int64) (int64, error)
 	// 获取域名
-	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]string, error)
+	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]v1.Domain, error)
 }
 
 func NewWebForwardingRepository(
@@ -318,9 +318,9 @@ func (r *webForwardingRepository) GetWebConfigId(ctx context.Context, id int64)
 }
 
 
-func (r *webForwardingRepository) GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]string, error) {
-	var domains []string
-	if err := r.db.WithContext(ctx).Model(&model.WebForwarding{}).Where("host_id = ? AND port = ?", hostId, port).Pluck("domain", &domains).Error; err != nil {
+func (r *webForwardingRepository) GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]v1.Domain, error) {
+	var domains []v1.Domain
+	if err := r.db.WithContext(ctx).Model(&model.WebForwarding{}).Where("host_id = ? AND port = ?", hostId, port).Select("domain,id").Scan(&domains).Error; err != nil {
 		return nil, err
 	}
 	return domains, nil

+ 2 - 2
internal/service/tcpforwarding.go

@@ -174,7 +174,7 @@ func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.Tcp
 		return err
 	}
 	// 验证端口重复
-	err = s.wafformatter.VerifyPort(ctx, "tcp", req.TcpForwardingData.Port, int64(require.HostId), "")
+	err = s.wafformatter.VerifyPort(ctx, "tcp", int64(req.TcpForwardingData.Id),req.TcpForwardingData.Port, int64(require.HostId), "")
 	if err != nil {
 		return err
 	}
@@ -260,7 +260,7 @@ func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.Tc
 
 	// 验证端口重复
 	if oldData.Port != req.TcpForwardingData.Port {
-		err = s.wafformatter.VerifyPort(ctx, "tcp", req.TcpForwardingData.Port, int64(require.HostId), "")
+		err = s.wafformatter.VerifyPort(ctx, "tcp", int64(req.TcpForwardingData.Id), req.TcpForwardingData.Port, int64(require.HostId), "")
 		if err != nil {
 			return err
 		}

+ 2 - 2
internal/service/udpforwarding.go

@@ -174,7 +174,7 @@ func (s *udpForWardingService) AddUdpForwarding(ctx context.Context, req *v1.Udp
 
 
 	// 验证端口重复
-	err = s.wafformatter.VerifyPort(ctx, "udp", req.UdpForwardingData.Port, int64(require.HostId), "")
+	err = s.wafformatter.VerifyPort(ctx, "udp", int64(req.UdpForwardingData.Id), req.UdpForwardingData.Port, int64(require.HostId), "")
 	if err != nil {
 		return err
 	}
@@ -259,7 +259,7 @@ func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.Ud
 
 	// 验证端口重复
 	if oldData.Port != req.UdpForwardingData.Port {
-		err = s.wafformatter.VerifyPort(ctx, "udp", req.UdpForwardingData.Port, int64(require.HostId), "")
+		err = s.wafformatter.VerifyPort(ctx, "udp", int64(req.UdpForwardingData.Id), req.UdpForwardingData.Port, int64(require.HostId), "")
 		if err != nil {
 			return err
 		}

+ 6 - 4
internal/service/wafformatter.go

@@ -38,7 +38,7 @@ type WafFormatterService interface {
 	// 判断域名是否是IDN,如果是,转换为 Punycode
 	ConvertToPunycodeIfIDN(ctx context.Context, domain string) (isIDN bool, punycodeDomain string, err error)
 	// 验证端口重复
-	VerifyPort(ctx context.Context,protocol string, port string,hostId int64,domain string) error
+	VerifyPort(ctx context.Context,protocol string, id int64, port string,hostId int64,domain string) error
 }
 
 func NewWafFormatterService(
@@ -493,7 +493,7 @@ func (s *wafFormatterService) ConvertToPunycodeIfIDN(ctx context.Context, domain
 
 
 // 验证端口重复
-func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, port string,hostId int64,domain string) error {
+func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, id int64, port string,hostId int64,domain string) error {
 	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
 	switch protocol {
 		case "http", "https":
@@ -511,14 +511,16 @@ func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, po
 			}
 
 			for _, v := range domains {
-				if v == "" {
+				// 防住空域名修改为非空域名报错
+				if v.Domain == "" && int64(v.Id) != id {
 					return errPortInUse
 				}
-				if net.ParseIP(v) != nil {
+				if net.ParseIP(v.Domain) != nil {
 					return errPortInUse
 				}
 			}
 
+			// 确保添加新规则时,没有已有域名的规则
 			if net.ParseIP(domain) != nil || domain == "" {
 				if len(domains) > 0 {
 					return errPortInUse

+ 5 - 4
internal/service/webforwarding.go

@@ -269,13 +269,13 @@ func (s *webForwardingService) buildProxyConfig(ctx context.Context, req *v1.Web
 
 	)
 
-
 	jsonData.IsOn = true
 	apiType = protocolHttps
+	jsonData.SslPolicyRef.SslPolicyId = req.WebForwardingData.SslPolicyId
 	// 判断协议类型,并处理 HTTPS 的特殊逻辑(证书)
 	if req.WebForwardingData.IsHttps == isHttps {
 		// 处理证书信息
-		if req.WebForwardingData.SslPolicyId == 0 {
+		if jsonData.SslPolicyRef.SslPolicyId == 0 {
 			sslPolicyId, err := s.sslCert.AddSslPolicy(ctx, nil)
 			if err != nil {
 				return v1.TypeJSON{}, err
@@ -344,7 +344,7 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 	}
 
 	// 验证端口重复
-	err = s.wafformatter.VerifyPort(ctx, "http", req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain)
+	err = s.wafformatter.VerifyPort(ctx,"http", int64(req.WebForwardingData.Id), req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain)
 	if err != nil {
 		return err
 	}
@@ -509,6 +509,7 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		return err
 	}
 	req.WebForwardingData.SslCertId = int64(oldData.SslCertId)
+	req.WebForwardingData.SslPolicyId = int64(oldData.SslPolicyId)
 	require, formData, err := s.prepareWafData(ctx, req)
 	if err != nil {
 		return err
@@ -516,7 +517,7 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 
 	// 验证端口重复
 	if oldData.Port != req.WebForwardingData.Port {
-		err = s.wafformatter.VerifyPort(ctx, "http", req.WebForwardingData.Port, int64(require.HostId), "")
+		err = s.wafformatter.VerifyPort(ctx, "http", int64(req.WebForwardingData.Id), req.WebForwardingData.Port, int64(require.HostId), "")
 		if err != nil {
 			return err
 		}