globallimit.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "golang.org/x/sync/errgroup"
  13. "gorm.io/gorm"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. type GlobalLimitService interface {
  19. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  20. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. }
  24. func NewGlobalLimitService(
  25. service *Service,
  26. globalLimitRepository repository.GlobalLimitRepository,
  27. duedate DuedateService,
  28. crawler CrawlerService,
  29. conf *viper.Viper,
  30. required RequiredService,
  31. parser ParserService,
  32. host HostService,
  33. gateWayGroup GatewayGroupService,
  34. hostRep repository.HostRepository,
  35. gateWayGroupRep repository.GatewayGroupRepository,
  36. cdnService CdnService,
  37. cdnRep repository.CdnRepository,
  38. tcpforwardingRep repository.TcpforwardingRepository,
  39. udpForWardingRep repository.UdpForWardingRepository,
  40. webForWardingRep repository.WebForwardingRepository,
  41. allowAndDeny AllowAndDenyIpService,
  42. allowAndDenyRep repository.AllowAndDenyIpRepository,
  43. tcpforwarding TcpforwardingService,
  44. udpForWarding UdpForWardingService,
  45. webForWarding WebForwardingService,
  46. ) GlobalLimitService {
  47. return &globalLimitService{
  48. Service: service,
  49. globalLimitRepository: globalLimitRepository,
  50. duedate: duedate,
  51. crawler: crawler,
  52. Url: conf.GetString("crawler.Url"),
  53. required: required,
  54. parser: parser,
  55. host: host,
  56. gateWayGroup: gateWayGroup,
  57. hostRep: hostRep,
  58. gateWayGroupRep: gateWayGroupRep,
  59. cdnService: cdnService,
  60. cdnRep: cdnRep,
  61. tcpforwardingRep: tcpforwardingRep,
  62. udpForWardingRep: udpForWardingRep,
  63. webForWardingRep: webForWardingRep,
  64. allowAndDeny: allowAndDeny,
  65. allowAndDenyRep: allowAndDenyRep,
  66. tcpforwarding: tcpforwarding,
  67. udpForWarding: udpForWarding,
  68. webForWarding: webForWarding,
  69. }
  70. }
  71. type globalLimitService struct {
  72. *Service
  73. globalLimitRepository repository.GlobalLimitRepository
  74. duedate DuedateService
  75. crawler CrawlerService
  76. Url string
  77. required RequiredService
  78. parser ParserService
  79. host HostService
  80. gateWayGroup GatewayGroupService
  81. hostRep repository.HostRepository
  82. gateWayGroupRep repository.GatewayGroupRepository
  83. cdnService CdnService
  84. cdnRep repository.CdnRepository
  85. tcpforwardingRep repository.TcpforwardingRepository
  86. udpForWardingRep repository.UdpForWardingRepository
  87. webForWardingRep repository.WebForwardingRepository
  88. allowAndDeny AllowAndDenyIpService
  89. allowAndDenyRep repository.AllowAndDenyIpRepository
  90. tcpforwarding TcpforwardingService
  91. udpForWarding UdpForWardingService
  92. webForWarding WebForwardingService
  93. }
  94. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  95. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  96. if err != nil {
  97. if !errors.Is(err, gorm.ErrRecordNotFound) {
  98. return 0, err
  99. }
  100. }
  101. if data != nil && data.CdnUid != 0 {
  102. return int64(data.CdnUid), nil
  103. }
  104. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  105. if err != nil {
  106. return 0, err
  107. }
  108. // 中文转拼音
  109. a := pinyin.NewArgs()
  110. a.Style = pinyin.Normal
  111. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  112. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  113. // 查询用户是否存在
  114. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  115. if err != nil {
  116. return 0, err
  117. }
  118. if UserId != 0 {
  119. return UserId, nil
  120. }
  121. // 注册用户
  122. userId, err := s.cdnService.AddUser(ctx, v1.User{
  123. Username: userName,
  124. Email: userInfo.Email,
  125. Fullname: userInfo.Username,
  126. Mobile: userInfo.PhoneNumber,
  127. })
  128. if err != nil {
  129. return 0, err
  130. }
  131. return userId, nil
  132. }
  133. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  134. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  135. Name: groupName,
  136. })
  137. if err != nil {
  138. return 0, err
  139. }
  140. return groupId, nil
  141. }
  142. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  143. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  144. if err != nil {
  145. return v1.GlobalLimitRequireResponse{}, err
  146. }
  147. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  148. if err != nil {
  149. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  150. }
  151. bpsInt, err := strconv.Atoi(configCount.Bps)
  152. if err != nil {
  153. return v1.GlobalLimitRequireResponse{}, err
  154. }
  155. resultFloat := float64(bpsInt) / 2.0 / 8.0
  156. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  157. res.MaxBytesMonth = configCount.MaxBytesMonth
  158. res.Operator = configCount.Operator
  159. res.IpCount = configCount.IpCount
  160. res.NodeArea = configCount.NodeArea
  161. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  162. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  163. if err != nil {
  164. return v1.GlobalLimitRequireResponse{}, err
  165. }
  166. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  167. if err != nil {
  168. return v1.GlobalLimitRequireResponse{}, err
  169. }
  170. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  171. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  172. if err != nil {
  173. return v1.GlobalLimitRequireResponse{}, err
  174. }
  175. return res, nil
  176. }
  177. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  178. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  179. }
  180. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  181. // 2. 将字符串解析成 time.Time 对象
  182. // time.Parse 会根据你提供的布局来理解输入的字符串
  183. t, err := time.Parse("2006-01-02 15:04:05", req)
  184. if err != nil {
  185. // 如果输入的字符串格式和布局不匹配,这里会报错
  186. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  187. }
  188. // 3. 定义新的输出格式 "YYYY-MM-DD"
  189. outputLayout := "2006-01-02"
  190. // 4. 将 time.Time 对象格式化为新的字符串
  191. outputTimeStr := t.Format(outputLayout)
  192. return outputTimeStr, nil
  193. }
  194. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  195. t, err := time.Parse("2006-01-02 15:04:05", req)
  196. if err != nil {
  197. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  198. }
  199. expiredAt := t.Unix()
  200. return expiredAt, nil
  201. }
  202. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  203. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  204. if err != nil {
  205. return err
  206. }
  207. if isExist {
  208. return fmt.Errorf("配置限制已存在")
  209. }
  210. require, err := s.GlobalLimitRequire(ctx, req)
  211. if err != nil {
  212. return err
  213. }
  214. g, gCtx := errgroup.WithContext(ctx)
  215. var gatewayGroupId int
  216. var userId int64
  217. var groupId int64
  218. g.Go(func() error {
  219. res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
  220. if e != nil {
  221. return fmt.Errorf("获取网关组失败: %w", e)
  222. }
  223. if res == 0 {
  224. return fmt.Errorf("获取网关组失败")
  225. }
  226. gatewayGroupId = res
  227. return nil
  228. })
  229. g.Go(func() error {
  230. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  231. if e != nil {
  232. return fmt.Errorf("获取cdn用户失败: %w", e)
  233. }
  234. if res == 0 {
  235. return fmt.Errorf("获取cdn用户失败")
  236. }
  237. userId = res
  238. return nil
  239. })
  240. g.Go(func() error {
  241. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  242. if e != nil {
  243. return fmt.Errorf("创建规则分组失败: %w", e)
  244. }
  245. if res == 0 {
  246. return fmt.Errorf("创建规则分组失败")
  247. }
  248. groupId = res
  249. return nil
  250. })
  251. if err = g.Wait(); err != nil {
  252. return err
  253. }
  254. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  255. if err != nil {
  256. return err
  257. }
  258. // 获取套餐ID
  259. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  260. if maxProtection == "" {
  261. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  262. }
  263. maxProtectionInt, err := strconv.Atoi(maxProtection)
  264. if err != nil {
  265. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  266. }
  267. var planId int64
  268. maxProtectionNum := 1
  269. if maxProtectionInt >= 2000 {
  270. maxProtectionNum = maxProtectionInt / 1000
  271. }
  272. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  273. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  274. if err != nil {
  275. if errors.Is(err, gorm.ErrRecordNotFound) {
  276. planId = 0
  277. }else {
  278. return err
  279. }
  280. }
  281. if planId == 0 {
  282. // 安全冗余套餐
  283. planId = 6
  284. s.logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  285. }
  286. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  287. UserId: userId,
  288. PlanId: planId,
  289. DayTo: outputTimeStr,
  290. Name: require.GlobalLimitName,
  291. IsFree: true,
  292. Period: "monthly",
  293. CountPeriod: 1,
  294. PeriodDayTo: outputTimeStr,
  295. })
  296. if err != nil {
  297. return err
  298. }
  299. if ruleId == 0 {
  300. return fmt.Errorf("分配套餐失败")
  301. }
  302. err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
  303. Id: gatewayGroupId,
  304. HostId: req.HostId,
  305. })
  306. if err != nil {
  307. return err
  308. }
  309. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  310. if err != nil {
  311. return err
  312. }
  313. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  314. HostId: req.HostId,
  315. Uid: req.Uid,
  316. Name: require.GlobalLimitName,
  317. RuleId: int(ruleId),
  318. GroupId: int(groupId),
  319. GatewayGroupId: gatewayGroupId,
  320. CdnUid: int(userId),
  321. Comment: req.Comment,
  322. ExpiredAt: expiredAt,
  323. })
  324. if err != nil {
  325. return err
  326. }
  327. return nil
  328. }
  329. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  330. require, err := s.GlobalLimitRequire(ctx, req)
  331. if err != nil {
  332. return err
  333. }
  334. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  335. if err != nil {
  336. return err
  337. }
  338. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  339. if err != nil {
  340. return err
  341. }
  342. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  343. UserPlanId: int64(data.RuleId),
  344. DayTo: outputTimeStr,
  345. Period: "monthly",
  346. CountPeriod: 1,
  347. IsFree: true,
  348. PeriodDayTo: outputTimeStr,
  349. })
  350. if err != nil {
  351. return err
  352. }
  353. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  354. if err != nil {
  355. return err
  356. }
  357. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  358. HostId: req.HostId,
  359. Comment: req.Comment,
  360. ExpiredAt: expiredAt,
  361. }); err != nil {
  362. return err
  363. }
  364. return nil
  365. }
  366. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  367. // 检查是否过期
  368. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  369. if err != nil {
  370. return err
  371. }
  372. if isExpired {
  373. return fmt.Errorf("实例未过期,无法删除")
  374. }
  375. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  376. if err != nil {
  377. return err
  378. }
  379. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  380. if err != nil {
  381. return err
  382. }
  383. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  384. if err != nil {
  385. return err
  386. }
  387. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  388. if err != nil {
  389. return err
  390. }
  391. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  392. if err != nil {
  393. return err
  394. }
  395. // 删除网站
  396. g, gCtx := errgroup.WithContext(ctx)
  397. g.Go(func() error {
  398. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  399. Ids: tcpIds,
  400. Uid: req.Uid,
  401. HostId: req.HostId,
  402. })
  403. if e != nil {
  404. return fmt.Errorf("删除TCP转发失败: %w", e)
  405. }
  406. return nil
  407. })
  408. g.Go(func() error {
  409. e := s.udpForWarding.DeleteUdpForwarding(ctx,udpIds)
  410. if e != nil {
  411. return fmt.Errorf("删除UDP转发失败: %w", e)
  412. }
  413. return nil
  414. })
  415. g.Go(func() error {
  416. e := s.webForWarding.DeleteWebForwarding(ctx,webIds)
  417. if e != nil {
  418. return fmt.Errorf("删除WEB转发失败: %w", e)
  419. }
  420. return nil
  421. })
  422. // 删除套餐
  423. g.Go(func() error {
  424. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  425. if e != nil {
  426. return fmt.Errorf("删除套餐失败: %w", e)
  427. }
  428. return nil
  429. })
  430. // 删除网站分组
  431. g.Go(func() error {
  432. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  433. if e != nil {
  434. return fmt.Errorf("删除网站分组失败: %w", e)
  435. }
  436. return nil
  437. })
  438. if err = g.Wait(); err != nil {
  439. return err
  440. }
  441. if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil {
  442. return err
  443. }
  444. if err := s.gateWayGroupRep.EditGatewayGroup(ctx,&model.GatewayGroup{
  445. Id: oldData.GatewayGroupId,
  446. HostId: 0,
  447. }); err != nil {
  448. return err
  449. }
  450. // 删除黑白名单
  451. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  452. HostId: req.HostId,
  453. Ids: BwIds,
  454. Uid: req.Uid,
  455. })
  456. if err != nil {
  457. return err
  458. }
  459. return nil
  460. }