浏览代码

feat(global-limit): 增加 UDP 协议和禁用海外选项

- 在 GlobalLimitRequireResponse 和 HostResponse 中添加 IsBanUdp 和 IsBanOverseas 字段
- 更新 GetGatewayGroupWhereHostIdNull 方法以支持新选项
- 在 host 服务中添加 UDP 协议配置解析
- 优化 web forwarding 服务中的旧 IP 地址处理逻辑
fusu 3 周之前
父节点
当前提交
c6d0a3baa3

+ 2 - 0
api/v1/globalLimit.go

@@ -23,6 +23,8 @@ type GlobalLimitRequireResponse struct {
 	Operator int
 	NodeArea      string
 	ConfigMaxProtection string
+	IsBanUdp       int
+	IsBanOverseas int
 }
 
 type GeneralLimitRequireRequest struct {

+ 2 - 0
api/v1/host.go

@@ -32,4 +32,6 @@ type GlobalLimitConfigResponse struct {
 	IpCount       int
 	NodeArea      string
 	ConfigMaxProtection string
+	IsBanUdp      int
+	IsBanOverseas int
 }

+ 6 - 4
internal/repository/gatewaygroup.go

@@ -16,7 +16,7 @@ type GatewayGroupRepository interface {
 	AddGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
 	EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
-	GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error)
+	GetGatewayGroupWhereHostIdNull(ctx context.Context, req v1.GlobalLimitRequireResponse) (int, error)
 	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*model.GatewayGroup, error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup], error)
 	EditGatewayGroupById(ctx context.Context, req *model.GatewayGroup) error
@@ -61,15 +61,17 @@ func (r *gatewayGroupRepository) DeleteGatewayGroup(ctx context.Context, id int)
 	return nil
 }
 
-func (r *gatewayGroupRepository) GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error) {
+func (r *gatewayGroupRepository) GetGatewayGroupWhereHostIdNull(ctx context.Context,req v1.GlobalLimitRequireResponse) (int, error) {
 	var id int
 	subQuery := r.DB(ctx).Model(&model.GateWayGroupIp{}).
 		Select("gateway_group_id").
 		Group("gateway_group_id").
-		Having("COUNT(*) = ?", count)
+		Having("COUNT(*) = ?", req.IpCount)
 
 	err := r.DB(ctx).Model(&model.GatewayGroup{}).
-		Where("operator = ?", operator).
+		Where("operator = ?", req.Operator).
+		Where("ban_udp", req.IsBanUdp).
+		Where("ban_overseas", req.IsBanOverseas).
 		Where("id IN (?)", subQuery).
 		Where("host_id = ?", 0).
 		Select("id").First(&id).Error

+ 2 - 1
internal/service/globallimit.go

@@ -168,6 +168,7 @@ func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.Glob
 	res.IpCount = configCount.IpCount
 	res.NodeArea = configCount.NodeArea
 	res.ConfigMaxProtection = configCount.ConfigMaxProtection
+	res.IsBanUdp = configCount.IsBanUdp
 	domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
 	if err != nil {
 		return v1.GlobalLimitRequireResponse{}, err
@@ -229,7 +230,7 @@ func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLi
 	var userId int64
 	var groupId int64
 	g.Go(func() error {
-		res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
+		res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require)
 		if e != nil {
 			return fmt.Errorf("获取网关组失败: %w", e)
 		}

+ 12 - 0
internal/service/host.go

@@ -46,6 +46,8 @@ const (
 	ConfigOperator             = "高防线路"
 	ConfigIpCount              = "高防节点IP"
 	NodeArea                   = "节点区域"
+	IsBanUdp                   = "UDP协议"
+	IsBanOverseas              = "禁用海外"
 )
 
 // unitSuffixMap 存储需要去除的单位后缀
@@ -216,6 +218,7 @@ func (s *hostService) GetGlobalLimitConfig(ctx context.Context, hostId int) (v1.
 		Operator:      0,
 		NodeArea:      "",
 		ConfigMaxProtection: "",
+		IsBanUdp:      0,
 	}
 	if val, ok := configsMap[ConfigBps]; ok {
 		data.Bps = val
@@ -257,6 +260,15 @@ func (s *hostService) GetGlobalLimitConfig(ctx context.Context, hostId int) (v1.
 	if val, ok := configsMap[ConfigMaxProtection]; ok {
 		data.ConfigMaxProtection = val
 	}
+
+	if val, ok := configsMap[IsBanUdp]; ok {
+		if val == "开通" {
+			data.IsBanUdp = 1
+		}
+		if val == "关闭" {
+			data.IsBanUdp = 0
+		}
+	}
 	return data, nil
 }
 

+ 0 - 1
internal/service/webforwarding.go

@@ -745,7 +745,6 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 			return err
 		}
 		oldIps = append(oldIps, ip)
-
 	}
 	for _, v := range req.WebForwardingData.BackendList {
 		ip, _, err := net.SplitHostPort(v.Addr)