globallimit.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
  7. "github.com/go-nunu/nunu-layout-advanced/internal/model"
  8. "github.com/go-nunu/nunu-layout-advanced/internal/repository"
  9. "github.com/mozillazg/go-pinyin"
  10. "github.com/spf13/viper"
  11. "golang.org/x/sync/errgroup"
  12. "gorm.io/gorm"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. type GlobalLimitService interface {
  18. GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error)
  19. AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  20. EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  21. DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error
  22. }
  23. func NewGlobalLimitService(
  24. service *Service,
  25. globalLimitRepository repository.GlobalLimitRepository,
  26. duedate DuedateService,
  27. crawler CrawlerService,
  28. conf *viper.Viper,
  29. required RequiredService,
  30. parser ParserService,
  31. host HostService,
  32. tcpLimit TcpLimitService,
  33. udpLimit UdpLimitService,
  34. webLimit WebLimitService,
  35. gateWayGroup GatewayGroupService,
  36. hostRep repository.HostRepository,
  37. gateWayGroupRep repository.GatewayGroupRepository,
  38. cdnService CdnService,
  39. cdnRep repository.CdnRepository,
  40. ) GlobalLimitService {
  41. return &globalLimitService{
  42. Service: service,
  43. globalLimitRepository: globalLimitRepository,
  44. duedate: duedate,
  45. crawler: crawler,
  46. Url: conf.GetString("crawler.Url"),
  47. required: required,
  48. parser: parser,
  49. host: host,
  50. tcpLimit: tcpLimit,
  51. udpLimit: udpLimit,
  52. webLimit: webLimit,
  53. gateWayGroup: gateWayGroup,
  54. hostRep: hostRep,
  55. gateWayGroupRep: gateWayGroupRep,
  56. cdnService: cdnService,
  57. cdnRep: cdnRep,
  58. }
  59. }
  60. type globalLimitService struct {
  61. *Service
  62. globalLimitRepository repository.GlobalLimitRepository
  63. duedate DuedateService
  64. crawler CrawlerService
  65. Url string
  66. required RequiredService
  67. parser ParserService
  68. host HostService
  69. tcpLimit TcpLimitService
  70. udpLimit UdpLimitService
  71. webLimit WebLimitService
  72. gateWayGroup GatewayGroupService
  73. hostRep repository.HostRepository
  74. gateWayGroupRep repository.GatewayGroupRepository
  75. cdnService CdnService
  76. cdnRep repository.CdnRepository
  77. }
  78. func (s *globalLimitService) GetCdnUserId(ctx context.Context, uid int64) (int64, error) {
  79. data, err := s.globalLimitRepository.GetGlobalLimitFirst(ctx, uid)
  80. if err != nil {
  81. if !errors.Is(err, gorm.ErrRecordNotFound) {
  82. return 0, err
  83. }
  84. }
  85. if data != nil && data.CdnUid != 0 {
  86. return int64(data.CdnUid), nil
  87. }
  88. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, uid)
  89. if err != nil {
  90. return 0, err
  91. }
  92. // 中文转拼音
  93. a := pinyin.NewArgs()
  94. a.Style = pinyin.Normal
  95. pinyinSlice := pinyin.LazyPinyin(userInfo.Username, a)
  96. userName := strconv.Itoa(int(uid)) + "_" + strings.Join(pinyinSlice, "_")
  97. // 查询用户是否存在
  98. UserId,err := s.cdnRep.GetUserId(ctx, userName)
  99. if err != nil {
  100. return 0, err
  101. }
  102. if UserId != 0 {
  103. return UserId, nil
  104. }
  105. // 注册用户
  106. userId, err := s.cdnService.AddUser(ctx, v1.User{
  107. Username: userName,
  108. Email: userInfo.Email,
  109. Fullname: userInfo.Username,
  110. Mobile: userInfo.PhoneNumber,
  111. })
  112. if err != nil {
  113. return 0, err
  114. }
  115. return userId, nil
  116. }
  117. func (s *globalLimitService) AddGroupId(ctx context.Context,groupName string) (int64, error) {
  118. groupId, err := s.cdnService.CreateGroup(ctx, v1.Group{
  119. Name: groupName,
  120. })
  121. if err != nil {
  122. return 0, err
  123. }
  124. return groupId, nil
  125. }
  126. func (s *globalLimitService) GlobalLimitRequire(ctx context.Context, req v1.GlobalLimitRequest) (res v1.GlobalLimitRequireResponse, err error) {
  127. res.ExpiredAt, err = s.duedate.NextDueDate(ctx, req.Uid, req.HostId)
  128. if err != nil {
  129. return v1.GlobalLimitRequireResponse{}, err
  130. }
  131. configCount, err := s.host.GetGlobalLimitConfig(ctx, req.HostId)
  132. if err != nil {
  133. return v1.GlobalLimitRequireResponse{}, fmt.Errorf("获取配置限制失败: %w", err)
  134. }
  135. bpsInt, err := strconv.Atoi(configCount.Bps)
  136. if err != nil {
  137. return v1.GlobalLimitRequireResponse{}, err
  138. }
  139. resultFloat := float64(bpsInt) / 2.0 / 8.0
  140. res.Bps = strconv.FormatFloat( resultFloat, 'f', -1, 64) + "M"
  141. res.MaxBytesMonth = configCount.MaxBytesMonth
  142. res.Operator = configCount.Operator
  143. res.IpCount = configCount.IpCount
  144. res.NodeArea = configCount.NodeArea
  145. res.ConfigMaxProtection = configCount.ConfigMaxProtection
  146. domain, err := s.hostRep.GetDomainById(ctx, req.HostId)
  147. if err != nil {
  148. return v1.GlobalLimitRequireResponse{}, err
  149. }
  150. userInfo,err := s.globalLimitRepository.GetUserInfo(ctx, int64(req.Uid))
  151. if err != nil {
  152. return v1.GlobalLimitRequireResponse{}, err
  153. }
  154. res.GlobalLimitName = strconv.Itoa(req.Uid) + "_" + userInfo.Username + "_" + strconv.Itoa(req.HostId) + "_" + domain
  155. res.HostName, err = s.globalLimitRepository.GetHostName(ctx, int64(req.HostId))
  156. if err != nil {
  157. return v1.GlobalLimitRequireResponse{}, err
  158. }
  159. return res, nil
  160. }
  161. func (s *globalLimitService) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) {
  162. return s.globalLimitRepository.GetGlobalLimit(ctx, id)
  163. }
  164. func (s *globalLimitService) ConversionTime(ctx context.Context,req string) (string, error) {
  165. // 2. 将字符串解析成 time.Time 对象
  166. // time.Parse 会根据你提供的布局来理解输入的字符串
  167. t, err := time.Parse("2006-01-02 15:04:05", req)
  168. if err != nil {
  169. // 如果输入的字符串格式和布局不匹配,这里会报错
  170. return "", fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  171. }
  172. // 3. 定义新的输出格式 "YYYY-MM-DD"
  173. outputLayout := "2006-01-02"
  174. // 4. 将 time.Time 对象格式化为新的字符串
  175. outputTimeStr := t.Format(outputLayout)
  176. return outputTimeStr, nil
  177. }
  178. func (s *globalLimitService) ConversionTimeUnix(ctx context.Context,req string) (int64, error) {
  179. t, err := time.Parse("2006-01-02 15:04:05", req)
  180. if err != nil {
  181. return 0, fmt.Errorf("输入的字符串格式和布局不匹配 %w", err)
  182. }
  183. expiredAt := t.Unix()
  184. return expiredAt, nil
  185. }
  186. func (s *globalLimitService) AddGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  187. isExist, err := s.globalLimitRepository.IsGlobalLimitExistByHostId(ctx, int64(req.HostId))
  188. if err != nil {
  189. return err
  190. }
  191. if isExist {
  192. return fmt.Errorf("配置限制已存在")
  193. }
  194. require, err := s.GlobalLimitRequire(ctx, req)
  195. if err != nil {
  196. return err
  197. }
  198. g, gCtx := errgroup.WithContext(ctx)
  199. var gatewayGroupId int
  200. var userId int64
  201. var groupId int64
  202. g.Go(func() error {
  203. res, e := s.gateWayGroupRep.GetGatewayGroupWhereHostIdNull(gCtx, require.Operator, require.IpCount)
  204. if e != nil {
  205. return fmt.Errorf("获取网关组失败: %w", e)
  206. }
  207. if res == 0 {
  208. return fmt.Errorf("获取网关组失败")
  209. }
  210. gatewayGroupId = res
  211. return nil
  212. })
  213. g.Go(func() error {
  214. res, e := s.GetCdnUserId(gCtx, int64(req.Uid))
  215. if e != nil {
  216. return fmt.Errorf("获取cdn用户失败: %w", e)
  217. }
  218. if res == 0 {
  219. return fmt.Errorf("获取cdn用户失败")
  220. }
  221. userId = res
  222. return nil
  223. })
  224. g.Go(func() error {
  225. res, e := s.AddGroupId(gCtx, require.GlobalLimitName)
  226. if e != nil {
  227. return fmt.Errorf("创建规则分组失败: %w", e)
  228. }
  229. if res == 0 {
  230. return fmt.Errorf("创建规则分组失败")
  231. }
  232. groupId = res
  233. return nil
  234. })
  235. if err = g.Wait(); err != nil {
  236. return err
  237. }
  238. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  239. if err != nil {
  240. return err
  241. }
  242. // 获取套餐ID
  243. maxProtection := strings.Split(require.ConfigMaxProtection, "G")
  244. maxProtectionInt, err := strconv.Atoi(maxProtection[0])
  245. if err != nil {
  246. return err
  247. }
  248. var planId int64
  249. if maxProtectionInt < 1000 {
  250. NodeAreaName := fmt.Sprintf("%s-%dT",require.NodeArea, 1)
  251. planId, err = s.globalLimitRepository.GetNodeArea(ctx, NodeAreaName)
  252. if err != nil {
  253. return err
  254. }
  255. }
  256. if planId == 0 {
  257. planId = 6
  258. }
  259. ruleId, err := s.cdnService.BindPlan(ctx, v1.Plan{
  260. UserId: userId,
  261. PlanId: planId,
  262. DayTo: outputTimeStr,
  263. Name: require.GlobalLimitName,
  264. IsFree: true,
  265. Period: "monthly",
  266. CountPeriod: 1,
  267. PeriodDayTo: outputTimeStr,
  268. })
  269. if err != nil {
  270. return err
  271. }
  272. if ruleId == 0 {
  273. return fmt.Errorf("分配套餐失败")
  274. }
  275. err = s.gateWayGroupRep.EditGatewayGroup(ctx, &model.GatewayGroup{
  276. Id: gatewayGroupId,
  277. HostId: req.HostId,
  278. })
  279. if err != nil {
  280. return err
  281. }
  282. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  283. if err != nil {
  284. return err
  285. }
  286. err = s.globalLimitRepository.AddGlobalLimit(ctx, &model.GlobalLimit{
  287. HostId: req.HostId,
  288. Uid: req.Uid,
  289. Name: require.GlobalLimitName,
  290. RuleId: int(ruleId),
  291. GroupId: int(groupId),
  292. GatewayGroupId: gatewayGroupId,
  293. CdnUid: int(userId),
  294. Comment: req.Comment,
  295. ExpiredAt: expiredAt,
  296. })
  297. if err != nil {
  298. return err
  299. }
  300. return nil
  301. }
  302. func (s *globalLimitService) EditGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  303. require, err := s.GlobalLimitRequire(ctx, req)
  304. if err != nil {
  305. return err
  306. }
  307. data, err := s.globalLimitRepository.GetGlobalLimitByHostId(ctx, int64(req.HostId))
  308. if err != nil {
  309. return err
  310. }
  311. outputTimeStr, err := s.ConversionTime(ctx, require.ExpiredAt)
  312. if err != nil {
  313. return err
  314. }
  315. err = s.cdnService.RenewPlan(ctx, v1.RenewalPlan{
  316. UserPlanId: int64(data.RuleId),
  317. DayTo: outputTimeStr,
  318. Period: "monthly",
  319. CountPeriod: 1,
  320. IsFree: true,
  321. PeriodDayTo: outputTimeStr,
  322. })
  323. if err != nil {
  324. return err
  325. }
  326. expiredAt, err := s.ConversionTimeUnix(ctx, require.ExpiredAt)
  327. if err != nil {
  328. return err
  329. }
  330. if err := s.globalLimitRepository.UpdateGlobalLimitByHostId(ctx, &model.GlobalLimit{
  331. HostId: req.HostId,
  332. Comment: req.Comment,
  333. ExpiredAt: expiredAt,
  334. }); err != nil {
  335. return err
  336. }
  337. return nil
  338. }
  339. func (s *globalLimitService) DeleteGlobalLimit(ctx context.Context, req v1.GlobalLimitRequest) error {
  340. if err := s.globalLimitRepository.DeleteGlobalLimitByHostId(ctx, int64(req.HostId)); err != nil {
  341. return err
  342. }
  343. return nil
  344. }