package waf import ( "context" "encoding/json" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf" "github.com/go-nunu/nunu-layout-advanced/internal/service" "strconv" ) type GatewayipService interface { GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64,uid int64) ([]string, error) GetGatewayipByHostIdFirst(ctx context.Context, hostId int64,uid int64) (string, error) AddIpWhereHostIdNull(ctx context.Context, hostId int64,uid int64) error } func NewGatewayipService( service *service.Service, gatewayipRepository waf.GatewayipRepository, host service.HostService, log service.LogService, ) GatewayipService { return &gatewayipService{ Service: service, gatewayipRepository: gatewayipRepository, host : host, log : log, } } type gatewayipService struct { *service.Service gatewayipRepository waf.GatewayipRepository host service.HostService log service.LogService } func (s *gatewayipService) GetGatewayip(ctx context.Context, id int64) (*model.Gatewayip, error) { return s.gatewayipRepository.GetGatewayip(ctx, id) } func (s *gatewayipService) AddIpWhereHostIdNull(ctx context.Context, hostId int64,uid int64) error { config, err := s.host.GetGlobalLimitConfig(ctx, int(hostId)) if err != nil { return err } ips, err := s.gatewayipRepository.GetIpWhereHostIdNull(ctx, v1.GlobalLimitRequireResponse{ HostId: int(hostId), Bps: config.Bps, MaxBytesMonth: config.MaxBytesMonth, IpCount: config.IpCount, Operator: config.Operator, NodeArea: config.NodeArea, ConfigMaxProtection: config.ConfigMaxProtection, IsBanUdp: config.IsBanUdp, IsBanOverseas: config.IsBanOverseas, }); if err != nil { return err } ipsJson, err := json.Marshal(ips) if err != nil { return err } configJson, err := json.Marshal(config) if err != nil { return err } if err = s.log.AddLog(ctx, &model.Log{ Uid: uid, Api: "AddIpWhereHostIdNull,分配网关组IP", Message: string(configJson) + "," + "hostId:" + strconv.FormatInt(hostId, 10), ExtraData: ipsJson, }); err != nil { return err } return nil } func (s *gatewayipService) GetGatewayipOnlyIpByHostIdAll(ctx context.Context, hostId int64,uid int64) ([]string, error) { gatewayIps, err := s.gatewayipRepository.GetGatewayipOnlyIpByHostIdAll(ctx, hostId) if err != nil { return nil, err } if len(gatewayIps) == 0 { expire, err := s.host.CheckExpired(ctx, uid, hostId) if err != nil { return nil, err } if !expire { return nil, fmt.Errorf("产品已过期,请及时续费") } err = s.AddIpWhereHostIdNull(ctx, hostId,uid) if err != nil { return nil, err } gatewayIps, err = s.gatewayipRepository.GetGatewayipOnlyIpByHostIdAll(ctx, hostId) if err != nil { return nil, err } } return gatewayIps, nil } func (s *gatewayipService) GetGatewayipByHostIdFirst(ctx context.Context, hostId int64,uid int64) (string, error) { gatewayIps, err := s.gatewayipRepository.GetGatewayipByHostIdFirst(ctx, hostId) if err != nil { return "", err } if len(gatewayIps) == 0 { err = s.AddIpWhereHostIdNull(ctx, hostId,uid) if err != nil { return "", err } gatewayIps, err = s.gatewayipRepository.GetGatewayipByHostIdFirst(ctx, hostId) if err != nil { return "", err } } return gatewayIps, nil }