Browse Source

refactor(internal/service): 重构后端数据格式化和验证逻辑- 移除了未使用的导入- 新增了 ValidateBackendData 方法用于验证后端数据
-重构了 TidyFormatBackendData 方法,简化了逻辑并提高了可读性
- 优化了 EditGameShieldBackend 方法,增加了数据模拟和预验证步骤

fusu 2 months ago
parent
commit
33b7866f6f
2 changed files with 144 additions and 71 deletions
  1. 87 67
      internal/service/formatter.go
  2. 57 4
      internal/service/gameshieldbackend.go

+ 87 - 67
internal/service/formatter.go

@@ -8,7 +8,6 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"maps"
-	"slices"
 	"sort"
 	"strconv"
 	"strings"
@@ -22,6 +21,7 @@ type FormatterService interface {
 	OldFormat(ctx context.Context, req *[]model.GameShieldBackend) (map[string]v1.SendGameShieldBackend, error)
 	TidyFormatBackendData(ctx context.Context, req *v1.GameShieldBackendArrayRequest, keyCounter int) (map[string]v1.SendGameShieldBackend, error)
 	Sort(ctx context.Context, mapData map[string]v1.SendGameShieldBackend) (map[string]v1.SendGameShieldBackend, error)
+	ValidateBackendData(ctx context.Context, mapData map[string]v1.SendGameShieldBackend, hostId int) error
 }
 
 func NewFormatterService(
@@ -56,6 +56,12 @@ func (service *formatterService) FormatBackendData(ctx context.Context, req *v1.
 	}
 	maps.Copy(formData, oldFormat)
 
+	// 验证
+	err = service.ValidateBackendData(ctx, formData, req.HostId)
+	if err != nil {
+		return "", err
+	}
+
 	sortedOutput, err := service.Sort(ctx, formData)
 	if err != nil {
 		return "", err
@@ -125,120 +131,82 @@ func (service *formatterService) OldFormat(ctx context.Context, req *[]model.Gam
 }
 
 func (service *formatterService) TidyFormatBackendData(ctx context.Context, req *v1.GameShieldBackendArrayRequest, keyCounter int) (map[string]v1.SendGameShieldBackend, error) {
-	// 初始化输出映射
 	output := make(map[string]v1.SendGameShieldBackend)
-
-	// 获取所需基础数据
 	userIp, err := service.gameShieldPublicIpService.GetUserIp(ctx, req.Uid)
 	if err != nil {
 		return nil, err
 	}
-
-	oldCount, err := service.gameShieldBackendRepository.GetGameShieldBackendConfigCountByHostId(ctx, req.HostId)
-	if err != nil {
-		return nil, err
-	}
-
-	oldMachineIp, err := service.gameShieldBackendRepository.GetGameShieldBackendSourceMachineIpByHostId(ctx, req.HostId)
-	if err != nil {
-		return nil, err
-	}
-
-	configCount, err := service.hostService.GetGameShieldConfig(ctx, req.HostId)
-	if err != nil {
-		return nil, err
-	}
-
-	// 遍历请求中的所有项目
 	for _, item := range req.Items {
-		// 检查并验证源机器IP
-		sourceIP := item.SourceMachineIP
+		// 提取必要字段
+		sourceIP := item.SourceMachineIP // 假设结构体中有这个字段
 		if sourceIP == "" {
-			return nil, fmt.Errorf("没有有效源IP的配置")
-		}
-		// 检查源机器IP是否为新增
-		if !slices.Contains(oldMachineIp, sourceIP) {
-			oldCount.SourceMachinesCount++
-			if oldCount.SourceMachinesCount > configCount.SourceMachinesCount {
-				return nil, fmt.Errorf("超出最大源机数量")
-			}
+			return nil, fmt.Errorf("没有有效源IP的配置") // 跳过没有有效源IP的配置
 		}
 
-		// 验证协议
-		protocol := item.Protocol
+		protocol := item.Protocol // 假设结构体中有这个字段
 		if protocol == "" {
-			return nil, fmt.Errorf("没有有效协议的配置")
+			return nil, fmt.Errorf("没有有效协议的配置") // 跳过没有有效协议的配置
 		}
-
-		// 获取并验证端口配置
+		// 获取端口数组
 		conPorts := service.FormatPort(ctx, item.ConnectPort)
 		sdkPorts := service.FormatPort(ctx, item.SdkPort)
 
-		// 验证端口数量匹配
+		// 验证端口数量
 		if len(sdkPorts) > 0 && len(conPorts) != len(sdkPorts) {
 			return nil, fmt.Errorf("端口数量不匹配")
 		}
 
-		// 验证规则条目数量
-		oldCount.RuleEntriesCount += int64(len(conPorts))
-		if oldCount.RuleEntriesCount > configCount.RuleEntriesCount {
-			return nil, fmt.Errorf("超出最大规则数量")
-		}
-
 		// 处理每一对端口
 		for i := 0; i < len(conPorts); i++ {
 			keyCounter++
 			key := fmt.Sprintf("key%d", keyCounter)
 
-			// 构建基本的后端配置项
+			// 使用数组中的具体端口
+			addr := fmt.Sprintf("%s:%d", sourceIP, conPorts[i])
+
 			itemMap := v1.SendGameShieldBackend{
-				Addr:     []string{fmt.Sprintf("%s:%d", sourceIP, conPorts[i])},
+				Addr:     []string{addr},
 				Protocol: protocol,
 				Type:     item.Type,
 			}
 
-			// 根据协议类型设置属性
+			//// 设置主机名(如果存在)
+			//if item.Host != "" {
+			//	itemMap["host"] = item.Host
+			//}
+
+			// 根据协议设置不同属性
 			if protocol != "udp" {
-				// 非UDP协议的配置
 				if item.Checked == "agent" {
 					itemMap.AgentAddr = fmt.Sprintf("%s:%s", sourceIP, "23350")
 				}
 				itemMap.ProxyAddr = userIp + ":32353"
+
 			} else {
-				// UDP协议的配置
 				itemMap.ProxyAddr = ""
 				itemMap.UdpSessionTimeout = "300s"
 			}
-
-			// 根据设备类型设置SDK IP
-			if item.Type == "pc" {
-				itemMap.SdkIp = item.SdkIp
-			} else {
+			if item.Type != "pc" {
 				itemMap.SdkIp = ""
+			} else {
+				itemMap.SdkIp = item.SdkIp
 			}
 
-			// 处理最大带宽设置
 			if item.MaxBandwidth == 1 {
-				oldCount.MaxBandwidthCount++
-				if oldCount.MaxBandwidthCount > configCount.MaxBandwidthCount {
-					return nil, fmt.Errorf("超出最大带宽数量")
-				}
 				itemMap.MaxBandwidth = "50m"
 			} else {
 				itemMap.MaxBandwidth = ""
 			}
-
-			// 设置SDK端口(如果存在)
-			if len(sdkPorts) > 0 {
-				sdkPort := sdkPorts[i]
-				// 检查移动端的SSH端口限制
-				if sdkPort <= 1024 && item.Type == "mobile" {
-					return nil, fmt.Errorf("移动端不支持SSH端口")
+			// 设置SDK端口 - 使用数组中的具体端口
+			if len(sdkPorts) != 0 {
+				if sdkPorts[i] <= 1024 {
+					if item.Type == "mobile" {
+						return nil, fmt.Errorf("移动端不支持SSH端口")
+					}
 				}
-				itemMap.SdkPort = sdkPort
+				itemMap.SdkPort = sdkPorts[i]
 			}
 
-			// 将配置项添加到输出映射
 			output[key] = itemMap
 		}
 	}
@@ -266,3 +234,55 @@ func (service *formatterService) Sort(ctx context.Context, mapData map[string]v1
 	}
 	return sortedOutput, nil
 }
+
+// 验证后端数据
+func (service *formatterService) ValidateBackendData(ctx context.Context, data map[string]v1.SendGameShieldBackend, hostId int) error {
+	// 获取配置限制
+	configCount, err := service.hostService.GetGameShieldConfig(ctx, hostId)
+	if err != nil {
+		return fmt.Errorf("获取配置限制失败: %w", err)
+	}
+
+	// 提取源机IP
+	sourceIPs := make(map[string]bool)
+	ruleEntriesCount := int64(0)
+	maxBandwidthCount := int64(0)
+
+	for _, item := range data {
+		// 计算规则条目数
+		ruleEntriesCount += int64(len(item.Addr))
+
+		// 计算源机IP数
+		for _, addr := range item.Addr {
+			parts := strings.Split(addr, ":")
+			if len(parts) > 0 {
+				sourceIPs[parts[0]] = true
+			}
+		}
+
+		// 计算最大带宽设置数
+		if item.MaxBandwidth != "" {
+			maxBandwidthCount++
+		}
+	}
+
+	// 验证源机数量
+	if int64(len(sourceIPs)) > configCount.SourceMachinesCount {
+		return fmt.Errorf("超出最大源机数量,当前配置允许%d个源机,合并后有%d个源机",
+			configCount.SourceMachinesCount, len(sourceIPs))
+	}
+
+	// 验证规则条目数
+	if ruleEntriesCount > configCount.RuleEntriesCount {
+		return fmt.Errorf("超出最大规则数量,当前配置允许%d个规则,合并后有%d个规则",
+			configCount.RuleEntriesCount, ruleEntriesCount)
+	}
+
+	// 验证最大带宽设置数
+	if maxBandwidthCount > configCount.MaxBandwidthCount {
+		return fmt.Errorf("超出最大带宽数量,当前配置允许%d个带宽设置,合并后有%d个",
+			configCount.MaxBandwidthCount, maxBandwidthCount)
+	}
+
+	return nil
+}

+ 57 - 4
internal/service/gameshieldbackend.go

@@ -172,22 +172,75 @@ func (s *gameShieldBackendService) AddGameShieldBackend(ctx context.Context, req
 	return res, nil
 }
 func (s *gameShieldBackendService) EditGameShieldBackend(ctx context.Context, req *v1.GameShieldBackendArrayRequest) (string, error) {
+	// 1. 获取当前所有数据库记录
+	currentData, err := s.gameShieldBackendRepository.GetGameShieldBackendByHostId(ctx, req.HostId)
+	if err != nil {
+		return "", fmt.Errorf("获取当前配置失败: %w", err)
+	}
+
+	// 2. 创建当前记录的副本用于模拟修改
+	simulatedData := make([]model.GameShieldBackend, len(currentData))
+	copy(simulatedData, currentData)
 
+	// 3. 创建ID到索引的映射,方便查找和修改
+	idToIndex := make(map[int64]int)
+	for i, item := range simulatedData {
+		idToIndex[int64(item.Id)] = i
+	}
+
+	// 4. 在副本上应用修改
 	for _, v := range req.Items {
 		if v.Id == 0 {
 			return "", fmt.Errorf("id 不能为空")
 		}
+
+		// 查找对应记录
+		idx, exists := idToIndex[int64(v.Id)]
+		if !exists {
+			return "", fmt.Errorf("ID为%d的记录不存在", v.Id)
+		}
+
+		// 更新记录(只更新需要修改的字段)
+		simulatedData[idx].SourceMachineIP = v.SourceMachineIP
+		simulatedData[idx].ConnectPort = v.ConnectPort
+		simulatedData[idx].Protocol = v.Protocol
+		simulatedData[idx].SdkPort = v.SdkPort
+		simulatedData[idx].SdkIp = v.SdkIp
+		simulatedData[idx].Type = v.Type
+		simulatedData[idx].MaxBandwidth = v.MaxBandwidth
+	}
+	// 5. 使用模拟修改后的数据进行验证
+	// 转换数据格式
+	simulatedBackend, err := s.formatter.OldFormat(ctx, &simulatedData)
+	if err != nil {
+		return "", fmt.Errorf("格式化模拟数据失败: %w", err)
+	}
+
+	// 验证修改后的配置
+	err = s.formatter.ValidateBackendData(ctx, simulatedBackend, req.HostId)
+	if err != nil {
+		return "", fmt.Errorf("验证失败: %w", err)
+	}
+
+	// 6. 验证通过,执行实际的数据库修改
+	for _, v := range req.Items {
 		if err := s.gameShieldBackendRepository.EditGameShieldBackend(ctx, &v); err != nil {
-			return "", err
+			return "", fmt.Errorf("修改数据失败(ID:%d): %w", v.Id, err)
 		}
 	}
-	res, _, err := s.GameShieldBackend(ctx, req)
+
+	// 7. 更新远程配置
+	res, _, err := s.GameShieldBackend(ctx, &v1.GameShieldBackendArrayRequest{
+		HostId: req.HostId,
+		Uid:    req.Uid,
+		Items:  nil,
+	})
 	if err != nil {
-		return "", err
+		return "", fmt.Errorf("更新配置失败: %w", err)
 	}
+
 	return res, nil
 }
-
 func (s *gameShieldBackendService) DeleteGameShieldBackend(ctx context.Context, req *v1.GameShieldBackendArrayRequest) (string, error) {
 	for _, v := range req.Items {
 		if err := s.gameShieldBackendRepository.DeleteGameShieldBackend(ctx, int64(v.Id)); err != nil {