package repository import ( "context" "errors" "fmt" v1 "github.com/go-nunu/nunu-layout-advanced/api/v1" "github.com/go-nunu/nunu-layout-advanced/internal/model" "gorm.io/gorm" "time" ) type GlobalLimitRepository interface { GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) AddGlobalLimit(ctx context.Context, req *model.GlobalLimit) error UpdateGlobalLimitByHostId(ctx context.Context, req *model.GlobalLimit) error IsGlobalLimitExistByHostId(ctx context.Context, hostId int64) (bool, error) GetGlobalLimitByHostId(ctx context.Context, hostId int64) (*model.GlobalLimit, error) GetGlobalLimitAllExpired(ctx context.Context,ids []int) ([]v1.GlobalLimitExpiredByHost, error) GetGlobalLimitAllHostId(ctx context.Context) ([]v1.GlobalLimitExpired, error) GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error) GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error) GetHostName(ctx context.Context,hostId int64) (string, error) GetNodeId(ctx context.Context, cndWebId int) (int64, error) // 获取套餐Id GetNodeArea(ctx context.Context, nodeAreaName string) (int64, error) // 修改套餐状态 EditHostState(ctx context.Context, hostId int64, state bool) error // 获取指定到期时间 GetGlobalLimitAlmostExpired(ctx context.Context, addTime int64) ([]model.GlobalLimit, error) // GetGlobalLimitsByExpirationRange 获取在指定时间范围内到期的全局限制 GetGlobalLimitsByExpirationRange(ctx context.Context, startTime, endTime int64) ([]model.GlobalLimit, error) } func NewGlobalLimitRepository( repository *Repository, ) GlobalLimitRepository { return &globalLimitRepository{ Repository: repository, } } type globalLimitRepository struct { *Repository } func (r *globalLimitRepository) GetGlobalLimit(ctx context.Context, id int64) (*model.GlobalLimit, error) { var globalLimit model.GlobalLimit return &globalLimit, nil } func (r *globalLimitRepository) AddGlobalLimit(ctx context.Context, req *model.GlobalLimit) error { if err := r.DB(ctx).Create(&req).Error; err != nil { return err } return nil } func (r *globalLimitRepository) UpdateGlobalLimitByHostId(ctx context.Context, req *model.GlobalLimit) error { if err := r.DB(ctx).Where("host_id = ?", req.HostId).Updates(&req).Error; err != nil { return err } return nil } func (r *globalLimitRepository) IsGlobalLimitExistByHostId(ctx context.Context, hostId int64) (bool, error) { var count int64 err := r.DB(ctx).Model(&model.GlobalLimit{}).Where("host_id = ? AND state = false", hostId).Count(&count).Error if err != nil { return false, err } return count > 0, nil } func (r *globalLimitRepository) GetGlobalLimitByHostId(ctx context.Context, hostId int64) (*model.GlobalLimit, error) { var globalLimit model.GlobalLimit if err := r.DB(ctx).Where("host_id = ?", hostId).First(&globalLimit).Error; err != nil { return nil, err } return &globalLimit, nil } func (r *globalLimitRepository) GetGlobalLimitAllExpired(ctx context.Context,ids []int) ([]v1.GlobalLimitExpiredByHost, error) { var res []v1.GlobalLimitExpiredByHost threeDaysDuration := 30 * 24 * time.Hour targetTime := time.Now().Add(threeDaysDuration) targetTimestamp := targetTime.Unix() if err := r.DB(ctx).Table("shd_host"). Where("id IN (?)", ids). Where("nextduedate < ?", targetTimestamp). Select("id", "uid", "nextduedate"). Find(&res). Error; err != nil { return nil, err } return res, nil } func (r *globalLimitRepository) GetGlobalLimitAllHostId(ctx context.Context) ([]v1.GlobalLimitExpired, error) { var res []v1.GlobalLimitExpired if err := r.DB(ctx).Model(&model.GlobalLimit{}). Select("host_id", "rule_id","comment"). Find(&res).Error; err != nil { return nil, err } return res, nil } func (r *globalLimitRepository) GetGlobalLimitFirst(ctx context.Context,uid int64) (*model.GlobalLimit, error) { var req model.GlobalLimit if err := r.DB(ctx).Where("uid = ?", uid).First(&req).Error; err != nil { return nil, err } return &req, nil } func (r *globalLimitRepository) GetUserInfo(ctx context.Context, uid int64) (v1.UserInfo, error) { var res v1.UserInfo if err := r.DB(ctx).Table("shd_clients"). Where("id = ?", uid). Select("username", "email", "phonenumber"). Find(&res).Error; err != nil { return v1.UserInfo{}, err } return res, nil } func (r *globalLimitRepository) GetHostName(ctx context.Context,hostId int64) (string, error) { var projectName string err := r.db.WithContext(ctx).Table("shd_host"). Select("shd_products.name"). Joins("JOIN shd_products ON shd_host.productid = shd_products.id"). Where("shd_host.id = ?", hostId). Find(&projectName).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return "", fmt.Errorf("未找到 hostId 为 %d 的项目名称", hostId) } return "", fmt.Errorf("查询 host 和 project 名称失败: %w", err) } // 如果查询成功,返回项目名称 return projectName, nil } func (r *globalLimitRepository) GetNodeId(ctx context.Context, cndWebId int) (int64, error) { var nodeId int64 if err := r.DBWithName(ctx,"cdn").WithContext(ctx).Table("cloud_servers").Where("id = ?", cndWebId).Select("clusterId").Scan(&nodeId).Error; err != nil { return 0, err } return nodeId, nil } // 获取cdn套餐ID func (r *globalLimitRepository) GetNodeArea(ctx context.Context, nodeAreaName string) (int64, error) { var nodeId int64 if err := r.DBWithName(ctx,"cdn").WithContext(ctx).Table("cloud_plans").Where("name = ?", nodeAreaName).Select("id").Scan(&nodeId).Error; err != nil { return 0, err } return nodeId, nil } func (r *globalLimitRepository) EditHostState(ctx context.Context, hostId int64, state bool) error { if err := r.DB(ctx).Model(&model.GlobalLimit{}).Where("host_id = ?", hostId).Update("state", state).Error; err != nil { return err } return nil } // GetGlobalLimitsByExpirationRange 获取在指定时间范围内到期的全局限制 func (r *globalLimitRepository) GetGlobalLimitsByExpirationRange(ctx context.Context, startTime, endTime int64) ([]model.GlobalLimit, error) { var res []model.GlobalLimit db := r.DB(ctx).Where("state = ?", true) if startTime != 0 { db = db.Where("expired_at >= ?", startTime) } if endTime != 0 { db = db.Where("expired_at < ?", endTime) } if err := db.Find(&res).Error; err != nil { return nil, err } return res, nil } // 获取指定到期时间 func (r *globalLimitRepository) GetGlobalLimitAlmostExpired(ctx context.Context, addTime int64) ([]model.GlobalLimit, error) { var res []model.GlobalLimit expiredTime := time.Now().Unix() + addTime if err := r.DB(ctx). Where("expired_At < ?", expiredTime). Find(&res).Error; err != nil { return nil, err } return res, nil }