globallimit.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. }
  24. func NewGlobalLimitService(
  25. service *Service,
  26. globalLimitRepository repository.GlobalLimitRepository,
  27. duedate DuedateService,
  28. crawler CrawlerService,
  29. conf *viper.Viper,
  30. required RequiredService,
  31. parser ParserService,
  32. host HostService,
  33. hostRep repository.HostRepository,
  34. cdnService CdnService,
  35. cdnRep repository.CdnRepository,
  36. tcpforwardingRep repository.TcpforwardingRepository,
  37. udpForWardingRep repository.UdpForWardingRepository,
  38. webForWardingRep repository.WebForwardingRepository,
  39. allowAndDeny AllowAndDenyIpService,
  40. allowAndDenyRep repository.AllowAndDenyIpRepository,
  41. tcpforwarding TcpforwardingService,
  42. udpForWarding UdpForWardingService,
  43. webForWarding WebForwardingService,
  44. gatewayIpRep repository.GatewayipRepository,
  45. ) GlobalLimitService {
  46. return &globalLimitService{
  47. Service: service,
  48. globalLimitRepository: globalLimitRepository,
  49. duedate: duedate,
  50. crawler: crawler,
  51. Url: conf.GetString("crawler.Url"),
  52. required: required,
  53. parser: parser,
  54. host: host,
  55. hostRep: hostRep,
  56. cdnService: cdnService,
  57. cdnRep: cdnRep,
  58. tcpforwardingRep: tcpforwardingRep,
  59. udpForWardingRep: udpForWardingRep,
  60. webForWardingRep: webForWardingRep,
  61. allowAndDeny: allowAndDeny,
  62. allowAndDenyRep: allowAndDenyRep,
  63. tcpforwarding: tcpforwarding,
  64. udpForWarding: udpForWarding,
  65. webForWarding: webForWarding,
  66. gatewayIpRep: gatewayIpRep,
  67. }
  68. }
  69. type globalLimitService struct {
  70. *Service
  71. globalLimitRepository repository.GlobalLimitRepository
  72. duedate DuedateService
  73. crawler CrawlerService
  74. Url string
  75. required RequiredService
  76. parser ParserService
  77. host HostService
  78. hostRep repository.HostRepository
  79. cdnService CdnService
  80. cdnRep repository.CdnRepository
  81. tcpforwardingRep repository.TcpforwardingRepository
  82. udpForWardingRep repository.UdpForWardingRepository
  83. webForWardingRep repository.WebForwardingRepository
  84. allowAndDeny AllowAndDenyIpService
  85. allowAndDenyRep repository.AllowAndDenyIpRepository
  86. tcpforwarding TcpforwardingService
  87. udpForWarding UdpForWardingService
  88. webForWarding WebForwardingService
  89. gatewayIpRep repository.GatewayipRepository
  90. }
  91. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  92. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  93. if err != nil {
  94. if !errors.Is(err, gorm.ErrRecordNotFound) {
  95. return 0, err
  96. }
  97. }
  98. if data != nil && data.CdnUid != 0 {
  99. return int64(data.CdnUid), nil
  100. }
  101. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  102. if err != nil {
  103. return 0, err
  104. }
  105. // 中文转拼音
  106. a := pinyin.NewArgs()
  107. a.Style = pinyin.Normal
  108. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  109. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  110. // 查询用户是否存在
  111. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  112. if err != nil {
  113. return 0, err
  114. }
  115. if UserId != 0 {
  116. return UserId, nil
  117. }
  118. // 注册用户
  119. userId, err := s.cdnService.AddUser(ctx, v1.User{
  120. Username: userName,
  121. Email: userInfo.Email,
  122. Fullname: userInfo.Username,
  123. Mobile: userInfo.PhoneNumber,
  124. })
  125. if err != nil {
  126. return 0, err
  127. }
  128. return userId, nil
  129. }
  130. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  131. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  132. Name: groupName,
  133. })
  134. if err != nil {
  135. return 0, err
  136. }
  137. return groupId, nil
  138. }
  139. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  140. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  141. if err != nil {
  142. return v1.GlobalLimitRequireResponse{}, err
  143. }
  144. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  145. if err != nil {
  146. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  147. }
  148. bpsInt, err := strconv.Atoi(configCount.Bps)
  149. if err != nil {
  150. return v1.GlobalLimitRequireResponse{}, err
  151. }
  152. resultFloat := float64(bpsInt) / 2.0 / 8.0
  153. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  154. res.MaxBytesMonth = configCount.MaxBytesMonth
  155. res.Operator = configCount.Operator
  156. res.IpCount = configCount.IpCount
  157. res.NodeArea = configCount.NodeArea
  158. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  159. res.IsBanUdp = configCount.IsBanUdp
  160. res.HostId = req.HostId
  161. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  162. if err != nil {
  163. return v1.GlobalLimitRequireResponse{}, err
  164. }
  165. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  166. if err != nil {
  167. return v1.GlobalLimitRequireResponse{}, err
  168. }
  169. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  170. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  171. if err != nil {
  172. return v1.GlobalLimitRequireResponse{}, err
  173. }
  174. return res, nil
  175. }
  176. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  177. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  178. }
  179. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  180. // 2. 将字符串解析成 time.Time 对象
  181. // time.Parse 会根据你提供的布局来理解输入的字符串
  182. t, err := time.Parse("2006-01-02 15:04:05", req)
  183. if err != nil {
  184. // 如果输入的字符串格式和布局不匹配,这里会报错
  185. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  186. }
  187. // 3. 定义新的输出格式 "YYYY-MM-DD"
  188. outputLayout := "2006-01-02"
  189. // 4. 将 time.Time 对象格式化为新的字符串
  190. outputTimeStr := t.Format(outputLayout)
  191. return outputTimeStr, nil
  192. }
  193. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  194. t, err := time.Parse("2006-01-02 15:04:05", req)
  195. if err != nil {
  196. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  197. }
  198. expiredAt := t.Unix()
  199. return expiredAt, nil
  200. }
  201. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  202. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  203. if err != nil {
  204. return err
  205. }
  206. if isExist {
  207. return fmt.Errorf("配置限制已存在")
  208. }
  209. require, err := s.GlobalLimitRequire(ctx, req)
  210. if err != nil {
  211. return err
  212. }
  213. g, gCtx := errgroup.WithContext(ctx)
  214. var userId int64
  215. var groupId int64
  216. g.Go(func() error {
  217. e := s.gatewayIpRep.GetIpWhereHostIdNull(gCtx, require)
  218. if e != nil {
  219. return fmt.Errorf("获取网关组失败: %w", e)
  220. }
  221. return nil
  222. })
  223. g.Go(func() error {
  224. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  225. if e != nil {
  226. return fmt.Errorf("获取cdn用户失败: %w", e)
  227. }
  228. if res == 0 {
  229. return fmt.Errorf("获取cdn用户失败")
  230. }
  231. userId = res
  232. return nil
  233. })
  234. g.Go(func() error {
  235. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  236. if e != nil {
  237. return fmt.Errorf("创建规则分组失败: %w", e)
  238. }
  239. if res == 0 {
  240. return fmt.Errorf("创建规则分组失败")
  241. }
  242. groupId = res
  243. return nil
  244. })
  245. if err = g.Wait(); err != nil {
  246. return err
  247. }
  248. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  249. if err != nil {
  250. return err
  251. }
  252. // 获取套餐ID
  253. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  254. if maxProtection == "" {
  255. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  256. }
  257. maxProtectionInt, err := strconv.Atoi(maxProtection)
  258. if err != nil {
  259. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  260. }
  261. var planId int64
  262. maxProtectionNum := 1
  263. if maxProtectionInt >= 2000 {
  264. maxProtectionNum = maxProtectionInt / 1000
  265. }
  266. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  267. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  268. if err != nil {
  269. if errors.Is(err, gorm.ErrRecordNotFound) {
  270. planId = 0
  271. }else {
  272. return err
  273. }
  274. }
  275. if planId == 0 {
  276. // 安全冗余套餐
  277. planId = 6
  278. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  279. }
  280. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  281. UserId: userId,
  282. PlanId: planId,
  283. DayTo: outputTimeStr,
  284. Name: require.GlobalLimitName,
  285. IsFree: true,
  286. Period: "monthly",
  287. CountPeriod: 1,
  288. PeriodDayTo: outputTimeStr,
  289. })
  290. if err != nil {
  291. return err
  292. }
  293. if ruleId == 0 {
  294. return fmt.Errorf("分配套餐失败")
  295. }
  296. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  297. if err != nil {
  298. return err
  299. }
  300. // 如果存在实例,恢复
  301. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  302. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  303. return err
  304. }
  305. if oldData!= nil && oldData.Id != 0 {
  306. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  307. HostId: req.HostId,
  308. Uid: req.Uid,
  309. Name: require.GlobalLimitName,
  310. RuleId: int(ruleId),
  311. GroupId: int(groupId),
  312. CdnUid: int(userId),
  313. Comment: req.Comment,
  314. ExpiredAt: expiredAt,
  315. State: true,
  316. })
  317. if err != nil {
  318. return err
  319. }
  320. return nil
  321. }
  322. // 如果不存在实例,创建
  323. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  324. HostId: req.HostId,
  325. Uid: req.Uid,
  326. Name: require.GlobalLimitName,
  327. RuleId: int(ruleId),
  328. GroupId: int(groupId),
  329. CdnUid: int(userId),
  330. Comment: req.Comment,
  331. State: true,
  332. ExpiredAt: expiredAt,
  333. })
  334. if err != nil {
  335. return err
  336. }
  337. return nil
  338. }
  339. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  340. require, err := s.GlobalLimitRequire(ctx, req)
  341. if err != nil {
  342. return err
  343. }
  344. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  345. if err != nil {
  346. return err
  347. }
  348. // 如果不存在实例,创建
  349. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  350. if err != nil {
  351. return err
  352. }
  353. if gatewayIp != nil {
  354. err = s.gatewayIpRep.GetIpWhereHostIdNull(ctx, require)
  355. if err != nil {
  356. return fmt.Errorf("获取网关组失败: %w", err)
  357. }
  358. return nil
  359. }
  360. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  361. if err != nil {
  362. return err
  363. }
  364. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  365. UserPlanId: int64(data.RuleId),
  366. DayTo: outputTimeStr,
  367. Period: "monthly",
  368. CountPeriod: 1,
  369. IsFree: true,
  370. PeriodDayTo: outputTimeStr,
  371. })
  372. if err != nil {
  373. return err
  374. }
  375. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  376. if err != nil {
  377. return err
  378. }
  379. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  380. HostId: req.HostId,
  381. Comment: req.Comment,
  382. ExpiredAt: expiredAt,
  383. }); err != nil {
  384. return err
  385. }
  386. return nil
  387. }
  388. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  389. // 检查是否过期
  390. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  391. if err != nil {
  392. return err
  393. }
  394. if isExpired {
  395. return fmt.Errorf("实例未过期,无法删除")
  396. }
  397. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  398. if err != nil {
  399. if errors.Is(err, gorm.ErrRecordNotFound) {
  400. return fmt.Errorf("实例不存在")
  401. }
  402. return err
  403. }
  404. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  405. if err != nil {
  406. return err
  407. }
  408. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  409. if err != nil {
  410. return err
  411. }
  412. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  413. if err != nil {
  414. return err
  415. }
  416. // 黑白IP
  417. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  418. if err != nil {
  419. return err
  420. }
  421. // 删除网站
  422. g, gCtx := errgroup.WithContext(ctx)
  423. g.Go(func() error {
  424. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  425. Ids: tcpIds,
  426. Uid: req.Uid,
  427. HostId: req.HostId,
  428. })
  429. if e != nil {
  430. return fmt.Errorf("删除TCP转发失败: %w", e)
  431. }
  432. return nil
  433. })
  434. g.Go(func() error {
  435. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  436. Ids: udpIds,
  437. Uid: req.Uid,
  438. HostId: req.HostId,
  439. })
  440. if e != nil {
  441. return fmt.Errorf("删除UDP转发失败: %w", e)
  442. }
  443. return nil
  444. })
  445. g.Go(func() error {
  446. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  447. Ids: webIds,
  448. Uid: req.Uid,
  449. HostId: req.HostId,
  450. })
  451. if e != nil {
  452. return fmt.Errorf("删除WEB转发失败: %w", e)
  453. }
  454. return nil
  455. })
  456. // 删除套餐
  457. g.Go(func() error {
  458. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  459. if e != nil {
  460. return fmt.Errorf("删除套餐失败: %w", e)
  461. }
  462. return nil
  463. })
  464. // 删除网站分组
  465. g.Go(func() error {
  466. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  467. if e != nil {
  468. return fmt.Errorf("删除网站分组失败: %w", e)
  469. }
  470. return nil
  471. })
  472. if err = g.Wait(); err != nil {
  473. return err
  474. }
  475. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  476. return err
  477. }
  478. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  479. return err
  480. }
  481. // 删除黑白名单
  482. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  483. HostId: req.HostId,
  484. Ids: BwIds,
  485. Uid: req.Uid,
  486. })
  487. if err != nil {
  488. return err
  489. }
  490. return nil
  491. }