Browse Source

feat(admin): 添加 WAF恢复和同步执行续费操作功能

- 新增 RecoverWaf 和 SyncExecuteRenewalActions 方法
- 实现批量获取全局限制和主机到期时间的功能
-优化 WAF 操作服务,支持恢复计划和执行续费操作- 更新 WAF 管理处理器,添加新的 API 路由
fusu 6 ngày trước cách đây
mục cha
commit
232363f162

+ 6 - 0
api/v1/admin/wagManage.go

@@ -24,3 +24,9 @@ type WafManageListRes struct {
 	ExpiredAt int64 `json:"expiredAt" form:"expiredAt" gorm:"column:expired_at;"`
 	NextDueDate int64 `json:"nextDueDate" form:"nextDueDate" gorm:"column:nextduedate"`
 }
+
+type RecoverWafRequest struct {
+	HostIds []int64 `json:"hostIds" form:"hostIds"`
+	Uid int64 `json:"uid" form:"uid"`
+
+}

+ 1 - 1
cmd/task/wire/wire.go

@@ -105,7 +105,7 @@ var serviceSet = wire.NewSet(
 	waf.NewZzybgpService,
 	admin2.NewWafLogService,
 	admin2.NewWafLogDataCleanService,
-	waf.NewWafOperationsService,
+	admin2.NewWafOperationsService,
 )
 
 // build App

+ 2 - 2
cmd/task/wire/wire_gen.go

@@ -91,7 +91,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	webForwardingService := waf2.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gatewayipService, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService, ccService, ccIpListService)
 	buildAudunService := waf2.NewBuildAudunService(serviceService, aoDunService, gatewayipRepository, hostService)
 	zzybgpService := waf2.NewZzybgpService(serviceService, gatewayipRepository, hostService, aoDunService)
-	wafOperationsService := waf2.NewWafOperationsService(serviceService, webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService)
+	wafOperationsService := admin2.NewWafOperationsService(serviceService, webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService)
 	wafTask := task.NewWafTask(webForwardingRepository, tcpforwardingRepository, udpForWardingRepository, cdnService, hostRepository, globalLimitRepository, expiredRepository, taskTask, gatewayipRepository, tcpforwardingService, udpForWardingService, webForwardingService, buildAudunService, zzybgpService, wafOperationsService)
 	taskServer := server.NewTaskServer(logger, userTask, gameShieldTask, wafTask)
 	jobJob := job.NewJob(transaction, logger, sidSid, rabbitMQ)
@@ -115,7 +115,7 @@ var jobSet = wire.NewSet(job.NewJob, job.NewUserJob, job.NewWhitelistJob, job.Ne
 
 var serverSet = wire.NewSet(server.NewTaskServer, server.NewJobServer)
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService, waf2.NewZzybgpService, admin2.NewWafLogService, admin2.NewWafLogDataCleanService, waf2.NewWafOperationsService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, gameShield.NewGameShieldService, service.NewCrawlerService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewHostService, gameShield.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewGameShieldUserIpService, waf2.NewWafFormatterService, flexCdn2.NewCdnService, service.NewRequestService, waf2.NewTcpforwardingService, waf2.NewUdpForWardingService, waf2.NewWebForwardingService, flexCdn2.NewProxyService, flexCdn2.NewSslCertService, flexCdn2.NewWebsocketService, waf2.NewCcService, waf2.NewGatewayipService, service.NewLogService, waf2.NewCcIpListService, waf2.NewBuildAudunService, waf2.NewZzybgpService, admin2.NewWafLogService, admin2.NewWafLogDataCleanService, admin2.NewWafOperationsService)
 
 // build App
 func newApp(task2 *server.TaskServer,

+ 30 - 0
internal/handler/admin/wafmanage.go

@@ -39,3 +39,33 @@ func (h *WafManageHandler) GetWafManageList(ctx *gin.Context) {
 	}
 	v1.HandleSuccess(ctx, res)
 }
+
+func (h *WafManageHandler) RecoverWaf(ctx *gin.Context) {
+	var req adminApi.RecoverWafRequest
+	if err := ctx.ShouldBind(&req); err != nil {
+		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, err.Error())
+		return
+	}
+	defaults.SetDefaults(&req)
+	err := h.wafManageService.RecoverWaf(ctx,req)
+	if err != nil {
+		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
+		return
+	}
+	v1.HandleSuccess(ctx, nil)
+}
+
+func (h *WafManageHandler) SyncExecuteRenewalActions(ctx *gin.Context) {
+	var req adminApi.RecoverWafRequest
+	if err := ctx.ShouldBind(&req); err != nil {
+		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, err.Error())
+		return
+	}
+	defaults.SetDefaults(&req)
+	err := h.wafManageService.SyncExecuteRenewalActions(ctx,req)
+	if err != nil {
+		v1.HandleError(ctx, http.StatusInternalServerError, err, err.Error())
+		return
+	}
+	v1.HandleSuccess(ctx, nil)
+}

+ 6 - 1
internal/repository/api/waf/globallimit.go

@@ -17,8 +17,8 @@ type GlobalLimitRepository interface {
 	UpdateGlobalLimitByHostId(ctx context.Context, req *model.GlobalLimit) error
 	IsGlobalLimitExistByHostId(ctx context.Context, hostId int64) (bool, error)
 	GetGlobalLimitByHostId(ctx context.Context, hostId int64) (*model.GlobalLimit, error)
+	GetGlobalLimitsByHostIds(ctx context.Context, hostIds []int64) (*[]model.GlobalLimit, error)
 	GetGlobalLimitAllExpired(ctx context.Context,ids []int) ([]v1.GlobalLimitExpiredByHost, error)
-	GetGlobalLimitAllHostId(ctx context.Context) ([]v1.GlobalLimitExpired, error)
 	GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error)
 	GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error)
 	GetHostName(ctx context.Context,hostId int64) (string, error)
@@ -86,6 +86,11 @@ func (r *globalLimitRepository) GetGlobalLimitByHostId(ctx context.Context, host
 
 }
 
+func (r *globalLimitRepository) GetGlobalLimitsByHostIds(ctx context.Context, hostIds []int64) (*[]model.GlobalLimit, error) {
+	var res []model.GlobalLimit
+	return &res, r.DB(ctx).Where("host_id IN ?", hostIds).Find(&res).Error
+}
+
 
 func (r *globalLimitRepository) GetGlobalLimitAllExpired(ctx context.Context,ids []int) ([]v1.GlobalLimitExpiredByHost, error) {
 	var res []v1.GlobalLimitExpiredByHost

+ 2 - 2
internal/repository/host.go

@@ -20,7 +20,7 @@ type HostRepository interface {
 	// 获取到期时间区间
 	GetExpireTimeRange(ctx context.Context,startTime int64,endTime int64) (int64, error)
 	// 获取指定套餐的到期时间
-	GetExpireTimeByHostId(ctx context.Context, hostIds []int) ([]v1.GetAlmostExpireHostResponse, error)
+	GetExpireTimeByHostId(ctx context.Context, hostIds []int64) ([]v1.GetAlmostExpireHostResponse, error)
 }
 
 func NewHostRepository(
@@ -119,7 +119,7 @@ func (r *hostRepository) GetExpireTimeRange(ctx context.Context,startTime int64,
 }
 
 // 获取指定套餐的到期时间
-func (r *hostRepository) GetExpireTimeByHostId(ctx context.Context, hostIds []int) ([]v1.GetAlmostExpireHostResponse, error) {
+func (r *hostRepository) GetExpireTimeByHostId(ctx context.Context, hostIds []int64) ([]v1.GetAlmostExpireHostResponse, error) {
 	var res []v1.GetAlmostExpireHostResponse
 	if err := r.DB(ctx).Table("shd_host").
 		Where("id IN ?", hostIds).

+ 44 - 0
internal/service/admin/wafmanage.go

@@ -4,28 +4,72 @@ import (
 	"context"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	adminApi "github.com/go-nunu/nunu-layout-advanced/api/v1/admin"
+	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository/admin"
+	wafRep "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 )
 
 type WafManageService interface {
 	GetWafManageList(ctx context.Context,req adminApi.WafManageList) (*v1.PaginatedResponse[adminApi.WafManageListRes], error)
+	RecoverWaf(ctx context.Context,req adminApi.RecoverWafRequest) error
+	SyncExecuteRenewalActions(ctx context.Context,req adminApi.RecoverWafRequest) error
 }
 func NewWafManageService(
     service *service.Service,
     wafManageRepository admin.WafManageRepository,
+	globalLimitRep wafRep.GlobalLimitRepository,
+	wafOperations WafOperationsService,
+	hostRep repository.HostRepository,
 ) WafManageService {
 	return &wafManageService{
 		Service:        service,
 		wafManageRepository: wafManageRepository,
+		globalLimitRep: globalLimitRep,
+		wafOperations: wafOperations,
+		hostRep: hostRep,
 	}
 }
 
 type wafManageService struct {
 	*service.Service
 	wafManageRepository admin.WafManageRepository
+	globalLimitRep wafRep.GlobalLimitRepository
+	wafOperations  WafOperationsService
+	hostRep        repository.HostRepository
 }
 
 func (s *wafManageService) GetWafManageList(ctx context.Context,req adminApi.WafManageList) (*v1.PaginatedResponse[adminApi.WafManageListRes], error) {
 	return s.wafManageRepository.GetWafManageList(ctx, req)
 }
+
+func (s *wafManageService) RecoverWaf(ctx context.Context,req adminApi.RecoverWafRequest) error {
+	wafModels, err := s.globalLimitRep.GetGlobalLimitsByHostIds(ctx, req.HostIds)
+	if err != nil {
+		return err
+	}
+	err = s.wafOperations.RecoverPlans(ctx, *wafModels, "closed")
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *wafManageService) SyncExecuteRenewalActions(ctx context.Context,req adminApi.RecoverWafRequest) error {
+	wafData, err := s.hostRep.GetExpireTimeByHostId(ctx, req.HostIds)
+	if err != nil {
+		return err
+	}
+	var renewalRequest []RenewalRequest
+	for i := range wafData {
+		renewalRequest = append(renewalRequest, RenewalRequest{
+			HostId:    wafData[i].HostId,
+			ExpiredAt: wafData[i].ExpiredAt,
+		})
+	}
+	err = s.wafOperations.ExecuteRenewalActions(ctx, renewalRequest)
+	if err != nil {
+		return err
+	}
+	return nil
+}

+ 15 - 14
internal/service/api/waf/wafoperations.go → internal/service/admin/wafoperations.go

@@ -1,4 +1,4 @@
-package waf
+package admin
 
 import (
 	"context"
@@ -9,6 +9,7 @@ import (
 	waf2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/waf"
 	"github.com/hashicorp/go-multierror"
 	"go.uber.org/zap"
 
@@ -45,11 +46,11 @@ func NewWafOperationsService(
 	globalLimitRep waf2.GlobalLimitRepository,
 	expiredRep repository.ExpiredRepository,
 	gatewayIpRep waf2.GatewayipRepository,
-	tcp TcpforwardingService,
-	udp UdpForWardingService,
-	web WebForwardingService,
-	buildAoDun BuildAudunService,
-	zzyBgp ZzybgpService,
+	tcp waf.TcpforwardingService,
+	udp waf.UdpForWardingService,
+	web waf.WebForwardingService,
+	buildAoDun waf.BuildAudunService,
+	zzyBgp waf.ZzybgpService,
 ) WafOperationsService {
 	return &wafOperationsService{
 		Service:          service,
@@ -78,12 +79,12 @@ type wafOperationsService struct {
 	hostRep          repository.HostRepository
 	globalLimitRep   waf2.GlobalLimitRepository
 	expiredRep       repository.ExpiredRepository
-	gatewayIpRep     waf2.GatewayipRepository
-	tcp              TcpforwardingService
-	udp              UdpForWardingService
-	web              WebForwardingService
-	buildAoDun       BuildAudunService
-	zzyBgp           ZzybgpService
+	gatewayIpRep waf2.GatewayipRepository
+	tcp          waf.TcpforwardingService
+	udp          waf.UdpForWardingService
+	web          waf.WebForwardingService
+	buildAoDun   waf.BuildAudunService
+	zzyBgp       waf.ZzybgpService
 }
 
 // GetForwardingRuleIds 获取主机关联的所有转发规则ID
@@ -402,9 +403,9 @@ func (s *wafOperationsService) RecoverPlans(ctx context.Context, limits []model.
 	}
 
 	// 1. 检查哪些套餐需要恢复(已续费且未过期)
-	var hostIdsToCheck []int
+	var hostIdsToCheck []int64
 	for _, limit := range limits {
-		hostIdsToCheck = append(hostIdsToCheck, limit.HostId)
+		hostIdsToCheck = append(hostIdsToCheck, int64(limit.HostId))
 	}
 
 	// 2. 获取最新的主机到期时间

+ 9 - 8
internal/task/waf.go

@@ -6,6 +6,7 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	waf2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
+	"github.com/go-nunu/nunu-layout-advanced/internal/service/admin"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service/api/waf"
 	"github.com/hashicorp/go-multierror"
@@ -47,7 +48,7 @@ func NewWafTask(
 	web waf.WebForwardingService,
 	buildAoDun waf.BuildAudunService,
 	zzyBgp waf.ZzybgpService,
-	wafOps waf.WafOperationsService,
+	wafOps admin.WafOperationsService,
 ) WafTask {
 	return &wafTask{
 		Task:              task,
@@ -82,8 +83,8 @@ type wafTask struct {
 	udp              waf.UdpForWardingService
 	web              waf.WebForwardingService
 	buildAoDun       waf.BuildAudunService
-	zzyBgp           waf.ZzybgpService
-	wafOps           waf.WafOperationsService
+	zzyBgp waf.ZzybgpService
+	wafOps admin.WafOperationsService
 }
 
 const (
@@ -92,7 +93,7 @@ const (
 )
 
 // RenewalRequest 现在使用service层的定义
-type RenewalRequest = waf.RenewalRequest
+type RenewalRequest = admin.RenewalRequest
 
 // =================================================================
 // =================== 核心辅助函数 (Core Helpers) =================
@@ -132,9 +133,9 @@ func (t *wafTask) findPlansNeedingSync(ctx context.Context, wafLimits []model.Gl
 	}
 	wafExpiredMap := make(map[int]int64, len(wafLimits))
 
-	var hostIds []int
+	var hostIds []int64
 	for _, limit := range wafLimits {
-		hostIds = append(hostIds, limit.HostId)
+		hostIds = append(hostIds, int64(limit.HostId))
 		wafExpiredMap[limit.HostId] = limit.ExpiredAt
 
 	}
@@ -329,9 +330,9 @@ func (t *wafTask) CleanUpStaleRecords(ctx context.Context) error {
 	}
 
 	// 3. [性能优化] 批量获取未清理记录的真实到期时间
-	uncleanedHostIds := make([]int, len(uncleanedLimits))
+	uncleanedHostIds := make([]int64, len(uncleanedLimits))
 	for i, limit := range uncleanedLimits {
-		uncleanedHostIds[i] = limit.HostId
+		uncleanedHostIds[i] = int64(limit.HostId)
 	}
 	hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, uncleanedHostIds)
 	if err != nil {