|
@@ -2,17 +2,22 @@ package service
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
|
|
|
"github.com/go-nunu/nunu-layout-advanced/internal/model"
|
|
|
"github.com/go-nunu/nunu-layout-advanced/internal/repository"
|
|
|
+ "github.com/spf13/cast"
|
|
|
+ "golang.org/x/sync/errgroup"
|
|
|
"strconv"
|
|
|
+ "strings"
|
|
|
)
|
|
|
|
|
|
type WebForwardingService interface {
|
|
|
- GetWebForwarding(ctx context.Context, id int64) (*model.WebForwarding, error)
|
|
|
- AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) (string, error)
|
|
|
- EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) (string, error)
|
|
|
- DeleteWebForwarding(ctx context.Context, wafWebId int) (string, error)
|
|
|
+ GetWebForwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.WebForwardingDataRequest, error)
|
|
|
+ AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error
|
|
|
+ EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error
|
|
|
+ DeleteWebForwarding(ctx context.Context, wafWebId int) error
|
|
|
}
|
|
|
|
|
|
func NewWebForwardingService(
|
|
@@ -43,21 +48,114 @@ type webForwardingService struct {
|
|
|
}
|
|
|
|
|
|
func (s *webForwardingService) require(ctx context.Context,req v1.GlobalRequire) (v1.GlobalRequire, error) {
|
|
|
- res, err := s.wafformatter.require(ctx, req, "web")
|
|
|
- if err != nil {
|
|
|
+ var err error
|
|
|
+ var res v1.GlobalRequire
|
|
|
+ g, gCtx := errgroup.WithContext(ctx)
|
|
|
+
|
|
|
+ g.Go(func() error {
|
|
|
+ result, e := s.wafformatter.require(gCtx, req, "web")
|
|
|
+ if e != nil {
|
|
|
+ return e
|
|
|
+ }
|
|
|
+ res = result
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ g.Go(func() error {
|
|
|
+ e := s.wafformatter.validateWafDomainCount(gCtx, req)
|
|
|
+ if e != nil {
|
|
|
+ return e
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ if err = g.Wait(); err != nil {
|
|
|
return v1.GlobalRequire{}, err
|
|
|
}
|
|
|
return res, nil
|
|
|
}
|
|
|
|
|
|
-func (s *webForwardingService) GetWebForwarding(ctx context.Context, id int64) (*model.WebForwarding, error) {
|
|
|
- return s.webForwardingRepository.GetWebForwarding(ctx, id)
|
|
|
+func (s *webForwardingService) GetWebForwarding(ctx context.Context, req v1.GetForwardingRequest) (v1.WebForwardingDataRequest, error) {
|
|
|
+ var webForwarding model.WebForwarding
|
|
|
+ var backend model.WebForwardingRule
|
|
|
+ var err error
|
|
|
+ g, gCtx := errgroup.WithContext(ctx)
|
|
|
+ g.Go(func() error {
|
|
|
+ res, e := s.webForwardingRepository.GetWebForwarding(gCtx, int64(req.Id))
|
|
|
+ if e != nil {
|
|
|
+ // 直接返回错误,errgroup 会捕获它
|
|
|
+ return fmt.Errorf("GetWebForwarding failed: %w", e)
|
|
|
+ }
|
|
|
+ if res != nil {
|
|
|
+ webForwarding = *res
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ g.Go(func() error {
|
|
|
+ res, e := s.webForwardingRepository.GetWebForwardingByID(ctx, req.Id)
|
|
|
+ if e != nil {
|
|
|
+ return fmt.Errorf("GetWebForwardingByID failed: %w", e)
|
|
|
+ }
|
|
|
+ if res != nil {
|
|
|
+ backend = *res
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ if err := g.Wait(); err != nil {
|
|
|
+ return v1.WebForwardingDataRequest{}, err
|
|
|
+ }
|
|
|
+
|
|
|
+ portInt, err := cast.ToIntE(webForwarding.Port)
|
|
|
+ if err != nil {
|
|
|
+ return v1.WebForwardingDataRequest{}, err
|
|
|
+ }
|
|
|
+ return v1.WebForwardingDataRequest{
|
|
|
+ Id: webForwarding.Id,
|
|
|
+ WafWebId: webForwarding.WafWebId,
|
|
|
+ Tag: webForwarding.Tag,
|
|
|
+ Port: portInt,
|
|
|
+ Domain: webForwarding.Domain,
|
|
|
+ CustomHost: webForwarding.CustomHost,
|
|
|
+ WafWebLimitId: webForwarding.WebLimitRuleId,
|
|
|
+ WafGatewayGroupId: webForwarding.WafGatewayGroupId,
|
|
|
+ CcCount: webForwarding.CcCount,
|
|
|
+ CcDuration: webForwarding.CcDuration,
|
|
|
+ CcBlockCount: webForwarding.CcBlockCount,
|
|
|
+ CcBlockDuration: webForwarding.CcBlockDuration,
|
|
|
+ Cc4xxCount: webForwarding.Cc4xxCount,
|
|
|
+ Cc4xxDuration: webForwarding.Cc4xxDuration,
|
|
|
+ Cc4xxBlockCount: webForwarding.Cc4xxBlockCount,
|
|
|
+ Cc4xxBlockDuration: webForwarding.Cc4xxBlockDuration,
|
|
|
+ Cc5xxCount: webForwarding.Cc5xxCount,
|
|
|
+ Cc5xxDuration: webForwarding.Cc5xxDuration,
|
|
|
+ Cc5xxBlockCount: webForwarding.Cc5xxBlockCount,
|
|
|
+ Cc5xxBlockDuration: webForwarding.Cc5xxBlockDuration,
|
|
|
+ IsHttps: webForwarding.IsHttps,
|
|
|
+ Comment: webForwarding.Comment,
|
|
|
+ BackendList: backend.BackendList,
|
|
|
+ AllowIpList: backend.AllowIpList,
|
|
|
+ DenyIpList: backend.DenyIpList,
|
|
|
+ AccessRule: backend.AccessRule,
|
|
|
+ }, nil
|
|
|
}
|
|
|
|
|
|
// buildWafFormData 辅助函数,用于构建通用的 formData
|
|
|
-func (s *webForwardingService) buildWafFormData(req *v1.WebForwardingData, require v1.GlobalRequire) map[string]interface{} {
|
|
|
+func (s *webForwardingService) buildWafFormData(req *v1.WebForwardingDataSend, require v1.GlobalRequire) map[string]interface{} {
|
|
|
+ // 将BackendList序列化为JSON字符串
|
|
|
+ backendJSON, err := json.MarshalIndent(req.BackendList, "", " ")
|
|
|
+ var backendStr interface{}
|
|
|
+ if err != nil {
|
|
|
+ // 如果序列化失败,使用空数组
|
|
|
+ backendStr = "[]"
|
|
|
+ } else {
|
|
|
+ // 成功序列化后,使用JSON字符串
|
|
|
+ backendStr = string(backendJSON)
|
|
|
+ }
|
|
|
+
|
|
|
return map[string]interface{}{
|
|
|
- "tag": req.Tag,
|
|
|
+ "waf_web_id": req.WafWebId,
|
|
|
+ "tag": require.Tag,
|
|
|
"port": req.Port,
|
|
|
"domain": req.Domain,
|
|
|
"custom_host": req.CustomHost,
|
|
@@ -75,7 +173,7 @@ func (s *webForwardingService) buildWafFormData(req *v1.WebForwardingData, requi
|
|
|
"cc_5xx_duration": req.Cc5xxDuration,
|
|
|
"cc_5xx_block_count": req.Cc5xxBlockCount,
|
|
|
"cc_5xx_block_duration": req.Cc5xxBlockDuration,
|
|
|
- "backend_list": req.BackendList,
|
|
|
+ "backend": backendStr,
|
|
|
"allow_ip_list": req.AllowIpList,
|
|
|
"deny_ip_list": req.DenyIpList,
|
|
|
"access_rule": req.AccessRule,
|
|
@@ -86,11 +184,11 @@ func (s *webForwardingService) buildWafFormData(req *v1.WebForwardingData, requi
|
|
|
|
|
|
// buildWebForwardingModel 辅助函数,用于构建通用的 WebForwarding 模型
|
|
|
// ruleId 是从 WAF 系统获取的 ID
|
|
|
-func (s *webForwardingService) buildWebForwardingModel(req *v1.WebForwardingData,ruleId int, require v1.GlobalRequire) *model.WebForwarding {
|
|
|
+func (s *webForwardingService) buildWebForwardingModel(req *v1.WebForwardingDataRequest,ruleId int, require v1.GlobalRequire) *model.WebForwarding {
|
|
|
return &model.WebForwarding{
|
|
|
HostId: require.HostId,
|
|
|
- RuleId: ruleId,
|
|
|
- Tag: req.Tag,
|
|
|
+ WafWebId: ruleId,
|
|
|
+ Tag: require.Tag,
|
|
|
Port: strconv.Itoa(req.Port),
|
|
|
Domain: req.Domain,
|
|
|
CustomHost: req.CustomHost,
|
|
@@ -113,58 +211,130 @@ func (s *webForwardingService) buildWebForwardingModel(req *v1.WebForwardingData
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) (string, error) {
|
|
|
+func (s *webForwardingService) buildWebRuleModel(reqData *v1.WebForwardingDataRequest, require v1.GlobalRequire, localDbId int) *model.WebForwardingRule {
|
|
|
+ return &model.WebForwardingRule{
|
|
|
+ Uid: require.Uid,
|
|
|
+ HostId: require.HostId,
|
|
|
+ WebId: localDbId, // 关联到本地数据库的主记录 ID
|
|
|
+ BackendList: reqData.BackendList,
|
|
|
+ AllowIpList: reqData.AllowIpList,
|
|
|
+ DenyIpList: reqData.DenyIpList,
|
|
|
+ AccessRule: reqData.AccessRule,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebForwardingRequest) (v1.GlobalRequire, map[string]interface{}, error) {
|
|
|
+ // 1. 获取必要的全局信息
|
|
|
require, err := s.require(ctx, v1.GlobalRequire{
|
|
|
- HostId: req.HostId,
|
|
|
- Uid: req.Uid,
|
|
|
+ HostId: req.HostId,
|
|
|
+ Uid: req.Uid,
|
|
|
Comment: req.WebForwardingData.Comment,
|
|
|
+ Domain: req.WebForwardingData.Domain,
|
|
|
})
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return v1.GlobalRequire{}, nil, err
|
|
|
+ }
|
|
|
+ if require.WafGatewayGroupId == 0 || require.LimitRuleId == 0 {
|
|
|
+ return v1.GlobalRequire{}, nil, fmt.Errorf("请先配置实例")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 2. 将字符串切片拼接成字符串,用于 WAF API
|
|
|
+ allowIpListStr := strings.Join(req.WebForwardingData.AllowIpList, "\n")
|
|
|
+ denyIpListStr := strings.Join(req.WebForwardingData.DenyIpList, "\n")
|
|
|
+
|
|
|
+ // 3. 创建用于构建 WAF 表单的数据结构
|
|
|
+ formDataBase := v1.WebForwardingDataSend{
|
|
|
+ Tag: require.Tag,
|
|
|
+ WafWebId: req.WebForwardingData.WafWebId,
|
|
|
+ WafGatewayGroupId: require.WafGatewayGroupId,
|
|
|
+ WafWebLimitId: require.LimitRuleId,
|
|
|
+ Port: req.WebForwardingData.Port,
|
|
|
+ Domain: req.WebForwardingData.Domain,
|
|
|
+ CustomHost: req.WebForwardingData.CustomHost,
|
|
|
+ CcCount: req.WebForwardingData.CcCount,
|
|
|
+ CcDuration: req.WebForwardingData.CcDuration,
|
|
|
+ CcBlockCount: req.WebForwardingData.CcBlockCount,
|
|
|
+ CcBlockDuration: req.WebForwardingData.CcBlockDuration,
|
|
|
+ Cc4xxCount: req.WebForwardingData.Cc4xxCount,
|
|
|
+ Cc4xxDuration: req.WebForwardingData.Cc4xxDuration,
|
|
|
+ Cc4xxBlockCount: req.WebForwardingData.Cc4xxBlockCount,
|
|
|
+ Cc4xxBlockDuration: req.WebForwardingData.Cc4xxBlockDuration,
|
|
|
+ Cc5xxCount: req.WebForwardingData.Cc5xxCount,
|
|
|
+ Cc5xxDuration: req.WebForwardingData.Cc5xxDuration,
|
|
|
+ Cc5xxBlockCount: req.WebForwardingData.Cc5xxBlockCount,
|
|
|
+ Cc5xxBlockDuration: req.WebForwardingData.Cc5xxBlockDuration,
|
|
|
+ IsHttps: req.WebForwardingData.IsHttps,
|
|
|
+ BackendList: req.WebForwardingData.BackendList,
|
|
|
+ AllowIpList: allowIpListStr,
|
|
|
+ DenyIpList: denyIpListStr,
|
|
|
+ AccessRule: req.WebForwardingData.AccessRule,
|
|
|
+ Comment: req.WebForwardingData.Comment,
|
|
|
+ }
|
|
|
+
|
|
|
+ // 4. 构建 WAF 表单数据映射
|
|
|
+ formData := s.buildWafFormData(&formDataBase, require)
|
|
|
+
|
|
|
+ return require, formData, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
|
|
|
+ require, formData, err := s.prepareWafData(ctx, req)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ err = s.wafformatter.validateWafPortCount(ctx, require.HostId)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
- formData := s.buildWafFormData(&req.WebForwardingData, require)
|
|
|
wafWebId, err := s.wafformatter.sendFormData(ctx, "admin/info/waf_web/new", "admin/new/waf_web", formData)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return err
|
|
|
}
|
|
|
+
|
|
|
webModel := s.buildWebForwardingModel(&req.WebForwardingData, wafWebId, require)
|
|
|
|
|
|
- if err := s.webForwardingRepository.AddWebForwarding(ctx, webModel); err != nil {
|
|
|
- return "", err
|
|
|
+ id, err := s.webForwardingRepository.AddWebForwarding(ctx, webModel)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, id)
|
|
|
+ if _, err = s.webForwardingRepository.AddWebForwardingIps(ctx, *webRuleModel); err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
- return "", nil
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) (string, error) {
|
|
|
- require, err := s.require(ctx, v1.GlobalRequire{
|
|
|
- HostId: req.HostId,
|
|
|
- Uid: req.Uid,
|
|
|
- Comment: req.WebForwardingData.Comment,
|
|
|
- })
|
|
|
+func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
|
|
|
+ WafWebId, err := s.webForwardingRepository.GetWebForwardingWafWebIdById(ctx, req.Id)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return err
|
|
|
}
|
|
|
- formData := s.buildWafFormData(&req.WebForwardingData, require)
|
|
|
- _, err = s.wafformatter.sendFormData(ctx, "admin/info/waf_web/edit?&__goadmin_edit_pk="+strconv.Itoa(req.WebForwardingData.WafWebId), "admin/edit/waf_web", formData)
|
|
|
+ req.WebForwardingData.WafWebId = WafWebId
|
|
|
+ require, formData, err := s.prepareWafData(ctx, req)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
+ _, err = s.wafformatter.sendFormData(ctx, "admin/info/waf_web/edit?&__goadmin_edit_pk="+strconv.Itoa(req.WebForwardingData.WafWebId), "admin/edit/waf_web", formData)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.WafWebId, require)
|
|
|
webModel.Id = req.Id
|
|
|
-
|
|
|
-
|
|
|
- if err := s.webForwardingRepository.AddWebForwarding(ctx, webModel); err != nil {
|
|
|
- return "", err
|
|
|
+ if err = s.webForwardingRepository.EditWebForwarding(ctx, webModel); err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
-
|
|
|
- return "", nil
|
|
|
+ webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, req.Id)
|
|
|
+ if err = s.webForwardingRepository.EditWebForwardingIps(ctx, *webRuleModel); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, wafWebId int) (string, error) {
|
|
|
- res, err := s.crawler.DeleteRule(ctx, wafWebId, "admin/delete/waf_web?page=1&__pageSize=10&__sort=waf_web_id&__sort_type=desc")
|
|
|
+func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, wafWebId int) error {
|
|
|
+ _, err := s.crawler.DeleteRule(ctx, wafWebId, "admin/delete/waf_web?page=1&__pageSize=10&__sort=waf_web_id&__sort_type=desc")
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return err
|
|
|
}
|
|
|
- return res, nil
|
|
|
+ return nil
|
|
|
}
|