globallimit.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. }
  24. func NewGlobalLimitService(
  25. service *Service,
  26. globalLimitRepository repository.GlobalLimitRepository,
  27. duedate DuedateService,
  28. crawler CrawlerService,
  29. conf *viper.Viper,
  30. required RequiredService,
  31. parser ParserService,
  32. host HostService,
  33. tcpLimit TcpLimitService,
  34. udpLimit UdpLimitService,
  35. webLimit WebLimitService,
  36. gateWayGroup GatewayGroupService,
  37. hostRep repository.HostRepository,
  38. gateWayGroupRep repository.GatewayGroupRepository,
  39. cdnService CdnService,
  40. cdnRep repository.CdnRepository,
  41. ) GlobalLimitService {
  42. return &globalLimitService{
  43. Service: service,
  44. globalLimitRepository: globalLimitRepository,
  45. duedate: duedate,
  46. crawler: crawler,
  47. Url: conf.GetString("crawler.Url"),
  48. required: required,
  49. parser: parser,
  50. host: host,
  51. tcpLimit: tcpLimit,
  52. udpLimit: udpLimit,
  53. webLimit: webLimit,
  54. gateWayGroup: gateWayGroup,
  55. hostRep: hostRep,
  56. gateWayGroupRep: gateWayGroupRep,
  57. cdnService: cdnService,
  58. cdnRep: cdnRep,
  59. }
  60. }
  61. type globalLimitService struct {
  62. *Service
  63. globalLimitRepository repository.GlobalLimitRepository
  64. duedate DuedateService
  65. crawler CrawlerService
  66. Url string
  67. required RequiredService
  68. parser ParserService
  69. host HostService
  70. tcpLimit TcpLimitService
  71. udpLimit UdpLimitService
  72. webLimit WebLimitService
  73. gateWayGroup GatewayGroupService
  74. hostRep repository.HostRepository
  75. gateWayGroupRep repository.GatewayGroupRepository
  76. cdnService CdnService
  77. cdnRep repository.CdnRepository
  78. }
  79. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  80. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  81. if err != nil {
  82. if !errors.Is(err, gorm.ErrRecordNotFound) {
  83. return 0, err
  84. }
  85. }
  86. if data != nil && data.CdnUid != 0 {
  87. return int64(data.CdnUid), nil
  88. }
  89. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  90. if err != nil {
  91. return 0, err
  92. }
  93. // 中文转拼音
  94. a := pinyin.NewArgs()
  95. a.Style = pinyin.Normal
  96. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  97. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  98. // 查询用户是否存在
  99. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  100. if err != nil {
  101. return 0, err
  102. }
  103. if UserId != 0 {
  104. return UserId, nil
  105. }
  106. // 注册用户
  107. userId, err := s.cdnService.AddUser(ctx, v1.User{
  108. Username: userName,
  109. Email: userInfo.Email,
  110. Fullname: userInfo.Username,
  111. Mobile: userInfo.PhoneNumber,
  112. })
  113. if err != nil {
  114. return 0, err
  115. }
  116. return userId, nil
  117. }
  118. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  119. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  120. Name: groupName,
  121. })
  122. if err != nil {
  123. return 0, err
  124. }
  125. return groupId, nil
  126. }
  127. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  128. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  129. if err != nil {
  130. return v1.GlobalLimitRequireResponse{}, err
  131. }
  132. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  133. if err != nil {
  134. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  135. }
  136. bpsInt, err := strconv.Atoi(configCount.Bps)
  137. if err != nil {
  138. return v1.GlobalLimitRequireResponse{}, err
  139. }
  140. resultFloat := float64(bpsInt) / 2.0 / 8.0
  141. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  142. res.MaxBytesMonth = configCount.MaxBytesMonth
  143. res.Operator = configCount.Operator
  144. res.IpCount = configCount.IpCount
  145. res.NodeArea = configCount.NodeArea
  146. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  147. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  148. if err != nil {
  149. return v1.GlobalLimitRequireResponse{}, err
  150. }
  151. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  152. if err != nil {
  153. return v1.GlobalLimitRequireResponse{}, err
  154. }
  155. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  156. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  157. if err != nil {
  158. return v1.GlobalLimitRequireResponse{}, err
  159. }
  160. return res, nil
  161. }
  162. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  163. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  164. }
  165. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  166. // 2. 将字符串解析成 time.Time 对象
  167. // time.Parse 会根据你提供的布局来理解输入的字符串
  168. t, err := time.Parse("2006-01-02 15:04:05", req)
  169. if err != nil {
  170. // 如果输入的字符串格式和布局不匹配,这里会报错
  171. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  172. }
  173. // 3. 定义新的输出格式 "YYYY-MM-DD"
  174. outputLayout := "2006-01-02"
  175. // 4. 将 time.Time 对象格式化为新的字符串
  176. outputTimeStr := t.Format(outputLayout)
  177. return outputTimeStr, nil
  178. }
  179. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  180. t, err := time.Parse("2006-01-02 15:04:05", req)
  181. if err != nil {
  182. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  183. }
  184. expiredAt := t.Unix()
  185. return expiredAt, nil
  186. }
  187. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  188. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  189. if err != nil {
  190. return err
  191. }
  192. if isExist {
  193. return fmt.Errorf("配置限制已存在")
  194. }
  195. require, err := s.GlobalLimitRequire(ctx, req)
  196. if err != nil {
  197. return err
  198. }
  199. g, gCtx := errgroup.WithContext(ctx)
  200. var gatewayGroupId int
  201. var userId int64
  202. var groupId int64
  203. g.Go(func() error {
  204. res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
  205. if e != nil {
  206. return fmt.Errorf("获取网关组失败: %w", e)
  207. }
  208. if res == 0 {
  209. return fmt.Errorf("获取网关组失败")
  210. }
  211. gatewayGroupId = res
  212. return nil
  213. })
  214. g.Go(func() error {
  215. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  216. if e != nil {
  217. return fmt.Errorf("获取cdn用户失败: %w", e)
  218. }
  219. if res == 0 {
  220. return fmt.Errorf("获取cdn用户失败")
  221. }
  222. userId = res
  223. return nil
  224. })
  225. g.Go(func() error {
  226. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  227. if e != nil {
  228. return fmt.Errorf("创建规则分组失败: %w", e)
  229. }
  230. if res == 0 {
  231. return fmt.Errorf("创建规则分组失败")
  232. }
  233. groupId = res
  234. return nil
  235. })
  236. if err = g.Wait(); err != nil {
  237. return err
  238. }
  239. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  240. if err != nil {
  241. return err
  242. }
  243. // 获取套餐ID
  244. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  245. if maxProtection == "" {
  246. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  247. }
  248. maxProtectionInt, err := strconv.Atoi(maxProtection)
  249. if err != nil {
  250. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  251. }
  252. var planId int64
  253. maxProtectionNum := 1
  254. if maxProtectionInt >= 2000 {
  255. maxProtectionNum = maxProtectionInt / 1000
  256. }
  257. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  258. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  259. if err != nil {
  260. if errors.Is(err, gorm.ErrRecordNotFound) {
  261. planId = 0
  262. }else {
  263. return err
  264. }
  265. }
  266. if planId == 0 {
  267. // 安全冗余套餐
  268. planId = 6
  269. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  270. }
  271. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  272. UserId: userId,
  273. PlanId: planId,
  274. DayTo: outputTimeStr,
  275. Name: require.GlobalLimitName,
  276. IsFree: true,
  277. Period: "monthly",
  278. CountPeriod: 1,
  279. PeriodDayTo: outputTimeStr,
  280. })
  281. if err != nil {
  282. return err
  283. }
  284. if ruleId == 0 {
  285. return fmt.Errorf("分配套餐失败")
  286. }
  287. err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
  288. Id: gatewayGroupId,
  289. HostId: req.HostId,
  290. })
  291. if err != nil {
  292. return err
  293. }
  294. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  295. if err != nil {
  296. return err
  297. }
  298. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  299. HostId: req.HostId,
  300. Uid: req.Uid,
  301. Name: require.GlobalLimitName,
  302. RuleId: int(ruleId),
  303. GroupId: int(groupId),
  304. GatewayGroupId: gatewayGroupId,
  305. CdnUid: int(userId),
  306. Comment: req.Comment,
  307. ExpiredAt: expiredAt,
  308. })
  309. if err != nil {
  310. return err
  311. }
  312. return nil
  313. }
  314. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  315. require, err := s.GlobalLimitRequire(ctx, req)
  316. if err != nil {
  317. return err
  318. }
  319. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  320. if err != nil {
  321. return err
  322. }
  323. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  324. if err != nil {
  325. return err
  326. }
  327. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  328. UserPlanId: int64(data.RuleId),
  329. DayTo: outputTimeStr,
  330. Period: "monthly",
  331. CountPeriod: 1,
  332. IsFree: true,
  333. PeriodDayTo: outputTimeStr,
  334. })
  335. if err != nil {
  336. return err
  337. }
  338. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  339. if err != nil {
  340. return err
  341. }
  342. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  343. HostId: req.HostId,
  344. Comment: req.Comment,
  345. ExpiredAt: expiredAt,
  346. }); err != nil {
  347. return err
  348. }
  349. return nil
  350. }
  351. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  352. if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil {
  353. return err
  354. }
  355. return nil
  356. }