globallimit.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. }
  24. func NewGlobalLimitService(
  25. service *Service,
  26. globalLimitRepository repository.GlobalLimitRepository,
  27. duedate DuedateService,
  28. crawler CrawlerService,
  29. conf *viper.Viper,
  30. required RequiredService,
  31. parser ParserService,
  32. host HostService,
  33. gateWayGroup GatewayGroupService,
  34. hostRep repository.HostRepository,
  35. gateWayGroupRep repository.GatewayGroupRepository,
  36. cdnService CdnService,
  37. cdnRep repository.CdnRepository,
  38. tcpforwardingRep repository.TcpforwardingRepository,
  39. udpForWardingRep repository.UdpForWardingRepository,
  40. webForWardingRep repository.WebForwardingRepository,
  41. allowAndDeny AllowAndDenyIpService,
  42. allowAndDenyRep repository.AllowAndDenyIpRepository,
  43. tcpforwarding TcpforwardingService,
  44. udpForWarding UdpForWardingService,
  45. webForWarding WebForwardingService,
  46. ) GlobalLimitService {
  47. return &globalLimitService{
  48. Service: service,
  49. globalLimitRepository: globalLimitRepository,
  50. duedate: duedate,
  51. crawler: crawler,
  52. Url: conf.GetString("crawler.Url"),
  53. required: required,
  54. parser: parser,
  55. host: host,
  56. gateWayGroup: gateWayGroup,
  57. hostRep: hostRep,
  58. gateWayGroupRep: gateWayGroupRep,
  59. cdnService: cdnService,
  60. cdnRep: cdnRep,
  61. tcpforwardingRep: tcpforwardingRep,
  62. udpForWardingRep: udpForWardingRep,
  63. webForWardingRep: webForWardingRep,
  64. allowAndDeny: allowAndDeny,
  65. allowAndDenyRep: allowAndDenyRep,
  66. tcpforwarding: tcpforwarding,
  67. udpForWarding: udpForWarding,
  68. webForWarding: webForWarding,
  69. }
  70. }
  71. type globalLimitService struct {
  72. *Service
  73. globalLimitRepository repository.GlobalLimitRepository
  74. duedate DuedateService
  75. crawler CrawlerService
  76. Url string
  77. required RequiredService
  78. parser ParserService
  79. host HostService
  80. gateWayGroup GatewayGroupService
  81. hostRep repository.HostRepository
  82. gateWayGroupRep repository.GatewayGroupRepository
  83. cdnService CdnService
  84. cdnRep repository.CdnRepository
  85. tcpforwardingRep repository.TcpforwardingRepository
  86. udpForWardingRep repository.UdpForWardingRepository
  87. webForWardingRep repository.WebForwardingRepository
  88. allowAndDeny AllowAndDenyIpService
  89. allowAndDenyRep repository.AllowAndDenyIpRepository
  90. tcpforwarding TcpforwardingService
  91. udpForWarding UdpForWardingService
  92. webForWarding WebForwardingService
  93. }
  94. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  95. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  96. if err != nil {
  97. if !errors.Is(err, gorm.ErrRecordNotFound) {
  98. return 0, err
  99. }
  100. }
  101. if data != nil && data.CdnUid != 0 {
  102. return int64(data.CdnUid), nil
  103. }
  104. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  105. if err != nil {
  106. return 0, err
  107. }
  108. // 中文转拼音
  109. a := pinyin.NewArgs()
  110. a.Style = pinyin.Normal
  111. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  112. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  113. // 查询用户是否存在
  114. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  115. if err != nil {
  116. return 0, err
  117. }
  118. if UserId != 0 {
  119. return UserId, nil
  120. }
  121. // 注册用户
  122. userId, err := s.cdnService.AddUser(ctx, v1.User{
  123. Username: userName,
  124. Email: userInfo.Email,
  125. Fullname: userInfo.Username,
  126. Mobile: userInfo.PhoneNumber,
  127. })
  128. if err != nil {
  129. return 0, err
  130. }
  131. return userId, nil
  132. }
  133. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  134. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  135. Name: groupName,
  136. })
  137. if err != nil {
  138. return 0, err
  139. }
  140. return groupId, nil
  141. }
  142. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  143. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  144. if err != nil {
  145. return v1.GlobalLimitRequireResponse{}, err
  146. }
  147. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  148. if err != nil {
  149. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  150. }
  151. bpsInt, err := strconv.Atoi(configCount.Bps)
  152. if err != nil {
  153. return v1.GlobalLimitRequireResponse{}, err
  154. }
  155. resultFloat := float64(bpsInt) / 2.0 / 8.0
  156. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  157. res.MaxBytesMonth = configCount.MaxBytesMonth
  158. res.Operator = configCount.Operator
  159. res.IpCount = configCount.IpCount
  160. res.NodeArea = configCount.NodeArea
  161. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  162. res.IsBanUdp = configCount.IsBanUdp
  163. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  164. if err != nil {
  165. return v1.GlobalLimitRequireResponse{}, err
  166. }
  167. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  168. if err != nil {
  169. return v1.GlobalLimitRequireResponse{}, err
  170. }
  171. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  172. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  173. if err != nil {
  174. return v1.GlobalLimitRequireResponse{}, err
  175. }
  176. return res, nil
  177. }
  178. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  179. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  180. }
  181. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  182. // 2. 将字符串解析成 time.Time 对象
  183. // time.Parse 会根据你提供的布局来理解输入的字符串
  184. t, err := time.Parse("2006-01-02 15:04:05", req)
  185. if err != nil {
  186. // 如果输入的字符串格式和布局不匹配,这里会报错
  187. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  188. }
  189. // 3. 定义新的输出格式 "YYYY-MM-DD"
  190. outputLayout := "2006-01-02"
  191. // 4. 将 time.Time 对象格式化为新的字符串
  192. outputTimeStr := t.Format(outputLayout)
  193. return outputTimeStr, nil
  194. }
  195. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  196. t, err := time.Parse("2006-01-02 15:04:05", req)
  197. if err != nil {
  198. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  199. }
  200. expiredAt := t.Unix()
  201. return expiredAt, nil
  202. }
  203. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  204. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  205. if err != nil {
  206. return err
  207. }
  208. if isExist {
  209. return fmt.Errorf("配置限制已存在")
  210. }
  211. require, err := s.GlobalLimitRequire(ctx, req)
  212. if err != nil {
  213. return err
  214. }
  215. g, gCtx := errgroup.WithContext(ctx)
  216. var gatewayGroupId int
  217. var userId int64
  218. var groupId int64
  219. g.Go(func() error {
  220. res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require)
  221. if e != nil {
  222. return fmt.Errorf("获取网关组失败: %w", e)
  223. }
  224. if res == 0 {
  225. return fmt.Errorf("获取网关组失败")
  226. }
  227. gatewayGroupId = res
  228. return nil
  229. })
  230. g.Go(func() error {
  231. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  232. if e != nil {
  233. return fmt.Errorf("获取cdn用户失败: %w", e)
  234. }
  235. if res == 0 {
  236. return fmt.Errorf("获取cdn用户失败")
  237. }
  238. userId = res
  239. return nil
  240. })
  241. g.Go(func() error {
  242. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  243. if e != nil {
  244. return fmt.Errorf("创建规则分组失败: %w", e)
  245. }
  246. if res == 0 {
  247. return fmt.Errorf("创建规则分组失败")
  248. }
  249. groupId = res
  250. return nil
  251. })
  252. if err = g.Wait(); err != nil {
  253. return err
  254. }
  255. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  256. if err != nil {
  257. return err
  258. }
  259. // 获取套餐ID
  260. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  261. if maxProtection == "" {
  262. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  263. }
  264. maxProtectionInt, err := strconv.Atoi(maxProtection)
  265. if err != nil {
  266. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  267. }
  268. var planId int64
  269. maxProtectionNum := 1
  270. if maxProtectionInt >= 2000 {
  271. maxProtectionNum = maxProtectionInt / 1000
  272. }
  273. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  274. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  275. if err != nil {
  276. if errors.Is(err, gorm.ErrRecordNotFound) {
  277. planId = 0
  278. }else {
  279. return err
  280. }
  281. }
  282. if planId == 0 {
  283. // 安全冗余套餐
  284. planId = 6
  285. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  286. }
  287. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  288. UserId: userId,
  289. PlanId: planId,
  290. DayTo: outputTimeStr,
  291. Name: require.GlobalLimitName,
  292. IsFree: true,
  293. Period: "monthly",
  294. CountPeriod: 1,
  295. PeriodDayTo: outputTimeStr,
  296. })
  297. if err != nil {
  298. return err
  299. }
  300. if ruleId == 0 {
  301. return fmt.Errorf("分配套餐失败")
  302. }
  303. err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
  304. Id: gatewayGroupId,
  305. HostId: req.HostId,
  306. })
  307. if err != nil {
  308. return err
  309. }
  310. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  311. if err != nil {
  312. return err
  313. }
  314. // 如果存在实例,恢复
  315. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  316. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  317. return err
  318. }
  319. if oldData!= nil && oldData.Id != 0 {
  320. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  321. HostId: req.HostId,
  322. Uid: req.Uid,
  323. Name: require.GlobalLimitName,
  324. RuleId: int(ruleId),
  325. GroupId: int(groupId),
  326. GatewayGroupId: gatewayGroupId,
  327. CdnUid: int(userId),
  328. Comment: req.Comment,
  329. ExpiredAt: expiredAt,
  330. State: true,
  331. })
  332. if err != nil {
  333. return err
  334. }
  335. return nil
  336. }
  337. // 如果不存在实例,创建
  338. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  339. HostId: req.HostId,
  340. Uid: req.Uid,
  341. Name: require.GlobalLimitName,
  342. RuleId: int(ruleId),
  343. GroupId: int(groupId),
  344. GatewayGroupId: gatewayGroupId,
  345. CdnUid: int(userId),
  346. Comment: req.Comment,
  347. State: true,
  348. ExpiredAt: expiredAt,
  349. })
  350. if err != nil {
  351. return err
  352. }
  353. return nil
  354. }
  355. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  356. require, err := s.GlobalLimitRequire(ctx, req)
  357. if err != nil {
  358. return err
  359. }
  360. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  361. if err != nil {
  362. return err
  363. }
  364. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  365. if err != nil {
  366. return err
  367. }
  368. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  369. UserPlanId: int64(data.RuleId),
  370. DayTo: outputTimeStr,
  371. Period: "monthly",
  372. CountPeriod: 1,
  373. IsFree: true,
  374. PeriodDayTo: outputTimeStr,
  375. })
  376. if err != nil {
  377. return err
  378. }
  379. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  380. if err != nil {
  381. return err
  382. }
  383. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  384. HostId: req.HostId,
  385. Comment: req.Comment,
  386. ExpiredAt: expiredAt,
  387. }); err != nil {
  388. return err
  389. }
  390. return nil
  391. }
  392. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  393. // 检查是否过期
  394. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  395. if err != nil {
  396. return err
  397. }
  398. if isExpired {
  399. return fmt.Errorf("实例未过期,无法删除")
  400. }
  401. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  402. if err != nil {
  403. if errors.Is(err, gorm.ErrRecordNotFound) {
  404. return fmt.Errorf("实例不存在")
  405. }
  406. return err
  407. }
  408. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  409. if err != nil {
  410. return err
  411. }
  412. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  413. if err != nil {
  414. return err
  415. }
  416. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  417. if err != nil {
  418. return err
  419. }
  420. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  421. if err != nil {
  422. return err
  423. }
  424. // 删除网站
  425. g, gCtx := errgroup.WithContext(ctx)
  426. g.Go(func() error {
  427. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  428. Ids: tcpIds,
  429. Uid: req.Uid,
  430. HostId: req.HostId,
  431. })
  432. if e != nil {
  433. return fmt.Errorf("删除TCP转发失败: %w", e)
  434. }
  435. return nil
  436. })
  437. g.Go(func() error {
  438. e := s.udpForWarding.DeleteUdpForwarding(ctx,udpIds)
  439. if e != nil {
  440. return fmt.Errorf("删除UDP转发失败: %w", e)
  441. }
  442. return nil
  443. })
  444. g.Go(func() error {
  445. e := s.webForWarding.DeleteWebForwarding(ctx,webIds)
  446. if e != nil {
  447. return fmt.Errorf("删除WEB转发失败: %w", e)
  448. }
  449. return nil
  450. })
  451. // 删除套餐
  452. g.Go(func() error {
  453. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  454. if e != nil {
  455. return fmt.Errorf("删除套餐失败: %w", e)
  456. }
  457. return nil
  458. })
  459. // 删除网站分组
  460. g.Go(func() error {
  461. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  462. if e != nil {
  463. return fmt.Errorf("删除网站分组失败: %w", e)
  464. }
  465. return nil
  466. })
  467. if err = g.Wait(); err != nil {
  468. return err
  469. }
  470. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  471. return err
  472. }
  473. if err := s.gateWayGroupRep.EditGatewayGroup(ctx,&model.GatewayGroup{
  474. Id: oldData.GatewayGroupId,
  475. HostId: 0,
  476. }); err != nil {
  477. return err
  478. }
  479. // 删除黑白名单
  480. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  481. HostId: req.HostId,
  482. Ids: BwIds,
  483. Uid: req.Uid,
  484. })
  485. if err != nil {
  486. return err
  487. }
  488. return nil
  489. }