Преглед изворни кода

fix(ip): 优化 IP 查询逻辑和计数功能

- 修改了 TCP、UDP 和 Web 转发规则中的 IP 查询逻辑
- 增加了对已存在 IP 数量的检查和计数
- 优化了聚合查询的管道步骤,提高了查询效率
- 调整了任务调度的时间格式,增加了灵活性
fusu пре 2 недеља
родитељ
комит
f7606b84b7

+ 1 - 1
api/v1/allowAndDenyIp.go

@@ -15,6 +15,6 @@ type DelAllowAndDenyIpRequest struct {
 }
 
 type IpCountResult struct {
-	Ip    string `bson:"_id"`   // MongoDB $group 的结果会放在 _id 字段
+	Ip    string `bson:"ip"`   // MongoDB $group 的结果会放在 _id 字段
 	Count int    `bson:"count"`
 }

+ 11 - 0
internal/repository/gatewayip.go

@@ -78,6 +78,17 @@ func (r *gatewayipRepository) GetIpWhereHostIdNull(ctx context.Context,req v1.Gl
 		return fmt.Errorf("主机ID错误, 请联系客服")
 	}
 
+	var count int64
+	err := r.DB(ctx).Model(&model.Gatewayip{}).Where("host_id = ?", req.HostId).Count(&count).Error
+	if err != nil {
+		return err
+	}
+	if count >= int64(req.IpCount) {
+		return nil
+	}
+
+	req.IpCount = int(int64(req.IpCount) - count)
+
 	// 使用事务保证操作的原子性
 	return r.DB(ctx).Transaction(func(tx *gorm.DB) error {
 		var idsToAssign []uint // 只需一个切片来接收ID

+ 27 - 8
internal/repository/tcpforwarding.go

@@ -195,36 +195,55 @@ func (r *tcpforwardingRepository) DeleteTcpForwardingIpsById(ctx context.Context
 
 
 // 获取IP数量等于1的IP
-func (r *tcpforwardingRepository) GetIpCountByIp(ctx context.Context,ips []string) ([]v1.IpCountResult, error) {
+func (r *tcpforwardingRepository) GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error) {
 	if len(ips) == 0 {
 		return []v1.IpCountResult{}, nil
 	}
+
 	pipeline := []bson.M{
+		// 1. 展开 backend_list 数组。此时 backend_list 字段会变成 "ip:port" 字符串。
+		{
+			"$unwind": "$backend_list",
+		},
+		// 2. 添加新字段 extracted_ip,存放从 "ip:port" 中解析出的 IP。
+		{
+			"$addFields": bson.M{
+				"extracted_ip": bson.M{
+					"$arrayElemAt": []interface{}{
+						// 直接在 backend_list 字符串上分割
+						bson.M{"$split": []string{"$backend_list", ":"}},
+						0,
+					},
+				},
+			},
+		},
+		// 3. 匹配我们关心的 IP
 		{
 			"$match": bson.M{
-				"ip": bson.M{"$in": ips},
+				"extracted_ip": bson.M{"$in": ips},
 			},
 		},
+		// 4. 按解析出的 IP 地址进行分组和计数
 		{
 			"$group": bson.M{
-				"_id":   "$ip",
+				"_id":   "$extracted_ip",
 				"count": bson.M{"$sum": 1},
 			},
 		},
+		// 5. 格式化最终输出
 		{
 			"$project": bson.M{
-				"_id":   0,       // 不输出默认的_id
-				"ip":    "$_id",  // 将分组的_id字段重命名为ip
-				"count": 1,       // 保留count字段
+				"_id":   0,
+				"ip":    "$_id",
+				"count": 1,
 			},
 		},
 	}
 
 	var results []v1.IpCountResult
-	// 使用 qmgo 执行聚合查询
 	err := r.mongoDB.Collection("tcp_forwarding_rules").Aggregate(ctx, pipeline).All(&results)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("聚合查询 tcp_forwarding_rules 失败: %w", err)
 	}
 	return results, nil
 }

+ 21 - 7
internal/repository/udpforwarding.go

@@ -193,32 +193,46 @@ func (r *udpForWardingRepository) GetIpCountByIp(ctx context.Context, ips []stri
 	if len(ips) == 0 {
 		return []v1.IpCountResult{}, nil
 	}
+
+	// 管道逻辑与 TCP 版本完全相同
 	pipeline := []bson.M{
+		{
+			"$unwind": "$backend_list",
+		},
+		{
+			"$addFields": bson.M{
+				"extracted_ip": bson.M{
+					"$arrayElemAt": []interface{}{
+						bson.M{"$split": []string{"$backend_list", ":"}},
+						0,
+					},
+				},
+			},
+		},
 		{
 			"$match": bson.M{
-				"ip": bson.M{"$in": ips},
+				"extracted_ip": bson.M{"$in": ips},
 			},
 		},
 		{
 			"$group": bson.M{
-				"_id":   "$ip",
+				"_id":   "$extracted_ip",
 				"count": bson.M{"$sum": 1},
 			},
 		},
 		{
 			"$project": bson.M{
-				"_id":   0,       // 不输出默认的_id
-				"ip":    "$_id",  // 将分组的_id字段重命名为ip
-				"count": 1,       // 保留count字段
+				"_id":   0,
+				"ip":    "$_id",
+				"count": 1,
 			},
 		},
 	}
 
 	var results []v1.IpCountResult
-	// 使用 qmgo 执行聚合查询
 	err := r.mongoDB.Collection("udp_forwarding_rules").Aggregate(ctx, pipeline).All(&results)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("聚合查询 udp_forwarding_rules 失败: %w", err)
 	}
 	return results, nil
 }

+ 36 - 6
internal/repository/webforwarding.go

@@ -238,27 +238,56 @@ func (r *webForwardingRepository) GetDomainCount(ctx context.Context, hostId int
 }
 
 // 获取IP数量等于1的IP
+
 func (r *webForwardingRepository) GetIpCountByIp(ctx context.Context, ips []string) ([]v1.IpCountResult, error) {
 	if len(ips) == 0 {
 		return []v1.IpCountResult{}, nil
 	}
+
 	pipeline := []bson.M{
+		// 第 1 步: $unwind - 展开 backend_list 数组
+		// 将包含多个 backend 对象的文档拆分成多条,每条只包含一个 backend 对象。
+		{
+			"$unwind": "$backend_list",
+		},
+
+		// 第 2 步: $addFields - 添加一个新字段用于存放解析出的 IP
+		// 我们需要从 "ip:port" 格式的 addr 字段中把 ip 提取出来。
+		// 使用 $split 操作符按 ":" 分割字符串,然后用 $arrayElemAt 取第一个元素。
+		{
+			"$addFields": bson.M{
+				"extracted_ip": bson.M{
+					"$arrayElemAt": []interface{}{
+						bson.M{"$split": []string{"$backend_list.addr", ":"}},
+						0,
+					},
+				},
+			},
+		},
+
+		// 第 3 步: $match - 匹配我们关心的 IP
+		// 在上一步创建的 extracted_ip 字段上进行匹配。
 		{
 			"$match": bson.M{
-				"ip": bson.M{"$in": ips},
+				"extracted_ip": bson.M{"$in": ips},
 			},
 		},
+
+		// 第 4 步: $group - 按解析出的 IP 地址进行分组和计数
 		{
 			"$group": bson.M{
-				"_id":   "$ip",
+				"_id":   "$extracted_ip", // 使用我们新创建的 extracted_ip 字段作为分组依据
 				"count": bson.M{"$sum": 1},
 			},
 		},
+
+		// 第 5 步: $project - 格式化最终输出
+		// 这个阶段和之前一样,只是为了让输出结果更清晰,并匹配 Go 结构体。
 		{
 			"$project": bson.M{
-				"_id":   0,      // 不输出默认的_id
-				"ip":    "$_id", // 将分组的_id字段重命名为ip
-				"count": 1,      // 保留count字段
+				"_id":   0,
+				"ip":    "$_id",
+				"count": 1,
 			},
 		},
 	}
@@ -267,7 +296,8 @@ func (r *webForwardingRepository) GetIpCountByIp(ctx context.Context, ips []stri
 	// 使用 qmgo 执行聚合查询
 	err := r.mongoDB.Collection("web_forwarding_rules").Aggregate(ctx, pipeline).All(&results)
 	if err != nil {
-		return nil, fmt.Errorf("聚合查询失败: %w", err)
+		// 加上错误包装,方便调试
+		return nil, fmt.Errorf("聚合查询 web_forwarding_rules 失败: %w", err)
 	}
 
 	return results, nil

+ 4 - 4
internal/server/task.go

@@ -76,7 +76,7 @@ func (t *TaskServer) Start(ctx context.Context) error {
 
 
 
-	_, err := t.scheduler.Cron("1 * * * *").Do(func() {
+	_, err := t.scheduler.Cron("* 1 * * *").Do(func() {
 		err := t.wafTask.SynchronizationTime(ctx)
 		if err != nil {
 			t.log.Error("同步到期时间失败", zap.Error(err))
@@ -86,7 +86,7 @@ func (t *TaskServer) Start(ctx context.Context) error {
 		t.log.Error("同步到期时间注册任务失败", zap.Error(err))
 	}
 
-	_, err = t.scheduler.Cron("1 * * * *").Do(func() {
+	_, err = t.scheduler.Cron("* 1 * * *").Do(func() {
 		err := t.wafTask.StopPlan(ctx)
 		if err != nil {
 			t.log.Error("停止套餐失败", zap.Error(err))
@@ -97,7 +97,7 @@ func (t *TaskServer) Start(ctx context.Context) error {
 	}
 
 
-	_, err = t.scheduler.Cron("1 * * * *").Do(func() {
+	_, err = t.scheduler.Cron("* 1 * * *").Do(func() {
 		err := t.wafTask.RecoverRecentPlan(ctx)
 		if err != nil {
 			t.log.Error("续费失败", zap.Error(err))
@@ -108,7 +108,7 @@ func (t *TaskServer) Start(ctx context.Context) error {
 	}
 
 
-	_, err = t.scheduler.Cron("1 * * * *").Do(func() {
+	_, err = t.scheduler.Cron("* 1 * * *").Do(func() {
 		err := t.wafTask.CleanUpStaleRecords(ctx)
 		if err != nil {
 			t.log.Error("续费失败", zap.Error(err))