Browse Source

feat(web-forwarding): 添加获取所有转发规则接口

- 新增 /webForward/getList 接口,用于获取所有转发规则
- 实现 GetWebForwardingList 方法,通过 errgroup 并发处理每个规则
- 添加 GetWebForwardingWafWebAllIds 方法,获取所有转发规则的 ID
-优化 GetWebForwardingWafWebAllIps 方法,处理并发和错误
fusu 1 tháng trước cách đây
mục cha
commit
e2fb5feb2e

+ 2 - 0
api/v1/webForwarding.go

@@ -89,3 +89,5 @@ type BackendList struct {
 	Timeout  string `json:"timeout,omitempty" form:"timeout" default:"30s"`
 	ProxyV1  bool `json:"proxy_v1,omitempty" form:"proxy_v1" default:"false"`
 }
+
+

+ 16 - 0
internal/handler/webforwarding.go

@@ -85,3 +85,19 @@ func (h *WebForwardingHandler) DeleteWebForwarding(ctx *gin.Context) {
 	}
 	v1.HandleSuccess(ctx, nil)
 }
+
+func (h *WebForwardingHandler) GetWebForwardingList(ctx *gin.Context) {
+	req := new(v1.GetForwardingRequest)
+	if err := ctx.ShouldBind(req); err != nil {
+		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, err.Error())
+		return
+	}
+	defaults.SetDefaults(req)
+	res, err := h.webForwardingService.GetWebForwardingWafWebAllIps(ctx, *req)
+	if err != nil {
+		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
+		return
+	}
+	v1.HandleSuccess(ctx, res)
+
+}

+ 10 - 0
internal/repository/webforwarding.go

@@ -20,6 +20,7 @@ type WebForwardingRepository interface {
 	GetWebForwardingWafWebIdById(ctx context.Context, id int) (int, error)
 	GetWebForwardingPortCountByHostId(ctx context.Context, hostId int) (int64, error)
 	GetWebForwardingDomainCountByHostId(ctx context.Context, hostId int) (int64, []string, error)
+	GetWebForwardingWafWebAllIds(ctx context.Context, hostId int) ([]int, error)
 	AddWebForwardingIps(ctx context.Context, req model.WebForwardingRule) (primitive.ObjectID, error)
 	EditWebForwardingIps(ctx context.Context, req model.WebForwardingRule) error
 	GetWebForwardingIpsByID(ctx context.Context, webId int) (*model.WebForwardingRule, error)
@@ -99,6 +100,15 @@ func (r *webForwardingRepository) GetWebForwardingDomainCountByHostId(ctx contex
 	return count, distinctDomains, nil
 }
 
+func (r *webForwardingRepository) GetWebForwardingWafWebAllIds(ctx context.Context, hostId int) ([]int, error) {
+	var ids []int
+	if err := r.db.Model(&model.WebForwarding{}).WithContext(ctx).Where("host_id = ?", hostId).Select("id").Find(&ids).Error; err != nil {
+		return nil, err
+	}
+
+	return ids, nil
+}
+
 
 // mongodb 插入
 func (r *webForwardingRepository) AddWebForwardingIps(ctx context.Context, req model.WebForwardingRule) (primitive.ObjectID, error) {

+ 1 - 0
internal/server/http.go

@@ -92,6 +92,7 @@ func NewHTTPServer(
 			noAuthRouter.POST("/gameShield/getOnline", ipAllowlistMiddleware, gameShieldHandler.GetGameShieldOnlineList)
 			noAuthRouter.POST("/gameShield/IsExistKey", gameShieldHandler.IsExistGameShieldKey)
 			noAuthRouter.POST("/webForward/get", ipAllowlistMiddleware, webForwardingHandler.GetWebForwarding)
+			noAuthRouter.POST("/webForward/getList", ipAllowlistMiddleware, webForwardingHandler.GetWebForwardingList)
 			noAuthRouter.POST("/webForward/add", ipAllowlistMiddleware, webForwardingHandler.AddWebForwarding)
 			noAuthRouter.POST("/webForward/edit", ipAllowlistMiddleware, webForwardingHandler.EditWebForwarding)
 			noAuthRouter.POST("/webForward/delete", ipAllowlistMiddleware, webForwardingHandler.DeleteWebForwarding)

+ 138 - 0
internal/service/webforwarding.go

@@ -14,6 +14,7 @@ import (
 
 type WebForwardingService interface {
 	GetWebForwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.WebForwardingDataRequest, error)
+	GetWebForwardingWafWebAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.WebForwardingDataRequest, error)
 	AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error
 	EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error
 	DeleteWebForwarding(ctx context.Context, Ids []int) error
@@ -357,3 +358,140 @@ func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []in
 
 	return nil
 }
+
+func (s *webForwardingService) GetWebForwardingWafWebAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.WebForwardingDataRequest, error) {
+	type CombinedResult struct {
+		Id          int
+		Forwarding  *model.WebForwarding
+		BackendRule *model.WebForwardingRule
+		Err         error // 如果此ID的处理出错,则携带错误
+	}
+	g, gCtx := errgroup.WithContext(ctx)
+	ids, err := s.webForwardingRepository.GetWebForwardingWafWebAllIds(gCtx, req.HostId)
+	if err != nil {
+		return nil, fmt.Errorf("GetWebForwardingWafWebAllIds failed: %w", err)
+	}
+
+	if len(ids) == 0 {
+		return nil, nil // 没有ID,直接返回空切片
+	}
+
+	// 创建一个通道来接收每个ID的处理结果
+	// 通道缓冲区大小设为ID数量,这样发送者不会因为接收者慢而阻塞(在所有goroutine都启动后)
+	resultsChan := make(chan CombinedResult, len(ids))
+
+	for _, idVal := range ids {
+		currentID := idVal // 捕获循环变量
+		g.Go(func() error {
+			var wf *model.WebForwarding
+			var bk *model.WebForwardingRule
+			var localErr error
+
+			// 1. 获取 WebForwarding 信息
+			wf, localErr = s.webForwardingRepository.GetWebForwarding(gCtx, int64(currentID))
+			if localErr != nil {
+				// 发送错误到通道,并由 errgroup 捕获
+				// errgroup 会处理第一个非nil错误,并取消其他 goroutine
+				resultsChan <- CombinedResult{Id: currentID, Err: fmt.Errorf("GetWebForwarding for id %d failed: %w", currentID, localErr)}
+				return localErr // 返回错误给 errgroup
+			}
+			if wf == nil { // 正常情况下,如果没错误,wf不应为nil,但防御性检查
+				localErr = fmt.Errorf("GetWebForwarding for id %d returned nil data without error", currentID)
+				resultsChan <- CombinedResult{Id: currentID, Err: localErr}
+				return localErr
+			}
+
+
+			// 2. 获取 Backend IP 信息
+			// 注意:这里我们允许 GetWebForwardingIpsByID 可能返回 nil 数据(例如没有规则)而不是错误
+			// 如果它也可能返回错误,则处理方式与上面类似
+			bk, localErr = s.webForwardingRepository.GetWebForwardingIpsByID(gCtx, currentID)
+			if localErr != nil {
+				// 如果获取IP信息失败是一个致命错误,则也应返回错误
+				// 如果允许部分成功(比如有WebForwarding但没有IP信息),则可以不将此视为errgroup的错误
+				// 这里假设它也是一个需要errgroup捕获的错误
+				resultsChan <- CombinedResult{Id: currentID, Forwarding: wf, Err: fmt.Errorf("GetWebForwardingIpsByID for id %d failed: %w", currentID, localErr)}
+				return localErr // 返回错误给 errgroup
+			}
+			// bk 可能是 nil 如果没有错误且没有规则,这取决于业务逻辑
+
+			// 发送成功的结果到通道
+			resultsChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk}
+			return nil // 此goroutine成功
+		})
+	}
+
+	// 等待所有goroutine完成
+	groupErr := g.Wait()
+
+	// 关闭通道,表示所有发送者都已完成
+	// 这一步很重要,这样下面的 range 循环才能正常结束
+	close(resultsChan)
+
+	// 如果 errgroup 捕获到任何错误,优先返回该错误
+	if groupErr != nil {
+		// 虽然errgroup已经出错了,但通道中可能已经有一些结果(来自出错前成功或出错的goroutine)
+		// 我们需要排空通道以避免goroutine泄漏(如果它们在发送时阻塞)
+		// 但由于我们优先返回groupErr,这些结果将被丢弃。
+		// 在这种设计下,通常任何一个子任务失败都会导致整个操作失败。
+		return nil, groupErr
+	}
+
+	// 如果没有错误,收集所有成功的结果
+	finalResults := make([]v1.WebForwardingDataRequest, 0, len(ids))
+	for res := range resultsChan {
+		// 再次检查通道中的错误,尽管 errgroup 应该已经捕获了
+		// 但这是一种更细致的错误处理,以防万一有goroutine在errgroup.Wait()前发送了错误但未被errgroup捕获
+		// (理论上,如果goroutine返回了错误,errgroup会处理)
+		// 主要目的是处理 res.forwarding 为 nil 的情况 (如果上面允许不返回错误)
+		if res.Err != nil {
+			// 如果到这里还有错误,说明逻辑可能有问题,或者我们决定忽略某些类型的子错误
+			// 在此示例中,因为 g.Wait() 没有错误,所以这里的 res.err 应该是nil
+			// 如果不是,那么可能是goroutine在return nil前发送了带有错误的res。
+			// 严格来说,如果errgroup没有错误,这里res.err也应该是nil
+			// 但以防万一,我们可以记录日志
+			return nil, fmt.Errorf("received error from goroutine for ID %d: %w", res.Id, res.Err)
+		}
+		if res.Forwarding == nil {
+			return nil, fmt.Errorf("received nil forwarding from goroutine for ID %d", res.Id)
+		}
+
+
+		dataReq := v1.WebForwardingDataRequest{
+			Id:                  res.Forwarding.Id,
+			WafWebId:            res.Forwarding.WafWebId,
+			Tag:                 res.Forwarding.Tag,
+			Port:                res.Forwarding.Port,
+			Domain:              res.Forwarding.Domain,
+			CustomHost:          res.Forwarding.CustomHost,
+			WafWebLimitId:       res.Forwarding.WebLimitRuleId,
+			WafGatewayGroupId:   res.Forwarding.WafGatewayGroupId,
+			CcCount:             res.Forwarding.CcCount,
+			CcDuration:          res.Forwarding.CcDuration,
+			CcBlockCount:        res.Forwarding.CcBlockCount,
+			CcBlockDuration:     res.Forwarding.CcBlockDuration,
+			Cc4xxCount:          res.Forwarding.Cc4xxCount,
+			Cc4xxDuration:       res.Forwarding.Cc4xxDuration,
+			Cc4xxBlockCount:     res.Forwarding.Cc4xxBlockCount,
+			Cc4xxBlockDuration:  res.Forwarding.Cc4xxBlockDuration,
+			Cc5xxCount:          res.Forwarding.Cc5xxCount,
+			Cc5xxDuration:       res.Forwarding.Cc5xxDuration,
+			Cc5xxBlockCount:     res.Forwarding.Cc5xxBlockCount,
+			Cc5xxBlockDuration:  res.Forwarding.Cc5xxBlockDuration,
+			IsHttps:             res.Forwarding.IsHttps,
+			Comment:             res.Forwarding.Comment,
+			HttpsKey:            res.Forwarding.HttpsKey,
+			HttpsCert:           res.Forwarding.HttpsCert,
+		}
+
+		if res.BackendRule != nil { // 只有当 BackendRule 存在时才填充相关字段
+			dataReq.BackendList = res.BackendRule.BackendList
+			dataReq.AllowIpList = res.BackendRule.AllowIpList
+			dataReq.DenyIpList = res.BackendRule.DenyIpList
+			dataReq.AccessRule = res.BackendRule.AccessRule
+		}
+		finalResults = append(finalResults, dataReq)
+	}
+
+	return finalResults, nil
+}