globallimit.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  24. }
  25. func NewGlobalLimitService(
  26. service *Service,
  27. globalLimitRepository repository.GlobalLimitRepository,
  28. duedate DuedateService,
  29. crawler CrawlerService,
  30. conf *viper.Viper,
  31. required RequiredService,
  32. parser ParserService,
  33. host HostService,
  34. hostRep repository.HostRepository,
  35. cdnService CdnService,
  36. cdnRep repository.CdnRepository,
  37. tcpforwardingRep repository.TcpforwardingRepository,
  38. udpForWardingRep repository.UdpForWardingRepository,
  39. webForWardingRep repository.WebForwardingRepository,
  40. allowAndDeny AllowAndDenyIpService,
  41. allowAndDenyRep repository.AllowAndDenyIpRepository,
  42. tcpforwarding TcpforwardingService,
  43. udpForWarding UdpForWardingService,
  44. webForWarding WebForwardingService,
  45. gatewayIpRep repository.GatewayipRepository,
  46. gatywayIp GatewayipService,
  47. ) GlobalLimitService {
  48. return &globalLimitService{
  49. Service: service,
  50. globalLimitRepository: globalLimitRepository,
  51. duedate: duedate,
  52. crawler: crawler,
  53. Url: conf.GetString("crawler.Url"),
  54. required: required,
  55. parser: parser,
  56. host: host,
  57. hostRep: hostRep,
  58. cdnService: cdnService,
  59. cdnRep: cdnRep,
  60. tcpforwardingRep: tcpforwardingRep,
  61. udpForWardingRep: udpForWardingRep,
  62. webForWardingRep: webForWardingRep,
  63. allowAndDeny: allowAndDeny,
  64. allowAndDenyRep: allowAndDenyRep,
  65. tcpforwarding: tcpforwarding,
  66. udpForWarding: udpForWarding,
  67. webForWarding: webForWarding,
  68. gatewayIpRep: gatewayIpRep,
  69. gatewayIp: gatywayIp,
  70. }
  71. }
  72. type globalLimitService struct {
  73. *Service
  74. globalLimitRepository repository.GlobalLimitRepository
  75. duedate DuedateService
  76. crawler CrawlerService
  77. Url string
  78. required RequiredService
  79. parser ParserService
  80. host HostService
  81. hostRep repository.HostRepository
  82. cdnService CdnService
  83. cdnRep repository.CdnRepository
  84. tcpforwardingRep repository.TcpforwardingRepository
  85. udpForWardingRep repository.UdpForWardingRepository
  86. webForWardingRep repository.WebForwardingRepository
  87. allowAndDeny AllowAndDenyIpService
  88. allowAndDenyRep repository.AllowAndDenyIpRepository
  89. tcpforwarding TcpforwardingService
  90. udpForWarding UdpForWardingService
  91. webForWarding WebForwardingService
  92. gatewayIpRep repository.GatewayipRepository
  93. gatewayIp GatewayipService
  94. }
  95. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  96. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  97. if err != nil {
  98. if !errors.Is(err, gorm.ErrRecordNotFound) {
  99. return 0, err
  100. }
  101. }
  102. if data != nil && data.CdnUid != 0 {
  103. return int64(data.CdnUid), nil
  104. }
  105. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  106. if err != nil {
  107. return 0, err
  108. }
  109. // 中文转拼音
  110. a := pinyin.NewArgs()
  111. a.Style = pinyin.Normal
  112. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  113. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  114. // 查询用户是否存在
  115. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  116. if err != nil {
  117. return 0, err
  118. }
  119. if UserId != 0 {
  120. return UserId, nil
  121. }
  122. // 注册用户
  123. userId, err := s.cdnService.AddUser(ctx, v1.User{
  124. Username: userName,
  125. Email: userInfo.Email,
  126. Fullname: userInfo.Username,
  127. Mobile: userInfo.PhoneNumber,
  128. })
  129. if err != nil {
  130. return 0, err
  131. }
  132. return userId, nil
  133. }
  134. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  135. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  136. Name: groupName,
  137. })
  138. if err != nil {
  139. return 0, err
  140. }
  141. return groupId, nil
  142. }
  143. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  144. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  145. if err != nil {
  146. return v1.GlobalLimitRequireResponse{}, err
  147. }
  148. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  149. if err != nil {
  150. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  151. }
  152. bpsInt, err := strconv.Atoi(configCount.Bps)
  153. if err != nil {
  154. return v1.GlobalLimitRequireResponse{}, err
  155. }
  156. resultFloat := float64(bpsInt) / 2.0 / 8.0
  157. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  158. res.MaxBytesMonth = configCount.MaxBytesMonth
  159. res.Operator = configCount.Operator
  160. res.IpCount = configCount.IpCount
  161. res.NodeArea = configCount.NodeArea
  162. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  163. res.IsBanUdp = configCount.IsBanUdp
  164. res.HostId = req.HostId
  165. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  166. if err != nil {
  167. return v1.GlobalLimitRequireResponse{}, err
  168. }
  169. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  170. if err != nil {
  171. return v1.GlobalLimitRequireResponse{}, err
  172. }
  173. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  174. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  175. if err != nil {
  176. return v1.GlobalLimitRequireResponse{}, err
  177. }
  178. return res, nil
  179. }
  180. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  181. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  182. }
  183. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  184. // 2. 将字符串解析成 time.Time 对象
  185. // time.Parse 会根据你提供的布局来理解输入的字符串
  186. t, err := time.Parse("2006-01-02 15:04:05", req)
  187. if err != nil {
  188. // 如果输入的字符串格式和布局不匹配,这里会报错
  189. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  190. }
  191. // 3. 定义新的输出格式 "YYYY-MM-DD"
  192. outputLayout := "2006-01-02"
  193. // 4. 将 time.Time 对象格式化为新的字符串
  194. outputTimeStr := t.Format(outputLayout)
  195. return outputTimeStr, nil
  196. }
  197. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  198. t, err := time.Parse("2006-01-02 15:04:05", req)
  199. if err != nil {
  200. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  201. }
  202. expiredAt := t.Unix()
  203. return expiredAt, nil
  204. }
  205. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  206. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  207. if err != nil {
  208. return err
  209. }
  210. if isExist {
  211. return fmt.Errorf("配置限制已存在")
  212. }
  213. require, err := s.GlobalLimitRequire(ctx, req)
  214. if err != nil {
  215. return err
  216. }
  217. g, gCtx := errgroup.WithContext(ctx)
  218. var userId int64
  219. var groupId int64
  220. g.Go(func() error {
  221. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  222. if e != nil {
  223. return fmt.Errorf("获取网关组失败: %w", e)
  224. }
  225. return nil
  226. })
  227. g.Go(func() error {
  228. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  229. if e != nil {
  230. return fmt.Errorf("获取cdn用户失败: %w", e)
  231. }
  232. if res == 0 {
  233. return fmt.Errorf("获取cdn用户失败")
  234. }
  235. userId = res
  236. return nil
  237. })
  238. g.Go(func() error {
  239. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  240. if e != nil {
  241. return fmt.Errorf("创建规则分组失败: %w", e)
  242. }
  243. if res == 0 {
  244. return fmt.Errorf("创建规则分组失败")
  245. }
  246. groupId = res
  247. return nil
  248. })
  249. if err = g.Wait(); err != nil {
  250. return err
  251. }
  252. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  253. if err != nil {
  254. return err
  255. }
  256. // 获取套餐ID
  257. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  258. if maxProtection == "" {
  259. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  260. }
  261. maxProtectionInt, err := strconv.Atoi(maxProtection)
  262. if err != nil {
  263. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  264. }
  265. var planId int64
  266. maxProtectionNum := 1
  267. if maxProtectionInt >= 2000 {
  268. maxProtectionNum = maxProtectionInt / 1000
  269. }
  270. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  271. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  272. if err != nil {
  273. if errors.Is(err, gorm.ErrRecordNotFound) {
  274. planId = 0
  275. }else {
  276. return err
  277. }
  278. }
  279. if planId == 0 {
  280. // 安全冗余套餐
  281. planId = 6
  282. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  283. }
  284. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  285. UserId: userId,
  286. PlanId: planId,
  287. DayTo: outputTimeStr,
  288. Name: require.GlobalLimitName,
  289. IsFree: true,
  290. Period: "monthly",
  291. CountPeriod: 1,
  292. PeriodDayTo: outputTimeStr,
  293. })
  294. if err != nil {
  295. return err
  296. }
  297. if ruleId == 0 {
  298. return fmt.Errorf("分配套餐失败")
  299. }
  300. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  301. if err != nil {
  302. return err
  303. }
  304. // 如果存在实例,恢复
  305. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  306. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  307. return err
  308. }
  309. if oldData!= nil && oldData.Id != 0 {
  310. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  311. HostId: req.HostId,
  312. Uid: req.Uid,
  313. Name: require.GlobalLimitName,
  314. RuleId: int(ruleId),
  315. GroupId: int(groupId),
  316. CdnUid: int(userId),
  317. Comment: req.Comment,
  318. ExpiredAt: expiredAt,
  319. State: true,
  320. })
  321. if err != nil {
  322. return err
  323. }
  324. return nil
  325. }
  326. // 如果不存在实例,创建
  327. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  328. HostId: req.HostId,
  329. Uid: req.Uid,
  330. Name: require.GlobalLimitName,
  331. RuleId: int(ruleId),
  332. GroupId: int(groupId),
  333. CdnUid: int(userId),
  334. Comment: req.Comment,
  335. State: true,
  336. ExpiredAt: expiredAt,
  337. })
  338. if err != nil {
  339. return err
  340. }
  341. return nil
  342. }
  343. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  344. require, err := s.GlobalLimitRequire(ctx, req)
  345. if err != nil {
  346. return err
  347. }
  348. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  349. if err != nil {
  350. return err
  351. }
  352. // 如果不存在实例,创建
  353. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  354. if err != nil {
  355. return err
  356. }
  357. if gatewayIp != nil {
  358. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  359. if err != nil {
  360. return fmt.Errorf("获取网关组失败: %w", err)
  361. }
  362. return nil
  363. }
  364. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  365. if err != nil {
  366. return err
  367. }
  368. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  369. UserPlanId: int64(data.RuleId),
  370. DayTo: outputTimeStr,
  371. Period: "monthly",
  372. CountPeriod: 1,
  373. IsFree: true,
  374. PeriodDayTo: outputTimeStr,
  375. })
  376. if err != nil {
  377. return err
  378. }
  379. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  380. if err != nil {
  381. return err
  382. }
  383. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  384. HostId: req.HostId,
  385. Comment: req.Comment,
  386. ExpiredAt: expiredAt,
  387. }); err != nil {
  388. return err
  389. }
  390. return nil
  391. }
  392. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  393. // 检查是否过期
  394. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  395. if err != nil {
  396. return err
  397. }
  398. if isExpired {
  399. return fmt.Errorf("实例未过期,无法删除")
  400. }
  401. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  402. if err != nil {
  403. if errors.Is(err, gorm.ErrRecordNotFound) {
  404. return fmt.Errorf("实例不存在")
  405. }
  406. return err
  407. }
  408. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  409. if err != nil {
  410. return err
  411. }
  412. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  413. if err != nil {
  414. return err
  415. }
  416. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  417. if err != nil {
  418. return err
  419. }
  420. // 黑白IP
  421. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  422. if err != nil {
  423. return err
  424. }
  425. // 删除网站
  426. g, gCtx := errgroup.WithContext(ctx)
  427. g.Go(func() error {
  428. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  429. Ids: tcpIds,
  430. Uid: req.Uid,
  431. HostId: req.HostId,
  432. })
  433. if e != nil {
  434. return fmt.Errorf("删除TCP转发失败: %w", e)
  435. }
  436. return nil
  437. })
  438. g.Go(func() error {
  439. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  440. Ids: udpIds,
  441. Uid: req.Uid,
  442. HostId: req.HostId,
  443. })
  444. if e != nil {
  445. return fmt.Errorf("删除UDP转发失败: %w", e)
  446. }
  447. return nil
  448. })
  449. g.Go(func() error {
  450. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  451. Ids: webIds,
  452. Uid: req.Uid,
  453. HostId: req.HostId,
  454. })
  455. if e != nil {
  456. return fmt.Errorf("删除WEB转发失败: %w", e)
  457. }
  458. return nil
  459. })
  460. // 删除套餐
  461. g.Go(func() error {
  462. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  463. if e != nil {
  464. return fmt.Errorf("删除套餐失败: %w", e)
  465. }
  466. return nil
  467. })
  468. // 删除网站分组
  469. g.Go(func() error {
  470. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  471. if e != nil {
  472. return fmt.Errorf("删除网站分组失败: %w", e)
  473. }
  474. return nil
  475. })
  476. if err = g.Wait(); err != nil {
  477. return err
  478. }
  479. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  480. return err
  481. }
  482. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  483. return err
  484. }
  485. // 删除黑白名单
  486. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  487. HostId: req.HostId,
  488. Ids: BwIds,
  489. Uid: req.Uid,
  490. })
  491. if err != nil {
  492. return err
  493. }
  494. return nil
  495. }