globallimit.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. }
  24. func NewGlobalLimitService(
  25. service *Service,
  26. globalLimitRepository repository.GlobalLimitRepository,
  27. duedate DuedateService,
  28. crawler CrawlerService,
  29. conf *viper.Viper,
  30. required RequiredService,
  31. parser ParserService,
  32. host HostService,
  33. gateWayGroup GatewayGroupService,
  34. hostRep repository.HostRepository,
  35. gateWayGroupRep repository.GatewayGroupRepository,
  36. cdnService CdnService,
  37. cdnRep repository.CdnRepository,
  38. tcpforwardingRep repository.TcpforwardingRepository,
  39. udpForWardingRep repository.UdpForWardingRepository,
  40. webForWardingRep repository.WebForwardingRepository,
  41. allowAndDeny AllowAndDenyIpService,
  42. allowAndDenyRep repository.AllowAndDenyIpRepository,
  43. tcpforwarding TcpforwardingService,
  44. udpForWarding UdpForWardingService,
  45. webForWarding WebForwardingService,
  46. ) GlobalLimitService {
  47. return &globalLimitService{
  48. Service: service,
  49. globalLimitRepository: globalLimitRepository,
  50. duedate: duedate,
  51. crawler: crawler,
  52. Url: conf.GetString("crawler.Url"),
  53. required: required,
  54. parser: parser,
  55. host: host,
  56. gateWayGroup: gateWayGroup,
  57. hostRep: hostRep,
  58. gateWayGroupRep: gateWayGroupRep,
  59. cdnService: cdnService,
  60. cdnRep: cdnRep,
  61. tcpforwardingRep: tcpforwardingRep,
  62. udpForWardingRep: udpForWardingRep,
  63. webForWardingRep: webForWardingRep,
  64. allowAndDeny: allowAndDeny,
  65. allowAndDenyRep: allowAndDenyRep,
  66. tcpforwarding: tcpforwarding,
  67. udpForWarding: udpForWarding,
  68. webForWarding: webForWarding,
  69. }
  70. }
  71. type globalLimitService struct {
  72. *Service
  73. globalLimitRepository repository.GlobalLimitRepository
  74. duedate DuedateService
  75. crawler CrawlerService
  76. Url string
  77. required RequiredService
  78. parser ParserService
  79. host HostService
  80. gateWayGroup GatewayGroupService
  81. hostRep repository.HostRepository
  82. gateWayGroupRep repository.GatewayGroupRepository
  83. cdnService CdnService
  84. cdnRep repository.CdnRepository
  85. tcpforwardingRep repository.TcpforwardingRepository
  86. udpForWardingRep repository.UdpForWardingRepository
  87. webForWardingRep repository.WebForwardingRepository
  88. allowAndDeny AllowAndDenyIpService
  89. allowAndDenyRep repository.AllowAndDenyIpRepository
  90. tcpforwarding TcpforwardingService
  91. udpForWarding UdpForWardingService
  92. webForWarding WebForwardingService
  93. }
  94. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  95. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  96. if err != nil {
  97. if !errors.Is(err, gorm.ErrRecordNotFound) {
  98. return 0, err
  99. }
  100. }
  101. if data != nil && data.CdnUid != 0 {
  102. return int64(data.CdnUid), nil
  103. }
  104. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  105. if err != nil {
  106. return 0, err
  107. }
  108. // 中文转拼音
  109. a := pinyin.NewArgs()
  110. a.Style = pinyin.Normal
  111. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  112. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  113. // 查询用户是否存在
  114. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  115. if err != nil {
  116. return 0, err
  117. }
  118. if UserId != 0 {
  119. return UserId, nil
  120. }
  121. // 注册用户
  122. userId, err := s.cdnService.AddUser(ctx, v1.User{
  123. Username: userName,
  124. Email: userInfo.Email,
  125. Fullname: userInfo.Username,
  126. Mobile: userInfo.PhoneNumber,
  127. })
  128. if err != nil {
  129. return 0, err
  130. }
  131. return userId, nil
  132. }
  133. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  134. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  135. Name: groupName,
  136. })
  137. if err != nil {
  138. return 0, err
  139. }
  140. return groupId, nil
  141. }
  142. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  143. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  144. if err != nil {
  145. return v1.GlobalLimitRequireResponse{}, err
  146. }
  147. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  148. if err != nil {
  149. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  150. }
  151. bpsInt, err := strconv.Atoi(configCount.Bps)
  152. if err != nil {
  153. return v1.GlobalLimitRequireResponse{}, err
  154. }
  155. resultFloat := float64(bpsInt) / 2.0 / 8.0
  156. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  157. res.MaxBytesMonth = configCount.MaxBytesMonth
  158. res.Operator = configCount.Operator
  159. res.IpCount = configCount.IpCount
  160. res.NodeArea = configCount.NodeArea
  161. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  162. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  163. if err != nil {
  164. return v1.GlobalLimitRequireResponse{}, err
  165. }
  166. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  167. if err != nil {
  168. return v1.GlobalLimitRequireResponse{}, err
  169. }
  170. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  171. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  172. if err != nil {
  173. return v1.GlobalLimitRequireResponse{}, err
  174. }
  175. return res, nil
  176. }
  177. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  178. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  179. }
  180. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  181. // 2. 将字符串解析成 time.Time 对象
  182. // time.Parse 会根据你提供的布局来理解输入的字符串
  183. t, err := time.Parse("2006-01-02 15:04:05", req)
  184. if err != nil {
  185. // 如果输入的字符串格式和布局不匹配,这里会报错
  186. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  187. }
  188. // 3. 定义新的输出格式 "YYYY-MM-DD"
  189. outputLayout := "2006-01-02"
  190. // 4. 将 time.Time 对象格式化为新的字符串
  191. outputTimeStr := t.Format(outputLayout)
  192. return outputTimeStr, nil
  193. }
  194. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  195. t, err := time.Parse("2006-01-02 15:04:05", req)
  196. if err != nil {
  197. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  198. }
  199. expiredAt := t.Unix()
  200. return expiredAt, nil
  201. }
  202. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  203. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  204. if err != nil {
  205. return err
  206. }
  207. if isExist {
  208. return fmt.Errorf("配置限制已存在")
  209. }
  210. require, err := s.GlobalLimitRequire(ctx, req)
  211. if err != nil {
  212. return err
  213. }
  214. g, gCtx := errgroup.WithContext(ctx)
  215. var gatewayGroupId int
  216. var userId int64
  217. var groupId int64
  218. g.Go(func() error {
  219. res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
  220. if e != nil {
  221. return fmt.Errorf("获取网关组失败: %w", e)
  222. }
  223. if res == 0 {
  224. return fmt.Errorf("获取网关组失败")
  225. }
  226. gatewayGroupId = res
  227. return nil
  228. })
  229. g.Go(func() error {
  230. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  231. if e != nil {
  232. return fmt.Errorf("获取cdn用户失败: %w", e)
  233. }
  234. if res == 0 {
  235. return fmt.Errorf("获取cdn用户失败")
  236. }
  237. userId = res
  238. return nil
  239. })
  240. g.Go(func() error {
  241. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  242. if e != nil {
  243. return fmt.Errorf("创建规则分组失败: %w", e)
  244. }
  245. if res == 0 {
  246. return fmt.Errorf("创建规则分组失败")
  247. }
  248. groupId = res
  249. return nil
  250. })
  251. if err = g.Wait(); err != nil {
  252. return err
  253. }
  254. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  255. if err != nil {
  256. return err
  257. }
  258. // 获取套餐ID
  259. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  260. if maxProtection == "" {
  261. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  262. }
  263. maxProtectionInt, err := strconv.Atoi(maxProtection)
  264. if err != nil {
  265. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  266. }
  267. var planId int64
  268. maxProtectionNum := 1
  269. if maxProtectionInt >= 2000 {
  270. maxProtectionNum = maxProtectionInt / 1000
  271. }
  272. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  273. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  274. if err != nil {
  275. if errors.Is(err, gorm.ErrRecordNotFound) {
  276. planId = 0
  277. }else {
  278. return err
  279. }
  280. }
  281. if planId == 0 {
  282. // 安全冗余套餐
  283. planId = 6
  284. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  285. }
  286. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  287. UserId: userId,
  288. PlanId: planId,
  289. DayTo: outputTimeStr,
  290. Name: require.GlobalLimitName,
  291. IsFree: true,
  292. Period: "monthly",
  293. CountPeriod: 1,
  294. PeriodDayTo: outputTimeStr,
  295. })
  296. if err != nil {
  297. return err
  298. }
  299. if ruleId == 0 {
  300. return fmt.Errorf("分配套餐失败")
  301. }
  302. err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
  303. Id: gatewayGroupId,
  304. HostId: req.HostId,
  305. })
  306. if err != nil {
  307. return err
  308. }
  309. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  310. if err != nil {
  311. return err
  312. }
  313. // 如果存在实例,恢复
  314. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  315. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  316. return err
  317. }
  318. if oldData!= nil && oldData.Id != 0 {
  319. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  320. HostId: req.HostId,
  321. Uid: req.Uid,
  322. Name: require.GlobalLimitName,
  323. RuleId: int(ruleId),
  324. GroupId: int(groupId),
  325. GatewayGroupId: gatewayGroupId,
  326. CdnUid: int(userId),
  327. Comment: req.Comment,
  328. ExpiredAt: expiredAt,
  329. State: 1,
  330. })
  331. if err != nil {
  332. return err
  333. }
  334. return nil
  335. }
  336. // 如果不存在实例,创建
  337. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  338. HostId: req.HostId,
  339. Uid: req.Uid,
  340. Name: require.GlobalLimitName,
  341. RuleId: int(ruleId),
  342. GroupId: int(groupId),
  343. GatewayGroupId: gatewayGroupId,
  344. CdnUid: int(userId),
  345. Comment: req.Comment,
  346. State: 1,
  347. ExpiredAt: expiredAt,
  348. })
  349. if err != nil {
  350. return err
  351. }
  352. return nil
  353. }
  354. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  355. require, err := s.GlobalLimitRequire(ctx, req)
  356. if err != nil {
  357. return err
  358. }
  359. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  360. if err != nil {
  361. return err
  362. }
  363. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  364. if err != nil {
  365. return err
  366. }
  367. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  368. UserPlanId: int64(data.RuleId),
  369. DayTo: outputTimeStr,
  370. Period: "monthly",
  371. CountPeriod: 1,
  372. IsFree: true,
  373. PeriodDayTo: outputTimeStr,
  374. })
  375. if err != nil {
  376. return err
  377. }
  378. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  379. if err != nil {
  380. return err
  381. }
  382. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  383. HostId: req.HostId,
  384. Comment: req.Comment,
  385. ExpiredAt: expiredAt,
  386. }); err != nil {
  387. return err
  388. }
  389. return nil
  390. }
  391. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  392. // 检查是否过期
  393. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  394. if err != nil {
  395. return err
  396. }
  397. if isExpired {
  398. return fmt.Errorf("实例未过期,无法删除")
  399. }
  400. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  401. if err != nil {
  402. if errors.Is(err, gorm.ErrRecordNotFound) {
  403. return fmt.Errorf("实例不存在")
  404. }
  405. return err
  406. }
  407. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  408. if err != nil {
  409. return err
  410. }
  411. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  412. if err != nil {
  413. return err
  414. }
  415. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  416. if err != nil {
  417. return err
  418. }
  419. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  420. if err != nil {
  421. return err
  422. }
  423. // 删除网站
  424. g, gCtx := errgroup.WithContext(ctx)
  425. g.Go(func() error {
  426. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  427. Ids: tcpIds,
  428. Uid: req.Uid,
  429. HostId: req.HostId,
  430. })
  431. if e != nil {
  432. return fmt.Errorf("删除TCP转发失败: %w", e)
  433. }
  434. return nil
  435. })
  436. g.Go(func() error {
  437. e := s.udpForWarding.DeleteUdpForwarding(ctx,udpIds)
  438. if e != nil {
  439. return fmt.Errorf("删除UDP转发失败: %w", e)
  440. }
  441. return nil
  442. })
  443. g.Go(func() error {
  444. e := s.webForWarding.DeleteWebForwarding(ctx,webIds)
  445. if e != nil {
  446. return fmt.Errorf("删除WEB转发失败: %w", e)
  447. }
  448. return nil
  449. })
  450. // 删除套餐
  451. g.Go(func() error {
  452. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  453. if e != nil {
  454. return fmt.Errorf("删除套餐失败: %w", e)
  455. }
  456. return nil
  457. })
  458. // 删除网站分组
  459. g.Go(func() error {
  460. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  461. if e != nil {
  462. return fmt.Errorf("删除网站分组失败: %w", e)
  463. }
  464. return nil
  465. })
  466. if err = g.Wait(); err != nil {
  467. return err
  468. }
  469. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), 0); err != nil {
  470. return err
  471. }
  472. if err := s.gateWayGroupRep.EditGatewayGroup(ctx,&model.GatewayGroup{
  473. Id: oldData.GatewayGroupId,
  474. HostId: 0,
  475. }); err != nil {
  476. return err
  477. }
  478. // 删除黑白名单
  479. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  480. HostId: req.HostId,
  481. Ids: BwIds,
  482. Uid: req.Uid,
  483. })
  484. if err != nil {
  485. return err
  486. }
  487. return nil
  488. }