Browse Source

add login demo

chris 2 years ago
parent
commit
16781d8eed

+ 2 - 2
cmd/server/main.go

@@ -13,12 +13,12 @@ func main() {
 	conf := config.NewConfig()
 	logger := log.NewLog(conf)
 
-	logger.Info("server start", zap.String("host", "http://127.0.0.1:"+conf.GetString("http.port")))
-
 	app, cleanup, err := wire.NewApp(conf, logger)
 	if err != nil {
 		panic(err)
 	}
+	logger.Info("server start", zap.String("host", "http://127.0.0.1:"+conf.GetString("http.port")))
+
 	http.Run(app, fmt.Sprintf(":%d", conf.GetInt("http.port")))
 	defer cleanup()
 

+ 1 - 1
cmd/server/wire/wire_gen.go

@@ -25,7 +25,7 @@ func NewApp(viperViper *viper.Viper, logger *log.Logger) (*gin.Engine, func(), e
 	jwt := middleware.NewJwt(viperViper)
 	sonyflakeSonyflake := sonyflake.NewSonyflake()
 	handlerHandler := handler.NewHandler(logger, sonyflakeSonyflake)
-	serviceService := service.NewService(logger)
+	serviceService := service.NewService(logger, sonyflakeSonyflake, jwt)
 	db := dao.NewDB(viperViper)
 	client := dao.NewRedis(viperViper)
 	daoDao := dao.NewDao(db, client, logger)

+ 1 - 1
config/local.yml

@@ -6,7 +6,7 @@ security:
     app_key: 123456
     app_security: 123456
   jwt:
-    key: 1234
+    key: QQYnRFerJTSEcrfB89fw8prOaObmrch8
 data:
   mysql:
     user: root:123456@tcp(127.0.0.1:3380)/user?charset=utf8mb4&parseTime=True&loc=Local

+ 1 - 1
config/prod.yml

@@ -6,7 +6,7 @@ security:
     app_key: 123456
     app_security: 123456
   jwt:
-    key: 1234
+    key: QQYnRFerJTSEcrfB89fw8prOaObmrch8
 data:
   mysql:
     user: root:123456@tcp(127.0.0.1:3380)/user?charset=utf8mb4&parseTime=True&loc=Local

+ 4 - 2
go.mod

@@ -4,15 +4,17 @@ go 1.16
 
 require (
 	github.com/gin-gonic/gin v1.9.1
-	github.com/golang-jwt/jwt/v5 v5.0.0 // indirect
+	github.com/golang-jwt/jwt/v5 v5.0.0
 	github.com/google/wire v0.5.0
-	github.com/redis/go-redis/v9 v9.0.5 // indirect
+	github.com/pkg/errors v0.9.1
+	github.com/redis/go-redis/v9 v9.0.5
 	github.com/robfig/cron v1.2.0
 	github.com/satori/go.uuid v1.2.0
 	github.com/sony/sonyflake v1.1.0
 	github.com/spf13/viper v1.16.0
 	github.com/stretchr/testify v1.8.4
 	go.uber.org/zap v1.24.0
+	golang.org/x/crypto v0.9.0
 	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 	gorm.io/driver/mysql v1.5.1
 	gorm.io/gorm v1.25.1

+ 2 - 0
go.sum

@@ -626,7 +626,9 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r
 github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
 github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
 github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
+github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao=
 github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w=
+github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y=
 github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=

+ 34 - 8
internal/dao/user.go

@@ -2,6 +2,8 @@ package dao
 
 import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"github.com/pkg/errors"
+	"gorm.io/gorm"
 )
 
 type UserDao struct {
@@ -10,21 +12,45 @@ type UserDao struct {
 
 func NewUserDao(dao *Dao) *UserDao {
 	return &UserDao{
-		Dao: dao,
+		dao,
 	}
 }
+func (d *UserDao) CreateUser(user *model.User) error {
+	if err := d.db.Create(user).Error; err != nil {
+		return errors.Wrap(err, "failed to create user")
+	}
+
+	return nil
+}
 
-func (r *UserDao) FirstById(id int64) (*model.User, error) {
+func (d *UserDao) GetUserById(userId string) (*model.User, error) {
 	var user model.User
-	if err := r.db.Where("id = ?", id).First(&user).Error; err != nil {
-		return nil, err
+	if err := d.db.Where("user_id = ?", userId).First(&user).Error; err != nil {
+		if err == gorm.ErrRecordNotFound {
+			return nil, nil
+		}
+		return nil, errors.Wrap(err, "failed to get user by ID")
 	}
+
 	return &user, nil
 }
 
-func (r *UserDao) CreateUser(user *model.User) (*model.User, error) {
-	if err := r.db.Create(user).Error; err != nil {
-		return nil, err
+func (d *UserDao) GetUserByUsername(username string) (*model.User, error) {
+	var user model.User
+	if err := d.db.Where("username = ?", username).First(&user).Error; err != nil {
+		if err == gorm.ErrRecordNotFound {
+			return nil, nil
+		}
+		return nil, errors.Wrap(err, "failed to get user by username")
 	}
-	return user, nil
+
+	return &user, nil
+}
+
+func (d *UserDao) UpdateUser(user *model.User) error {
+	if err := d.db.Save(user).Error; err != nil {
+		return errors.Wrap(err, "failed to update user")
+	}
+
+	return nil
 }

+ 9 - 0
internal/handler/handler.go

@@ -1,6 +1,8 @@
 package handler
 
 import (
+	"github.com/gin-gonic/gin"
+	"github.com/go-nunu/nunu-layout-advanced/internal/middleware"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"github.com/sony/sonyflake"
 )
@@ -14,3 +16,10 @@ func NewHandler(logger *log.Logger, sf *sonyflake.Sonyflake) *Handler {
 		logger: logger,
 	}
 }
+func GetUserIdFromCtx(ctx *gin.Context) string {
+	v, exists := ctx.Get("claims")
+	if !exists {
+		return ""
+	}
+	return v.(*middleware.MyCustomClaims).UserId
+}

+ 46 - 25
internal/handler/user.go

@@ -2,10 +2,9 @@ package handler
 
 import (
 	"github.com/gin-gonic/gin"
-	"github.com/go-nunu/nunu-layout-advanced/internal/model"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/helper/resp"
-	"go.uber.org/zap"
+	"github.com/pkg/errors"
 	"net/http"
 )
 
@@ -21,46 +20,68 @@ func NewUserHandler(handler *Handler, userService *service.UserService) *UserHan
 	}
 }
 
-func (c *UserHandler) CreateUser(ctx *gin.Context) {
+func (h *UserHandler) Register(ctx *gin.Context) {
+	req := new(service.RegisterRequest)
+	if err := ctx.ShouldBindJSON(req); err != nil {
+		resp.HandleError(ctx, http.StatusBadRequest, 1, errors.Wrap(err, "invalid request").Error(), nil)
+		return
+	}
 
-	var params struct {
-		Username string `json:"username" binding:"required,min=2,max=20"`
-		Email    string `json:"email" binding:"required,email"`
+	if err := h.userService.Register(req); err != nil {
+		resp.HandleError(ctx, http.StatusBadRequest, 1, errors.Wrap(err, "invalid request").Error(), nil)
+		return
 	}
-	if err := ctx.ShouldBind(&params); err != nil {
-		resp.HandleError(ctx, http.StatusBadRequest, 1, err.Error(), nil)
+
+	resp.HandleSuccess(ctx, nil)
+}
+
+func (h *UserHandler) Login(ctx *gin.Context) {
+	var req service.LoginRequest
+	if err := ctx.ShouldBindJSON(&req); err != nil {
+		resp.HandleError(ctx, http.StatusBadRequest, 1, errors.Wrap(err, "invalid request").Error(), nil)
 		return
 	}
 
-	user, err := c.userService.CreateUser(&model.User{
-		Username: params.Username,
-		Email:    params.Email,
-	})
-	c.logger.Info("CreateUser", zap.Any("user", user))
+	token, err := h.userService.Login(&req)
 	if err != nil {
-		resp.HandleError(ctx, http.StatusInternalServerError, 1, err.Error(), nil)
+		resp.HandleError(ctx, http.StatusUnauthorized, 1, err.Error(), nil)
 		return
 	}
-	resp.HandleSuccess(ctx, user)
+
+	resp.HandleSuccess(ctx, gin.H{
+		"accessToken": token,
+	})
 }
-func (c *UserHandler) GetUserById(ctx *gin.Context) {
 
-	var params struct {
-		Id int64 `form:"id" binding:"required"`
-	}
-	if err := ctx.ShouldBind(&params); err != nil {
-		resp.HandleError(ctx, http.StatusBadRequest, 1, err.Error(), nil)
+func (h *UserHandler) GetProfile(ctx *gin.Context) {
+	userId := GetUserIdFromCtx(ctx)
+	if userId == "" {
+		resp.HandleError(ctx, http.StatusUnauthorized, 1, "unauthorized", nil)
 		return
 	}
 
-	user, err := c.userService.GetUserById(params.Id)
-	c.logger.Info("GetUserByID", zap.Any("user", user))
+	user, err := h.userService.GetProfile(userId)
 	if err != nil {
-		resp.HandleError(ctx, http.StatusInternalServerError, 1, err.Error(), nil)
+		resp.HandleError(ctx, http.StatusBadRequest, 1, err.Error(), nil)
 		return
 	}
+
 	resp.HandleSuccess(ctx, user)
 }
-func (c *UserHandler) UpdateUser(ctx *gin.Context) {
+
+func (h *UserHandler) UpdateProfile(ctx *gin.Context) {
+	userId := GetUserIdFromCtx(ctx)
+
+	var req service.UpdateProfileRequest
+	if err := ctx.ShouldBindJSON(&req); err != nil {
+		resp.HandleError(ctx, http.StatusBadRequest, 1, errors.Wrap(err, "invalid request").Error(), nil)
+		return
+	}
+
+	if err := h.userService.UpdateProfile(userId, &req); err != nil {
+		resp.HandleError(ctx, http.StatusBadRequest, 1, err.Error(), nil)
+		return
+	}
+
 	resp.HandleSuccess(ctx, nil)
 }

+ 21 - 11
internal/middleware/jwt.go

@@ -1,7 +1,6 @@
 package middleware
 
 import (
-	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/helper/resp"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
@@ -10,29 +9,40 @@ import (
 	"go.uber.org/zap"
 	"net/http"
 	"regexp"
+	"time"
 )
 
 type JWT struct {
-	key string
+	key []byte
 }
 type MyCustomClaims struct {
-	UserId int64
+	UserId string
 	jwt.RegisteredClaims
 }
 
 // NewJwt https://pkg.go.dev/github.com/golang-jwt/jwt/v5
 func NewJwt(conf *viper.Viper) *JWT {
-	return &JWT{key: conf.GetString("security.jwt.key")}
+	return &JWT{key: []byte(conf.GetString("security.jwt.key"))}
 }
-func (j *JWT) GenToken() string {
-	token := jwt.NewWithClaims(jwt.SigningMethodHS512, MyCustomClaims{
-		UserId: 1,
+func (j *JWT) GenToken(userId string, expiresAt time.Time) string {
+	token := jwt.NewWithClaims(jwt.SigningMethodHS256, MyCustomClaims{
+		UserId: userId,
+		RegisteredClaims: jwt.RegisteredClaims{
+			ExpiresAt: jwt.NewNumericDate(expiresAt),
+			IssuedAt:  jwt.NewNumericDate(time.Now()),
+			NotBefore: jwt.NewNumericDate(time.Now()),
+			Issuer:    "",
+			Subject:   "",
+			ID:        "",
+			Audience:  []string{},
+		},
 	})
 
 	// Sign and get the complete encoded token as a string using the key
 	tokenString, err := token.SignedString(j.key)
-
-	fmt.Println(tokenString, err)
+	if err != nil {
+		return ""
+	}
 	return tokenString
 
 }
@@ -40,7 +50,7 @@ func (j *JWT) ParseToken(tokenString string) (*MyCustomClaims, error) {
 	re, _ := regexp.Compile(`(?i)Bearer `)
 	tokenString = re.ReplaceAllString(tokenString, "")
 	token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
-		return []byte("AllYourBase"), nil
+		return j.key, nil
 	})
 
 	if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
@@ -113,5 +123,5 @@ func NoStrictAuth(j *JWT, logger *log.Logger) gin.HandlerFunc {
 
 func recoveryLoggerFunc(ctx *gin.Context, logger *log.Logger) {
 	userInfo := ctx.MustGet("claims").(*MyCustomClaims)
-	logger.NewContext(ctx, zap.Int64("UserId", userInfo.UserId))
+	logger.NewContext(ctx, zap.String("UserId", userInfo.UserId))
 }

+ 13 - 4
internal/model/user.go

@@ -1,11 +1,20 @@
 package model
 
-import "gorm.io/gorm"
+import (
+	"gorm.io/gorm"
+	"time"
+)
 
 type User struct {
-	gorm.Model
-	Username string `gorm:"not null"`
-	Email    string `gorm:"unique;not null"`
+	Id        uint   `gorm:"primarykey"`
+	UserId    string `gorm:"unique;not null"`
+	Username  string `gorm:"unique;not null"`
+	Nickname  string `gorm:"not null"`
+	Password  string `gorm:"not null"`
+	Email     string `gorm:"not null"`
+	CreatedAt time.Time
+	UpdatedAt time.Time
+	DeletedAt gorm.DeletedAt `gorm:"index"`
 }
 
 func (u *User) TableName() string {

+ 6 - 3
internal/server/http.go

@@ -25,24 +25,27 @@ func NewServerHTTP(
 	// 无权限路由
 	noAuthRouter := r.Group("/").Use(middleware.RequestLogMiddleware(logger))
 	{
-		noAuthRouter.GET("/user", userHandler.GetUserById)
+
 		noAuthRouter.GET("/", func(ctx *gin.Context) {
 			logger.WithContext(ctx).Info("hello")
 			resp.HandleSuccess(ctx, map[string]interface{}{
 				"say": "Hi Nunu!",
 			})
 		})
+
+		noAuthRouter.POST("/user/register", userHandler.Register)
+		noAuthRouter.POST("/user/login", userHandler.Login)
 	}
 	// 非严格权限路由
 	noStrictAuthRouter := r.Group("/").Use(middleware.NoStrictAuth(jwt, logger), middleware.RequestLogMiddleware(logger))
 	{
-		noStrictAuthRouter.POST("/user", userHandler.CreateUser)
+		noStrictAuthRouter.GET("/user", userHandler.GetProfile)
 	}
 
 	// 严格权限路由
 	strictAuthRouter := r.Group("/").Use(middleware.StrictAuth(jwt, logger), middleware.RequestLogMiddleware(logger))
 	{
-		strictAuthRouter.PUT("/user", userHandler.UpdateUser)
+		strictAuthRouter.PUT("/user", userHandler.UpdateProfile)
 	}
 
 	return r

+ 12 - 4
internal/service/service.go

@@ -1,13 +1,21 @@
 package service
 
-import "github.com/go-nunu/nunu-layout-advanced/pkg/log"
+import (
+	"github.com/go-nunu/nunu-layout-advanced/internal/middleware"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
+	"github.com/sony/sonyflake"
+)
 
 type Service struct {
-	logger *log.Logger
+	logger    *log.Logger
+	sonyflake *sonyflake.Sonyflake
+	jwt       *middleware.JWT
 }
 
-func NewService(logger *log.Logger) *Service {
+func NewService(logger *log.Logger, sonyflake *sonyflake.Sonyflake, jwt *middleware.JWT) *Service {
 	return &Service{
-		logger: logger,
+		logger:    logger,
+		sonyflake: sonyflake,
+		jwt:       jwt,
 	}
 }

+ 126 - 6
internal/service/user.go

@@ -3,23 +3,143 @@ package service
 import (
 	"github.com/go-nunu/nunu-layout-advanced/internal/dao"
 	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/helper/convert"
+	"github.com/golang-jwt/jwt/v5"
+	"github.com/pkg/errors"
+	"golang.org/x/crypto/bcrypt"
+	"time"
 )
 
+type RegisterRequest struct {
+	Username string `json:"username" binding:"required"`
+	Password string `json:"password" binding:"required"`
+	Email    string `json:"email" binding:"required,email"`
+}
+
+type LoginRequest struct {
+	Username string `json:"username" binding:"required"`
+	Password string `json:"password" binding:"required"`
+}
+
+type UpdateProfileRequest struct {
+	Nickname string `json:"nickname"`
+	Email    string `json:"email" binding:"required,email"`
+	Avatar   string `json:"avatar"`
+}
+
+type ChangePasswordRequest struct {
+	OldPassword string `json:"oldPassword" binding:"required"`
+	NewPassword string `json:"newPassword" binding:"required"`
+}
+
 type UserService struct {
-	*Service
 	userDao *dao.UserDao
+	*Service
 }
 
 func NewUserService(service *Service, userDao *dao.UserDao) *UserService {
 	return &UserService{
-		Service: service,
 		userDao: userDao,
+		Service: service,
+	}
+}
+
+func (s *UserService) Register(req *RegisterRequest) error {
+	// 生成用户ID
+	userId, err := s.generateUserId()
+	if err != nil {
+		return errors.Wrap(err, "failed to generate user ID")
+	}
+
+	// 检查用户名是否已存在
+	if user, err := s.userDao.GetUserByUsername(req.Username); err == nil && user != nil {
+		return errors.New("username already exists")
+	}
+
+	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
+	if err != nil {
+		return errors.Wrap(err, "failed to hash password")
+	}
+
+	// 创建用户
+	user := &model.User{
+		UserId:   userId,
+		Username: req.Username,
+		Password: string(hashedPassword),
+		Email:    req.Email,
+	}
+	if err = s.userDao.CreateUser(user); err != nil {
+		return errors.Wrap(err, "failed to create user")
+	}
+
+	return nil
+}
+
+func (s *UserService) Login(req *LoginRequest) (string, error) {
+	user, err := s.userDao.GetUserByUsername(req.Username)
+	if err != nil || user == nil {
+		return "", errors.Wrap(err, "failed to get user by username")
+	}
+
+	err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password))
+	if err != nil {
+		return "", errors.Wrap(err, "failed to hash password")
+	}
+	// 生成JWT token
+	token, err := s.generateToken(user.UserId)
+	if err != nil {
+		return "", errors.Wrap(err, "failed to generate JWT token")
 	}
+
+	return token, nil
 }
 
-func (s *UserService) GetUserById(id int64) (*model.User, error) {
-	return s.userDao.FirstById(id)
+func (s *UserService) GetProfile(userId string) (*model.User, error) {
+	user, err := s.userDao.GetUserById(userId)
+	if err != nil {
+		return nil, errors.Wrap(err, "failed to get user by ID")
+	}
+
+	return user, nil
 }
-func (s *UserService) CreateUser(user *model.User) (*model.User, error) {
-	return s.userDao.CreateUser(user)
+
+func (s *UserService) UpdateProfile(userId string, req *UpdateProfileRequest) error {
+	user, err := s.userDao.GetUserById(userId)
+	if err != nil {
+		return errors.Wrap(err, "failed to get user by ID")
+	}
+
+	user.Email = req.Email
+	user.Nickname = req.Nickname
+
+	if err = s.userDao.UpdateUser(user); err != nil {
+		return errors.Wrap(err, "failed to update user")
+	}
+
+	return nil
+}
+
+func (s *UserService) generateUserId() (string, error) {
+	// 生成分布式ID
+	id, err := s.sonyflake.NextID()
+	if err != nil {
+		return "", errors.Wrap(err, "failed to generate sonyflake ID")
+	}
+
+	// 将ID转换为字符串
+	return convert.IntToBase62(int(id)), nil
+}
+
+func (s *UserService) generateToken(userId string) (string, error) {
+	// 生成JWT token
+	s.jwt.GenToken(userId, time.Now().Add(time.Hour*24*90))
+	token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+		"userId": userId,
+		"exp":    time.Now().Add(time.Hour * 24).Unix(),
+	}).SignedString([]byte("secret"))
+	if err != nil {
+		return "", errors.Wrap(err, "failed to generate JWT token")
+	}
+
+	return token, nil
 }

+ 24 - 0
pkg/helper/convert/convert.go

@@ -0,0 +1,24 @@
+package convert
+
+const (
+	base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+)
+
+func IntToBase62(n int) string {
+	if n == 0 {
+		return string(base62[0])
+	}
+
+	var result []byte
+	for n > 0 {
+		result = append(result, base62[n%62])
+		n /= 62
+	}
+
+	// 反转字符串
+	for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
+		result[i], result[j] = result[j], result[i]
+	}
+
+	return string(result)
+}

+ 7 - 6
test/server/handler/user_test.go

@@ -17,7 +17,7 @@ import (
 )
 
 var headers = map[string]string{
-	"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySW5mbyI6eyJ1c2VyU2lkIjoiOHpsdGxQRzhXSCIsIm5pY2tuYW1lIjoi55CD55CDIiwidXNlcklkIjowfSwiZXhwIjoxNjg3NzcwMzYzLCJqdGkiOiI4emx0bFBHOFdIIiwiaXNzIjoiaHR0cHM6Ly90ZWh1Yi5jb20vYXBpIiwibmJmIjoxNjcyMjE3NzYzLCJzdWIiOiI4emx0bFBHOFdIIn0.G0sSUzj3GBANqj6dU7rSMsr44SARgYwH1ERwKUCaxsM",
+	"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJVc2VySWQiOiJ5aHM2SGVzZmdGIiwiZXhwIjoxNjkzOTE0ODgwLCJuYmYiOjE2ODYxMzg4ODAsImlhdCI6MTY4NjEzODg4MH0.NnFrZFgc_333a9PXqaoongmIDksNvQoHzgM_IhJM4MQ",
 }
 
 func TestMain(m *testing.M) {
@@ -65,9 +65,9 @@ func NewRequest(method, path string, header map[string]string, body io.Reader) (
 	return response, nil
 }
 
-func TestGetUserById(t *testing.T) {
+func TestGetProfile(t *testing.T) {
 	response, err := NewRequest("GET",
-		fmt.Sprintf("/user?id=%s", "-1"),
+		fmt.Sprintf("/user"),
 		headers,
 		nil,
 	)
@@ -76,13 +76,14 @@ func TestGetUserById(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(t, 1, response.Code)
 }
-func TestCreateUser(t *testing.T) {
+func TestUpdateProfile(t *testing.T) {
 	params, err := json.Marshal(map[string]interface{}{
 		"email":    "5303221@gmail.com",
-		"username": "test",
+		"username": "user1",
+		"nickname": "8888",
 	})
 	assert.Nil(t, err)
-	response, err := NewRequest("POST",
+	response, err := NewRequest("PUT",
 		"/user",
 		headers,
 		bytes.NewBuffer(params),

+ 21 - 13
test/server/service/user_test.go

@@ -3,12 +3,12 @@ package service
 import (
 	"fmt"
 	"github.com/go-nunu/nunu-layout-advanced/internal/dao"
-	"github.com/go-nunu/nunu-layout-advanced/internal/model"
+	"github.com/go-nunu/nunu-layout-advanced/internal/middleware"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/config"
+	"github.com/go-nunu/nunu-layout-advanced/pkg/helper/sonyflake"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
 	"github.com/stretchr/testify/assert"
-	"gorm.io/gorm"
 	"os"
 	"testing"
 )
@@ -25,8 +25,9 @@ func TestMain(m *testing.M) {
 	logger := log.NewLog(conf)
 	db := dao.NewDB(conf)
 	rdb := dao.NewRedis(conf)
-
-	srv := service.NewService(logger)
+	jwt := middleware.NewJwt(conf)
+	sf := sonyflake.NewSonyflake()
+	srv := service.NewService(logger, sf, jwt)
 	repo := dao.NewDao(db, rdb, logger)
 	userDao := dao.NewUserDao(repo)
 	userService = service.NewUserService(srv, userDao)
@@ -37,15 +38,22 @@ func TestMain(m *testing.M) {
 	os.Exit(code)
 
 }
-func TestGetUserByEmail(t *testing.T) {
-	_, err := userService.GetUserById(0)
-	assert.Equal(t, err, gorm.ErrRecordNotFound, "they should be equal")
+func TestRegister(t *testing.T) {
+	req := service.RegisterRequest{
+		Username: "user1",
+		Password: "123456",
+		Email:    "user1@mail.com",
+	}
+	err := userService.Register(&req)
+	assert.Equal(t, err, nil, "they should be equal")
 }
 
-func TestCreateUser(t *testing.T) {
-	_, err := userService.CreateUser(&model.User{
-		Username: "test",
-		Email:    "nunu@mail.com",
-	})
-	assert.NotEqual(t, err, nil, "they should be equal")
+func TestLogin(t *testing.T) {
+	req := service.LoginRequest{
+		Username: "user1",
+		Password: "123456",
+	}
+	token, err := userService.Login(&req)
+	assert.Equal(t, err, nil, "they should be equal")
+	t.Log("token", token)
 }