瀏覽代碼

refactor(waf): 重构端口验证逻辑

- 将原有的单个函数拆分为多个专门处理不同协议的函数
- 优化了 HTTP/HTTPS 端口的验证逻辑,增加了对协议冲突和域名冲突的检查
- 分离 TCP 和 UDP 端口的验证逻辑,使其更加清晰- 在 Web 转发规则中增加了 IsHttps 字段,用于区分 HTTP 和 HTTPS 规则
fusu 3 周之前
父節點
當前提交
bf44e3f1d3
共有 3 個文件被更改,包括 107 次插入54 次删除
  1. 1 0
      api/v1/webForwarding.go
  2. 1 1
      internal/repository/webforwarding.go
  3. 105 53
      internal/service/wafformatter.go

+ 1 - 0
api/v1/webForwarding.go

@@ -48,5 +48,6 @@ type CcConfigRequest struct {
 type Domain struct {
 	Id     int    `gorm:"column:id"`
 	Domain string `gorm:"column:domain"`
+	IsHttps int    `gorm:"column:is_https"`
 }
 

+ 1 - 1
internal/repository/webforwarding.go

@@ -320,7 +320,7 @@ func (r *webForwardingRepository) GetWebConfigId(ctx context.Context, id int64)
 
 func (r *webForwardingRepository) GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]v1.Domain, error) {
 	var domains []v1.Domain
-	if err := r.db.WithContext(ctx).Model(&model.WebForwarding{}).Where("host_id = ? AND port = ?", hostId, port).Select("domain,id").Scan(&domains).Error; err != nil {
+	if err := r.db.WithContext(ctx).Model(&model.WebForwarding{}).Where("host_id = ? AND port = ?", hostId, port).Select("domain,id,is_https").Scan(&domains).Error; err != nil {
 		return nil, err
 	}
 	return domains, nil

+ 105 - 53
internal/service/wafformatter.go

@@ -494,65 +494,117 @@ func (s *wafFormatterService) ConvertToPunycodeIfIDN(ctx context.Context, domain
 
 // 验证端口重复
 func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, id int64, port string,hostId int64,domain string) error {
-	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
 	switch protocol {
-		case "http", "https":
-			domains, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
-			if err != nil {
-				return err
-			}
-			tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
-			if err != nil {
-				return err
-			}
+	case "http", "https":
+		return s.verifyWebForwardingPort(ctx, protocol, id, port, hostId, domain)
+	case "tcp":
+		return s.verifyTCPPort(ctx, hostId, port)
+	case "udp":
+		return s.verifyUDPPort(ctx, hostId, port)
+	default:
+		return fmt.Errorf("不支持的协议类型:%s", protocol)
+	}
+}
 
-			if tcpCount > 0 {
-				return errPortInUse
-			}
 
-			for _, v := range domains {
-				// 防住空域名修改为非空域名报错
-				if v.Domain == "" && int64(v.Id) != id {
-					return errPortInUse
-				}
-				if net.ParseIP(v.Domain) != nil {
-					return errPortInUse
-				}
-			}
+// verifyWebForwardingPort 专门处理 HTTP 和 HTTPS 的端口校验逻辑。
+func (s *wafFormatterService) verifyWebForwardingPort(ctx context.Context, protocol string, id int64, port string, hostId int64, domain string) error {
+	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
 
-			// 确保添加新规则时,没有已有域名的规则
-			if net.ParseIP(domain) != nil || domain == "" {
-				if len(domains) > 0 {
-					return errPortInUse
-				}
-			}
+	// 1. 检查是否存在 TCP 转发规则占用该端口
+	tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
+	if err != nil {
+		return err
+	}
+	if tcpCount > 0 {
+		return errPortInUse
+	}
 
-			return nil
+	// 2. 获取该主机和端口上所有已存在的 Web 转发规则
+	existingRules, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
+	if err != nil {
+		return err
+	}
 
+	// 如果没有任何规则,则该端口可用,直接返回
+	if len(existingRules) == 0 {
+		return nil
+	}
 
-		case "tcp":
-			count, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
-			if err != nil {
-				return err
-			}
-			webCount, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
-			if err != nil {
-				return err
-			}
-			if count + int64(len(webCount)) > 0 {
-				return errPortInUse
-			}
-			return nil
-		case "udp":
-			count, err := s.udpForWardingRep.GetPortCount(ctx, hostId, port)
-			if err != nil {
-				return err
-			}
-			if count > 0 {
-				return errPortInUse
-			}
-			return nil
-		default:
-			return fmt.Errorf("不支持的协议类型:%s", protocol)
+	// 3. 核心逻辑:检查协议冲突和域名冲突
+	isNewRuleHTTPS := 0
+	if protocol == "https" {
+		isNewRuleHTTPS = 1
+	}
+
+	for _, rule := range existingRules {
+		// 关键检查:HTTP 和 HTTPS 不能在同一个端口上共存。
+		if rule.IsHttps != isNewRuleHTTPS {
+			return errPortInUse
+		}
+
+		// 如果现有规则是“全匹配”规则(空域名或IP),并且不是我们正在编辑的规则,则冲突。
+		isExistingRuleCatchAll := rule.Domain == "" || net.ParseIP(rule.Domain) != nil
+		if isExistingRuleCatchAll && int64(rule.Id) != id {
+			return errPortInUse
+		}
 	}
+
+	// 4. 反向检查:如果要添加/修改的规则是“全匹配”规则,则该端口上不能有其他规则。
+	isNewRuleCatchAll := domain == "" || net.ParseIP(domain) != nil
+	if isNewRuleCatchAll {
+		// 如果已存在规则数大于1,则必然冲突。
+		if len(existingRules) > 1 {
+			return errPortInUse
+		}
+		// 如果只存在1条规则,但其ID和当前要修改的ID不同,也冲突。
+		// (此场景意味着你在为一个已有其他规则的端口添加一条新的“全匹配”规则)
+		if len(existingRules) == 1 && int64(existingRules[0].Id) != id {
+			return errPortInUse
+		}
+	}
+
+	return nil
+}
+
+
+// verifyTCPPort 专门处理 TCP 的端口校验逻辑。
+func (s *wafFormatterService) verifyTCPPort(ctx context.Context, hostId int64, port string) error {
+	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
+
+	// TCP 规则不能与已有的 TCP 规则共存
+	tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
+	if err != nil {
+		return err
+	}
+	if tcpCount > 0 {
+		return errPortInUse
+	}
+
+	// TCP 规则也不能与已有的 Web 转发(HTTP/HTTPS)规则共存
+	webRules, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
+	if err != nil {
+		return err
+	}
+	if len(webRules) > 0 {
+		return errPortInUse
+	}
+
+	return nil
+}
+
+// verifyUDPPort 专门处理 UDP 的端口校验逻辑。
+func (s *wafFormatterService) verifyUDPPort(ctx context.Context, hostId int64, port string) error {
+	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
+
+	// UDP 规则不能与已有的 UDP 规则共存
+	count, err := s.udpForWardingRep.GetPortCount(ctx, hostId, port)
+	if err != nil {
+		return err
+	}
+	if count > 0 {
+		return errPortInUse
+	}
+
+	return nil
 }