瀏覽代碼

fix(service): 添加端口重复验证功能

- 在 TCP、UDP 和 Web 转发服务中添加端口重复验证
- 新增 GetPortCount 和 GetDomainByHostIdPort 方法用于查询端口使用情况
- VerifyPort 方法实现端口重复验证逻辑
- 在创建和更新转发规则时进行端口重复检查
fusu 4 周之前
父節點
當前提交
51511f741f

+ 10 - 0
internal/repository/tcpforwarding.go

@@ -26,6 +26,8 @@ type TcpforwardingRepository interface {
 	DeleteTcpForwardingIpsById(ctx context.Context, tcpId int) error
 	// 获取IP数量等于1的IP
 	GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error)
+	// 获取端口数量
+	GetPortCount(ctx context.Context,hostId int64, port string) (int64, error)
 }
 
 func NewTcpforwardingRepository(
@@ -223,4 +225,12 @@ func (r *tcpforwardingRepository) GetIpCountByIp(ctx context.Context,ips []strin
 		return nil, err
 	}
 	return results, nil
+}
+
+func (r *tcpforwardingRepository) GetPortCount(ctx context.Context,hostId int64, port string) (int64, error) {
+	var count int64
+	if err := r.db.WithContext(ctx).Model(&model.Tcpforwarding{}).Where("host_id = ? AND port = ?", hostId, port).Count(&count).Error; err != nil {
+		return 0, err
+	}
+	return count, nil
 }

+ 10 - 0
internal/repository/udpforwarding.go

@@ -26,6 +26,8 @@ type UdpForWardingRepository interface {
 	DeleteUdpForwardingIpsById(ctx context.Context, udpId int) error
 	// 获取ip数量等于1的ip
 	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
+	// 获取端口数量
+	GetPortCount(ctx context.Context, hostId int64, port string) (int64, error)
 }
 
 func NewUdpForWardingRepository(
@@ -218,4 +220,12 @@ func (r *udpForWardingRepository) GetIpCountByIp(ctx context.Context, ips []stri
 		return nil, err
 	}
 	return results, nil
+}
+
+func (r *udpForWardingRepository) GetPortCount(ctx context.Context, hostId int64, port string) (int64, error) {
+	var count int64
+	if err := r.db.WithContext(ctx).Model(&model.UdpForWarding{}).Where("host_id = ? AND port = ?", hostId, port).Count(&count).Error; err != nil {
+		return 0, err
+	}
+	return count, nil
 }

+ 12 - 0
internal/repository/webforwarding.go

@@ -34,6 +34,8 @@ type WebForwardingRepository interface {
 	GetSslCertId (ctx context.Context, sslPocyID int) ([]v1.SslCertsJSON, error)
 	// 获取CDN的web配置的id
 	GetWebConfigId(ctx context.Context, id int64) (int64, error)
+	// 获取域名
+	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]string, error)
 }
 
 func NewWebForwardingRepository(
@@ -312,4 +314,14 @@ func (r *webForwardingRepository) GetWebConfigId(ctx context.Context, id int64)
 		return 0, err
 	}
 	return webConfigId, nil
+}
+
+
+func (r *webForwardingRepository) GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]string, error) {
+	var domains []string
+	if err := r.db.WithContext(ctx).Model(&model.WebForwarding{}).Where("host_id = ? AND port = ?", hostId, port).Pluck("domain", &domains).Error; err != nil {
+		return nil, err
+	}
+	return domains, nil
+
 }

+ 16 - 0
internal/service/tcpforwarding.go

@@ -173,6 +173,12 @@ func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.Tcp
 	if err != nil {
 		return err
 	}
+	// 验证端口重复
+	err = s.wafformatter.VerifyPort(ctx, "tcp", req.TcpForwardingData.Port, int64(require.HostId), "")
+	if err != nil {
+		return err
+	}
+
 
 	tcpId, err := s.cdn.CreateWebsite(ctx, formData)
 	if err != nil {
@@ -252,6 +258,16 @@ func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.Tc
 		return err
 	}
 
+	// 验证端口重复
+	if oldData.Port != req.TcpForwardingData.Port {
+		err = s.wafformatter.VerifyPort(ctx, "tcp", req.TcpForwardingData.Port, int64(require.HostId), "")
+		if err != nil {
+			return err
+		}
+	}
+
+
+
 	//修改网站端口
 	if oldData.Port != req.TcpForwardingData.Port {
 		err = s.cdn.EditServerType(ctx, v1.EditWebsite{

+ 15 - 0
internal/service/udpforwarding.go

@@ -172,6 +172,13 @@ func (s *udpForWardingService) AddUdpForwarding(ctx context.Context, req *v1.Udp
 		return err
 	}
 
+
+	// 验证端口重复
+	err = s.wafformatter.VerifyPort(ctx, "udp", req.UdpForwardingData.Port, int64(require.HostId), "")
+	if err != nil {
+		return err
+	}
+
 	udpId, err := s.cdn.CreateWebsite(ctx, formData)
 	if err != nil {
 		return err
@@ -250,6 +257,14 @@ func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.Ud
 		return err
 	}
 
+	// 验证端口重复
+	if oldData.Port != req.UdpForwardingData.Port {
+		err = s.wafformatter.VerifyPort(ctx, "udp", req.UdpForwardingData.Port, int64(require.HostId), "")
+		if err != nil {
+			return err
+		}
+	}
+
 	//修改网站端口
 	if oldData.Port != req.UdpForwardingData.Port {
 		err = s.cdn.EditServerType(ctx, v1.EditWebsite{

+ 65 - 0
internal/service/wafformatter.go

@@ -43,6 +43,8 @@ type WafFormatterService interface {
 	ParseCert(ctx context.Context, httpsCert string, httpKey string) (serverName string, commonName []string, DNSNames []string, before int64, after int64, isSelfSigned bool, err error)
 	AddSSLPolicy(ctx context.Context, req v1.SSL) (sslPolicyId int64, sslCertId int64, err error)
 	EditSSL(ctx context.Context, req v1.SSL) error
+	// 验证端口重复
+	VerifyPort(ctx context.Context,protocol string, port string,hostId int64,domain string) error
 }
 
 func NewWafFormatterService(
@@ -638,3 +640,66 @@ func (s *wafFormatterService) EditSSL(ctx context.Context, req v1.SSL) error {
 	}
 	return nil
 }
+
+// 验证端口重复
+func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, port string,hostId int64,domain string) error {
+	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
+	switch protocol {
+		case "http", "https":
+			domains, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+
+			if tcpCount > 0 {
+				return errPortInUse
+			}
+
+			for _, v := range domains {
+				if v == "" {
+					return errPortInUse
+				}
+				if net.ParseIP(v) != nil {
+					return errPortInUse
+				}
+			}
+
+			if net.ParseIP(domain) != nil || domain == "" {
+				if len(domains) > 0 {
+					return errPortInUse
+				}
+			}
+
+			return nil
+
+
+		case "tcp":
+			count, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			webCount, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			if count + int64(len(webCount)) > 0 {
+				return errPortInUse
+			}
+			return nil
+		case "udp":
+			count, err := s.udpForWardingRep.GetPortCount(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			if count > 0 {
+				return errPortInUse
+			}
+			return nil
+		default:
+			return fmt.Errorf("不支持的协议类型:%s", protocol)
+	}
+}

+ 14 - 0
internal/service/webforwarding.go

@@ -367,6 +367,12 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		return err
 	}
 
+	// 验证端口重复
+	err = s.wafformatter.VerifyPort(ctx, "http", req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain)
+	if err != nil {
+		return err
+	}
+
 	webId, err := s.cdn.CreateWebsite(ctx, formData)
 	if err != nil {
 		return err
@@ -504,6 +510,14 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		return err
 	}
 
+	// 验证端口重复
+	if oldData.Port != req.WebForwardingData.Port {
+		err = s.wafformatter.VerifyPort(ctx, "http", req.WebForwardingData.Port, int64(require.HostId), "")
+		if err != nil {
+			return err
+		}
+	}
+
 	//修改网站端口
 	if oldData.Port != req.WebForwardingData.Port || oldData.IsHttps != req.WebForwardingData.IsHttps || oldData.HttpsCert != req.WebForwardingData.HttpsCert || oldData.HttpsKey != req.WebForwardingData.HttpsKey {
 		var typeJson []byte