globallimit.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. package waf
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  11. "github.com/mozillazg/go-pinyin"
  12. "github.com/spf13/viper"
  13. "go.uber.org/zap"
  14. "golang.org/x/sync/errgroup"
  15. "gorm.io/gorm"
  16. "strconv"
  17. "strings"
  18. "time"
  19. )
  20. type GlobalLimitService interface {
  21. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  22. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  23. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  24. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  25. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  26. }
  27. func NewGlobalLimitService(
  28. service *service.Service,
  29. globalLimitRepository waf.GlobalLimitRepository,
  30. duedate service.DuedateService,
  31. crawler service.CrawlerService,
  32. conf *viper.Viper,
  33. required service.RequiredService,
  34. parser service.ParserService,
  35. host service.HostService,
  36. hostRep repository.HostRepository,
  37. cdnService service.CdnService,
  38. cdnRep repository.CdnRepository,
  39. tcpforwardingRep waf.TcpforwardingRepository,
  40. udpForWardingRep waf.UdpForWardingRepository,
  41. webForWardingRep waf.WebForwardingRepository,
  42. allowAndDeny AllowAndDenyIpService,
  43. allowAndDenyRep waf.AllowAndDenyIpRepository,
  44. tcpforwarding TcpforwardingService,
  45. udpForWarding UdpForWardingService,
  46. webForWarding WebForwardingService,
  47. gatewayIpRep waf.GatewayipRepository,
  48. gatywayIp GatewayipService,
  49. ) GlobalLimitService {
  50. return &globalLimitService{
  51. Service: service,
  52. globalLimitRepository: globalLimitRepository,
  53. duedate: duedate,
  54. crawler: crawler,
  55. Url: conf.GetString("crawler.Url"),
  56. required: required,
  57. parser: parser,
  58. host: host,
  59. hostRep: hostRep,
  60. cdnService: cdnService,
  61. cdnRep: cdnRep,
  62. tcpforwardingRep: tcpforwardingRep,
  63. udpForWardingRep: udpForWardingRep,
  64. webForWardingRep: webForWardingRep,
  65. allowAndDeny: allowAndDeny,
  66. allowAndDenyRep: allowAndDenyRep,
  67. tcpforwarding: tcpforwarding,
  68. udpForWarding: udpForWarding,
  69. webForWarding: webForWarding,
  70. gatewayIpRep: gatewayIpRep,
  71. gatewayIp: gatywayIp,
  72. }
  73. }
  74. type globalLimitService struct {
  75. *service.Service
  76. globalLimitRepository waf.GlobalLimitRepository
  77. duedate service.DuedateService
  78. crawler service.CrawlerService
  79. Url string
  80. required service.RequiredService
  81. parser service.ParserService
  82. host service.HostService
  83. hostRep repository.HostRepository
  84. cdnService service.CdnService
  85. cdnRep repository.CdnRepository
  86. tcpforwardingRep waf.TcpforwardingRepository
  87. udpForWardingRep waf.UdpForWardingRepository
  88. webForWardingRep waf.WebForwardingRepository
  89. allowAndDeny AllowAndDenyIpService
  90. allowAndDenyRep waf.AllowAndDenyIpRepository
  91. tcpforwarding TcpforwardingService
  92. udpForWarding UdpForWardingService
  93. webForWarding WebForwardingService
  94. gatewayIpRep waf.GatewayipRepository
  95. gatewayIp GatewayipService
  96. }
  97. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  98. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  99. if err != nil {
  100. if !errors.Is(err, gorm.ErrRecordNotFound) {
  101. return 0, err
  102. }
  103. }
  104. if data != nil && data.CdnUid != 0 {
  105. return int64(data.CdnUid), nil
  106. }
  107. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  108. if err != nil {
  109. return 0, err
  110. }
  111. // 中文转拼音
  112. a := pinyin.NewArgs()
  113. a.Style = pinyin.Normal
  114. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  115. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  116. // 查询用户是否存在
  117. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  118. if err != nil {
  119. return 0, err
  120. }
  121. if UserId != 0 {
  122. return UserId, nil
  123. }
  124. // 注册用户
  125. userId, err := s.cdnService.AddUser(ctx, v1.User{
  126. Username: userName,
  127. Email: userInfo.Email,
  128. Fullname: userInfo.Username,
  129. Mobile: userInfo.PhoneNumber,
  130. })
  131. if err != nil {
  132. return 0, err
  133. }
  134. return userId, nil
  135. }
  136. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  137. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  138. Name: groupName,
  139. })
  140. if err != nil {
  141. return 0, err
  142. }
  143. return groupId, nil
  144. }
  145. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  146. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  147. if err != nil {
  148. return v1.GlobalLimitRequireResponse{}, err
  149. }
  150. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  151. if err != nil {
  152. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  153. }
  154. bpsInt, err := strconv.Atoi(configCount.Bps)
  155. if err != nil {
  156. return v1.GlobalLimitRequireResponse{}, err
  157. }
  158. resultFloat := float64(bpsInt) / 2.0 / 8.0
  159. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  160. res.MaxBytesMonth = configCount.MaxBytesMonth
  161. res.Operator = configCount.Operator
  162. res.IpCount = configCount.IpCount
  163. res.NodeArea = configCount.NodeArea
  164. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  165. res.IsBanUdp = configCount.IsBanUdp
  166. res.HostId = req.HostId
  167. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  168. if err != nil {
  169. return v1.GlobalLimitRequireResponse{}, err
  170. }
  171. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  172. if err != nil {
  173. return v1.GlobalLimitRequireResponse{}, err
  174. }
  175. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  176. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  177. if err != nil {
  178. return v1.GlobalLimitRequireResponse{}, err
  179. }
  180. return res, nil
  181. }
  182. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  183. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  184. }
  185. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  186. // 2. 将字符串解析成 time.Time 对象
  187. // time.Parse 会根据你提供的布局来理解输入的字符串
  188. t, err := time.Parse("2006-01-02 15:04:05", req)
  189. if err != nil {
  190. // 如果输入的字符串格式和布局不匹配,这里会报错
  191. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  192. }
  193. // 3. 定义新的输出格式 "YYYY-MM-DD"
  194. outputLayout := "2006-01-02"
  195. // 4. 将 time.Time 对象格式化为新的字符串
  196. outputTimeStr := t.Format(outputLayout)
  197. return outputTimeStr, nil
  198. }
  199. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  200. t, err := time.Parse("2006-01-02 15:04:05", req)
  201. if err != nil {
  202. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  203. }
  204. expiredAt := t.Unix()
  205. return expiredAt, nil
  206. }
  207. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  208. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  209. if err != nil {
  210. return err
  211. }
  212. if isExist {
  213. return fmt.Errorf("配置限制已存在")
  214. }
  215. require, err := s.GlobalLimitRequire(ctx, req)
  216. if err != nil {
  217. return err
  218. }
  219. g, gCtx := errgroup.WithContext(ctx)
  220. var userId int64
  221. var groupId int64
  222. g.Go(func() error {
  223. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  224. if e != nil {
  225. return fmt.Errorf("获取网关组失败: %w", e)
  226. }
  227. return nil
  228. })
  229. g.Go(func() error {
  230. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  231. if e != nil {
  232. return fmt.Errorf("获取cdn用户失败: %w", e)
  233. }
  234. if res == 0 {
  235. return fmt.Errorf("获取cdn用户失败")
  236. }
  237. userId = res
  238. return nil
  239. })
  240. g.Go(func() error {
  241. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  242. if e != nil {
  243. return fmt.Errorf("创建规则分组失败: %w", e)
  244. }
  245. if res == 0 {
  246. return fmt.Errorf("创建规则分组失败")
  247. }
  248. groupId = res
  249. return nil
  250. })
  251. if err = g.Wait(); err != nil {
  252. return err
  253. }
  254. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  255. if err != nil {
  256. return err
  257. }
  258. // 获取套餐ID
  259. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  260. if maxProtection == "" {
  261. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  262. }
  263. maxProtectionInt, err := strconv.Atoi(maxProtection)
  264. if err != nil {
  265. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  266. }
  267. var planId int64
  268. maxProtectionNum := 1
  269. if maxProtectionInt >= 2000 {
  270. maxProtectionNum = maxProtectionInt / 1000
  271. }
  272. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  273. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  274. if err != nil {
  275. if errors.Is(err, gorm.ErrRecordNotFound) {
  276. planId = 0
  277. }else {
  278. return err
  279. }
  280. }
  281. if planId == 0 {
  282. // 安全冗余套餐
  283. planId = 6
  284. s.Logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  285. }
  286. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  287. UserId: userId,
  288. PlanId: planId,
  289. DayTo: outputTimeStr,
  290. Name: require.GlobalLimitName,
  291. IsFree: true,
  292. Period: "monthly",
  293. CountPeriod: 1,
  294. PeriodDayTo: outputTimeStr,
  295. })
  296. if err != nil {
  297. return err
  298. }
  299. if ruleId == 0 {
  300. return fmt.Errorf("分配套餐失败")
  301. }
  302. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  303. if err != nil {
  304. return err
  305. }
  306. // 如果存在实例,恢复
  307. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  308. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  309. return err
  310. }
  311. if oldData!= nil && oldData.Id != 0 {
  312. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  313. HostId: req.HostId,
  314. Uid: req.Uid,
  315. Name: require.GlobalLimitName,
  316. RuleId: int(ruleId),
  317. GroupId: int(groupId),
  318. CdnUid: int(userId),
  319. Comment: req.Comment,
  320. ExpiredAt: expiredAt,
  321. State: true,
  322. })
  323. if err != nil {
  324. return err
  325. }
  326. return nil
  327. }
  328. // 如果不存在实例,创建
  329. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  330. HostId: req.HostId,
  331. Uid: req.Uid,
  332. Name: require.GlobalLimitName,
  333. RuleId: int(ruleId),
  334. GroupId: int(groupId),
  335. CdnUid: int(userId),
  336. Comment: req.Comment,
  337. State: true,
  338. ExpiredAt: expiredAt,
  339. })
  340. if err != nil {
  341. return err
  342. }
  343. return nil
  344. }
  345. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  346. require, err := s.GlobalLimitRequire(ctx, req)
  347. if err != nil {
  348. return err
  349. }
  350. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  351. if err != nil {
  352. return err
  353. }
  354. // 如果不存在实例,创建
  355. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  356. if err != nil {
  357. return err
  358. }
  359. if gatewayIp == nil {
  360. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  361. if err != nil {
  362. return fmt.Errorf("获取网关组失败: %w", err)
  363. }
  364. }
  365. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  366. if err != nil {
  367. return err
  368. }
  369. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  370. UserPlanId: int64(data.RuleId),
  371. DayTo: outputTimeStr,
  372. Period: "monthly",
  373. CountPeriod: 1,
  374. IsFree: true,
  375. PeriodDayTo: outputTimeStr,
  376. })
  377. if err != nil {
  378. return err
  379. }
  380. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  381. if err != nil {
  382. return err
  383. }
  384. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  385. HostId: req.HostId,
  386. Comment: req.Comment,
  387. ExpiredAt: expiredAt,
  388. }); err != nil {
  389. return err
  390. }
  391. return nil
  392. }
  393. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  394. // 检查是否过期
  395. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  396. if err != nil {
  397. return err
  398. }
  399. if isExpired {
  400. return fmt.Errorf("实例未过期,无法删除")
  401. }
  402. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  403. if err != nil {
  404. if errors.Is(err, gorm.ErrRecordNotFound) {
  405. return fmt.Errorf("实例不存在")
  406. }
  407. return err
  408. }
  409. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  410. if err != nil {
  411. return err
  412. }
  413. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  414. if err != nil {
  415. return err
  416. }
  417. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  418. if err != nil {
  419. return err
  420. }
  421. // 黑白IP
  422. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  423. if err != nil {
  424. return err
  425. }
  426. // 删除网站
  427. g, gCtx := errgroup.WithContext(ctx)
  428. g.Go(func() error {
  429. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  430. Ids: tcpIds,
  431. Uid: req.Uid,
  432. HostId: req.HostId,
  433. })
  434. if e != nil {
  435. return fmt.Errorf("删除TCP转发失败: %w", e)
  436. }
  437. return nil
  438. })
  439. g.Go(func() error {
  440. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  441. Ids: udpIds,
  442. Uid: req.Uid,
  443. HostId: req.HostId,
  444. })
  445. if e != nil {
  446. return fmt.Errorf("删除UDP转发失败: %w", e)
  447. }
  448. return nil
  449. })
  450. g.Go(func() error {
  451. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  452. Ids: webIds,
  453. Uid: req.Uid,
  454. HostId: req.HostId,
  455. })
  456. if e != nil {
  457. return fmt.Errorf("删除WEB转发失败: %w", e)
  458. }
  459. return nil
  460. })
  461. // 删除套餐
  462. g.Go(func() error {
  463. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  464. if e != nil {
  465. return fmt.Errorf("删除套餐失败: %w", e)
  466. }
  467. return nil
  468. })
  469. // 删除网站分组
  470. g.Go(func() error {
  471. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  472. if e != nil {
  473. return fmt.Errorf("删除网站分组失败: %w", e)
  474. }
  475. return nil
  476. })
  477. if err = g.Wait(); err != nil {
  478. return err
  479. }
  480. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  481. return err
  482. }
  483. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  484. return err
  485. }
  486. // 删除黑白名单
  487. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  488. HostId: req.HostId,
  489. Ids: BwIds,
  490. Uid: req.Uid,
  491. })
  492. if err != nil {
  493. return err
  494. }
  495. return nil
  496. }