globallimit.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. package waf
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. flexCdn2 "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/flexCdn"
  10. "github.com/go-nunu/nunu-layout-advanced/internal/repository/api/waf"
  11. "github.com/go-nunu/nunu-layout-advanced/internal/service"
  12. "github.com/go-nunu/nunu-layout-advanced/internal/service/api/flexCdn"
  13. "github.com/mozillazg/go-pinyin"
  14. "github.com/spf13/viper"
  15. "golang.org/x/sync/errgroup"
  16. "gorm.io/gorm"
  17. "strconv"
  18. "strings"
  19. "time"
  20. )
  21. type GlobalLimitService interface {
  22. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  23. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  24. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  25. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  26. GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error)
  27. }
  28. func NewGlobalLimitService(
  29. service *service.Service,
  30. globalLimitRepository waf.GlobalLimitRepository,
  31. duedate service.DuedateService,
  32. crawler service.CrawlerService,
  33. conf *viper.Viper,
  34. required service.RequiredService,
  35. parser service.ParserService,
  36. host service.HostService,
  37. hostRep repository.HostRepository,
  38. cdnService flexCdn.CdnService,
  39. cdnRep flexCdn2.CdnRepository,
  40. tcpforwardingRep waf.TcpforwardingRepository,
  41. udpForWardingRep waf.UdpForWardingRepository,
  42. webForWardingRep waf.WebForwardingRepository,
  43. allowAndDeny AllowAndDenyIpService,
  44. allowAndDenyRep waf.AllowAndDenyIpRepository,
  45. tcpforwarding TcpforwardingService,
  46. udpForWarding UdpForWardingService,
  47. webForWarding WebForwardingService,
  48. gatewayIpRep waf.GatewayipRepository,
  49. gatywayIp GatewayipService,
  50. bulidAudun BuildAudunService,
  51. ) GlobalLimitService {
  52. return &globalLimitService{
  53. Service: service,
  54. globalLimitRepository: globalLimitRepository,
  55. duedate: duedate,
  56. crawler: crawler,
  57. Url: conf.GetString("crawler.Url"),
  58. required: required,
  59. parser: parser,
  60. host: host,
  61. hostRep: hostRep,
  62. cdnService: cdnService,
  63. cdnRep: cdnRep,
  64. tcpforwardingRep: tcpforwardingRep,
  65. udpForWardingRep: udpForWardingRep,
  66. webForWardingRep: webForWardingRep,
  67. allowAndDeny: allowAndDeny,
  68. allowAndDenyRep: allowAndDenyRep,
  69. tcpforwarding: tcpforwarding,
  70. udpForWarding: udpForWarding,
  71. webForWarding: webForWarding,
  72. gatewayIpRep: gatewayIpRep,
  73. gatewayIp: gatywayIp,
  74. bulidAudun: bulidAudun,
  75. }
  76. }
  77. type globalLimitService struct {
  78. *service.Service
  79. globalLimitRepository waf.GlobalLimitRepository
  80. duedate service.DuedateService
  81. crawler service.CrawlerService
  82. Url string
  83. required service.RequiredService
  84. parser service.ParserService
  85. host service.HostService
  86. hostRep repository.HostRepository
  87. cdnService flexCdn.CdnService
  88. cdnRep flexCdn2.CdnRepository
  89. tcpforwardingRep waf.TcpforwardingRepository
  90. udpForWardingRep waf.UdpForWardingRepository
  91. webForWardingRep waf.WebForwardingRepository
  92. allowAndDeny AllowAndDenyIpService
  93. allowAndDenyRep waf.AllowAndDenyIpRepository
  94. tcpforwarding TcpforwardingService
  95. udpForWarding UdpForWardingService
  96. webForWarding WebForwardingService
  97. gatewayIpRep waf.GatewayipRepository
  98. gatewayIp GatewayipService
  99. bulidAudun BuildAudunService
  100. }
  101. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  102. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  103. if err != nil {
  104. if !errors.Is(err, gorm.ErrRecordNotFound) {
  105. return 0, err
  106. }
  107. }
  108. if data != nil && data.CdnUid != 0 {
  109. return int64(data.CdnUid), nil
  110. }
  111. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  112. if err != nil {
  113. return 0, err
  114. }
  115. // 中文转拼音
  116. a := pinyin.NewArgs()
  117. a.Style = pinyin.Normal
  118. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  119. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  120. // 查询用户是否存在
  121. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  122. if err != nil {
  123. return 0, err
  124. }
  125. if UserId != 0 {
  126. return UserId, nil
  127. }
  128. // 注册用户
  129. userId, err := s.cdnService.AddUser(ctx, v1.User{
  130. Username: userName,
  131. Email: userInfo.Email,
  132. Fullname: userInfo.Username,
  133. Mobile: userInfo.PhoneNumber,
  134. })
  135. if err != nil {
  136. return 0, err
  137. }
  138. return userId, nil
  139. }
  140. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  141. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  142. Name: groupName,
  143. })
  144. if err != nil {
  145. return 0, err
  146. }
  147. return groupId, nil
  148. }
  149. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  150. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  151. if err != nil {
  152. return v1.GlobalLimitRequireResponse{}, err
  153. }
  154. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  155. if err != nil {
  156. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  157. }
  158. res.MaxBytesMonth = configCount.MaxBytesMonth
  159. res.Operator = configCount.Operator
  160. res.IpCount = configCount.IpCount
  161. res.NodeArea = configCount.NodeArea
  162. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  163. res.IsBanUdp = configCount.IsBanUdp
  164. res.HostId = req.HostId
  165. res.Bps = configCount.Bps
  166. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  167. if err != nil {
  168. return v1.GlobalLimitRequireResponse{}, err
  169. }
  170. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  171. if err != nil {
  172. return v1.GlobalLimitRequireResponse{}, err
  173. }
  174. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  175. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  176. if err != nil {
  177. return v1.GlobalLimitRequireResponse{}, err
  178. }
  179. return res, nil
  180. }
  181. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  182. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  183. }
  184. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  185. // 2. 将字符串解析成 time.Time 对象
  186. // time.Parse 会根据你提供的布局来理解输入的字符串
  187. t, err := time.Parse("2006-01-02 15:04:05", req)
  188. if err != nil {
  189. // 如果输入的字符串格式和布局不匹配,这里会报错
  190. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  191. }
  192. // 3. 定义新的输出格式 "YYYY-MM-DD"
  193. outputLayout := "2006-01-02"
  194. // 4. 将 time.Time 对象格式化为新的字符串
  195. outputTimeStr := t.Format(outputLayout)
  196. return outputTimeStr, nil
  197. }
  198. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  199. t, err := time.Parse("2006-01-02 15:04:05", req)
  200. if err != nil {
  201. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  202. }
  203. expiredAt := t.Unix()
  204. return expiredAt, nil
  205. }
  206. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  207. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  208. if err != nil {
  209. return err
  210. }
  211. if isExist {
  212. return fmt.Errorf("配置限制已存在")
  213. }
  214. require, err := s.GlobalLimitRequire(ctx, req)
  215. if err != nil {
  216. return err
  217. }
  218. g, gCtx := errgroup.WithContext(ctx)
  219. var userId int64
  220. var groupId int64
  221. g.Go(func() error {
  222. e := s.gatewayIp.AddIpWhereHostIdNull(gCtx, int64(req.HostId),int64(req.Uid))
  223. if e != nil {
  224. return fmt.Errorf("获取网关组失败: %w", e)
  225. }
  226. return nil
  227. })
  228. g.Go(func() error {
  229. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  230. if e != nil {
  231. return fmt.Errorf("获取cdn用户失败: %w", e)
  232. }
  233. if res == 0 {
  234. return fmt.Errorf("获取cdn用户失败")
  235. }
  236. userId = res
  237. return nil
  238. })
  239. g.Go(func() error {
  240. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  241. if e != nil {
  242. return fmt.Errorf("创建规则分组失败: %w", e)
  243. }
  244. if res == 0 {
  245. return fmt.Errorf("创建规则分组失败")
  246. }
  247. groupId = res
  248. return nil
  249. })
  250. if err = g.Wait(); err != nil {
  251. return err
  252. }
  253. // 添加带宽限制
  254. err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "add")
  255. if err != nil {
  256. return err
  257. }
  258. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  259. if err != nil {
  260. return err
  261. }
  262. // 如果存在实例,恢复
  263. oldData,err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  264. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  265. return err
  266. }
  267. if oldData!= nil && oldData.Id != 0 {
  268. err = s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  269. HostId: req.HostId,
  270. Uid: req.Uid,
  271. Name: require.GlobalLimitName,
  272. GroupId: int(groupId),
  273. CdnUid: int(userId),
  274. Comment: req.Comment,
  275. ExpiredAt: expiredAt,
  276. State: true,
  277. })
  278. if err != nil {
  279. return err
  280. }
  281. return nil
  282. }
  283. // 如果不存在实例,创建
  284. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  285. HostId: req.HostId,
  286. Uid: req.Uid,
  287. Name: require.GlobalLimitName,
  288. GroupId: int(groupId),
  289. CdnUid: int(userId),
  290. Comment: req.Comment,
  291. State: true,
  292. ExpiredAt: expiredAt,
  293. })
  294. if err != nil {
  295. return err
  296. }
  297. return nil
  298. }
  299. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  300. require, err := s.GlobalLimitRequire(ctx, req)
  301. if err != nil {
  302. return err
  303. }
  304. // 如果不存在实例,创建
  305. gatewayIp, err := s.gatewayIpRep.GetGatewayipByHostIdAll(ctx, int64(req.HostId))
  306. if err != nil {
  307. return err
  308. }
  309. if gatewayIp == nil {
  310. err = s.gatewayIp.AddIpWhereHostIdNull(ctx, int64(req.HostId), int64(req.Uid))
  311. if err != nil {
  312. return fmt.Errorf("获取网关组失败: %w", err)
  313. }
  314. }
  315. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  316. if err != nil {
  317. return err
  318. }
  319. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  320. HostId: req.HostId,
  321. Comment: req.Comment,
  322. ExpiredAt: expiredAt,
  323. }); err != nil {
  324. return err
  325. }
  326. return nil
  327. }
  328. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  329. // 检查是否过期
  330. isExpired, err := s.host.CheckExpired(ctx, int64(req.Uid), int64(req.HostId))
  331. if err != nil {
  332. return err
  333. }
  334. if isExpired {
  335. return fmt.Errorf("实例未过期,无法删除")
  336. }
  337. oldData, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  338. if err != nil {
  339. if errors.Is(err, gorm.ErrRecordNotFound) {
  340. return fmt.Errorf("实例不存在")
  341. }
  342. return err
  343. }
  344. tcpIds, err := s.tcpforwardingRep.GetTcpForwardingAllIdsByID(ctx,req.HostId)
  345. if err != nil {
  346. return err
  347. }
  348. udpIds, err := s.udpForWardingRep.GetUdpForwardingWafUdpAllIds(ctx,req.HostId)
  349. if err != nil {
  350. return err
  351. }
  352. webIds, err := s.webForWardingRep.GetWebForwardingWafWebAllIds(ctx,req.HostId)
  353. if err != nil {
  354. return err
  355. }
  356. // 删除带宽限制
  357. err = s.bulidAudun.Bandwidth(ctx, int64(req.HostId), "del")
  358. if err != nil {
  359. return err
  360. }
  361. // 黑白IP
  362. BwIds, err := s.allowAndDenyRep.GetIpCountListId(ctx, int64(req.HostId))
  363. if err != nil {
  364. return err
  365. }
  366. // 删除网站
  367. g, gCtx := errgroup.WithContext(ctx)
  368. g.Go(func() error {
  369. e := s.tcpforwarding.DeleteTcpForwarding(ctx,v1.DeleteTcpForwardingRequest{
  370. Ids: tcpIds,
  371. Uid: req.Uid,
  372. HostId: req.HostId,
  373. })
  374. if e != nil {
  375. return fmt.Errorf("删除TCP转发失败: %w", e)
  376. }
  377. return nil
  378. })
  379. g.Go(func() error {
  380. e := s.udpForWarding.DeleteUdpForwarding(ctx, v1.DeleteUdpForwardingRequest{
  381. Ids: udpIds,
  382. Uid: req.Uid,
  383. HostId: req.HostId,
  384. })
  385. if e != nil {
  386. return fmt.Errorf("删除UDP转发失败: %w", e)
  387. }
  388. return nil
  389. })
  390. g.Go(func() error {
  391. e := s.webForWarding.DeleteWebForwarding(ctx,v1.DeleteWebForwardingRequest{
  392. Ids: webIds,
  393. Uid: req.Uid,
  394. HostId: req.HostId,
  395. })
  396. if e != nil {
  397. return fmt.Errorf("删除WEB转发失败: %w", e)
  398. }
  399. return nil
  400. })
  401. // 删除网站分组
  402. g.Go(func() error {
  403. e := s.cdnService.DelServerGroup(gCtx, int64(oldData.GroupId))
  404. if e != nil {
  405. return fmt.Errorf("删除网站分组失败: %w", e)
  406. }
  407. return nil
  408. })
  409. if err = g.Wait(); err != nil {
  410. return err
  411. }
  412. if err := s.globalLimitRepository.EditHostState(ctx, int64(req.HostId), false); err != nil {
  413. return err
  414. }
  415. if err := s.gatewayIpRep.CleanIPByHostId(ctx, []int64{int64(req.HostId)}); err != nil {
  416. return err
  417. }
  418. // 删除黑白名单
  419. err = s.allowAndDeny.DeleteAllowAndDenyIps(ctx, v1.DelAllowAndDenyIpRequest{
  420. HostId: req.HostId,
  421. Ids: BwIds,
  422. Uid: req.Uid,
  423. })
  424. if err != nil {
  425. return err
  426. }
  427. return nil
  428. }