Procházet zdrojové kódy

refactor(waf): 优化 WAF 服务到期处理逻辑

- 重构了 WafTask 接口,新增了 SynchronizationTime、StopPlan 和 RecoverStopPlan 方法- 修改了 CcRepository 接口,将 EditCcState 方法重命名为 GetCcId,以更准确地反映其功能
- 更新了 HostRepository 接口,新增了 GetExpireTimeRange 和 GetExpireTimeByHostId 方法
- 优化了全局到期时间处理逻辑,提高了数据一致性和错误处理能力
- 新增了到期时间同步任务,用于定期同步 WAF 和 Host 之间的到期时间差异
- 实现了到期套餐关闭和恢复功能,以自动处理到期和续费情况
fusu před 3 týdny
rodič
revize
2bebcdb5c6

+ 7 - 6
internal/repository/cc.go

@@ -7,7 +7,7 @@ import (
 
 type CcRepository interface {
 	GetCcList(ctx context.Context, serviceId int64) ([]v1.CCList, error)
-	EditCcState(ctx context.Context, serviceId int64, ip string) error
+	GetCcId(ctx context.Context, serviceId int64, ip string) (int64, error)
 }
 
 func NewCcRepository(
@@ -30,9 +30,10 @@ func (r *ccRepository) GetCcList(ctx context.Context, serviceId int64) ([]v1.CCL
 	return req, nil
 }
 
-func (r *ccRepository) EditCcState(ctx context.Context, serviceId int64, ip string) error {
-	if err := r.DBWithName(ctx,"cdn").Table("cloud_ip_items").Where("sourceServerId = ? AND value = ?", serviceId, ip).Update("state", 0).Error; err != nil {
-		return err
+func (r *ccRepository) GetCcId(ctx context.Context, serviceId int64, ip string) (int64, error) {
+	var req int64
+	if err := r.DBWithName(ctx,"cdn").Table("cloud_ip_items").Where("sourceServerId = ? AND value = ? AND state = 1", serviceId, ip).Select("id").Scan(&req).Error; err != nil {
+		return 0, err
 	}
-	return nil
-}
+	return req, nil
+}

+ 1 - 1
internal/repository/globallimit.go

@@ -177,7 +177,7 @@ func (r *globalLimitRepository) GetGlobalLimitAlmostExpired(ctx context.Context,
 	var res []model.GlobalLimit
 	expiredTime := time.Now().Unix() + addTime
 	if err := r.DB(ctx).
-		Where("nextduedate < ?", expiredTime).
+		Where("expired_At < ?", expiredTime).
 		Find(&res).Error; err != nil {
 		return nil, err
 	}

+ 28 - 1
internal/repository/host.go

@@ -13,10 +13,14 @@ type HostRepository interface {
 	GetProductConfigOption(ctx context.Context, id []int) ([]v1.ProductConfigOption, error)
 	GetProductConfigOptionSub(ctx context.Context, id []int) ([]v1.ProductConfigOptionSub, error)
 	GetDomainById(ctx context.Context, id int) (string, error)
-	// 获取到期时间
+	// 获取指定用户指定套餐的到期时间
 	GetExpireTime(ctx context.Context, uid int64, hostId int64) (string, error)
 	// 获取指定到期时间
 	GetAlmostExpired(ctx context.Context, hostId []int,addTime int64) ([]v1.GetAlmostExpireHostResponse, error)
+	// 获取到期时间区间
+	GetExpireTimeRange(ctx context.Context,startTime int64,endTime int64) (int64, error)
+	// 获取指定套餐的到期时间
+	GetExpireTimeByHostId(ctx context.Context, hostIds []int) ([]v1.GetAlmostExpireHostResponse, error)
 }
 
 func NewHostRepository(
@@ -100,3 +104,26 @@ func (r *hostRepository) GetAlmostExpired(ctx context.Context, hostId []int,addT
 	return res, nil
 }
 
+
+// 获取到期时间区间
+func (r *hostRepository) GetExpireTimeRange(ctx context.Context,startTime int64,endTime int64) (int64, error) {
+	var res int64
+	if err := r.DB(ctx).Table("shd_host").
+		Where("nextduedate > ?", startTime).
+		Where("nextduedate < ?", endTime).
+		Count(&res).Error; err != nil {
+		return 0, err
+	}
+	return res, nil
+}
+
+// 获取指定套餐的到期时间
+func (r *hostRepository) GetExpireTimeByHostId(ctx context.Context, hostIds []int) ([]v1.GetAlmostExpireHostResponse, error) {
+	var res []v1.GetAlmostExpireHostResponse
+	if err := r.DB(ctx).Table("shd_host").
+		Where("id IN ?", hostIds).
+		Find(&res).Error; err != nil {
+		return nil, err
+	}
+	return res, nil
+}

+ 11 - 1
internal/service/cc.go

@@ -16,11 +16,13 @@ func NewCcService(
     service *Service,
     ccRepository repository.CcRepository,
 	webForwardingRep repository.WebForwardingRepository,
+	cdn cdnService,
 ) CcService {
 	return &ccService{
 		Service:        service,
 		ccRepository: ccRepository,
 		webForwardingRep: webForwardingRep,
+		cdn: cdn,
 	}
 }
 
@@ -28,6 +30,7 @@ type ccService struct {
 	*Service
 	ccRepository repository.CcRepository
 	webForwardingRep repository.WebForwardingRepository
+	cdn cdnService
 }
 
 func (s *ccService) GetCcList(ctx context.Context, req v1.CCListRequest) ([]v1.CCListResponse, error) {
@@ -66,7 +69,14 @@ func (s *ccService) EditCcState(ctx context.Context, req v1.CCStateRequest) erro
 		if webData.CdnWebId == 0 {
 			return fmt.Errorf("网站不存在")
 		}
-		err = s.ccRepository.EditCcState(ctx, int64(webData.CdnWebId), v)
+		ccId, err := s.ccRepository.GetCcId(ctx, int64(webData.CdnWebId), v)
+		if err != nil {
+			return err
+		}
+		if ccId == 0 {
+			return fmt.Errorf("IP不存在")
+		}
+		err = s.cdn.DelIpItem(ctx, ccId, v, "", "", 2000000000)
 		if err != nil {
 			return err
 		}

+ 24 - 0
internal/service/cdn.go

@@ -996,4 +996,28 @@ func (s *cdnService) DelServerGroup(ctx context.Context,serverId int64) error {
 		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 	}
 	return nil
+}
+
+// 删除IP
+func (s *cdnService) DelIpItem(ctx context.Context,ipitemId int64,value string,ipFrom string,ipTo string,ipListId int64) error {
+	formData := map[string]interface{}{
+		"ipitemId": ipitemId,
+		"value":    value,
+		"ipFrom":   ipFrom,
+		"ipTo":     ipTo,
+		"ipListId": ipListId,
+	}
+	apiUrl := s.Url + "IPItemService/deleteIPItem"
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl)
+	if err != nil {
+		return err
+	}
+	var res v1.GeneralResponse[any]
+	if err := json.Unmarshal(resBody, &res); err != nil {
+		return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+	}
+	if res.Code != 200 {
+		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
+	}
+	return nil
 }

+ 285 - 84
internal/task/waf.go

@@ -2,16 +2,20 @@ package task
 
 import (
 	"context"
+	"fmt"
 	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/hashicorp/go-multierror"
+	"go.uber.org/zap"
 	"sync"
 	"time"
 )
 
 type WafTask interface {
+	//获取到期时间小于3天的同步时间
+	SynchronizationTime(ctx context.Context) error
 }
 
 func NewWafTask (
@@ -43,11 +47,13 @@ type wafTask struct {
 	globalLimitRep repository.GlobalLimitRepository
 }
 
-func (t wafTask) CheckExpiredTask(ctx context.Context) error {
-	return nil
-
-}
 
+const (
+	// 3天后秒数
+	OneDaysInSeconds = 3 * 24 * 60 * 60
+	// 7天前秒数
+	SevenDaysInSeconds = 7 * 24 * 60 * 60 * -1
+)
 // 获取cdn web id
 func (t wafTask) GetCdnWebId(ctx context.Context,hostId int) ([]int, error) {
 	tcpIds, err := t.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx, hostId)
@@ -119,7 +125,7 @@ func (t wafTask) GetAlmostExpiring(ctx context.Context,hostIds []int,addTime int
 }
 
 
-// 获取全局到期时间
+// 获取waf全局到期时间
 func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]model.GlobalLimit,error) {
 	res, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, addTime)
 	if err != nil {
@@ -128,118 +134,313 @@ func (t wafTask) GetGlobalAlmostExpiring(ctx context.Context,addTime int64) ([]m
 	return res, nil
 }
 
-// 获取cdn web id
-
-func (t wafTask) GetGlobalAllHostId(ctx context.Context,addTime int64) (map[int]int64, error) {
-	globalData, err := t.GetGlobalAlmostExpiring(ctx,addTime)
-	if err != nil {
-		return nil, err
-	}
-
-	var hostIds []int
-	for _, v := range globalData {
-		hostIds = append(hostIds, v.HostId)
-	}
-
-	globalDataMap := make(map[int]int64, len(globalData))
-	planMap := make(map[int]int64, len(globalData))
-
-	for _, v := range globalData {
-		globalDataMap[v.HostId] = v.ExpiredAt
-		planMap[v.HostId] = int64(v.RuleId)
-	}
-
-	hostData,err := t.GetAlmostExpiring(ctx,hostIds,addTime)
-	if err != nil {
-		return nil, err
-	}
-
-	hostDataMap := make(map[int]int64, len(hostData))
-	for _, v := range hostData {
-		hostDataMap[v.HostId] = v.ExpiredAt
-	}
-
-	editMap := make(map[int]int64)
-
-	for k, v := range globalDataMap {
-		if hostDataMap[k] != v {
-			editMap[k] = v
-		}
-	}
-
-	planExpireMap := make(map[int]int64)
-	for k, v := range planMap {
-		if _, ok := editMap[k]; ok {
-			planExpireMap[k] = v
-		}
-	}
-
-	return editMap, nil
-}
 
 
 // 修改全局续费
-func (t wafTask) EditGlobalExpired(ctx context.Context,req []struct{
+func (t wafTask) EditGlobalExpired(ctx context.Context, req []struct{
 	hostId int
 	expiredAt int64
-},state bool) error {
+}, state bool) error {
+	var result *multierror.Error // 使用 multierror
+
 	for _, v := range req {
 		err := t.globalLimitRep.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
-			HostId: v.hostId,
+			HostId:    v.hostId,
 			ExpiredAt: v.expiredAt,
-			State: state,
+			State:     state,
 		})
 		if err != nil {
-			return err
+			// 收集错误,而不是直接返回
+			result = multierror.Append(result, err)
 		}
 	}
-	return nil
+
+	// 返回所有收集到的错误
+	return result.ErrorOrNil()
 }
 
 
+
 // 续费套餐
-func (t wafTask) EnablePlan(ctx context.Context,req []struct{
+func (t wafTask) EnablePlan(ctx context.Context, req []struct{
 	planId int
 	expiredAt int64
 }) error {
+	var result *multierror.Error
+
 	for _, v := range req {
 		err := t.cdn.RenewPlan(ctx, v1.RenewalPlan{
-			UserPlanId: int64(v.planId),
-			IsFree: true,
-			DayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"),
-			Period:     "monthly",
+			UserPlanId:  int64(v.planId),
+			IsFree:      true,
+			DayTo:       time.Unix(v.expiredAt, 0).Format("2006-01-02"),
+			Period:      "monthly",
 			CountPeriod: 1,
-			PeriodDayTo: time.Unix(v.expiredAt,0).Format("2006-01-02"),
+			PeriodDayTo: time.Unix(v.expiredAt, 0).Format("2006-01-02"),
 		})
 		if err != nil {
-			return err
+			result = multierror.Append(result, err)
 		}
 	}
-	return nil
+
+	return result.ErrorOrNil()
 }
 
+
+
 // 续费操作
-func (t wafTask) EditExpired(ctx context.Context,req []struct {
-	hostId int
-	expiredAt int64
-	planId int
-}) error {
+type RenewalRequest struct {
+	HostId    int
+	PlanId    int
+	ExpiredAt int64
+}
 
-	var sendData []struct {
-		hostId int
+// 续费操作
+func (t wafTask) EditExpired(ctx context.Context, reqs []RenewalRequest) error {
+	// 如果请求为空,直接返回
+	if len(reqs) == 0 {
+		return nil
+	}
+
+	// 1. 准备用于更新 GlobalLimit 的数据
+	var globalLimitUpdates []struct {
+		hostId    int
 		expiredAt int64
 	}
-	for _, v := range req {
-		sendData = append(sendData, struct {
-			hostId int
+	for _, req := range reqs {
+		globalLimitUpdates = append(globalLimitUpdates, struct {
+			hostId    int
 			expiredAt int64
-		}{
-			hostId: v.hostId,
-			expiredAt: v.expiredAt,
-		})
+		}{req.HostId, req.ExpiredAt})
+	}
+
+	// 2. 准备用于续费套餐的数据
+	var planRenewals []struct {
+		planId    int
+		expiredAt int64
+	}
+	for _, req := range reqs {
+		planRenewals = append(planRenewals, struct {
+			planId    int
+			expiredAt int64
+		}{req.PlanId, req.ExpiredAt})
+	}
+
+	var result *multierror.Error
+
+	// 3. 执行更新,并收集错误
+	if err := t.EditGlobalExpired(ctx, globalLimitUpdates, true); err != nil {
+		result = multierror.Append(result, err)
+	}
+
+	if err := t.EnablePlan(ctx, planRenewals); err != nil {
+		result = multierror.Append(result, err)
+	}
+
+	return result.ErrorOrNil()
+}
+
+
+
+// findMismatchedExpirations 检查 WAF 和 Host 的到期时间差异,并返回需要同步的请求。
+func (t *wafTask) findMismatchedExpirations(ctx context.Context, wafLimits []model.GlobalLimit) ([]RenewalRequest, error) {
+	if len(wafLimits) == 0 {
+		return nil, nil
+	}
+
+	// 2. 将 WAF 数据组织成 Map
+	wafExpiredMap := make(map[int]int64, len(wafLimits))
+	wafPlanMap := make(map[int]int, len(wafLimits))
+	var hostIds []int
+	for _, limit := range wafLimits {
+		hostIds = append(hostIds, limit.HostId)
+		wafExpiredMap[limit.HostId] = limit.ExpiredAt
+		wafPlanMap[limit.HostId] = limit.RuleId
+	}
+
+	// 3. 获取对应 Host 的到期时间
+	hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
+	if err != nil {
+		return nil, fmt.Errorf("获取主机到期时间失败: %w", err)
+	}
+	hostExpiredMap := make(map[int]int64, len(hostExpirations))
+	for _, h := range hostExpirations {
+		hostExpiredMap[h.HostId] = h.ExpiredAt
+	}
+
+	// 4. 找出时间不一致的记录
+	var renewalRequests []RenewalRequest
+	for hostId, wafExpiredTime := range wafExpiredMap {
+		hostTime, ok := hostExpiredMap[hostId]
+
+		// 如果 Host 时间与 WAF 时间不一致,则需要同步
+		if !ok || hostTime != wafExpiredTime {
+			planId, planOk := wafPlanMap[hostId]
+			if !planOk {
+				t.logger.Warn("数据不一致:在waf_limits中找不到hostId对应的套餐ID", zap.Int("hostId", hostId))
+				continue
+			}
+			renewalRequests = append(renewalRequests, RenewalRequest{
+				HostId:    hostId,
+				ExpiredAt: hostTime, // 以 WAF 表的时间为准
+				PlanId:    planId,
+			})
+		}
+	}
+
+	return renewalRequests, nil
+}
+
+
+//获取到期时间小于3天的同步时间
+
+func (t *wafTask) SynchronizationTime(ctx context.Context) error {
+	// 1. 获取 WAF 全局配置中即将到期(小于3天)的数据
+	wafLimits, err := t.GetGlobalAlmostExpiring(ctx, OneDaysInSeconds)
+	if err != nil {
+		return fmt.Errorf("获取全局到期配置失败: %w", err)
 	}
-	if err := t.EditGlobalExpired(ctx,sendData,true); err != nil {
-		return err
+
+	// 2. 找出需要同步的数据
+	renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
+	if err != nil {
+		return err // 错误已在辅助函数中包装
+	}
+
+	// 3. 如果有需要同步的数据,执行续费操作
+	if len(renewalRequests) > 0 {
+		t.logger.Info("发现记录需要同步到期时间。", zap.Int("数量", len(renewalRequests)))
+		return t.EditExpired(ctx, renewalRequests)
 	}
+
 	return nil
-}
+}
+
+
+
+//获取到期的进行关闭套餐操作
+// 获取到期的进行关闭套餐操作
+func (t *wafTask) StopPlan(ctx context.Context) error {
+	// 1. 获取 WAF 全局配置中已经到期的数据
+	// 使用 time.Now().Unix() 表示获取所有 expired_at <= 当前时间的记录
+	wafLimits, err := t.globalLimitRep.GetGlobalLimitAlmostExpired(ctx, time.Now().Unix())
+	if err != nil {
+		return fmt.Errorf("获取全局到期配置失败: %w", err)
+	}
+	if len(wafLimits) == 0 {
+		return nil // 没有到期的,任务完成
+	}
+
+	// 2. (可选,但推荐)先同步任何时间不一致的数据,确保状态准确
+	renewalRequests, err := t.findMismatchedExpirations(ctx, wafLimits)
+	if err != nil {
+		t.logger.Error("在关闭套餐前,同步时间失败", zap.Error(err))
+		// 根据业务决定是否要继续,这里我们选择继续,但记录错误
+	}
+	if len(renewalRequests) > 0 {
+		t.logger.Info("关闭套餐前,发现并同步不一致的时间记录", zap.Int("数量", len(renewalRequests)))
+		if err := t.EditExpired(ctx, renewalRequests); err != nil {
+			t.logger.Error("同步不一致的时间记录失败", zap.Error(err))
+		}
+	}
+
+	// 3. 关闭所有已经到期的套餐
+	t.logger.Info("开始关闭已到期的WAF服务", zap.Int("数量", len(wafLimits)))
+	var allErrors *multierror.Error
+
+	for _, limit := range wafLimits {
+
+		webIds, err := t.GetCdnWebId(ctx, limit.HostId)
+		if err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("获取hostId %d 的webId失败: %w", limit.HostId, err))
+			continue // 继续处理下一个
+		}
+
+		if err := t.BanServer(ctx, webIds, false); err != nil {
+			allErrors = multierror.Append(allErrors, fmt.Errorf("关闭hostId %d 的服务失败: %w", limit.HostId, err))
+		}
+
+	}
+
+	return allErrors.ErrorOrNil()
+}
+//对于到期7天内续费的产品需要进行恢复操作
+
+// RecoverStopPlan 对于到期7天内续费的产品进行恢复操作
+func (t *wafTask) RecoverStopPlan(ctx context.Context) error {
+	// 1. 查找在过去7天内到期,并且当前状态为“已关闭”的 WAF 记录
+	// 这可能需要一个新的 repository 方法,例如: GetRecentlyClosedLimits
+	// 我们先假设有这样一个方法,它返回 state=false 且 expired_at 在 (now-7天, now] 之间的记录
+	since := time.Now().Add(-7 * 24 * time.Hour).Unix()
+
+	// 假设你有一个方法 `GetClosedLimitsSince(ctx, sinceTime)`
+	// closedLimits, err := t.globalLimitRep.GetClosedLimitsSince(ctx, since)
+	// 为简化,我们先获取所有7天内到期的,再在逻辑里判断
+
+	// 简单的实现:获取7天内到期的所有记录
+	wafLimits, err := t.globalLimitRep.GetLimitsExpiredSince(ctx, since) // 假设有这个方法
+	if err != nil {
+		return fmt.Errorf("获取近期到期配置失败: %w", err)
+	}
+	if len(wafLimits) == 0 {
+		return nil
+	}
+
+	// 提取 hostIds 并过滤出已关闭的记录
+	var hostIds []int
+	closedLimitsMap := make(map[int]model.GlobalLimit)
+	for _, limit := range wafLimits {
+		if !limit.State { // 只处理状态为“已关闭”的
+			hostIds = append(hostIds, limit.HostId)
+			closedLimitsMap[limit.HostId] = limit
+		}
+	}
+	if len(hostIds) == 0 {
+		return nil // 没有已关闭的记录需要检查
+	}
+
+	// 2. 获取这些 host 的当前到期时间
+	hostExpirations, err := t.hostRep.GetExpireTimeByHostId(ctx, hostIds)
+	if err != nil {
+		return fmt.Errorf("获取主机当前到期时间失败: %w", err)
+	}
+	hostExpiredMap := make(map[int]int64)
+	for _, h := range hostExpirations {
+		hostExpiredMap[h.HostId] = h.ExpiredAt
+	}
+
+	var allErrors *multierror.Error
+	// 3. 比较时间,找出已续费的 host,并恢复服务
+	for hostId, closedLimit := range closedLimitsMap {
+		currentHostExpiry, ok := hostExpiredMap[hostId]
+		if !ok {
+			continue // host 不存在了,跳过
+		}
+
+		// 如果 host 表的到期时间 > global_limit 表的到期时间,说明已续费
+		if currentHostExpiry > closedLimit.ExpiredAt {
+			t.logger.Info("发现已续费并关闭的WAF服务,准备恢复", zap.Int("hostId", hostId))
+
+			// 3a. 恢复网站服务
+			webIds, err := t.GetCdnWebId(ctx, hostId)
+			if err != nil {
+				allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d 时获取webId失败: %w", hostId, err))
+				continue
+			}
+			if err := t.BanServer(ctx, webIds, true); err != nil { // true 表示启用
+				allErrors = multierror.Append(allErrors, fmt.Errorf("恢复hostId %d 服务失败: %w", hostId, err))
+				continue
+			}
+
+			// 3b. 更新 global_limit 表的时间和状态
+			var singleUpdate []struct{hostId int; expiredAt int64}
+			singleUpdate = append(singleUpdate, struct{hostId int; expiredAt int64}{hostId: hostId, expiredAt: currentHostExpiry})
+			if err := t.EditGlobalExpired(ctx, singleUpdate, true); err != nil { // true 表示启用
+				allErrors = multierror.Append(allErrors, fmt.Errorf("更新hostId %d 状态为已恢复失败: %w", hostId, err))
+			}
+		}
+	}
+
+	return allErrors.ErrorOrNil()
+}
+
+//对于大于7天的药进行数据情侣操作