globallimit.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. package waf
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. flexCdn2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  11. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  12. "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
  13. "github.com/hashicorp/go-multierror"
  14. "github.com/mozillazg/go-pinyin"
  15. "github.com/spf13/viper"
  16. "golang.org/x/sync/errgroup"
  17. "gorm.io/gorm"
  18. "strconv"
  19. "strings"
  20. "sync"
  21. "time"
  22. )
  23. type GlobalLimitService interface {
  24. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  25. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  26. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  27. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  28. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  29. }
  30. func NewGlobalLimitService(
  31. service *service.Service,
  32. globalLimitRepository waf.GlobalLimitRepository,
  33. duedate service.DuedateService,
  34. crawler service.CrawlerService,
  35. conf *viper.Viper,
  36. required service.RequiredService,
  37. parser service.ParserService,
  38. host service.HostService,
  39. hostRep repository.HostRepository,
  40. cdnService flexCdn.CdnService,
  41. cdnRep flexCdn2.CdnRepository,
  42. tcpforwardingRep waf.TcpforwardingRepository,
  43. udpForWardingRep waf.UdpForWardingRepository,
  44. webForWardingRep waf.WebForwardingRepository,
  45. allowAndDeny AllowAndDenyIpService,
  46. allowAndDenyRep waf.AllowAndDenyIpRepository,
  47. tcpforwarding TcpforwardingService,
  48. udpForWarding UdpForWardingService,
  49. webForWarding WebForwardingService,
  50. gatewayIpRep waf.GatewayipRepository,
  51. gatywayIp GatewayipService,
  52. bulidAudun BuildAudunService,
  53. ) GlobalLimitService {
  54. return &globalLimitService{
  55. Service: service,
  56. globalLimitRepository: globalLimitRepository,
  57. duedate: duedate,
  58. crawler: crawler,
  59. Url: conf.GetString("crawler.Url"),
  60. required: required,
  61. parser: parser,
  62. host: host,
  63. hostRep: hostRep,
  64. cdnService: cdnService,
  65. cdnRep: cdnRep,
  66. tcpforwardingRep: tcpforwardingRep,
  67. udpForWardingRep: udpForWardingRep,
  68. webForWardingRep: webForWardingRep,
  69. allowAndDeny: allowAndDeny,
  70. allowAndDenyRep: allowAndDenyRep,
  71. tcpforwarding: tcpforwarding,
  72. udpForWarding: udpForWarding,
  73. webForWarding: webForWarding,
  74. gatewayIpRep: gatewayIpRep,
  75. gatewayIp: gatywayIp,
  76. bulidAudun: bulidAudun,
  77. }
  78. }
  79. type globalLimitService struct {
  80. *service.Service
  81. globalLimitRepository waf.GlobalLimitRepository
  82. duedate service.DuedateService
  83. crawler service.CrawlerService
  84. Url string
  85. required service.RequiredService
  86. parser service.ParserService
  87. host service.HostService
  88. hostRep repository.HostRepository
  89. cdnService flexCdn.CdnService
  90. cdnRep flexCdn2.CdnRepository
  91. tcpforwardingRep waf.TcpforwardingRepository
  92. udpForWardingRep waf.UdpForWardingRepository
  93. webForWardingRep waf.WebForwardingRepository
  94. allowAndDeny AllowAndDenyIpService
  95. allowAndDenyRep waf.AllowAndDenyIpRepository
  96. tcpforwarding TcpforwardingService
  97. udpForWarding UdpForWardingService
  98. webForWarding WebForwardingService
  99. gatewayIpRep waf.GatewayipRepository
  100. gatewayIp GatewayipService
  101. bulidAudun BuildAudunService
  102. }
  103. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  104. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  105. if err != nil {
  106. if !errors.Is(err, gorm.ErrRecordNotFound) {
  107. return 0, err
  108. }
  109. }
  110. if data != nil && data.CdnUid != 0 {
  111. return int64(data.CdnUid), nil
  112. }
  113. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  114. if err != nil {
  115. return 0, err
  116. }
  117. // 中文转拼音
  118. a := pinyin.NewArgs()
  119. a.Style = pinyin.Normal
  120. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  121. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  122. // 查询用户是否存在
  123. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  124. if err != nil {
  125. return 0, err
  126. }
  127. if UserId != 0 {
  128. return UserId, nil
  129. }
  130. // 注册用户
  131. userId, err := s.cdnService.AddUser(ctx, v1.User{
  132. Username: userName,
  133. Email: userInfo.Email,
  134. Fullname: userInfo.Username,
  135. Mobile: userInfo.PhoneNumber,
  136. })
  137. if err != nil {
  138. return 0, err
  139. }
  140. return userId, nil
  141. }
  142. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  143. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  144. Name: groupName,
  145. })
  146. if err != nil {
  147. return 0, err
  148. }
  149. return groupId, nil
  150. }
  151. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  152. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  153. if err != nil {
  154. return v1.GlobalLimitRequireResponse{}, err
  155. }
  156. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  157. if err != nil {
  158. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  159. }
  160. res.MaxBytesMonth = configCount.MaxBytesMonth
  161. res.Operator = configCount.Operator
  162. res.IpCount = configCount.IpCount
  163. res.NodeArea = configCount.NodeArea
  164. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  165. res.IsBanUdp = configCount.IsBanUdp
  166. res.HostId = req.HostId
  167. res.Bps = configCount.Bps
  168. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  169. if err != nil {
  170. return v1.GlobalLimitRequireResponse{}, err
  171. }
  172. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  173. if err != nil {
  174. return v1.GlobalLimitRequireResponse{}, err
  175. }
  176. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  177. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  178. if err != nil {
  179. return v1.GlobalLimitRequireResponse{}, err
  180. }
  181. return res, nil
  182. }
  183. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  184. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  185. }
  186. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  187. // 2. 将字符串解析成 time.Time 对象
  188. // time.Parse 会根据你提供的布局来理解输入的字符串
  189. t, err := time.Parse("2006-01-02 15:04:05", req)
  190. if err != nil {
  191. // 如果输入的字符串格式和布局不匹配,这里会报错
  192. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  193. }
  194. // 3. 定义新的输出格式 "YYYY-MM-DD"
  195. outputLayout := "2006-01-02"
  196. // 4. 将 time.Time 对象格式化为新的字符串
  197. outputTimeStr := t.Format(outputLayout)
  198. return outputTimeStr, nil
  199. }
  200. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  201. t, err := time.Parse("2006-01-02 15:04:05", req)
  202. if err != nil {
  203. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  204. }
  205. expiredAt := t.Unix()
  206. return expiredAt, nil
  207. }
  208. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  209. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  210. if err != nil {
  211. return err
  212. }
  213. if isExist {
  214. return fmt.Errorf("配置限制已存在")
  215. }
  216. require, err := s.GlobalLimitRequire(ctx, req)
  217. if err != nil {
  218. return err
  219. }
  220. g, gCtx := errgroup.WithContext(ctx)
  221. var userId int64
  222. var groupId int64
  223. g.Go(func() error {
  224. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  225. if e != nil {
  226. return fmt.Errorf("获取网关组失败: %w", e)
  227. }
  228. return nil
  229. })
  230. g.Go(func() error {
  231. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  232. if e != nil {
  233. return fmt.Errorf("获取cdn用户失败: %w", e)
  234. }
  235. if res == 0 {
  236. return fmt.Errorf("获取cdn用户失败")
  237. }
  238. userId = res
  239. return nil
  240. })
  241. g.Go(func() error {
  242. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  243. if e != nil {
  244. return fmt.Errorf("创建规则分组失败: %w", e)
  245. }
  246. if res == 0 {
  247. return fmt.Errorf("创建规则分组失败")
  248. }
  249. groupId = res
  250. return nil
  251. })
  252. if err = g.Wait(); err != nil {
  253. return err
  254. }
  255. // 添加带宽限制
  256. ip, err := s.gatewayIpRep.GetGatewayipOnlyIpByHostIdAll(ctx, int64(req.HostId))
  257. if err != nil {
  258. return err
  259. }
  260. bpsInt, err := strconv.Atoi(require.Bps)
  261. if err != nil {
  262. return err
  263. }
  264. var wg sync.WaitGroup
  265. wg.Add(len(ip))
  266. var errChan = make(chan error, len(ip))
  267. if ip != nil {
  268. for _, v := range ip {
  269. go func(v string) {
  270. defer wg.Done()
  271. err := s.bulidAudun.AddBandwidth(ctx, v1.Bandwidth{
  272. Name: require.Bps,
  273. ServerIPStart: v,
  274. SpeedlimitOut: int64(bpsInt),
  275. })
  276. if err != nil {
  277. errChan <- err
  278. }
  279. }(v)
  280. }
  281. wg.Wait()
  282. close(errChan)
  283. var allErrors error
  284. for err := range errChan {
  285. allErrors = multierror.Append(allErrors, err)
  286. }
  287. if allErrors != nil {
  288. return allErrors
  289. }
  290. }
  291. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  292. if err != nil {
  293. return err
  294. }
  295. // 如果存在实例,恢复
  296. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  297. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  298. return err
  299. }
  300. if oldData!= nil && oldData.Id != 0 {
  301. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  302. HostId: req.HostId,
  303. Uid: req.Uid,
  304. Name: require.GlobalLimitName,
  305. GroupId: int(groupId),
  306. CdnUid: int(userId),
  307. Comment: req.Comment,
  308. ExpiredAt: expiredAt,
  309. State: true,
  310. })
  311. if err != nil {
  312. return err
  313. }
  314. return nil
  315. }
  316. // 如果不存在实例,创建
  317. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  318. HostId: req.HostId,
  319. Uid: req.Uid,
  320. Name: require.GlobalLimitName,
  321. GroupId: int(groupId),
  322. CdnUid: int(userId),
  323. Comment: req.Comment,
  324. State: true,
  325. ExpiredAt: expiredAt,
  326. })
  327. if err != nil {
  328. return err
  329. }
  330. return nil
  331. }
  332. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  333. require, err := s.GlobalLimitRequire(ctx, req)
  334. if err != nil {
  335. return err
  336. }
  337. // 如果不存在实例,创建
  338. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  339. if err != nil {
  340. return err
  341. }
  342. if gatewayIp == nil {
  343. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  344. if err != nil {
  345. return fmt.Errorf("获取网关组失败: %w", err)
  346. }
  347. }
  348. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  349. if err != nil {
  350. return err
  351. }
  352. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  353. HostId: req.HostId,
  354. Comment: req.Comment,
  355. ExpiredAt: expiredAt,
  356. }); err != nil {
  357. return err
  358. }
  359. return nil
  360. }
  361. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  362. // 检查是否过期
  363. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  364. if err != nil {
  365. return err
  366. }
  367. if isExpired {
  368. return fmt.Errorf("实例未过期,无法删除")
  369. }
  370. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  371. if err != nil {
  372. if errors.Is(err, gorm.ErrRecordNotFound) {
  373. return fmt.Errorf("实例不存在")
  374. }
  375. return err
  376. }
  377. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  378. if err != nil {
  379. return err
  380. }
  381. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  382. if err != nil {
  383. return err
  384. }
  385. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  386. if err != nil {
  387. return err
  388. }
  389. // 黑白IP
  390. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  391. if err != nil {
  392. return err
  393. }
  394. // 删除网站
  395. g, gCtx := errgroup.WithContext(ctx)
  396. g.Go(func() error {
  397. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  398. Ids: tcpIds,
  399. Uid: req.Uid,
  400. HostId: req.HostId,
  401. })
  402. if e != nil {
  403. return fmt.Errorf("删除TCP转发失败: %w", e)
  404. }
  405. return nil
  406. })
  407. g.Go(func() error {
  408. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  409. Ids: udpIds,
  410. Uid: req.Uid,
  411. HostId: req.HostId,
  412. })
  413. if e != nil {
  414. return fmt.Errorf("删除UDP转发失败: %w", e)
  415. }
  416. return nil
  417. })
  418. g.Go(func() error {
  419. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  420. Ids: webIds,
  421. Uid: req.Uid,
  422. HostId: req.HostId,
  423. })
  424. if e != nil {
  425. return fmt.Errorf("删除WEB转发失败: %w", e)
  426. }
  427. return nil
  428. })
  429. // 删除网站分组
  430. g.Go(func() error {
  431. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  432. if e != nil {
  433. return fmt.Errorf("删除网站分组失败: %w", e)
  434. }
  435. return nil
  436. })
  437. if err = g.Wait(); err != nil {
  438. return err
  439. }
  440. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  441. return err
  442. }
  443. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  444. return err
  445. }
  446. // 删除黑白名单
  447. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  448. HostId: req.HostId,
  449. Ids: BwIds,
  450. Uid: req.Uid,
  451. })
  452. if err != nil {
  453. return err
  454. }
  455. return nil
  456. }