globallimit.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. package waf
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. flexCdn2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  11. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  12. "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
  13. "github.com/mozillazg/go-pinyin"
  14. "github.com/spf13/viper"
  15. "golang.org/x/sync/errgroup"
  16. "gorm.io/gorm"
  17. "strconv"
  18. "strings"
  19. "time"
  20. )
  21. type GlobalLimitService interface {
  22. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  23. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  24. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  25. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  26. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  27. }
  28. func NewGlobalLimitService(
  29. service *service.Service,
  30. globalLimitRepository waf.GlobalLimitRepository,
  31. duedate service.DuedateService,
  32. crawler service.CrawlerService,
  33. conf *viper.Viper,
  34. required service.RequiredService,
  35. parser service.ParserService,
  36. host service.HostService,
  37. hostRep repository.HostRepository,
  38. cdnService flexCdn.CdnService,
  39. cdnRep flexCdn2.CdnRepository,
  40. tcpforwardingRep waf.TcpforwardingRepository,
  41. udpForWardingRep waf.UdpForWardingRepository,
  42. webForWardingRep waf.WebForwardingRepository,
  43. allowAndDeny AllowAndDenyIpService,
  44. allowAndDenyRep waf.AllowAndDenyIpRepository,
  45. tcpforwarding TcpforwardingService,
  46. udpForWarding UdpForWardingService,
  47. webForWarding WebForwardingService,
  48. gatewayIpRep waf.GatewayipRepository,
  49. gatywayIp GatewayipService,
  50. bulidAudun BuildAudunService,
  51. zzyBgp ZzybgpService,
  52. ) GlobalLimitService {
  53. return &globalLimitService{
  54. Service: service,
  55. globalLimitRepository: globalLimitRepository,
  56. duedate: duedate,
  57. crawler: crawler,
  58. Url: conf.GetString("crawler.Url"),
  59. required: required,
  60. parser: parser,
  61. host: host,
  62. hostRep: hostRep,
  63. cdnService: cdnService,
  64. cdnRep: cdnRep,
  65. tcpforwardingRep: tcpforwardingRep,
  66. udpForWardingRep: udpForWardingRep,
  67. webForWardingRep: webForWardingRep,
  68. allowAndDeny: allowAndDeny,
  69. allowAndDenyRep: allowAndDenyRep,
  70. tcpforwarding: tcpforwarding,
  71. udpForWarding: udpForWarding,
  72. webForWarding: webForWarding,
  73. gatewayIpRep: gatewayIpRep,
  74. gatewayIp: gatywayIp,
  75. bulidAudun: bulidAudun,
  76. zzyBgp: zzyBgp,
  77. }
  78. }
  79. type globalLimitService struct {
  80. *service.Service
  81. globalLimitRepository waf.GlobalLimitRepository
  82. duedate service.DuedateService
  83. crawler service.CrawlerService
  84. Url string
  85. required service.RequiredService
  86. parser service.ParserService
  87. host service.HostService
  88. hostRep repository.HostRepository
  89. cdnService flexCdn.CdnService
  90. cdnRep flexCdn2.CdnRepository
  91. tcpforwardingRep waf.TcpforwardingRepository
  92. udpForWardingRep waf.UdpForWardingRepository
  93. webForWardingRep waf.WebForwardingRepository
  94. allowAndDeny AllowAndDenyIpService
  95. allowAndDenyRep waf.AllowAndDenyIpRepository
  96. tcpforwarding TcpforwardingService
  97. udpForWarding UdpForWardingService
  98. webForWarding WebForwardingService
  99. gatewayIpRep waf.GatewayipRepository
  100. gatewayIp GatewayipService
  101. bulidAudun BuildAudunService
  102. zzyBgp ZzybgpService
  103. }
  104. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  105. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  106. if err != nil {
  107. if !errors.Is(err, gorm.ErrRecordNotFound) {
  108. return 0, err
  109. }
  110. }
  111. if data != nil && data.CdnUid != 0 {
  112. return int64(data.CdnUid), nil
  113. }
  114. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  115. if err != nil {
  116. return 0, err
  117. }
  118. // 中文转拼音
  119. a := pinyin.NewArgs()
  120. a.Style = pinyin.Normal
  121. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  122. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  123. // 查询用户是否存在
  124. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  125. if err != nil {
  126. return 0, err
  127. }
  128. if UserId != 0 {
  129. return UserId, nil
  130. }
  131. // 注册用户
  132. userId, err := s.cdnService.AddUser(ctx, v1.User{
  133. Username: userName,
  134. Email: userInfo.Email,
  135. Fullname: userInfo.Username,
  136. Mobile: userInfo.PhoneNumber,
  137. })
  138. if err != nil {
  139. return 0, err
  140. }
  141. return userId, nil
  142. }
  143. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  144. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  145. Name: groupName,
  146. })
  147. if err != nil {
  148. return 0, err
  149. }
  150. return groupId, nil
  151. }
  152. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  153. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  154. if err != nil {
  155. return v1.GlobalLimitRequireResponse{}, err
  156. }
  157. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  158. if err != nil {
  159. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  160. }
  161. res.MaxBytesMonth = configCount.MaxBytesMonth
  162. res.Operator = configCount.Operator
  163. res.IpCount = configCount.IpCount
  164. res.NodeArea = configCount.NodeArea
  165. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  166. res.IsBanUdp = configCount.IsBanUdp
  167. res.HostId = req.HostId
  168. res.Bps = configCount.Bps
  169. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  170. if err != nil {
  171. return v1.GlobalLimitRequireResponse{}, err
  172. }
  173. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  174. if err != nil {
  175. return v1.GlobalLimitRequireResponse{}, err
  176. }
  177. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  178. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  179. if err != nil {
  180. return v1.GlobalLimitRequireResponse{}, err
  181. }
  182. return res, nil
  183. }
  184. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  185. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  186. }
  187. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  188. // 2. 将字符串解析成 time.Time 对象
  189. // time.Parse 会根据你提供的布局来理解输入的字符串
  190. t, err := time.Parse("2006-01-02 15:04:05", req)
  191. if err != nil {
  192. // 如果输入的字符串格式和布局不匹配,这里会报错
  193. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  194. }
  195. // 3. 定义新的输出格式 "YYYY-MM-DD"
  196. outputLayout := "2006-01-02"
  197. // 4. 将 time.Time 对象格式化为新的字符串
  198. outputTimeStr := t.Format(outputLayout)
  199. return outputTimeStr, nil
  200. }
  201. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  202. t, err := time.Parse("2006-01-02 15:04:05", req)
  203. if err != nil {
  204. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  205. }
  206. expiredAt := t.Unix()
  207. return expiredAt, nil
  208. }
  209. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  210. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  211. if err != nil {
  212. return err
  213. }
  214. if isExist {
  215. return fmt.Errorf("配置限制已存在")
  216. }
  217. require, err := s.GlobalLimitRequire(ctx, req)
  218. if err != nil {
  219. return err
  220. }
  221. g, gCtx := errgroup.WithContext(ctx)
  222. var userId int64
  223. var groupId int64
  224. g.Go(func() error {
  225. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  226. if e != nil {
  227. return fmt.Errorf("获取网关组失败: %w", e)
  228. }
  229. return nil
  230. })
  231. g.Go(func() error {
  232. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  233. if e != nil {
  234. return fmt.Errorf("获取cdn用户失败: %w", e)
  235. }
  236. if res == 0 {
  237. return fmt.Errorf("获取cdn用户失败")
  238. }
  239. userId = res
  240. return nil
  241. })
  242. g.Go(func() error {
  243. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  244. if e != nil {
  245. return fmt.Errorf("创建规则分组失败: %w", e)
  246. }
  247. if res == 0 {
  248. return fmt.Errorf("创建规则分组失败")
  249. }
  250. groupId = res
  251. return nil
  252. })
  253. if err = g.Wait(); err != nil {
  254. return err
  255. }
  256. // 添加防护
  257. err = s.zzyBgp.SetDefense(ctx, int64(req.HostId), 0)
  258. if err != nil {
  259. return err
  260. }
  261. // 添加带宽限制
  262. err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "add")
  263. if err != nil {
  264. return err
  265. }
  266. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  267. if err != nil {
  268. return err
  269. }
  270. // 如果存在实例,恢复
  271. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  272. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  273. return err
  274. }
  275. if oldData!= nil && oldData.Id != 0 {
  276. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  277. HostId: req.HostId,
  278. Uid: req.Uid,
  279. Name: require.GlobalLimitName,
  280. GroupId: int(groupId),
  281. CdnUid: int(userId),
  282. Comment: req.Comment,
  283. ExpiredAt: expiredAt,
  284. State: true,
  285. })
  286. if err != nil {
  287. return err
  288. }
  289. return nil
  290. }
  291. // 如果不存在实例,创建
  292. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  293. HostId: req.HostId,
  294. Uid: req.Uid,
  295. Name: require.GlobalLimitName,
  296. GroupId: int(groupId),
  297. CdnUid: int(userId),
  298. Comment: req.Comment,
  299. State: true,
  300. ExpiredAt: expiredAt,
  301. })
  302. if err != nil {
  303. return err
  304. }
  305. return nil
  306. }
  307. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  308. require, err := s.GlobalLimitRequire(ctx, req)
  309. if err != nil {
  310. return err
  311. }
  312. // 如果不存在实例,创建
  313. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  314. if err != nil {
  315. return err
  316. }
  317. if gatewayIp == nil {
  318. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  319. if err != nil {
  320. return fmt.Errorf("获取网关组失败: %w", err)
  321. }
  322. }
  323. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  324. if err != nil {
  325. return err
  326. }
  327. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  328. HostId: req.HostId,
  329. Comment: req.Comment,
  330. ExpiredAt: expiredAt,
  331. }); err != nil {
  332. return err
  333. }
  334. return nil
  335. }
  336. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  337. // 检查是否过期
  338. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  339. if err != nil {
  340. return err
  341. }
  342. if isExpired {
  343. return fmt.Errorf("实例未过期,无法删除")
  344. }
  345. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  346. if err != nil {
  347. if errors.Is(err, gorm.ErrRecordNotFound) {
  348. return fmt.Errorf("实例不存在")
  349. }
  350. return err
  351. }
  352. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  353. if err != nil {
  354. return err
  355. }
  356. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  357. if err != nil {
  358. return err
  359. }
  360. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  361. if err != nil {
  362. return err
  363. }
  364. // 重置防护
  365. err = s.zzyBgp.SetDefense(ctx, int64(req.HostId), 10)
  366. if err != nil {
  367. return err
  368. }
  369. // 删除带宽限制
  370. err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "del")
  371. if err != nil {
  372. return err
  373. }
  374. // 黑白IP
  375. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  376. if err != nil {
  377. return err
  378. }
  379. // 删除网站
  380. g, gCtx := errgroup.WithContext(ctx)
  381. g.Go(func() error {
  382. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  383. Ids: tcpIds,
  384. Uid: req.Uid,
  385. HostId: req.HostId,
  386. })
  387. if e != nil {
  388. return fmt.Errorf("删除TCP转发失败: %w", e)
  389. }
  390. return nil
  391. })
  392. g.Go(func() error {
  393. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  394. Ids: udpIds,
  395. Uid: req.Uid,
  396. HostId: req.HostId,
  397. })
  398. if e != nil {
  399. return fmt.Errorf("删除UDP转发失败: %w", e)
  400. }
  401. return nil
  402. })
  403. g.Go(func() error {
  404. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  405. Ids: webIds,
  406. Uid: req.Uid,
  407. HostId: req.HostId,
  408. })
  409. if e != nil {
  410. return fmt.Errorf("删除WEB转发失败: %w", e)
  411. }
  412. return nil
  413. })
  414. // 删除网站分组
  415. g.Go(func() error {
  416. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  417. if e != nil {
  418. return fmt.Errorf("删除网站分组失败: %w", e)
  419. }
  420. return nil
  421. })
  422. if err = g.Wait(); err != nil {
  423. return err
  424. }
  425. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  426. return err
  427. }
  428. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  429. return err
  430. }
  431. // 删除黑白名单
  432. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  433. HostId: req.HostId,
  434. Ids: BwIds,
  435. Uid: req.Uid,
  436. })
  437. if err != nil {
  438. return err
  439. }
  440. return nil
  441. }