Browse Source

refactor(webforwarding): 重构网站转发功能

- 修改了获取网关组 IP 的逻辑,优化了 IP 获取的准确性和效率- 重构了添加和删除网站转发的流程,提高了代码的可读性和可维护性
- 新增了查找两个列表差异的辅助函数,简化了对后端列表变化的处理
- 调整了与 CDN 相关的操作,确保源站添加和删除的正确性
- 优化了域名白名单和 IP 白名单的处理逻辑,提高了安全性
fusu 1 month ago
parent
commit
b99fe12f78

+ 3 - 3
internal/repository/gatewaygroup.go

@@ -17,7 +17,7 @@ type GatewayGroupRepository interface {
 	EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
 	EditGatewayGroup(ctx context.Context, req *model.GatewayGroup) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
 	GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error)
 	GetGatewayGroupWhereHostIdNull(ctx context.Context,operator int, count int) (int, error)
-	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error)
+	GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*model.GatewayGroup, error)
 	GetGatewayGroupByRuleId(ctx context.Context, ruleId int64) (*model.GatewayGroup, error)
 	GetGatewayGroupByRuleId(ctx context.Context, ruleId int64) (*model.GatewayGroup, error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup], error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup], error)
 	EditGatewayGroupById(ctx context.Context, req *model.GatewayGroup) error
 	EditGatewayGroupById(ctx context.Context, req *model.GatewayGroup) error
@@ -83,8 +83,8 @@ func (r *gatewayGroupRepository) GetGatewayGroupWhereHostIdNull(ctx context.Cont
 	return id, nil
 	return id, nil
 }
 }
 
 
-func (r *gatewayGroupRepository) GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*[]model.GatewayGroup, error) {
-	res := []model.GatewayGroup{}
+func (r *gatewayGroupRepository) GetGatewayGroupByHostId(ctx context.Context, hostId int64) (*model.GatewayGroup, error) {
+	res := model.GatewayGroup{}
 	if err := r.DB(ctx).Where("host_id = ?", hostId).Find(&res).Error; err != nil {
 	if err := r.DB(ctx).Where("host_id = ?", hostId).Find(&res).Error; err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 1 - 1
internal/repository/gatewaygroupip.go

@@ -74,7 +74,7 @@ func (r *gateWayGroupIpRepository) GetGateWayGroupIpByGatewayGroupId(ctx context
 
 
 func (r *gateWayGroupIpRepository) GetGateWayGroupFirstIpByGatewayGroupId(ctx context.Context, gatewayGroupId int) (string, error) {
 func (r *gateWayGroupIpRepository) GetGateWayGroupFirstIpByGatewayGroupId(ctx context.Context, gatewayGroupId int) (string, error) {
 	var res string
 	var res string
-	if err := r.DB(ctx).Model(&model.GateWayGroupIp{}).Where("gateway_group_id = ?", gatewayGroupId).Select("ip").First(&res).Error; err != nil {
+	if err := r.DB(ctx).Model(&model.GateWayGroupIp{}).Where("gateway_group_id = ?", gatewayGroupId).Select("ip").Order("id asc").First(&res).Error; err != nil {
 		return "", err
 		return "", err
 	}
 	}
 	return res, nil
 	return res, nil

+ 1 - 0
internal/repository/webforwarding.go

@@ -151,6 +151,7 @@ func (r *webForwardingRepository) EditWebForwardingIps(ctx context.Context, req
 
 
 	updateData["deny_ip_list"] = req.DenyIpList
 	updateData["deny_ip_list"] = req.DenyIpList
 
 
+	updateData["cdn_origin_ids"] = req.CdnOriginIds
 
 
 	// 始终更新更新时间
 	// 始终更新更新时间
 	updateData["updated_at"] = time.Now()
 	updateData["updated_at"] = time.Now()

+ 3 - 3
internal/service/gatewaygroup.go

@@ -16,7 +16,7 @@ type GatewayGroupService interface {
 	AddGatewayGroup(ctx context.Context, req v1.AddGateWayGroupRequest) (int, error)
 	AddGatewayGroup(ctx context.Context, req v1.AddGateWayGroupRequest) (int, error)
 	EditGatewayGroup(ctx context.Context, req v1.AddGateWayGroupAdminRequest) error
 	EditGatewayGroup(ctx context.Context, req v1.AddGateWayGroupAdminRequest) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
 	DeleteGatewayGroup(ctx context.Context, id int) error
-	GetGatewayGroupByHostId(ctx context.Context, hostId int) ([]model.GatewayGroup, error)
+	GetGatewayGroupByHostId(ctx context.Context, hostId int) (model.GatewayGroup, error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup]	, error)
 	GetGatewayGroupList(ctx context.Context,req v1.SearchGatewayGroupParams) (*v1.PaginatedResponse[model.GatewayGroup]	, error)
 	AddGatewayGroupAdmin(ctx context.Context,req v1.AddGateWayGroupAdminRequest) error
 	AddGatewayGroupAdmin(ctx context.Context,req v1.AddGateWayGroupAdminRequest) error
 	EditGatewayGroupAdmin(ctx context.Context, req v1.AddGateWayGroupAdminRequest) error
 	EditGatewayGroupAdmin(ctx context.Context, req v1.AddGateWayGroupAdminRequest) error
@@ -76,10 +76,10 @@ func (s *gatewayGroupService) AddGatewayGroup(ctx context.Context, req v1.AddGat
 	return gateWayGroupId, nil
 	return gateWayGroupId, nil
 }
 }
 
 
-func (s *gatewayGroupService) GetGatewayGroupByHostId(ctx context.Context, hostId int) ([]model.GatewayGroup, error) {
+func (s *gatewayGroupService) GetGatewayGroupByHostId(ctx context.Context, hostId int) (model.GatewayGroup, error) {
 	res, err := s.gatewayGroupRepository.GetGatewayGroupByHostId(ctx, int64(hostId))
 	res, err := s.gatewayGroupRepository.GetGatewayGroupByHostId(ctx, int64(hostId))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return model.GatewayGroup{}, err
 	}
 	}
 	return *res, nil
 	return *res, nil
 }
 }

+ 0 - 15
internal/service/wafformatter.go

@@ -28,7 +28,6 @@ type WafFormatterService interface {
 	findIpDifferences(oldIps, newIps []string) ([]string, []string)
 	findIpDifferences(oldIps, newIps []string) ([]string, []string)
 	WashDeleteWafIp(ctx context.Context, backendList []string,allowIpList []string) ([]string, error)
 	WashDeleteWafIp(ctx context.Context, backendList []string,allowIpList []string) ([]string, error)
 	WashEditWafIp(ctx context.Context, newBackendList []string,newAllowIpList []string,oldBackendList []string,oldAllowIpList []string) ([]string, []string, []string,  []string, error)
 	WashEditWafIp(ctx context.Context, newBackendList []string,newAllowIpList []string,oldBackendList []string,oldAllowIpList []string) ([]string, []string, []string,  []string, error)
-	GetIp(ctx context.Context, gatewayGroupId int) ([]string,string, error)
 	//cdn添加网站
 	//cdn添加网站
 	AddOrigin(ctx context.Context, req v1.WebJson) (int64, error)
 	AddOrigin(ctx context.Context, req v1.WebJson) (int64, error)
 }
 }
@@ -391,20 +390,6 @@ func (s *wafFormatterService) WashEditWafIp(ctx context.Context, newBackendList
 	return addedIps, removedIps ,addedAllowIps, removedAllowIps, nil
 	return addedIps, removedIps ,addedAllowIps, removedAllowIps, nil
 }
 }
 
 
-func (s *wafFormatterService) GetIp(ctx context.Context, gatewayGroupId int) ([]string,string, error) {
-	WafGatewayGroupRuleId, err := s.gatewayGroupRep.GetGatewayGroupByRuleId(ctx, int64(gatewayGroupId))
-	if err != nil {
-		return nil, "", err
-	}
-	ips, err := s.gatewayGroupIpRep.GetGateWayGroupAllIpByGatewayGroupId(ctx, WafGatewayGroupRuleId.Id)
-	if err != nil {
-		return nil, "", err
-	}
-	if len(ips) == 0 {
-		return nil, "", fmt.Errorf("请联系客服分配网关IP")
-	}
-	return ips,ips[0], nil
-}
 
 
 func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (int64, error) {
 func (s *wafFormatterService) AddOrigin(ctx context.Context, req v1.WebJson) (int64, error) {
 	ip, port, err := net.SplitHostPort(req.BackendList)
 	ip, port, err := net.SplitHostPort(req.BackendList)

+ 168 - 70
internal/service/webforwarding.go

@@ -9,9 +9,9 @@ import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/rabbitmq"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
+	"maps"
 	"net"
 	"net"
 	"sort"
 	"sort"
-	"strconv"
 )
 )
 
 
 type WebForwardingService interface {
 type WebForwardingService interface {
@@ -204,10 +204,17 @@ func (s *webForwardingService) prepareWafData(ctx context.Context, req *v1.WebFo
 		return RequireResponse{}, v1.Website{}, err // 错误信息在辅助函数中已经包装好了
 		return RequireResponse{}, v1.Website{}, err // 错误信息在辅助函数中已经包装好了
 	}
 	}
 
 
-	var serverName []string
+	type serverNames struct {
+		ServerNames string `json:"name" form:"name"`
+		Type string `json:"type" form:"type"`
+	}
+	var serverName []serverNames
 	var serverJson []byte
 	var serverJson []byte
 	if req.WebForwardingData.Domain != "" {
 	if req.WebForwardingData.Domain != "" {
-		serverName = append(serverName, req.WebForwardingData.Domain)
+		serverName = append(serverName, serverNames{
+			ServerNames: req.WebForwardingData.Domain,
+			Type: "full",
+		})
 		serverJson, err = json.Marshal(serverName)
 		serverJson, err = json.Marshal(serverName)
 		if err != nil {
 		if err != nil {
 			return RequireResponse{}, v1.Website{}, err
 			return RequireResponse{}, v1.Website{}, err
@@ -292,6 +299,35 @@ func (s *webForwardingService) buildProxyJSONConfig(ctx context.Context, req *v1
 	return apiType, byteData, nil
 	return apiType, byteData, nil
 }
 }
 
 
+// 查找两个列表的差异
+func (s webForwardingService) FindDifferenceList (oldList, newList []v1.BackendList) (added, removed []v1.BackendList) {
+	diff := make(map[v1.BackendList]int)
+
+	// 1. 遍历旧列表,为每个元素计数 +1
+	for _, item := range oldList {
+		diff[item]++
+	}
+
+	// 2. 遍历新列表,为每个元素计数 -1
+	for _, item := range newList {
+		diff[item]--
+	}
+
+	// 3. 遍历 diff map 来找出差异
+	for item, count := range diff {
+		if count > 0 {
+			// 如果 count > 0,说明这个元素在 oldList 中但不在 newList 中
+			removed = append(removed, item)
+		} else if count < 0 {
+			// 如果 count < 0,说明这个元素在 newList 中但不在 oldList 中
+			added = append(added, item)
+		}
+		// 如果 count == 0,说明元素在两个列表中都存在,不做任何操作
+	}
+
+	return added, removed
+}
+
 
 
 
 
 func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
 func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.WebForwardingRequest) error {
@@ -315,11 +351,10 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 	// 添加源站
 	// 添加源站
 	cdnOriginIds := make(map[string]int64)
 	cdnOriginIds := make(map[string]int64)
 	for _, v := range req.WebForwardingData.BackendList {
 	for _, v := range req.WebForwardingData.BackendList {
-		var apiType string
+		apiType := protocolHttp
+		// 如果条件满足,则覆盖为 HTTPS
 		if v.IsHttps == isHttps {
 		if v.IsHttps == isHttps {
 			apiType = protocolHttps
 			apiType = protocolHttps
-		}else {
-			apiType = protocolHttp
 		}
 		}
 		id, err := s.wafformatter.AddOrigin(ctx, v1.WebJson{
 		id, err := s.wafformatter.AddOrigin(ctx, v1.WebJson{
 			ApiType:  apiType,
 			ApiType:  apiType,
@@ -364,7 +399,11 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		if len(require.GatewayIps) == 0 {
 		if len(require.GatewayIps) == 0 {
 			return fmt.Errorf("网关组不存在")
 			return fmt.Errorf("网关组不存在")
 		}
 		}
-		go s.wafformatter.PublishDomainWhitelistTask(doMain, require.GatewayIps[0], "add")
+		firstIp,err :=  s.GetGatewayFirstIp(ctx, require.HostId)
+		if err != nil {
+			return err
+		}
+		go s.wafformatter.PublishDomainWhitelistTask(doMain, firstIp, "add")
 
 
 	}
 	}
 
 
@@ -449,7 +488,10 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 
 
 	// 异步任务:将域名添加到白名单
 	// 异步任务:将域名添加到白名单
 	if webData.Domain != req.WebForwardingData.Domain {
 	if webData.Domain != req.WebForwardingData.Domain {
-
+		firstIp,err :=  s.GetGatewayFirstIp(ctx, req.HostId)
+		if err != nil {
+			return err
+		}
 		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, req.WebForwardingData.Domain)
 		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, req.WebForwardingData.Domain)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -461,8 +503,8 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 		if len(require.GatewayIps) == 0 {
 		if len(require.GatewayIps) == 0 {
 			return fmt.Errorf("网关组不存在")
 			return fmt.Errorf("网关组不存在")
 		}
 		}
-		go s.wafformatter.PublishDomainWhitelistTask(oldDomain, require.GatewayIps[0], "del")
-		go s.wafformatter.PublishDomainWhitelistTask(doMain, require.GatewayIps[0], "add")
+		go s.wafformatter.PublishDomainWhitelistTask(oldDomain, firstIp, "del")
+		go s.wafformatter.PublishDomainWhitelistTask(doMain, firstIp, "add")
 	}
 	}
 
 
 	// IP过白
 	// IP过白
@@ -523,16 +565,55 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 
 
 
 
 
 
+	//修改源站
+	addOrigins, delOrigins := s.FindDifferenceList(ipData.BackendList, req.WebForwardingData.BackendList)
+	addedIds := make(map[string]int64)
+	for _, v := range addOrigins {
+		var apiType string
+		if v.IsHttps == isHttps {
+			apiType = protocolHttps
+		}else {
+			apiType = protocolHttp
+		}
+		id, err := s.wafformatter.AddOrigin(ctx,v1.WebJson{
+			ApiType: apiType,
+			BackendList: v.Addr,
+			Host: v.CustomHost,
+			Comment: req.WebForwardingData.Comment,
+		})
+		if err != nil {
+			return err
+		}
+		addedIds[v.Addr] = id
+	}
+	for _, v := range addedIds {
+		err = s.cdn.AddServerOrigin(ctx, int64(oldData.CdnWebId), v)
+		if err != nil {
+			return err
+		}
+	}
 
 
+	maps.Copy(ipData.CdnOriginIds, addedIds)
+	for k, v := range ipData.CdnOriginIds {
+		for _, ip := range delOrigins {
+			if k == ip.Addr {
+				err = s.cdn.DelServerOrigin(ctx, int64(oldData.CdnWebId), v)
+				if err != nil {
+					return err
+				}
+				delete(ipData.CdnOriginIds, k)
+			}
+		}
+	}
 
 
 
 
 
 
-	webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.WafWebId, require)
+	webModel := s.buildWebForwardingModel(&req.WebForwardingData, req.WebForwardingData.CdnWebId, require)
 	webModel.Id = req.WebForwardingData.Id
 	webModel.Id = req.WebForwardingData.Id
 	if err = s.webForwardingRepository.EditWebForwarding(ctx, webModel); err != nil {
 	if err = s.webForwardingRepository.EditWebForwarding(ctx, webModel); err != nil {
 		return err
 		return err
 	}
 	}
-	webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, req.WebForwardingData.Id)
+	webRuleModel := s.buildWebRuleModel(&req.WebForwardingData, require, req.WebForwardingData.Id,ipData.CdnOriginIds)
 	if err = s.webForwardingRepository.EditWebForwardingIps(ctx, *webRuleModel); err != nil {
 	if err = s.webForwardingRepository.EditWebForwardingIps(ctx, *webRuleModel); err != nil {
 		return err
 		return err
 	}
 	}
@@ -540,63 +621,64 @@ func (s *webForwardingService) EditWebForwarding(ctx context.Context, req *v1.We
 }
 }
 
 
 func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []int) error {
 func (s *webForwardingService) DeleteWebForwarding(ctx context.Context, Ids []int) error {
-	//for _, Id := range Ids {
-	//	wafWebId, err := s.webForwardingRepository.GetWebForwardingWafWebIdById(ctx, Id)
-	//	if err != nil {
-	//		return err
-	//	}
-	//	_, err = s.crawler.DeleteRule(ctx, wafWebId, "admin/delete/waf_web?page=1&__pageSize=10&__sort=waf_web_id&__sort_type=desc")
-	//	if err != nil {
-	//		return err
-	//	}
-	//	webData, err := s.webForwardingRepository.GetWebForwarding(ctx, int64(Id))
-	//	if err != nil {
-	//		return err
-	//	}
-	//
-	//	_, firstIp, err := s.wafformatter.GetIp(ctx, webData.WafGatewayGroupId)
-	//	if err != nil {
-	//		return err
-	//	}
-	//	// 异步任务:将域名添加到白名单
-	//	if webData.Domain != "" {
-	//
-	//		doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, webData.Domain)
-	//		if err != nil {
-	//			return err
-	//		}
-	//		go s.wafformatter.PublishDomainWhitelistTask(doMain,firstIp, "del")
-	//	}
-	//	// IP过白
-	//	ipData, err := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, Id)
-	//	if err != nil {
-	//		return err
-	//	}
-	//	var ips []string
-	//	if len(ipData.BackendList) > 0 {
-	//		for _, v := range ipData.BackendList {
-	//			ip, _, err := net.SplitHostPort(v.Addr)
-	//			if err != nil {
-	//				return err
-	//			}
-	//			ips = append(ips, ip)
-	//		}
-	//	}
-	//	if len(ipData.AllowIpList) > 0 {
-	//		ips = append(ips, ipData.AllowIpList...)
-	//	}
-	//	if len(ips) > 0 {
-	//		go s.wafformatter.PublishIpWhitelistTask(ips, "del","")
-	//	}
-	//
-	//
-	//	if err = s.webForwardingRepository.DeleteWebForwarding(ctx, int64(Id)); err != nil {
-	//		return err
-	//	}
-	//	if err = s.webForwardingRepository.DeleteWebForwardingIpsById(ctx, Id); err != nil {
-	//		return err
-	//	}
-	//}
+	for _, Id := range Ids {
+		oldData, err := s.webForwardingRepository.GetWebForwarding(ctx, int64(Id))
+		if err != nil {
+			return err
+		}
+
+		err = s.cdn.DelServer(ctx, int64(oldData.CdnWebId))
+		if err != nil {
+			return err
+		}
+
+
+
+
+
+		// 异步任务:将域名添加到白名单
+		firstIp,err :=  s.GetGatewayFirstIp(ctx, oldData.HostId)
+		if err != nil {
+			return err
+		}
+		if oldData.Domain != "" {
+
+			doMain, err := s.wafformatter.ConvertToWildcardDomain(ctx, oldData.Domain)
+			if err != nil {
+				return err
+			}
+			go s.wafformatter.PublishDomainWhitelistTask(doMain,firstIp, "del")
+		}
+		// IP过白
+		ipData, err := s.webForwardingRepository.GetWebForwardingIpsByID(ctx, Id)
+		if err != nil {
+			return err
+		}
+		var ips []string
+		if len(ipData.BackendList) > 0 {
+			for _, v := range ipData.BackendList {
+				ip, _, err := net.SplitHostPort(v.Addr)
+				if err != nil {
+					return err
+				}
+				ips = append(ips, ip)
+			}
+		}
+		if len(ipData.AllowIpList) > 0 {
+			ips = append(ips, ipData.AllowIpList...)
+		}
+		if len(ips) > 0 {
+			go s.wafformatter.PublishIpWhitelistTask(ips, "del","")
+		}
+
+
+		if err = s.webForwardingRepository.DeleteWebForwarding(ctx, int64(Id)); err != nil {
+			return err
+		}
+		if err = s.webForwardingRepository.DeleteWebForwardingIpsById(ctx, Id); err != nil {
+			return err
+		}
+	}
 
 
 	return nil
 	return nil
 }
 }
@@ -724,4 +806,20 @@ func (s *webForwardingService) GetWebForwardingWafWebAllIps(ctx context.Context,
 	return finalResults, nil
 	return finalResults, nil
 }
 }
 
 
-
+func (s *webForwardingService) GetGatewayFirstIp(ctx context.Context, hostId int) (string, error) {
+	gateWayGroup, err := s.gatewayGroupRep.GetGatewayGroupByHostId(ctx, int64(hostId))
+	if err != nil {
+		return "", err
+	}
+	if gateWayGroup == nil {
+		return  "",fmt.Errorf("网关组不存在")
+	}
+	gateWayIps, err := s.gatewayGroupIpRep.GetGateWayGroupFirstIpByGatewayGroupId(ctx, gateWayGroup.Id)
+	if err != nil {
+		return "", err
+	}
+	if len(gateWayIps) == 0 {
+		return  "",fmt.Errorf("网关组IP为空")
+	}
+	return gateWayIps, nil
+}