浏览代码

fix(service): 添加端口重复验证功能

- 在 TCP、UDP 和 Web 转发服务中添加端口重复验证
- 新增 GetPortCount 和 GetDomainByHostIdPort 方法用于查询端口使用情况
- VerifyPort 方法实现端口重复验证逻辑
- 在创建和更新转发规则时进行端口重复检查
fusu 4 周之前
父节点
当前提交
51511f741f

+ 10 - 0
internal/repository/tcpforwarding.go

@@ -26,6 +26,8 @@ type TcpforwardingRepository interface {
 	DeleteTcpForwardingIpsById(ctx context.Context, tcpId int) error
 	DeleteTcpForwardingIpsById(ctx context.Context, tcpId int) error
 	// 获取IP数量等于1的IP
 	// 获取IP数量等于1的IP
 	GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error)
 	GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error)
+	// 获取端口数量
+	GetPortCount(ctx context.Context,hostId int64, port string) (int64, error)
 }
 }
 
 
 func NewTcpforwardingRepository(
 func NewTcpforwardingRepository(
@@ -223,4 +225,12 @@ func (r *tcpforwardingRepository) GetIpCountByIp(ctx context.Context,ips []strin
 		return nil, err
 		return nil, err
 	}
 	}
 	return results, nil
 	return results, nil
+}
+
+func (r *tcpforwardingRepository) GetPortCount(ctx context.Context,hostId int64, port string) (int64, error) {
+	var count int64
+	if err := r.db.WithContext(ctx).Model(&model.Tcpforwarding{}).Where("host_id = ? AND port = ?", hostId, port).Count(&count).Error; err != nil {
+		return 0, err
+	}
+	return count, nil
 }
 }

+ 10 - 0
internal/repository/udpforwarding.go

@@ -26,6 +26,8 @@ type UdpForWardingRepository interface {
 	DeleteUdpForwardingIpsById(ctx context.Context, udpId int) error
 	DeleteUdpForwardingIpsById(ctx context.Context, udpId int) error
 	// 获取ip数量等于1的ip
 	// 获取ip数量等于1的ip
 	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
 	GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error)
+	// 获取端口数量
+	GetPortCount(ctx context.Context, hostId int64, port string) (int64, error)
 }
 }
 
 
 func NewUdpForWardingRepository(
 func NewUdpForWardingRepository(
@@ -218,4 +220,12 @@ func (r *udpForWardingRepository) GetIpCountByIp(ctx context.Context, ips []stri
 		return nil, err
 		return nil, err
 	}
 	}
 	return results, nil
 	return results, nil
+}
+
+func (r *udpForWardingRepository) GetPortCount(ctx context.Context, hostId int64, port string) (int64, error) {
+	var count int64
+	if err := r.db.WithContext(ctx).Model(&model.UdpForWarding{}).Where("host_id = ? AND port = ?", hostId, port).Count(&count).Error; err != nil {
+		return 0, err
+	}
+	return count, nil
 }
 }

+ 12 - 0
internal/repository/webforwarding.go

@@ -34,6 +34,8 @@ type WebForwardingRepository interface {
 	GetSslCertId (ctx context.Context, sslPocyID int) ([]v1.SslCertsJSON, error)
 	GetSslCertId (ctx context.Context, sslPocyID int) ([]v1.SslCertsJSON, error)
 	// 获取CDN的web配置的id
 	// 获取CDN的web配置的id
 	GetWebConfigId(ctx context.Context, id int64) (int64, error)
 	GetWebConfigId(ctx context.Context, id int64) (int64, error)
+	// 获取域名
+	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]string, error)
 }
 }
 
 
 func NewWebForwardingRepository(
 func NewWebForwardingRepository(
@@ -312,4 +314,14 @@ func (r *webForwardingRepository) GetWebConfigId(ctx context.Context, id int64)
 		return 0, err
 		return 0, err
 	}
 	}
 	return webConfigId, nil
 	return webConfigId, nil
+}
+
+
+func (r *webForwardingRepository) GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]string, error) {
+	var domains []string
+	if err := r.db.WithContext(ctx).Model(&model.WebForwarding{}).Where("host_id = ? AND port = ?", hostId, port).Pluck("domain", &domains).Error; err != nil {
+		return nil, err
+	}
+	return domains, nil
+
 }
 }

+ 16 - 0
internal/service/tcpforwarding.go

@@ -173,6 +173,12 @@ func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.Tcp
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	// 验证端口重复
+	err = s.wafformatter.VerifyPort(ctx, "tcp", req.TcpForwardingData.Port, int64(require.HostId), "")
+	if err != nil {
+		return err
+	}
+
 
 
 	tcpId, err := s.cdn.CreateWebsite(ctx, formData)
 	tcpId, err := s.cdn.CreateWebsite(ctx, formData)
 	if err != nil {
 	if err != nil {
@@ -252,6 +258,16 @@ func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.Tc
 		return err
 		return err
 	}
 	}
 
 
+	// 验证端口重复
+	if oldData.Port != req.TcpForwardingData.Port {
+		err = s.wafformatter.VerifyPort(ctx, "tcp", req.TcpForwardingData.Port, int64(require.HostId), "")
+		if err != nil {
+			return err
+		}
+	}
+
+
+
 	//修改网站端口
 	//修改网站端口
 	if oldData.Port != req.TcpForwardingData.Port {
 	if oldData.Port != req.TcpForwardingData.Port {
 		err = s.cdn.EditServerType(ctx, v1.EditWebsite{
 		err = s.cdn.EditServerType(ctx, v1.EditWebsite{

+ 15 - 0
internal/service/udpforwarding.go

@@ -172,6 +172,13 @@ func (s *udpForWardingService) AddUdpForwarding(ctx context.Context, req *v1.Udp
 		return err
 		return err
 	}
 	}
 
 
+
+	// 验证端口重复
+	err = s.wafformatter.VerifyPort(ctx, "udp", req.UdpForwardingData.Port, int64(require.HostId), "")
+	if err != nil {
+		return err
+	}
+
 	udpId, err := s.cdn.CreateWebsite(ctx, formData)
 	udpId, err := s.cdn.CreateWebsite(ctx, formData)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -250,6 +257,14 @@ func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.Ud
 		return err
 		return err
 	}
 	}
 
 
+	// 验证端口重复
+	if oldData.Port != req.UdpForwardingData.Port {
+		err = s.wafformatter.VerifyPort(ctx, "udp", req.UdpForwardingData.Port, int64(require.HostId), "")
+		if err != nil {
+			return err
+		}
+	}
+
 	//修改网站端口
 	//修改网站端口
 	if oldData.Port != req.UdpForwardingData.Port {
 	if oldData.Port != req.UdpForwardingData.Port {
 		err = s.cdn.EditServerType(ctx, v1.EditWebsite{
 		err = s.cdn.EditServerType(ctx, v1.EditWebsite{

+ 65 - 0
internal/service/wafformatter.go

@@ -43,6 +43,8 @@ type WafFormatterService interface {
 	ParseCert(ctx context.Context, httpsCert string, httpKey string) (serverName string, commonName []string, DNSNames []string, before int64, after int64, isSelfSigned bool, err error)
 	ParseCert(ctx context.Context, httpsCert string, httpKey string) (serverName string, commonName []string, DNSNames []string, before int64, after int64, isSelfSigned bool, err error)
 	AddSSLPolicy(ctx context.Context, req v1.SSL) (sslPolicyId int64, sslCertId int64, err error)
 	AddSSLPolicy(ctx context.Context, req v1.SSL) (sslPolicyId int64, sslCertId int64, err error)
 	EditSSL(ctx context.Context, req v1.SSL) error
 	EditSSL(ctx context.Context, req v1.SSL) error
+	// 验证端口重复
+	VerifyPort(ctx context.Context,protocol string, port string,hostId int64,domain string) error
 }
 }
 
 
 func NewWafFormatterService(
 func NewWafFormatterService(
@@ -638,3 +640,66 @@ func (s *wafFormatterService) EditSSL(ctx context.Context, req v1.SSL) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+// 验证端口重复
+func (s *wafFormatterService) VerifyPort(ctx context.Context,protocol string, port string,hostId int64,domain string) error {
+	errPortInUse := fmt.Errorf("端口 %s 已经被使用,无法添加", port)
+	switch protocol {
+		case "http", "https":
+			domains, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			tcpCount, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+
+			if tcpCount > 0 {
+				return errPortInUse
+			}
+
+			for _, v := range domains {
+				if v == "" {
+					return errPortInUse
+				}
+				if net.ParseIP(v) != nil {
+					return errPortInUse
+				}
+			}
+
+			if net.ParseIP(domain) != nil || domain == "" {
+				if len(domains) > 0 {
+					return errPortInUse
+				}
+			}
+
+			return nil
+
+
+		case "tcp":
+			count, err := s.tcpforwardingRep.GetPortCount(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			webCount, err := s.webForwardingRep.GetDomainByHostIdPort(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			if count + int64(len(webCount)) > 0 {
+				return errPortInUse
+			}
+			return nil
+		case "udp":
+			count, err := s.udpForWardingRep.GetPortCount(ctx, hostId, port)
+			if err != nil {
+				return err
+			}
+			if count > 0 {
+				return errPortInUse
+			}
+			return nil
+		default:
+			return fmt.Errorf("不支持的协议类型:%s", protocol)
+	}
+}

+ 14 - 0
internal/service/webforwarding.go

@@ -367,6 +367,12 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		return err
 		return err
 	}
 	}
 
 
+	// 验证端口重复
+	err = s.wafformatter.VerifyPort(ctx, "http", req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain)
+	if err != nil {
+		return err
+	}
+
 	webId, err := s.cdn.CreateWebsite(ctx, formData)
 	webId, err := s.cdn.CreateWebsite(ctx, formData)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -504,6 +510,14 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		return err
 		return err
 	}
 	}
 
 
+	// 验证端口重复
+	if oldData.Port != req.WebForwardingData.Port {
+		err = s.wafformatter.VerifyPort(ctx, "http", req.WebForwardingData.Port, int64(require.HostId), "")
+		if err != nil {
+			return err
+		}
+	}
+
 	//修改网站端口
 	//修改网站端口
 	if oldData.Port != req.WebForwardingData.Port || oldData.IsHttps != req.WebForwardingData.IsHttps || oldData.HttpsCert != req.WebForwardingData.HttpsCert || oldData.HttpsKey != req.WebForwardingData.HttpsKey {
 	if oldData.Port != req.WebForwardingData.Port || oldData.IsHttps != req.WebForwardingData.IsHttps || oldData.HttpsCert != req.WebForwardingData.HttpsCert || oldData.HttpsKey != req.WebForwardingData.HttpsKey {
 		var typeJson []byte
 		var typeJson []byte