globallimit.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. package waf
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. flexCdn2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  11. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  12. "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
  13. "github.com/mozillazg/go-pinyin"
  14. "github.com/spf13/viper"
  15. "golang.org/x/sync/errgroup"
  16. "gorm.io/gorm"
  17. "strconv"
  18. "strings"
  19. "time"
  20. )
  21. type GlobalLimitService interface {
  22. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  23. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  24. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  25. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  26. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  27. }
  28. func NewGlobalLimitService(
  29. service *service.Service,
  30. globalLimitRepository waf.GlobalLimitRepository,
  31. duedate service.DuedateService,
  32. crawler service.CrawlerService,
  33. conf *viper.Viper,
  34. required service.RequiredService,
  35. parser service.ParserService,
  36. host service.HostService,
  37. hostRep repository.HostRepository,
  38. cdnService flexCdn.CdnService,
  39. cdnRep flexCdn2.CdnRepository,
  40. tcpforwardingRep waf.TcpforwardingRepository,
  41. udpForWardingRep waf.UdpForWardingRepository,
  42. webForWardingRep waf.WebForwardingRepository,
  43. allowAndDeny AllowAndDenyIpService,
  44. allowAndDenyRep waf.AllowAndDenyIpRepository,
  45. tcpforwarding TcpforwardingService,
  46. udpForWarding UdpForWardingService,
  47. webForWarding WebForwardingService,
  48. gatewayIpRep waf.GatewayipRepository,
  49. gatywayIp GatewayipService,
  50. ) GlobalLimitService {
  51. return &globalLimitService{
  52. Service: service,
  53. globalLimitRepository: globalLimitRepository,
  54. duedate: duedate,
  55. crawler: crawler,
  56. Url: conf.GetString("crawler.Url"),
  57. required: required,
  58. parser: parser,
  59. host: host,
  60. hostRep: hostRep,
  61. cdnService: cdnService,
  62. cdnRep: cdnRep,
  63. tcpforwardingRep: tcpforwardingRep,
  64. udpForWardingRep: udpForWardingRep,
  65. webForWardingRep: webForWardingRep,
  66. allowAndDeny: allowAndDeny,
  67. allowAndDenyRep: allowAndDenyRep,
  68. tcpforwarding: tcpforwarding,
  69. udpForWarding: udpForWarding,
  70. webForWarding: webForWarding,
  71. gatewayIpRep: gatewayIpRep,
  72. gatewayIp: gatywayIp,
  73. }
  74. }
  75. type globalLimitService struct {
  76. *service.Service
  77. globalLimitRepository waf.GlobalLimitRepository
  78. duedate service.DuedateService
  79. crawler service.CrawlerService
  80. Url string
  81. required service.RequiredService
  82. parser service.ParserService
  83. host service.HostService
  84. hostRep repository.HostRepository
  85. cdnService flexCdn.CdnService
  86. cdnRep flexCdn2.CdnRepository
  87. tcpforwardingRep waf.TcpforwardingRepository
  88. udpForWardingRep waf.UdpForWardingRepository
  89. webForWardingRep waf.WebForwardingRepository
  90. allowAndDeny AllowAndDenyIpService
  91. allowAndDenyRep waf.AllowAndDenyIpRepository
  92. tcpforwarding TcpforwardingService
  93. udpForWarding UdpForWardingService
  94. webForWarding WebForwardingService
  95. gatewayIpRep waf.GatewayipRepository
  96. gatewayIp GatewayipService
  97. }
  98. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  99. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  100. if err != nil {
  101. if !errors.Is(err, gorm.ErrRecordNotFound) {
  102. return 0, err
  103. }
  104. }
  105. if data != nil && data.CdnUid != 0 {
  106. return int64(data.CdnUid), nil
  107. }
  108. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  109. if err != nil {
  110. return 0, err
  111. }
  112. // 中文转拼音
  113. a := pinyin.NewArgs()
  114. a.Style = pinyin.Normal
  115. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  116. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  117. // 查询用户是否存在
  118. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  119. if err != nil {
  120. return 0, err
  121. }
  122. if UserId != 0 {
  123. return UserId, nil
  124. }
  125. // 注册用户
  126. userId, err := s.cdnService.AddUser(ctx, v1.User{
  127. Username: userName,
  128. Email: userInfo.Email,
  129. Fullname: userInfo.Username,
  130. Mobile: userInfo.PhoneNumber,
  131. })
  132. if err != nil {
  133. return 0, err
  134. }
  135. return userId, nil
  136. }
  137. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  138. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  139. Name: groupName,
  140. })
  141. if err != nil {
  142. return 0, err
  143. }
  144. return groupId, nil
  145. }
  146. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  147. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  148. if err != nil {
  149. return v1.GlobalLimitRequireResponse{}, err
  150. }
  151. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  152. if err != nil {
  153. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  154. }
  155. bpsInt, err := strconv.Atoi(configCount.Bps)
  156. if err != nil {
  157. return v1.GlobalLimitRequireResponse{}, err
  158. }
  159. resultFloat := float64(bpsInt) / 2.0 / 8.0
  160. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  161. res.MaxBytesMonth = configCount.MaxBytesMonth
  162. res.Operator = configCount.Operator
  163. res.IpCount = configCount.IpCount
  164. res.NodeArea = configCount.NodeArea
  165. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  166. res.IsBanUdp = configCount.IsBanUdp
  167. res.HostId = req.HostId
  168. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  169. if err != nil {
  170. return v1.GlobalLimitRequireResponse{}, err
  171. }
  172. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  173. if err != nil {
  174. return v1.GlobalLimitRequireResponse{}, err
  175. }
  176. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  177. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  178. if err != nil {
  179. return v1.GlobalLimitRequireResponse{}, err
  180. }
  181. return res, nil
  182. }
  183. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  184. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  185. }
  186. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  187. // 2. 将字符串解析成 time.Time 对象
  188. // time.Parse 会根据你提供的布局来理解输入的字符串
  189. t, err := time.Parse("2006-01-02 15:04:05", req)
  190. if err != nil {
  191. // 如果输入的字符串格式和布局不匹配,这里会报错
  192. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  193. }
  194. // 3. 定义新的输出格式 "YYYY-MM-DD"
  195. outputLayout := "2006-01-02"
  196. // 4. 将 time.Time 对象格式化为新的字符串
  197. outputTimeStr := t.Format(outputLayout)
  198. return outputTimeStr, nil
  199. }
  200. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  201. t, err := time.Parse("2006-01-02 15:04:05", req)
  202. if err != nil {
  203. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  204. }
  205. expiredAt := t.Unix()
  206. return expiredAt, nil
  207. }
  208. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  209. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  210. if err != nil {
  211. return err
  212. }
  213. if isExist {
  214. return fmt.Errorf("配置限制已存在")
  215. }
  216. require, err := s.GlobalLimitRequire(ctx, req)
  217. if err != nil {
  218. return err
  219. }
  220. g, gCtx := errgroup.WithContext(ctx)
  221. var userId int64
  222. var groupId int64
  223. g.Go(func() error {
  224. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  225. if e != nil {
  226. return fmt.Errorf("获取网关组失败: %w", e)
  227. }
  228. return nil
  229. })
  230. g.Go(func() error {
  231. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  232. if e != nil {
  233. return fmt.Errorf("获取cdn用户失败: %w", e)
  234. }
  235. if res == 0 {
  236. return fmt.Errorf("获取cdn用户失败")
  237. }
  238. userId = res
  239. return nil
  240. })
  241. g.Go(func() error {
  242. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  243. if e != nil {
  244. return fmt.Errorf("创建规则分组失败: %w", e)
  245. }
  246. if res == 0 {
  247. return fmt.Errorf("创建规则分组失败")
  248. }
  249. groupId = res
  250. return nil
  251. })
  252. if err = g.Wait(); err != nil {
  253. return err
  254. }
  255. // 获取套餐ID
  256. //maxProtection := strings.TrimSuffix(require.ConfigMaxProtection, "G")
  257. //if maxProtection == "" {
  258. // return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',数字部分为空", require.ConfigMaxProtection)
  259. //}
  260. //maxProtectionInt, err := strconv.Atoi(maxProtection)
  261. //if err != nil {
  262. // return fmt.Errorf("无效的配置 ConfigMaxProtection: '%s',无法转换为数字", require.ConfigMaxProtection)
  263. //}
  264. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  265. if err != nil {
  266. return err
  267. }
  268. // 如果存在实例,恢复
  269. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  270. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  271. return err
  272. }
  273. if oldData!= nil && oldData.Id != 0 {
  274. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  275. HostId: req.HostId,
  276. Uid: req.Uid,
  277. Name: require.GlobalLimitName,
  278. GroupId: int(groupId),
  279. CdnUid: int(userId),
  280. Comment: req.Comment,
  281. ExpiredAt: expiredAt,
  282. State: true,
  283. })
  284. if err != nil {
  285. return err
  286. }
  287. return nil
  288. }
  289. // 如果不存在实例,创建
  290. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  291. HostId: req.HostId,
  292. Uid: req.Uid,
  293. Name: require.GlobalLimitName,
  294. GroupId: int(groupId),
  295. CdnUid: int(userId),
  296. Comment: req.Comment,
  297. State: true,
  298. ExpiredAt: expiredAt,
  299. })
  300. if err != nil {
  301. return err
  302. }
  303. return nil
  304. }
  305. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  306. require, err := s.GlobalLimitRequire(ctx, req)
  307. if err != nil {
  308. return err
  309. }
  310. // 如果不存在实例,创建
  311. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  312. if err != nil {
  313. return err
  314. }
  315. if gatewayIp == nil {
  316. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  317. if err != nil {
  318. return fmt.Errorf("获取网关组失败: %w", err)
  319. }
  320. }
  321. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  322. if err != nil {
  323. return err
  324. }
  325. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  326. HostId: req.HostId,
  327. Comment: req.Comment,
  328. ExpiredAt: expiredAt,
  329. }); err != nil {
  330. return err
  331. }
  332. return nil
  333. }
  334. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  335. // 检查是否过期
  336. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  337. if err != nil {
  338. return err
  339. }
  340. if isExpired {
  341. return fmt.Errorf("实例未过期,无法删除")
  342. }
  343. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  344. if err != nil {
  345. if errors.Is(err, gorm.ErrRecordNotFound) {
  346. return fmt.Errorf("实例不存在")
  347. }
  348. return err
  349. }
  350. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  351. if err != nil {
  352. return err
  353. }
  354. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  355. if err != nil {
  356. return err
  357. }
  358. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  359. if err != nil {
  360. return err
  361. }
  362. // 黑白IP
  363. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  364. if err != nil {
  365. return err
  366. }
  367. // 删除网站
  368. g, gCtx := errgroup.WithContext(ctx)
  369. g.Go(func() error {
  370. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  371. Ids: tcpIds,
  372. Uid: req.Uid,
  373. HostId: req.HostId,
  374. })
  375. if e != nil {
  376. return fmt.Errorf("删除TCP转发失败: %w", e)
  377. }
  378. return nil
  379. })
  380. g.Go(func() error {
  381. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  382. Ids: udpIds,
  383. Uid: req.Uid,
  384. HostId: req.HostId,
  385. })
  386. if e != nil {
  387. return fmt.Errorf("删除UDP转发失败: %w", e)
  388. }
  389. return nil
  390. })
  391. g.Go(func() error {
  392. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  393. Ids: webIds,
  394. Uid: req.Uid,
  395. HostId: req.HostId,
  396. })
  397. if e != nil {
  398. return fmt.Errorf("删除WEB转发失败: %w", e)
  399. }
  400. return nil
  401. })
  402. // 删除网站分组
  403. g.Go(func() error {
  404. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  405. if e != nil {
  406. return fmt.Errorf("删除网站分组失败: %w", e)
  407. }
  408. return nil
  409. })
  410. if err = g.Wait(); err != nil {
  411. return err
  412. }
  413. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  414. return err
  415. }
  416. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  417. return err
  418. }
  419. // 删除黑白名单
  420. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  421. HostId: req.HostId,
  422. Ids: BwIds,
  423. Uid: req.Uid,
  424. })
  425. if err != nil {
  426. return err
  427. }
  428. return nil
  429. }