globallimit.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. package waf
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. flexCdn2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  11. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  12. "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
  13. "github.com/mozillazg/go-pinyin"
  14. "github.com/spf13/viper"
  15. "go.uber.org/zap"
  16. "golang.org/x/sync/errgroup"
  17. "gorm.io/gorm"
  18. "strconv"
  19. "strings"
  20. "time"
  21. )
  22. type GlobalLimitService interface {
  23. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  24. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  25. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  26. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  27. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  28. }
  29. func NewGlobalLimitService(
  30. service *service.Service,
  31. globalLimitRepository waf.GlobalLimitRepository,
  32. duedate service.DuedateService,
  33. crawler service.CrawlerService,
  34. conf *viper.Viper,
  35. required service.RequiredService,
  36. parser service.ParserService,
  37. host service.HostService,
  38. hostRep repository.HostRepository,
  39. cdnService flexCdn.CdnService,
  40. cdnRep flexCdn2.CdnRepository,
  41. tcpforwardingRep waf.TcpforwardingRepository,
  42. udpForWardingRep waf.UdpForWardingRepository,
  43. webForWardingRep waf.WebForwardingRepository,
  44. allowAndDeny AllowAndDenyIpService,
  45. allowAndDenyRep waf.AllowAndDenyIpRepository,
  46. tcpforwarding TcpforwardingService,
  47. udpForWarding UdpForWardingService,
  48. webForWarding WebForwardingService,
  49. gatewayIpRep waf.GatewayipRepository,
  50. gatywayIp GatewayipService,
  51. ) GlobalLimitService {
  52. return &globalLimitService{
  53. Service: service,
  54. globalLimitRepository: globalLimitRepository,
  55. duedate: duedate,
  56. crawler: crawler,
  57. Url: conf.GetString("crawler.Url"),
  58. required: required,
  59. parser: parser,
  60. host: host,
  61. hostRep: hostRep,
  62. cdnService: cdnService,
  63. cdnRep: cdnRep,
  64. tcpforwardingRep: tcpforwardingRep,
  65. udpForWardingRep: udpForWardingRep,
  66. webForWardingRep: webForWardingRep,
  67. allowAndDeny: allowAndDeny,
  68. allowAndDenyRep: allowAndDenyRep,
  69. tcpforwarding: tcpforwarding,
  70. udpForWarding: udpForWarding,
  71. webForWarding: webForWarding,
  72. gatewayIpRep: gatewayIpRep,
  73. gatewayIp: gatywayIp,
  74. }
  75. }
  76. type globalLimitService struct {
  77. *service.Service
  78. globalLimitRepository waf.GlobalLimitRepository
  79. duedate service.DuedateService
  80. crawler service.CrawlerService
  81. Url string
  82. required service.RequiredService
  83. parser service.ParserService
  84. host service.HostService
  85. hostRep repository.HostRepository
  86. cdnService flexCdn.CdnService
  87. cdnRep flexCdn2.CdnRepository
  88. tcpforwardingRep waf.TcpforwardingRepository
  89. udpForWardingRep waf.UdpForWardingRepository
  90. webForWardingRep waf.WebForwardingRepository
  91. allowAndDeny AllowAndDenyIpService
  92. allowAndDenyRep waf.AllowAndDenyIpRepository
  93. tcpforwarding TcpforwardingService
  94. udpForWarding UdpForWardingService
  95. webForWarding WebForwardingService
  96. gatewayIpRep waf.GatewayipRepository
  97. gatewayIp GatewayipService
  98. }
  99. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  100. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  101. if err != nil {
  102. if !errors.Is(err, gorm.ErrRecordNotFound) {
  103. return 0, err
  104. }
  105. }
  106. if data != nil && data.CdnUid != 0 {
  107. return int64(data.CdnUid), nil
  108. }
  109. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  110. if err != nil {
  111. return 0, err
  112. }
  113. // 中文转拼音
  114. a := pinyin.NewArgs()
  115. a.Style = pinyin.Normal
  116. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  117. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  118. // 查询用户是否存在
  119. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  120. if err != nil {
  121. return 0, err
  122. }
  123. if UserId != 0 {
  124. return UserId, nil
  125. }
  126. // 注册用户
  127. userId, err := s.cdnService.AddUser(ctx, v1.User{
  128. Username: userName,
  129. Email: userInfo.Email,
  130. Fullname: userInfo.Username,
  131. Mobile: userInfo.PhoneNumber,
  132. })
  133. if err != nil {
  134. return 0, err
  135. }
  136. return userId, nil
  137. }
  138. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  139. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  140. Name: groupName,
  141. })
  142. if err != nil {
  143. return 0, err
  144. }
  145. return groupId, nil
  146. }
  147. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  148. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  149. if err != nil {
  150. return v1.GlobalLimitRequireResponse{}, err
  151. }
  152. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  153. if err != nil {
  154. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  155. }
  156. bpsInt, err := strconv.Atoi(configCount.Bps)
  157. if err != nil {
  158. return v1.GlobalLimitRequireResponse{}, err
  159. }
  160. resultFloat := float64(bpsInt) / 2.0 / 8.0
  161. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  162. res.MaxBytesMonth = configCount.MaxBytesMonth
  163. res.Operator = configCount.Operator
  164. res.IpCount = configCount.IpCount
  165. res.NodeArea = configCount.NodeArea
  166. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  167. res.IsBanUdp = configCount.IsBanUdp
  168. res.HostId = req.HostId
  169. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  170. if err != nil {
  171. return v1.GlobalLimitRequireResponse{}, err
  172. }
  173. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  174. if err != nil {
  175. return v1.GlobalLimitRequireResponse{}, err
  176. }
  177. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  178. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  179. if err != nil {
  180. return v1.GlobalLimitRequireResponse{}, err
  181. }
  182. return res, nil
  183. }
  184. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  185. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  186. }
  187. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  188. // 2. 将字符串解析成 time.Time 对象
  189. // time.Parse 会根据你提供的布局来理解输入的字符串
  190. t, err := time.Parse("2006-01-02 15:04:05", req)
  191. if err != nil {
  192. // 如果输入的字符串格式和布局不匹配,这里会报错
  193. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  194. }
  195. // 3. 定义新的输出格式 "YYYY-MM-DD"
  196. outputLayout := "2006-01-02"
  197. // 4. 将 time.Time 对象格式化为新的字符串
  198. outputTimeStr := t.Format(outputLayout)
  199. return outputTimeStr, nil
  200. }
  201. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  202. t, err := time.Parse("2006-01-02 15:04:05", req)
  203. if err != nil {
  204. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  205. }
  206. expiredAt := t.Unix()
  207. return expiredAt, nil
  208. }
  209. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  210. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  211. if err != nil {
  212. return err
  213. }
  214. if isExist {
  215. return fmt.Errorf("配置限制已存在")
  216. }
  217. require, err := s.GlobalLimitRequire(ctx, req)
  218. if err != nil {
  219. return err
  220. }
  221. g, gCtx := errgroup.WithContext(ctx)
  222. var userId int64
  223. var groupId int64
  224. g.Go(func() error {
  225. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  226. if e != nil {
  227. return fmt.Errorf("获取网关组失败: %w", e)
  228. }
  229. return nil
  230. })
  231. g.Go(func() error {
  232. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  233. if e != nil {
  234. return fmt.Errorf("获取cdn用户失败: %w", e)
  235. }
  236. if res == 0 {
  237. return fmt.Errorf("获取cdn用户失败")
  238. }
  239. userId = res
  240. return nil
  241. })
  242. g.Go(func() error {
  243. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  244. if e != nil {
  245. return fmt.Errorf("创建规则分组失败: %w", e)
  246. }
  247. if res == 0 {
  248. return fmt.Errorf("创建规则分组失败")
  249. }
  250. groupId = res
  251. return nil
  252. })
  253. if err = g.Wait(); err != nil {
  254. return err
  255. }
  256. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  257. if err != nil {
  258. return err
  259. }
  260. // 获取套餐ID
  261. maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  262. if maxProtection == "" {
  263. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  264. }
  265. maxProtectionInt, err := strconv.Atoi(maxProtection)
  266. if err != nil {
  267. return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  268. }
  269. var planId int64
  270. maxProtectionNum := 1
  271. if maxProtectionInt >= 2000 {
  272. maxProtectionNum = maxProtectionInt / 1000
  273. }
  274. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, maxProtectionNum)
  275. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  276. if err != nil {
  277. if errors.Is(err, gorm.ErrRecordNotFound) {
  278. planId = 0
  279. }else {
  280. return err
  281. }
  282. }
  283. if planId == 0 {
  284. // 安全冗余套餐
  285. planId = 6
  286. s.Logger.Warn("获取套餐Id失败", zap.String("节点区域", NodeAreaName), zap.String("防御阈值", require.ConfigMaxProtection),zap.Int64("套餐Id", int64(req.Uid)),zap.Int64("魔方套餐Id", int64(req.HostId)))
  287. }
  288. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  289. UserId: userId,
  290. PlanId: planId,
  291. DayTo: outputTimeStr,
  292. Name: require.GlobalLimitName,
  293. IsFree: true,
  294. Period: "monthly",
  295. CountPeriod: 1,
  296. PeriodDayTo: outputTimeStr,
  297. })
  298. if err != nil {
  299. return err
  300. }
  301. if ruleId == 0 {
  302. return fmt.Errorf("分配套餐失败")
  303. }
  304. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  305. if err != nil {
  306. return err
  307. }
  308. // 如果存在实例,恢复
  309. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  310. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  311. return err
  312. }
  313. if oldData!= nil && oldData.Id != 0 {
  314. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  315. HostId: req.HostId,
  316. Uid: req.Uid,
  317. Name: require.GlobalLimitName,
  318. RuleId: int(ruleId),
  319. GroupId: int(groupId),
  320. CdnUid: int(userId),
  321. Comment: req.Comment,
  322. ExpiredAt: expiredAt,
  323. State: true,
  324. })
  325. if err != nil {
  326. return err
  327. }
  328. return nil
  329. }
  330. // 如果不存在实例,创建
  331. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  332. HostId: req.HostId,
  333. Uid: req.Uid,
  334. Name: require.GlobalLimitName,
  335. RuleId: int(ruleId),
  336. GroupId: int(groupId),
  337. CdnUid: int(userId),
  338. Comment: req.Comment,
  339. State: true,
  340. ExpiredAt: expiredAt,
  341. })
  342. if err != nil {
  343. return err
  344. }
  345. return nil
  346. }
  347. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  348. require, err := s.GlobalLimitRequire(ctx, req)
  349. if err != nil {
  350. return err
  351. }
  352. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  353. if err != nil {
  354. return err
  355. }
  356. // 如果不存在实例,创建
  357. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  358. if err != nil {
  359. return err
  360. }
  361. if gatewayIp == nil {
  362. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  363. if err != nil {
  364. return fmt.Errorf("获取网关组失败: %w", err)
  365. }
  366. }
  367. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  368. if err != nil {
  369. return err
  370. }
  371. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  372. UserPlanId: int64(data.RuleId),
  373. DayTo: outputTimeStr,
  374. Period: "monthly",
  375. CountPeriod: 1,
  376. IsFree: true,
  377. PeriodDayTo: outputTimeStr,
  378. })
  379. if err != nil {
  380. return err
  381. }
  382. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  383. if err != nil {
  384. return err
  385. }
  386. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  387. HostId: req.HostId,
  388. Comment: req.Comment,
  389. ExpiredAt: expiredAt,
  390. }); err != nil {
  391. return err
  392. }
  393. return nil
  394. }
  395. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  396. // 检查是否过期
  397. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  398. if err != nil {
  399. return err
  400. }
  401. if isExpired {
  402. return fmt.Errorf("实例未过期,无法删除")
  403. }
  404. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  405. if err != nil {
  406. if errors.Is(err, gorm.ErrRecordNotFound) {
  407. return fmt.Errorf("实例不存在")
  408. }
  409. return err
  410. }
  411. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  412. if err != nil {
  413. return err
  414. }
  415. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  416. if err != nil {
  417. return err
  418. }
  419. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  420. if err != nil {
  421. return err
  422. }
  423. // 黑白IP
  424. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  425. if err != nil {
  426. return err
  427. }
  428. // 删除网站
  429. g, gCtx := errgroup.WithContext(ctx)
  430. g.Go(func() error {
  431. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  432. Ids: tcpIds,
  433. Uid: req.Uid,
  434. HostId: req.HostId,
  435. })
  436. if e != nil {
  437. return fmt.Errorf("删除TCP转发失败: %w", e)
  438. }
  439. return nil
  440. })
  441. g.Go(func() error {
  442. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  443. Ids: udpIds,
  444. Uid: req.Uid,
  445. HostId: req.HostId,
  446. })
  447. if e != nil {
  448. return fmt.Errorf("删除UDP转发失败: %w", e)
  449. }
  450. return nil
  451. })
  452. g.Go(func() error {
  453. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  454. Ids: webIds,
  455. Uid: req.Uid,
  456. HostId: req.HostId,
  457. })
  458. if e != nil {
  459. return fmt.Errorf("删除WEB转发失败: %w", e)
  460. }
  461. return nil
  462. })
  463. // 删除套餐
  464. g.Go(func() error {
  465. e := s.cdnService.DelUserPlan(gCtx, int64(oldData.RuleId))
  466. if e != nil {
  467. return fmt.Errorf("删除套餐失败: %w", e)
  468. }
  469. return nil
  470. })
  471. // 删除网站分组
  472. g.Go(func() error {
  473. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  474. if e != nil {
  475. return fmt.Errorf("删除网站分组失败: %w", e)
  476. }
  477. return nil
  478. })
  479. if err = g.Wait(); err != nil {
  480. return err
  481. }
  482. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  483. return err
  484. }
  485. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  486. return err
  487. }
  488. // 删除黑白名单
  489. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  490. HostId: req.HostId,
  491. Ids: BwIds,
  492. Uid: req.Uid,
  493. })
  494. if err != nil {
  495. return err
  496. }
  497. return nil
  498. }