globallimit.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  24. }
  25. func NewGlobalLimitService(
  26. service *Service,
  27. globalLimitRepository repository.GlobalLimitRepository,
  28. duedate DuedateService,
  29. crawler CrawlerService,
  30. conf *viper.Viper,
  31. required RequiredService,
  32. parser ParserService,
  33. host HostService,
  34. hostRep repository.HostRepository,
  35. cdnService CdnService,
  36. cdnRep repository.CdnRepository,
  37. tcpforwardingRep repository.TcpforwardingRepository,
  38. udpForWardingRep repository.UdpForWardingRepository,
  39. webForWardingRep repository.WebForwardingRepository,
  40. allowAndDeny AllowAndDenyIpService,
  41. allowAndDenyRep repository.AllowAndDenyIpRepository,
  42. tcpforwarding TcpforwardingService,
  43. udpForWarding UdpForWardingService,
  44. webForWarding WebForwardingService,
  45. gatewayIpRep repository.GatewayipRepository,
  46. ) GlobalLimitService {
  47. return &globalLimitService{
  48. Service: service,
  49. globalLimitRepository: globalLimitRepository,
  50. duedate: duedate,
  51. crawler: crawler,
  52. Url: conf.GetString("crawler.Url"),
  53. required: required,
  54. parser: parser,
  55. host: host,
  56. hostRep: hostRep,
  57. cdnService: cdnService,
  58. cdnRep: cdnRep,
  59. tcpforwardingRep: tcpforwardingRep,
  60. udpForWardingRep: udpForWardingRep,
  61. webForWardingRep: webForWardingRep,
  62. allowAndDeny: allowAndDeny,
  63. allowAndDenyRep: allowAndDenyRep,
  64. tcpforwarding: tcpforwarding,
  65. udpForWarding: udpForWarding,
  66. webForWarding: webForWarding,
  67. gatewayIpRep: gatewayIpRep,
  68. }
  69. }
  70. type globalLimitService struct {
  71. *Service
  72. globalLimitRepository repository.GlobalLimitRepository
  73. duedate DuedateService
  74. crawler CrawlerService
  75. Url string
  76. required RequiredService
  77. parser ParserService
  78. host HostService
  79. hostRep repository.HostRepository
  80. cdnService CdnService
  81. cdnRep repository.CdnRepository
  82. tcpforwardingRep repository.TcpforwardingRepository
  83. udpForWardingRep repository.UdpForWardingRepository
  84. webForWardingRep repository.WebForwardingRepository
  85. allowAndDeny AllowAndDenyIpService
  86. allowAndDenyRep repository.AllowAndDenyIpRepository
  87. tcpforwarding TcpforwardingService
  88. udpForWarding UdpForWardingService
  89. webForWarding WebForwardingService
  90. gatewayIpRep repository.GatewayipRepository
  91. }
  92. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  93. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  94. if err != nil {
  95. if !errors.Is(err, gorm.ErrRecordNotFound) {
  96. return 0, err
  97. }
  98. }
  99. if data != nil && data.CdnUid != 0 {
  100. return int64(data.CdnUid), nil
  101. }
  102. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  103. if err != nil {
  104. return 0, err
  105. }
  106. // 中文转拼音
  107. a := pinyin.NewArgs()
  108. a.Style = pinyin.Normal
  109. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  110. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  111. // 查询用户是否存在
  112. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  113. if err != nil {
  114. return 0, err
  115. }
  116. if UserId != 0 {
  117. return UserId, nil
  118. }
  119. // 注册用户
  120. userId, err := s.cdnService.AddUser(ctx, v1.User{
  121. Username: userName,
  122. Email: userInfo.Email,
  123. Fullname: userInfo.Username,
  124. Mobile: userInfo.PhoneNumber,
  125. })
  126. if err != nil {
  127. return 0, err
  128. }
  129. return userId, nil
  130. }
  131. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  132. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  133. Name: groupName,
  134. })
  135. if err != nil {
  136. return 0, err
  137. }
  138. return groupId, nil
  139. }
  140. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  141. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  142. if err != nil {
  143. return v1.GlobalLimitRequireResponse{}, err
  144. }
  145. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  146. if err != nil {
  147. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  148. }
  149. bpsInt, err := strconv.Atoi(configCount.Bps)
  150. if err != nil {
  151. return v1.GlobalLimitRequireResponse{}, err
  152. }
  153. resultFloat := float64(bpsInt) / 2.0 / 8.0
  154. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  155. res.MaxBytesMonth = configCount.MaxBytesMonth
  156. res.Operator = configCount.Operator
  157. res.IpCount = configCount.IpCount
  158. res.NodeArea = configCount.NodeArea
  159. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  160. res.IsBanUdp = configCount.IsBanUdp
  161. res.HostId = req.HostId
  162. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  163. if err != nil {
  164. return v1.GlobalLimitRequireResponse{}, err
  165. }
  166. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  167. if err != nil {
  168. return v1.GlobalLimitRequireResponse{}, err
  169. }
  170. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  171. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  172. if err != nil {
  173. return v1.GlobalLimitRequireResponse{}, err
  174. }
  175. return res, nil
  176. }
  177. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  178. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  179. }
  180. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  181. // 2. 将字符串解析成 time.Time 对象
  182. // time.Parse 会根据你提供的布局来理解输入的字符串
  183. t, err := time.Parse("2006-01-02 15:04:05", req)
  184. if err != nil {
  185. // 如果输入的字符串格式和布局不匹配,这里会报错
  186. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  187. }
  188. // 3. 定义新的输出格式 "YYYY-MM-DD"
  189. outputLayout := "2006-01-02"
  190. // 4. 将 time.Time 对象格式化为新的字符串
  191. outputTimeStr := t.Format(outputLayout)
  192. return outputTimeStr, nil
  193. }
  194. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  195. t, err := time.Parse("2006-01-02 15:04:05", req)
  196. if err != nil {
  197. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  198. }
  199. expiredAt := t.Unix()
  200. return expiredAt, nil
  201. }
  202. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  203. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  204. if err != nil {
  205. return err
  206. }
  207. if isExist {
  208. return fmt.Errorf("配置限制已存在")
  209. }
  210. require, err := s.GlobalLimitRequire(ctx, req)
  211. if err != nil {
  212. return err
  213. }
  214. g, gCtx := errgroup.WithContext(ctx)
  215. var userId int64
  216. var groupId int64
  217. g.Go(func() error {
  218. e := s.gatewayIpRep.GetIpWhereHostIdNull(gCtx, require)
  219. if e != nil {
  220. return fmt.Errorf("获取网关组失败: %w", e)
  221. }
  222. return nil
  223. })
  224. g.Go(func() error {
  225. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  226. if e != nil {
  227. return fmt.Errorf("获取cdn用户失败: %w", e)
  228. }
  229. if res == 0 {
  230. return fmt.Errorf("获取cdn用户失败")
  231. }
  232. userId = res
  233. return nil
  234. })
  235. g.Go(func() error {
  236. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  237. if e != nil {
  238. return fmt.Errorf("创建规则分组失败: %w", e)
  239. }
  240. if res == 0 {
  241. return fmt.Errorf("创建规则分组失败")
  242. }
  243. groupId = res
  244. return nil
  245. })
  246. if err = g.Wait(); err != nil {
  247. return err
  248. }
  249. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  250. if err != nil {
  251. return err
  252. }
  253. // 获取套餐ID
  254. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  255. if maxProtection == "" {
  256. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  257. }
  258. maxProtectionInt, err := strconv.Atoi(maxProtection)
  259. if err != nil {
  260. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  261. }
  262. var planId int64
  263. maxProtectionNum := 1
  264. if maxProtectionInt >= 2000 {
  265. maxProtectionNum = maxProtectionInt / 1000
  266. }
  267. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  268. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  269. if err != nil {
  270. if errors.Is(err, gorm.ErrRecordNotFound) {
  271. planId = 0
  272. }else {
  273. return err
  274. }
  275. }
  276. if planId == 0 {
  277. // 安全冗余套餐
  278. planId = 6
  279. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  280. }
  281. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  282. UserId: userId,
  283. PlanId: planId,
  284. DayTo: outputTimeStr,
  285. Name: require.GlobalLimitName,
  286. IsFree: true,
  287. Period: "monthly",
  288. CountPeriod: 1,
  289. PeriodDayTo: outputTimeStr,
  290. })
  291. if err != nil {
  292. return err
  293. }
  294. if ruleId == 0 {
  295. return fmt.Errorf("分配套餐失败")
  296. }
  297. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  298. if err != nil {
  299. return err
  300. }
  301. // 如果存在实例,恢复
  302. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  303. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  304. return err
  305. }
  306. if oldData!= nil && oldData.Id != 0 {
  307. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  308. HostId: req.HostId,
  309. Uid: req.Uid,
  310. Name: require.GlobalLimitName,
  311. RuleId: int(ruleId),
  312. GroupId: int(groupId),
  313. CdnUid: int(userId),
  314. Comment: req.Comment,
  315. ExpiredAt: expiredAt,
  316. State: true,
  317. })
  318. if err != nil {
  319. return err
  320. }
  321. return nil
  322. }
  323. // 如果不存在实例,创建
  324. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  325. HostId: req.HostId,
  326. Uid: req.Uid,
  327. Name: require.GlobalLimitName,
  328. RuleId: int(ruleId),
  329. GroupId: int(groupId),
  330. CdnUid: int(userId),
  331. Comment: req.Comment,
  332. State: true,
  333. ExpiredAt: expiredAt,
  334. })
  335. if err != nil {
  336. return err
  337. }
  338. return nil
  339. }
  340. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  341. require, err := s.GlobalLimitRequire(ctx, req)
  342. if err != nil {
  343. return err
  344. }
  345. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  346. if err != nil {
  347. return err
  348. }
  349. // 如果不存在实例,创建
  350. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  351. if err != nil {
  352. return err
  353. }
  354. if gatewayIp != nil {
  355. err = s.gatewayIpRep.GetIpWhereHostIdNull(ctx, require)
  356. if err != nil {
  357. return fmt.Errorf("获取网关组失败: %w", err)
  358. }
  359. return nil
  360. }
  361. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  362. if err != nil {
  363. return err
  364. }
  365. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  366. UserPlanId: int64(data.RuleId),
  367. DayTo: outputTimeStr,
  368. Period: "monthly",
  369. CountPeriod: 1,
  370. IsFree: true,
  371. PeriodDayTo: outputTimeStr,
  372. })
  373. if err != nil {
  374. return err
  375. }
  376. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  377. if err != nil {
  378. return err
  379. }
  380. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  381. HostId: req.HostId,
  382. Comment: req.Comment,
  383. ExpiredAt: expiredAt,
  384. }); err != nil {
  385. return err
  386. }
  387. return nil
  388. }
  389. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  390. // 检查是否过期
  391. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  392. if err != nil {
  393. return err
  394. }
  395. if isExpired {
  396. return fmt.Errorf("实例未过期,无法删除")
  397. }
  398. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  399. if err != nil {
  400. if errors.Is(err, gorm.ErrRecordNotFound) {
  401. return fmt.Errorf("实例不存在")
  402. }
  403. return err
  404. }
  405. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  406. if err != nil {
  407. return err
  408. }
  409. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  410. if err != nil {
  411. return err
  412. }
  413. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  414. if err != nil {
  415. return err
  416. }
  417. // 黑白IP
  418. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  419. if err != nil {
  420. return err
  421. }
  422. // 删除网站
  423. g, gCtx := errgroup.WithContext(ctx)
  424. g.Go(func() error {
  425. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  426. Ids: tcpIds,
  427. Uid: req.Uid,
  428. HostId: req.HostId,
  429. })
  430. if e != nil {
  431. return fmt.Errorf("删除TCP转发失败: %w", e)
  432. }
  433. return nil
  434. })
  435. g.Go(func() error {
  436. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  437. Ids: udpIds,
  438. Uid: req.Uid,
  439. HostId: req.HostId,
  440. })
  441. if e != nil {
  442. return fmt.Errorf("删除UDP转发失败: %w", e)
  443. }
  444. return nil
  445. })
  446. g.Go(func() error {
  447. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  448. Ids: webIds,
  449. Uid: req.Uid,
  450. HostId: req.HostId,
  451. })
  452. if e != nil {
  453. return fmt.Errorf("删除WEB转发失败: %w", e)
  454. }
  455. return nil
  456. })
  457. // 删除套餐
  458. g.Go(func() error {
  459. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  460. if e != nil {
  461. return fmt.Errorf("删除套餐失败: %w", e)
  462. }
  463. return nil
  464. })
  465. // 删除网站分组
  466. g.Go(func() error {
  467. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  468. if e != nil {
  469. return fmt.Errorf("删除网站分组失败: %w", e)
  470. }
  471. return nil
  472. })
  473. if err = g.Wait(); err != nil {
  474. return err
  475. }
  476. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  477. return err
  478. }
  479. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  480. return err
  481. }
  482. // 删除黑白名单
  483. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  484. HostId: req.HostId,
  485. Ids: BwIds,
  486. Uid: req.Uid,
  487. })
  488. if err != nil {
  489. return err
  490. }
  491. return nil
  492. }