package waf import ( "context" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn" "golang.org/x/sync/errgroup" "sort" ) type UdpForWardingService interface { GetUdpForWarding(ctx context.Context,req v1.GetForwardingRequest) (v1.UdpForwardingDataRequest, error) AddUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) (int, error) EditUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) error DeleteUdpForwarding(ctx context.Context, req v1.DeleteUdpForwardingRequest) error GetUdpForwardingWafUdpAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.UdpForwardingDataRequest, error) } func NewUdpForWardingService( service *service.Service, udpForWardingRepository waf.UdpForWardingRepository, required service.RequiredService, parser service.ParserService, crawler service.CrawlerService, globalRep waf.GlobalLimitRepository, hostRep repository.HostRepository, wafformatter WafFormatterService, cdn flexCdn.CdnService, proxy flexCdn.ProxyService, aidedUdp AidedUdpService, ) UdpForWardingService { return &udpForWardingService{ Service: service, udpForWardingRepository: udpForWardingRepository, required: required, parser: parser, crawler: crawler, globalRep: globalRep, hostRep: hostRep, wafformatter: wafformatter, cdn: cdn, proxy: proxy, aidedUdp: aidedUdp, } } type udpForWardingService struct { *service.Service udpForWardingRepository waf.UdpForWardingRepository required service.RequiredService parser service.ParserService crawler service.CrawlerService globalRep waf.GlobalLimitRepository hostRep repository.HostRepository wafformatter WafFormatterService cdn flexCdn.CdnService proxy flexCdn.ProxyService aidedUdp AidedUdpService } // GetUdpForWarding 获取单个UDP转发配置详情 // 该函数根据ID同时查询主记录和规则记录,并合并返回完整的配置信息 func (s *udpForWardingService) GetUdpForWarding(ctx context.Context, req v1.GetForwardingRequest) (v1.UdpForwardingDataRequest, error) { // 参数验证 if req.Id <= 0 { return v1.UdpForwardingDataRequest{}, fmt.Errorf("非法的ID参数: %d", req.Id) } var udpForWarding model.UdpForWarding var backend model.UdpForwardingRule var err error // 并发查询主记录和规则记录以提高性能 g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { res, e := s.udpForWardingRepository.GetUdpForWarding(gCtx, int64(req.Id)) if e != nil { return fmt.Errorf("查询UDP转发主记录失败 ID:%d, %w", req.Id, e) } if res != nil { udpForWarding = *res } return nil }) g.Go(func() error { res, e := s.udpForWardingRepository.GetUdpForwardingIpsByID(gCtx, req.Id) if e != nil { return fmt.Errorf("查询UDP转发规则记录失败 ID:%d, %w", req.Id, e) } if res != nil { backend = *res } return nil }) if err = g.Wait(); err != nil { return v1.UdpForwardingDataRequest{}, err } // 检查是否找到主记录 if udpForWarding.Id == 0 { return v1.UdpForwardingDataRequest{}, fmt.Errorf("UDP转发配置不存在 ID:%d", req.Id) } return v1.UdpForwardingDataRequest{ Id: udpForWarding.Id, Port: udpForWarding.Port, BackendList: backend.BackendList, Comment: udpForWarding.Comment, Proxy: udpForWarding.Proxy, }, nil } // AddUdpForwarding 添加 UDP 转发配置 // 该函数完成 UDP 转发的完整创建流程:验证、创建 CDN、添加源站、配置代理、保存数据、处理异步任务 func (s *udpForWardingService) AddUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) (int, error) { // 1. 数据准备和验证 require, formData, err := s.aidedUdp.PrepareWafData(ctx, req) if err != nil { return 0, err } if err := s.aidedUdp.ValidateAddRequest(ctx, req, require); err != nil { return 0, err } // 2. 创建CDN网站 udpId, err := s.aidedUdp.CreateCdnWebsite(ctx, formData) if err != nil { return 0, err } // 3. 添加源站 cdnOriginIds, err := s.aidedUdp.AddOriginsToWebsite(ctx, req, udpId) if err != nil { return 0, err } // 4. 配置代理协议 if err := s.aidedUdp.ConfigureProxyProtocol(ctx, req, udpId); err != nil { return 0, err } // 5. 保存到数据库 id, err := s.aidedUdp.SaveToDatabase(ctx, req, require, udpId, cdnOriginIds) if err != nil { return 0, err } // 6. 处理异步任务 s.aidedUdp.ProcessAsyncTasks(req) return id, nil } // EditUdpForwarding 编辑 UDP 转发配置 // 该函数完成 UDP 转发的完整编辑流程:验证、更新 CDN、处理IP白名单、更新源站、更新数据库 func (s *udpForWardingService) EditUdpForwarding(ctx context.Context, req *v1.UdpForwardingRequest) error { // 1. 数据准备和验证 require, formData, err := s.aidedUdp.PrepareWafData(ctx, req) if err != nil { return err } oldData, err := s.udpForWardingRepository.GetUdpForWarding(ctx, int64(req.UdpForwardingData.Id)) if err != nil { return fmt.Errorf("获取原始数据失败: %w", err) } if err := s.aidedUdp.ValidateEditRequest(ctx, req, require, oldData); err != nil { return err } // 2. 更新CDN配置 if err := s.aidedUdp.UpdateCdnConfiguration(ctx, req, oldData, require, formData); err != nil { return err } // 3. 获取IP数据并处理白名单 ipData, err := s.udpForWardingRepository.GetUdpForwardingIpsByID(ctx, req.UdpForwardingData.Id) if err != nil { return fmt.Errorf("获取IP数据失败: %w", err) } if err := s.aidedUdp.ProcessIpWhitelistChanges(ctx, req, ipData); err != nil { return err } // 4. 更新源站配置 if err := s.aidedUdp.UpdateOriginServers(ctx, req, oldData, ipData); err != nil { return err } // 5. 更新数据库记录 if err := s.aidedUdp.UpdateDatabaseRecords(ctx, req, oldData, require, ipData); err != nil { return err } return nil } // DeleteUdpForwarding 批量删除 UDP 转发配置 // 该函数支持批量删除多个 UDP 转发配置,对每个配置都执行完整的删除流程 func (s *udpForWardingService) DeleteUdpForwarding(ctx context.Context, req v1.DeleteUdpForwardingRequest) error { // 批量删除处理 for _, id := range req.Ids { if err := s.deleteSingleUdpForwarding(ctx, id, req.HostId); err != nil { return fmt.Errorf("删除UDP转发配置失败 ID:%d, %w", id, err) } } return nil } // deleteSingleUdpForwarding 删除单个 UDP 转发配置 // 该函数完成单个配置的完整删除流程:权限验证、删除 CDN、清理IP白名单、清理数据库 func (s *udpForWardingService) deleteSingleUdpForwarding(ctx context.Context, id int, hostId int) error { // 1. 获取原始数据并验证权限 oldData, err := s.udpForWardingRepository.GetUdpForWarding(ctx, int64(id)) if err != nil { return fmt.Errorf("获取UDP转发数据失败: %w", err) } if err := s.aidedUdp.ValidateDeletePermission(oldData, hostId); err != nil { return err } // 2. 删除CDN服务器 if err := s.aidedUdp.DeleteCdnServer(ctx, oldData.CdnWebId); err != nil { return err } // 3. 处理IP白名单清理 if err := s.aidedUdp.ProcessDeleteIpWhitelist(ctx, id); err != nil { return err } // 4. 清理数据库记录 if err := s.aidedUdp.CleanupDatabaseRecords(ctx, id); err != nil { return err } return nil } // GetUdpForwardingWafUdpAllIps 获取指定主机的所有 UDP 转发配置列表 // 该函数使用并发查询优化性能,同时获取多个配置的详细信息并按ID降序排列 func (s *udpForWardingService) GetUdpForwardingWafUdpAllIps(ctx context.Context, req v1.GetForwardingRequest) ([]v1.UdpForwardingDataRequest, error) { type CombinedResult struct { Id int Forwarding *model.UdpForWarding BackendRule *model.UdpForwardingRule Err error } g, gCtx := errgroup.WithContext(ctx) ids, err := s.udpForWardingRepository.GetUdpForwardingWafUdpAllIds(gCtx, req.HostId) if err != nil { return nil, fmt.Errorf("获取UDP转发ID列表失败: %w", err) } if len(ids) == 0 { return nil, nil } resChan := make(chan CombinedResult, len(ids)) g.Go(func() error { for _, idVal := range ids { currentID := idVal g.Go(func() error { var wf *model.UdpForWarding var bk *model.UdpForwardingRule var localErr error wf, localErr = s.udpForWardingRepository.GetUdpForWarding(gCtx, int64(currentID)) if localErr != nil { resChan <- CombinedResult{Id: currentID, Err: localErr} return localErr } bk, localErr = s.udpForWardingRepository.GetUdpForwardingIpsByID(gCtx, currentID) if localErr != nil { resChan <- CombinedResult{Id: currentID, Err: localErr} return localErr } resChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk} return nil }) } return nil }) groupErr := g.Wait() close(resChan) if groupErr != nil { return nil, groupErr } res := make([]v1.UdpForwardingDataRequest, 0, len(ids)) for r := range resChan { if r.Err != nil { return nil, fmt.Errorf("处理ID %d 时出错: %w", r.Id, r.Err) } if r.Forwarding == nil { return nil, fmt.Errorf("ID %d 对应的转发配置为空", r.Id) } dataReq := v1.UdpForwardingDataRequest{ Id: r.Forwarding.Id, Port: r.Forwarding.Port, Comment: r.Forwarding.Comment, Proxy: r.Forwarding.Proxy, } if r.BackendRule != nil { dataReq.BackendList = r.BackendRule.BackendList } res = append(res, dataReq) } sort.Slice(res, func(i, j int) bool { return res[i].Id > res[j].Id }) return res, nil }