Selaa lähdekoodia

refactor(webforwarding): 重构网站转发功能

- 修改了获取网关组 IP 的逻辑,优化了 IP 获取的准确性和效率- 重构了添加和删除网站转发的流程,提高了代码的可读性和可维护性
- 新增了查找两个列表差异的辅助函数,简化了对后端列表变化的处理
- 调整了与 CDN 相关的操作,确保源站添加和删除的正确性
- 优化了域名白名单和 IP 白名单的处理逻辑,提高了安全性
fusu 1 kuukausi sitten
vanhempi
sitoutus
b99fe12f78

+ 3 - 3
internal/repository/gatewaygroup.go

@@ -17,7 +17,7 @@ type GatewayGroupRepository interface {
 	EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
 	GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error)
-	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error)
+	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*model.GatewayGroup, error)
 	GetGatewayGroupByRuleId(ctx context.Context, ruleId int64) (*model.GatewayGroup, error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup], error)
 	EditGatewayGroupById(ctx context.Context, req *model.GatewayGroup) error
@@ -83,8 +83,8 @@ func (r *gatewayGroupRepository) GetGatewayGroupWhereHostIdNull(ctx context.Cont
 	return id, nil
 }
 
-func (r *gatewayGroupRepository) GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error) {
-	res := []model.GatewayGroup{}
+func (r *gatewayGroupRepository) GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*model.GatewayGroup, error) {
+	res := model.GatewayGroup{}
 	if err := r.DB(ctx).Where("host_id = ?", hostId).Find(&res).Error; err != nil {
 		return nil, err
 	}

+ 1 - 1
internal/repository/gatewaygroupip.go

@@ -74,7 +74,7 @@ func (r *gateWayGroupIpRepository) GetGateWayGroupIpByGatewayGroupId(ctx context
 
 func (r *gateWayGroupIpRepository) GetGateWayGroupFirstIpByGatewayGroupId(ctx context.Context, gatewayGroupId int) (string, error) {
 	var res string
-	if err := r.DB(ctx).Model(&model.GateWayGroupIp{}).Where("gateway_group_id = ?", gatewayGroupId).Select("ip").First(&res).Error; err != nil {
+	if err := r.DB(ctx).Model(&model.GateWayGroupIp{}).Where("gateway_group_id = ?", gatewayGroupId).Select("ip").Order("id asc").First(&res).Error; err != nil {
 		return "", err
 	}
 	return res, nil

+ 1 - 0
internal/repository/webforwarding.go

@@ -151,6 +151,7 @@ func (r *webForwardingRepository) EditWebForwardingIps(ctx context.Context, req
 
 	updateData["deny_ip_list"] = req.DenyIpList
 
+	updateData["cdn_origin_ids"] = req.CdnOriginIds
 
 	// 始终更新更新时间
 	updateData["updated_at"] = time.Now()

+ 3 - 3
internal/service/gatewaygroup.go

@@ -16,7 +16,7 @@ type GatewayGroupService interface {
 	AddGatewayGroup(ctx context.Context, req v1.AddGateWayGroupRequest) (int, error)
 	EditGatewayGroup(ctx context.Context, req v1.AddGateWayGroupAdminRequest) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
-	GetGatewayGroupByHostId(ctx context.Context, hostId int) ([]model.GatewayGroup, error)
+	GetGatewayGroupByHostId(ctx context.Context, hostId int) (model.GatewayGroup, error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup]	, error)
 	AddGatewayGroupAdmin(ctx context.Context,req v1.AddGateWayGroupAdminRequest) error
 	EditGatewayGroupAdmin(ctx context.Context, req v1.AddGateWayGroupAdminRequest) error
@@ -76,10 +76,10 @@ func (s *gatewayGroupService) AddGatewayGroup(ctx context.Context, req v1.AddGat
 	return gateWayGroupId, nil
 }
 
-func (s *gatewayGroupService) GetGatewayGroupByHostId(ctx context.Context, hostId int) ([]model.GatewayGroup, error) {
+func (s *gatewayGroupService) GetGatewayGroupByHostId(ctx context.Context, hostId int) (model.GatewayGroup, error) {
 	res, err := s.gatewayGroupRepository.GetGatewayGroupByHostId(ctx, int64(hostId))
 	if err != nil {
-		return nil, err
+		return model.GatewayGroup{}, err
 	}
 	return *res, nil
 }

+ 0 - 15
internal/service/wafformatter.go

@@ -28,7 +28,6 @@ type WafFormatterService interface {
 	findIpDifferences(oldIps, newIps []string) ([]string, []string)
 	WashDeleteWafIp(ctx context.Context, backendList []string,allowIpList []string) ([]string, error)
 	WashEditWafIp(ctx context.Context, newBackendList []string,newAllowIpList []string,oldBackendList []string,oldAllowIpList []string) ([]string, []string, []string,  []string, error)
-	GetIp(ctx context.Context, gatewayGroupId int) ([]string,string, error)
 	//cdn添加网站
 	AddOrigin(ctx context.Context, req v1.WebJson) (int64, error)
 }
@@ -391,20 +390,6 @@ func (s *wafFormatterService) WashEditWafIp(ctx context.Context, newBackendList
 	return addedIps, removedIps ,addedAllowIps, removedAllowIps, nil
 }
 
-func (s *wafFormatterService) GetIp(ctx context.Context, gatewayGroupId int) ([]string,string, error) {
-	WafGatewayGroupRuleId, err := s.gatewayGroupRep.GetGatewayGroupByRuleId(ctx, int64(gatewayGroupId))
-	if err != nil {
-		return nil, "", err
-	}
-	ips, err := s.gatewayGroupIpRep.GetGateWayGroupAllIpByGatewayGroupId(ctx, WafGatewayGroupRuleId.Id)
-	if err != nil {
-		return nil, "", err
-	}
-	if len(ips) == 0 {
-		return nil, "", fmt.Errorf("请联系客服分配网关IP")
-	}
-	return ips,ips[0], nil
-}
 
 func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (int64, error) {
 	ip, port, err := net.SplitHostPort(req.BackendList)

+ 168 - 70
internal/service/webforwarding.go

@@ -9,9 +9,9 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq"
 	"golang.org/x/sync/errgroup"
+	"maps"
 	"net"
 	"sort"
-	"strconv"
 )
 
 type WebForwardingService interface {
@@ -204,10 +204,17 @@ func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebFo
 		return RequireResponse{}, v1.Website{}, err // 错误信息在辅助函数中已经包装好了
 	}
 
-	var serverName []string
+	type serverNames struct {
+		ServerNames string `json:"name" form:"name"`
+		Type string `json:"type" form:"type"`
+	}
+	var serverName []serverNames
 	var serverJson []byte
 	if req.WebForwardingData.Domain != "" {
-		serverName = append(serverName, req.WebForwardingData.Domain)
+		serverName = append(serverName, serverNames{
+			ServerNames: req.WebForwardingData.Domain,
+			Type: "full",
+		})
 		serverJson, err = json.Marshal(serverName)
 		if err != nil {
 			return RequireResponse{}, v1.Website{}, err
@@ -292,6 +299,35 @@ func (s *webForwardingService) buildProxyJSONConfig(ctx context.Context, req *v1
 	return apiType, byteData, nil
 }
 
+// 查找两个列表的差异
+func (s webForwardingService) FindDifferenceList (oldList, newList []v1.BackendList) (added, removed []v1.BackendList) {
+	diff := make(map[v1.BackendList]int)
+
+	// 1. 遍历旧列表,为每个元素计数 +1
+	for _, item := range oldList {
+		diff[item]++
+	}
+
+	// 2. 遍历新列表,为每个元素计数 -1
+	for _, item := range newList {
+		diff[item]--
+	}
+
+	// 3. 遍历 diff map 来找出差异
+	for item, count := range diff {
+		if count > 0 {
+			// 如果 count > 0,说明这个元素在 oldList 中但不在 newList 中
+			removed = append(removed, item)
+		} else if count < 0 {
+			// 如果 count < 0,说明这个元素在 newList 中但不在 oldList 中
+			added = append(added, item)
+		}
+		// 如果 count == 0,说明元素在两个列表中都存在,不做任何操作
+	}
+
+	return added, removed
+}
+
 
 
 func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
@@ -315,11 +351,10 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 	// 添加源站
 	cdnOriginIds := make(map[string]int64)
 	for _, v := range req.WebForwardingData.BackendList {
-		var apiType string
+		apiType := protocolHttp
+		// 如果条件满足,则覆盖为 HTTPS
 		if v.IsHttps == isHttps {
 			apiType = protocolHttps
-		}else {
-			apiType = protocolHttp
 		}
 		id, err := s.wafformatter.AddOrigin(ctx, v1.WebJson{
 			ApiType:  apiType,
@@ -364,7 +399,11 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		if len(require.GatewayIps) == 0 {
 			return fmt.Errorf("网关组不存在")
 		}
-		go s.wafformatter.PublishDomainWhitelistTask(doMain, require.GatewayIps[0], "add")
+		firstIp,err :=  s.GetGatewayFirstIp(ctx, require.HostId)
+		if err != nil {
+			return err
+		}
+		go s.wafformatter.PublishDomainWhitelistTask(doMain, firstIp, "add")
 
 	}
 
@@ -449,7 +488,10 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 
 	// 异步任务:将域名添加到白名单
 	if webData.Domain != req.WebForwardingData.Domain {
-
+		firstIp,err :=  s.GetGatewayFirstIp(ctx, req.HostId)
+		if err != nil {
+			return err
+		}
 		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, req.WebForwardingData.Domain)
 		if err != nil {
 			return err
@@ -461,8 +503,8 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		if len(require.GatewayIps) == 0 {
 			return fmt.Errorf("网关组不存在")
 		}
-		go s.wafformatter.PublishDomainWhitelistTask(oldDomain, require.GatewayIps[0], "del")
-		go s.wafformatter.PublishDomainWhitelistTask(doMain, require.GatewayIps[0], "add")
+		go s.wafformatter.PublishDomainWhitelistTask(oldDomain, firstIp, "del")
+		go s.wafformatter.PublishDomainWhitelistTask(doMain, firstIp, "add")
 	}
 
 	// IP过白
@@ -523,16 +565,55 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 
 
 
+	//修改源站
+	addOrigins, delOrigins := s.FindDifferenceList(ipData.BackendList, req.WebForwardingData.BackendList)
+	addedIds := make(map[string]int64)
+	for _, v := range addOrigins {
+		var apiType string
+		if v.IsHttps == isHttps {
+			apiType = protocolHttps
+		}else {
+			apiType = protocolHttp
+		}
+		id, err := s.wafformatter.AddOrigin(ctx,v1.WebJson{
+			ApiType: apiType,
+			BackendList: v.Addr,
+			Host: v.CustomHost,
+			Comment: req.WebForwardingData.Comment,
+		})
+		if err != nil {
+			return err
+		}
+		addedIds[v.Addr] = id
+	}
+	for _, v := range addedIds {
+		err = s.cdn.AddServerOrigin(ctx, int64(oldData.CdnWebId), v)
+		if err != nil {
+			return err
+		}
+	}
 
+	maps.Copy(ipData.CdnOriginIds, addedIds)
+	for k, v := range ipData.CdnOriginIds {
+		for _, ip := range delOrigins {
+			if k == ip.Addr {
+				err = s.cdn.DelServerOrigin(ctx, int64(oldData.CdnWebId), v)
+				if err != nil {
+					return err
+				}
+				delete(ipData.CdnOriginIds, k)
+			}
+		}
+	}
 
 
 
-	webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.WafWebId, require)
+	webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.CdnWebId, require)
 	webModel.Id = req.WebForwardingData.Id
 	if err = s.webForwardingRepository.EditWebForwarding(ctx, webModel); err != nil {
 		return err
 	}
-	webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, req.WebForwardingData.Id)
+	webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, req.WebForwardingData.Id,ipData.CdnOriginIds)
 	if err = s.webForwardingRepository.EditWebForwardingIps(ctx, *webRuleModel); err != nil {
 		return err
 	}
@@ -540,63 +621,64 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 }
 
 func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []int) error {
-	//for _, Id := range Ids {
-	//	wafWebId, err := s.webForwardingRepository.GetWebForwardingWafWebIdById(ctx, Id)
-	//	if err != nil {
-	//		return err
-	//	}
-	//	_, err = s.crawler.DeleteRule(ctx, wafWebId, "admin/delete/waf_web?page=1&__pageSize=10&__sort=waf_web_id&__sort_type=desc")
-	//	if err != nil {
-	//		return err
-	//	}
-	//	webData, err := s.webForwardingRepository.GetWebForwarding(ctx, int64(Id))
-	//	if err != nil {
-	//		return err
-	//	}
-	//
-	//	_, firstIp, err := s.wafformatter.GetIp(ctx, webData.WafGatewayGroupId)
-	//	if err != nil {
-	//		return err
-	//	}
-	//	// 异步任务:将域名添加到白名单
-	//	if webData.Domain != "" {
-	//
-	//		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, webData.Domain)
-	//		if err != nil {
-	//			return err
-	//		}
-	//		go s.wafformatter.PublishDomainWhitelistTask(doMain,firstIp, "del")
-	//	}
-	//	// IP过白
-	//	ipData, err := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, Id)
-	//	if err != nil {
-	//		return err
-	//	}
-	//	var ips []string
-	//	if len(ipData.BackendList) > 0 {
-	//		for _, v := range ipData.BackendList {
-	//			ip, _, err := net.SplitHostPort(v.Addr)
-	//			if err != nil {
-	//				return err
-	//			}
-	//			ips = append(ips, ip)
-	//		}
-	//	}
-	//	if len(ipData.AllowIpList) > 0 {
-	//		ips = append(ips, ipData.AllowIpList...)
-	//	}
-	//	if len(ips) > 0 {
-	//		go s.wafformatter.PublishIpWhitelistTask(ips, "del","")
-	//	}
-	//
-	//
-	//	if err = s.webForwardingRepository.DeleteWebForwarding(ctx, int64(Id)); err != nil {
-	//		return err
-	//	}
-	//	if err = s.webForwardingRepository.DeleteWebForwardingIpsById(ctx, Id); err != nil {
-	//		return err
-	//	}
-	//}
+	for _, Id := range Ids {
+		oldData, err := s.webForwardingRepository.GetWebForwarding(ctx, int64(Id))
+		if err != nil {
+			return err
+		}
+
+		err = s.cdn.DelServer(ctx, int64(oldData.CdnWebId))
+		if err != nil {
+			return err
+		}
+
+
+
+
+
+		// 异步任务:将域名添加到白名单
+		firstIp,err :=  s.GetGatewayFirstIp(ctx, oldData.HostId)
+		if err != nil {
+			return err
+		}
+		if oldData.Domain != "" {
+
+			doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, oldData.Domain)
+			if err != nil {
+				return err
+			}
+			go s.wafformatter.PublishDomainWhitelistTask(doMain,firstIp, "del")
+		}
+		// IP过白
+		ipData, err := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, Id)
+		if err != nil {
+			return err
+		}
+		var ips []string
+		if len(ipData.BackendList) > 0 {
+			for _, v := range ipData.BackendList {
+				ip, _, err := net.SplitHostPort(v.Addr)
+				if err != nil {
+					return err
+				}
+				ips = append(ips, ip)
+			}
+		}
+		if len(ipData.AllowIpList) > 0 {
+			ips = append(ips, ipData.AllowIpList...)
+		}
+		if len(ips) > 0 {
+			go s.wafformatter.PublishIpWhitelistTask(ips, "del","")
+		}
+
+
+		if err = s.webForwardingRepository.DeleteWebForwarding(ctx, int64(Id)); err != nil {
+			return err
+		}
+		if err = s.webForwardingRepository.DeleteWebForwardingIpsById(ctx, Id); err != nil {
+			return err
+		}
+	}
 
 	return nil
 }
@@ -724,4 +806,20 @@ func (s *webForwardingService) GetWebForwardingWafWebAllIps(ctx context.Context,
 	return finalResults, nil
 }
 
-
+func (s *webForwardingService) GetGatewayFirstIp(ctx context.Context, hostId int) (string, error) {
+	gateWayGroup, err := s.gatewayGroupRep.GetGatewayGroupByHostId(ctx, int64(hostId))
+	if err != nil {
+		return "", err
+	}
+	if gateWayGroup == nil {
+		return  "",fmt.Errorf("网关组不存在")
+	}
+	gateWayIps, err := s.gatewayGroupIpRep.GetGateWayGroupFirstIpByGatewayGroupId(ctx, gateWayGroup.Id)
+	if err != nil {
+		return "", err
+	}
+	if len(gateWayIps) == 0 {
+		return  "",fmt.Errorf("网关组IP为空")
+	}
+	return gateWayIps, nil
+}