|
@@ -494,65 +494,117 @@ func (s *wafFormatterService) ConvertToPunycodeIfIDN(ctx context.Context, domain
|
|
|
|
|
|
// 验证端口重复
|
|
|
func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, id int64, port string,hostId int64,domain string) error {
|
|
|
- errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
|
|
|
switch protocol {
|
|
|
- case "http", "https":
|
|
|
- domains, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
+ case "http", "https":
|
|
|
+ return s.verifyWebForwardingPort(ctx, protocol, id, port, hostId, domain)
|
|
|
+ case "tcp":
|
|
|
+ return s.verifyTCPPort(ctx, hostId, port)
|
|
|
+ case "udp":
|
|
|
+ return s.verifyUDPPort(ctx, hostId, port)
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("不支持的协议类型:%s", protocol)
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- if tcpCount > 0 {
|
|
|
- return errPortInUse
|
|
|
- }
|
|
|
|
|
|
- for _, v := range domains {
|
|
|
- // 防住空域名修改为非空域名报错
|
|
|
- if v.Domain == "" && int64(v.Id) != id {
|
|
|
- return errPortInUse
|
|
|
- }
|
|
|
- if net.ParseIP(v.Domain) != nil {
|
|
|
- return errPortInUse
|
|
|
- }
|
|
|
- }
|
|
|
+// verifyWebForwardingPort 专门处理 HTTP 和 HTTPS 的端口校验逻辑。
|
|
|
+func (s *wafFormatterService) verifyWebForwardingPort(ctx context.Context, protocol string, id int64, port string, hostId int64, domain string) error {
|
|
|
+ errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
|
|
|
|
|
|
- // 确保添加新规则时,没有已有域名的规则
|
|
|
- if net.ParseIP(domain) != nil || domain == "" {
|
|
|
- if len(domains) > 0 {
|
|
|
- return errPortInUse
|
|
|
- }
|
|
|
- }
|
|
|
+ // 1. 检查是否存在 TCP 转发规则占用该端口
|
|
|
+ tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if tcpCount > 0 {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
|
|
|
- return nil
|
|
|
+ // 2. 获取该主机和端口上所有已存在的 Web 转发规则
|
|
|
+ existingRules, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
|
|
|
+ // 如果没有任何规则,则该端口可用,直接返回
|
|
|
+ if len(existingRules) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
|
|
|
- case "tcp":
|
|
|
- count, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- webCount, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- if count + int64(len(webCount)) > 0 {
|
|
|
- return errPortInUse
|
|
|
- }
|
|
|
- return nil
|
|
|
- case "udp":
|
|
|
- count, err := s.udpForWardingRep.GetPortCount(ctx, hostId, port)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- if count > 0 {
|
|
|
- return errPortInUse
|
|
|
- }
|
|
|
- return nil
|
|
|
- default:
|
|
|
- return fmt.Errorf("不支持的协议类型:%s", protocol)
|
|
|
+ // 3. 核心逻辑:检查协议冲突和域名冲突
|
|
|
+ isNewRuleHTTPS := 0
|
|
|
+ if protocol == "https" {
|
|
|
+ isNewRuleHTTPS = 1
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, rule := range existingRules {
|
|
|
+ // 关键检查:HTTP 和 HTTPS 不能在同一个端口上共存。
|
|
|
+ if rule.IsHttps != isNewRuleHTTPS {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
+
|
|
|
+ // 如果现有规则是“全匹配”规则(空域名或IP),并且不是我们正在编辑的规则,则冲突。
|
|
|
+ isExistingRuleCatchAll := rule.Domain == "" || net.ParseIP(rule.Domain) != nil
|
|
|
+ if isExistingRuleCatchAll && int64(rule.Id) != id {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
}
|
|
|
+
|
|
|
+ // 4. 反向检查:如果要添加/修改的规则是“全匹配”规则,则该端口上不能有其他规则。
|
|
|
+ isNewRuleCatchAll := domain == "" || net.ParseIP(domain) != nil
|
|
|
+ if isNewRuleCatchAll {
|
|
|
+ // 如果已存在规则数大于1,则必然冲突。
|
|
|
+ if len(existingRules) > 1 {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
+ // 如果只存在1条规则,但其ID和当前要修改的ID不同,也冲突。
|
|
|
+ // (此场景意味着你在为一个已有其他规则的端口添加一条新的“全匹配”规则)
|
|
|
+ if len(existingRules) == 1 && int64(existingRules[0].Id) != id {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+// verifyTCPPort 专门处理 TCP 的端口校验逻辑。
|
|
|
+func (s *wafFormatterService) verifyTCPPort(ctx context.Context, hostId int64, port string) error {
|
|
|
+ errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
|
|
|
+
|
|
|
+ // TCP 规则不能与已有的 TCP 规则共存
|
|
|
+ tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if tcpCount > 0 {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
+
|
|
|
+ // TCP 规则也不能与已有的 Web 转发(HTTP/HTTPS)规则共存
|
|
|
+ webRules, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if len(webRules) > 0 {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// verifyUDPPort 专门处理 UDP 的端口校验逻辑。
|
|
|
+func (s *wafFormatterService) verifyUDPPort(ctx context.Context, hostId int64, port string) error {
|
|
|
+ errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
|
|
|
+
|
|
|
+ // UDP 规则不能与已有的 UDP 规则共存
|
|
|
+ count, err := s.udpForWardingRep.GetPortCount(ctx, hostId, port)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if count > 0 {
|
|
|
+ return errPortInUse
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|