package waf import ( "context" "fmt" "sort" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository" "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn" "golang.org/x/sync/errgroup" ) type TcpforwardingService interface { GetTcpforwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.TcpForwardingDataRequest, error) AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) (int, error) EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error DeleteTcpForwarding(ctx context.Context, req v1.DeleteTcpForwardingRequest) error GetTcpForwardingAllIpsByHostId(ctx context.Context, req v1.GetForwardingRequest) ([]v1.TcpForwardingDataRequest, error) } func NewTcpforwardingService( service *service.Service, tcpforwardingRepository waf.TcpforwardingRepository, parser service.ParserService, required service.RequiredService, crawler service.CrawlerService, globalRep waf.GlobalLimitRepository, hostRep repository.HostRepository, wafformatter WafFormatterService, cdn flexCdn.CdnService, proxy flexCdn.ProxyService, aidedTcp AidedTcpService, ) TcpforwardingService { return &tcpforwardingService{ Service: service, tcpforwardingRepository: tcpforwardingRepository, parser: parser, required: required, crawler: crawler, globalRep: globalRep, hostRep: hostRep, wafformatter: wafformatter, cdn: cdn, proxy: proxy, aidedTcp: aidedTcp, } } type tcpforwardingService struct { *service.Service tcpforwardingRepository waf.TcpforwardingRepository parser service.ParserService required service.RequiredService crawler service.CrawlerService globalRep waf.GlobalLimitRepository hostRep repository.HostRepository wafformatter WafFormatterService cdn flexCdn.CdnService proxy flexCdn.ProxyService aidedTcp AidedTcpService } // GetTcpforwarding 获取单个TCP转发配置详情 // 该函数根据ID同时查询主记录和规则记录,并合并返回完整的配置信息 func (s *tcpforwardingService) GetTcpforwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.TcpForwardingDataRequest, error) { // 参数验证 if req.Id <= 0 { return v1.TcpForwardingDataRequest{}, fmt.Errorf("非法的ID参数: %d", req.Id) } var tcpForwarding model.Tcpforwarding var backend model.TcpForwardingRule var err error // 并发查询主记录和规则记录以提高性能 g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { res, e := s.tcpforwardingRepository.GetTcpforwarding(gCtx, int64(req.Id)) if e != nil { return fmt.Errorf("查询TCP转发主记录失败 ID:%d, %w", req.Id, e) } if res != nil { tcpForwarding = *res } return nil }) g.Go(func() error { res, e := s.tcpforwardingRepository.GetTcpForwardingIpsByID(gCtx, req.Id) if e != nil { return fmt.Errorf("查询TCP转发规则记录失败 ID:%d, %w", req.Id, e) } if res != nil { backend = *res } return nil }) if err = g.Wait(); err != nil { return v1.TcpForwardingDataRequest{}, err } // 检查是否找到主记录 if tcpForwarding.Id == 0 { return v1.TcpForwardingDataRequest{}, fmt.Errorf("TCP转发配置不存在 ID:%d", req.Id) } return v1.TcpForwardingDataRequest{ Id: tcpForwarding.Id, Port: tcpForwarding.Port, Comment: tcpForwarding.Comment, Proxy: tcpForwarding.Proxy, BackendList: backend.BackendList, }, nil } // AddTcpForwarding 添加 TCP 转发配置 // 该函数完成 TCP 转发的完整创建流程:验证、创建 CDN、添加源站、配置代理、保存数据、处理异步任务 func (s *tcpforwardingService) AddTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) (int, error) { // 1. 数据准备和验证 require, formData, err := s.aidedTcp.PrepareWafData(ctx, req) if err != nil { return 0, err } if err := s.aidedTcp.ValidateAddRequest(ctx, req, require); err != nil { return 0, err } // 2. 创建CDN网站 tcpId, err := s.aidedTcp.CreateCdnWebsite(ctx, formData) if err != nil { return 0, err } // 3. 添加源站 cdnOriginIds, err := s.aidedTcp.AddOriginsToWebsite(ctx, req, tcpId) if err != nil { return 0, err } // 4. 配置代理协议 if err := s.aidedTcp.ConfigureProxyProtocol(ctx, req, tcpId); err != nil { return 0, err } // 5. 保存到数据库 id, err := s.aidedTcp.SaveToDatabase(ctx, req, require, tcpId, cdnOriginIds) if err != nil { return 0, err } // 6. 处理异步任务 s.aidedTcp.ProcessAsyncTasks(req) return id, nil } // EditTcpForwarding 编辑 TCP 转发配置 // 该函数完成 TCP 转发的完整编辑流程:验证、更新 CDN、处理IP白名单、更新源站、更新数据库 func (s *tcpforwardingService) EditTcpForwarding(ctx context.Context, req *v1.TcpForwardingRequest) error { // 1. 数据准备和验证 require, formData, err := s.aidedTcp.PrepareWafData(ctx, req) if err != nil { return err } oldData, err := s.tcpforwardingRepository.GetTcpforwarding(ctx, int64(req.TcpForwardingData.Id)) if err != nil { return fmt.Errorf("获取原始数据失败: %w", err) } if err := s.aidedTcp.ValidateEditRequest(ctx, req, require, oldData); err != nil { return err } // 2. 更新CDN配置 if err := s.aidedTcp.UpdateCdnConfiguration(ctx, req, oldData, require, formData); err != nil { return err } // 3. 获取IP数据并处理白名单 ipData, err := s.tcpforwardingRepository.GetTcpForwardingIpsByID(ctx, req.TcpForwardingData.Id) if err != nil { return fmt.Errorf("获取IP数据失败: %w", err) } if err := s.aidedTcp.ProcessIpWhitelistChanges(ctx, req, ipData); err != nil { return err } // 4. 更新源站配置 if err := s.aidedTcp.UpdateOriginServers(ctx, req, oldData, ipData); err != nil { return err } // 5. 更新数据库记录 if err := s.aidedTcp.UpdateDatabaseRecords(ctx, req, oldData, require, ipData); err != nil { return err } return nil } // DeleteTcpForwarding 批量删除 TCP 转发配置 // 该函数支持批量删除多个 TCP 转发配置,对每个配置都执行完整的删除流程 func (s *tcpforwardingService) DeleteTcpForwarding(ctx context.Context, req v1.DeleteTcpForwardingRequest) error { // 批量删除处理 for _, id := range req.Ids { if err := s.deleteSingleTcpForwarding(ctx, id, req.HostId); err != nil { return fmt.Errorf("删除TCP转发配置失败 ID:%d, %w", id, err) } } return nil } // deleteSingleTcpForwarding 删除单个 TCP 转发配置 // 该函数完成单个配置的完整删除流程:权限验证、删除 CDN、清理IP白名单、清理数据库 func (s *tcpforwardingService) deleteSingleTcpForwarding(ctx context.Context, id int, hostId int) error { // 1. 获取原始数据并验证权限 oldData, err := s.tcpforwardingRepository.GetTcpforwarding(ctx, int64(id)) if err != nil { return fmt.Errorf("获取TCP转发数据失败: %w", err) } if err := s.aidedTcp.ValidateDeletePermission(oldData, hostId); err != nil { return err } // 2. 删除CDN服务器 if err := s.aidedTcp.DeleteCdnServer(ctx, oldData.CdnWebId); err != nil { return err } // 3. 处理IP白名单清理 if err := s.aidedTcp.ProcessDeleteIpWhitelist(ctx, id); err != nil { return err } // 4. 清理数据库记录 if err := s.aidedTcp.CleanupDatabaseRecords(ctx, id); err != nil { return err } return nil } // GetTcpForwardingAllIpsByHostId 获取指定主机的所有 TCP 转发配置列表 // 该函数使用并发查询优化性能,同时获取多个配置的详细信息并按ID降序排列 func (s *tcpforwardingService) GetTcpForwardingAllIpsByHostId(ctx context.Context, req v1.GetForwardingRequest) ([]v1.TcpForwardingDataRequest, error) { type CombinedResult struct { Id int Forwarding *model.Tcpforwarding BackendRule *model.TcpForwardingRule Err error } g, gCtx := errgroup.WithContext(ctx) ids, err := s.tcpforwardingRepository.GetTcpForwardingAllIdsByID(gCtx, req.HostId) if err != nil { return nil, fmt.Errorf("GetTcpForwardingAllIds failed: %w", err) } if len(ids) == 0 { return nil, nil } resChan := make(chan CombinedResult, len(ids)) g.Go(func() error { for _, idVal := range ids { currentID := idVal g.Go(func() error { var wf *model.Tcpforwarding var bk *model.TcpForwardingRule var localErr error wf, localErr = s.tcpforwardingRepository.GetTcpforwarding(gCtx, int64(currentID)) if localErr != nil { resChan <- CombinedResult{Id: currentID, Err: localErr} return localErr } bk, localErr = s.tcpforwardingRepository.GetTcpForwardingIpsByID(gCtx, currentID) if localErr != nil { resChan <- CombinedResult{Id: currentID, Err: localErr} return localErr } resChan <- CombinedResult{Id: currentID, Forwarding: wf, BackendRule: bk} return nil }) } return nil }) groupErr := g.Wait() close(resChan) if groupErr != nil { return nil, groupErr } res := make([]v1.TcpForwardingDataRequest, 0, len(ids)) for r := range resChan { if r.Err != nil { return nil, fmt.Errorf("received error from goroutine for ID %d: %w", r.Id, r.Err) } if r.Forwarding == nil { return nil, fmt.Errorf("received nil forwarding from goroutine for ID %d", r.Id) } dataReq := v1.TcpForwardingDataRequest{ Id: r.Forwarding.Id, Port: r.Forwarding.Port, Comment: r.Forwarding.Comment, Proxy: r.Forwarding.Proxy, } if r.BackendRule != nil { dataReq.BackendList = r.BackendRule.BackendList } res = append(res, dataReq) } sort.Slice(res, func(i, j int) bool { return res[i].Id > res[j].Id }) return res, nil }