Explorar el Código

feat: Added database transactions and removed handler interface definitions

chris hace 1 año
padre
commit
55b0cbdc6a

+ 1 - 0
Makefile

@@ -14,6 +14,7 @@ bootstrap:
 mock:
 	mockgen -source=internal/service/user.go -destination test/mocks/service/user.go
 	mockgen -source=internal/repository/user.go -destination test/mocks/repository/user.go
+	mockgen -source=internal/repository/repository.go -destination test/mocks/repository/repository.go
 
 .PHONY: test
 test:

+ 1 - 1
api/v1/errors.go

@@ -9,5 +9,5 @@ var (
 	ErrInternalServerError = newError(500, "Internal Server Error")
 
 	// more biz errors
-	ErrUsernameAlreadyUse = newError(1001, "The username is already in use.")
+	ErrEmailAlreadyUse = newError(1001, "The email is already in use.")
 )

+ 3 - 6
api/v1/user.go

@@ -1,13 +1,12 @@
 package v1
 
 type RegisterRequest struct {
-	Username string `json:"username" binding:"required" example:"alan"`
-	Password string `json:"password" binding:"required" example:"123456"`
 	Email    string `json:"email" binding:"required,email" example:"1234@gmail.com"`
+	Password string `json:"password" binding:"required" example:"123456"`
 }
 
 type LoginRequest struct {
-	Username string `json:"username" binding:"required" example:"alan"`
+	Email    string `json:"email" binding:"required,email" example:"1234@gmail.com"`
 	Password string `json:"password" binding:"required" example:"123456"`
 }
 type LoginResponseData struct {
@@ -21,12 +20,10 @@ type LoginResponse struct {
 type UpdateProfileRequest struct {
 	Nickname string `json:"nickname" example:"alan"`
 	Email    string `json:"email" binding:"required,email" example:"1234@gmail.com"`
-	Avatar   string `json:"avatar" example:"xxxx"`
 }
 type GetProfileResponseData struct {
 	UserId   string `json:"userId"`
-	Nickname string `json:"nickname"`
-	Username string `json:"username"`
+	Nickname string `json:"nickname" example:"alan"`
 }
 type GetProfileResponse struct {
 	Response

+ 10 - 8
cmd/server/wire/wire.go

@@ -17,9 +17,12 @@ import (
 	"github.com/spf13/viper"
 )
 
-var handlerSet = wire.NewSet(
-	handler.NewHandler,
-	handler.NewUserHandler,
+var repositorySet = wire.NewSet(
+	repository.NewDB,
+	repository.NewRedis,
+	repository.NewRepository,
+	repository.NewTransaction,
+	repository.NewUserRepository,
 )
 
 var serviceSet = wire.NewSet(
@@ -27,12 +30,11 @@ var serviceSet = wire.NewSet(
 	service.NewUserService,
 )
 
-var repositorySet = wire.NewSet(
-	repository.NewDB,
-	repository.NewRedis,
-	repository.NewRepository,
-	repository.NewUserRepository,
+var handlerSet = wire.NewSet(
+	handler.NewHandler,
+	handler.NewUserHandler,
 )
+
 var serverSet = wire.NewSet(
 	server.NewHTTPServer,
 	server.NewJob,

+ 5 - 4
cmd/server/wire/wire_gen.go

@@ -25,11 +25,12 @@ import (
 func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), error) {
 	jwtJWT := jwt.NewJwt(viperViper)
 	handlerHandler := handler.NewHandler(logger)
-	sidSid := sid.NewSid()
-	serviceService := service.NewService(logger, sidSid, jwtJWT)
 	db := repository.NewDB(viperViper, logger)
 	client := repository.NewRedis(viperViper)
 	repositoryRepository := repository.NewRepository(db, client, logger)
+	transaction := repository.NewTransaction(repositoryRepository)
+	sidSid := sid.NewSid()
+	serviceService := service.NewService(transaction, logger, sidSid, jwtJWT)
 	userRepository := repository.NewUserRepository(repositoryRepository)
 	userService := service.NewUserService(serviceService, userRepository)
 	userHandler := handler.NewUserHandler(handlerHandler, userService)
@@ -42,11 +43,11 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 // wire.go:
 
-var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler)
+var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewRepository, repository.NewTransaction, repository.NewUserRepository)
 
 var serviceSet = wire.NewSet(service.NewService, service.NewUserService)
 
-var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewRepository, repository.NewUserRepository)
+var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler)
 
 var serverSet = wire.NewSet(server.NewHTTPServer, server.NewJob, server.NewTask)
 

+ 9 - 13
docs/docs.go

@@ -138,7 +138,8 @@ const docTemplate = `{
             "type": "object",
             "properties": {
                 "nickname": {
-                    "type": "string"
+                    "type": "string",
+                    "example": "alan"
                 },
                 "userId": {
                     "type": "string"
@@ -148,17 +149,17 @@ const docTemplate = `{
         "github_com_go-nunu_nunu-layout-advanced_api_v1.LoginRequest": {
             "type": "object",
             "required": [
-                "password",
-                "username"
+                "email",
+                "password"
             ],
             "properties": {
-                "password": {
+                "email": {
                     "type": "string",
-                    "example": "123456"
+                    "example": "1234@gmail.com"
                 },
-                "username": {
+                "password": {
                     "type": "string",
-                    "example": "alan"
+                    "example": "123456"
                 }
             }
         },
@@ -188,8 +189,7 @@ const docTemplate = `{
             "type": "object",
             "required": [
                 "email",
-                "password",
-                "username"
+                "password"
             ],
             "properties": {
                 "email": {
@@ -199,10 +199,6 @@ const docTemplate = `{
                 "password": {
                     "type": "string",
                     "example": "123456"
-                },
-                "username": {
-                    "type": "string",
-                    "example": "alan"
                 }
             }
         },

+ 9 - 13
docs/swagger.json

@@ -131,7 +131,8 @@
             "type": "object",
             "properties": {
                 "nickname": {
-                    "type": "string"
+                    "type": "string",
+                    "example": "alan"
                 },
                 "userId": {
                     "type": "string"
@@ -141,17 +142,17 @@
         "github_com_go-nunu_nunu-layout-advanced_api_v1.LoginRequest": {
             "type": "object",
             "required": [
-                "password",
-                "username"
+                "email",
+                "password"
             ],
             "properties": {
-                "password": {
+                "email": {
                     "type": "string",
-                    "example": "123456"
+                    "example": "1234@gmail.com"
                 },
-                "username": {
+                "password": {
                     "type": "string",
-                    "example": "alan"
+                    "example": "123456"
                 }
             }
         },
@@ -181,8 +182,7 @@
             "type": "object",
             "required": [
                 "email",
-                "password",
-                "username"
+                "password"
             ],
             "properties": {
                 "email": {
@@ -192,10 +192,6 @@
                 "password": {
                     "type": "string",
                     "example": "123456"
-                },
-                "username": {
-                    "type": "string",
-                    "example": "alan"
                 }
             }
         },

+ 5 - 8
docs/swagger.yaml

@@ -11,21 +11,22 @@ definitions:
   github_com_go-nunu_nunu-layout-advanced_api_v1.GetProfileResponseData:
     properties:
       nickname:
+        example: alan
         type: string
       userId:
         type: string
     type: object
   github_com_go-nunu_nunu-layout-advanced_api_v1.LoginRequest:
     properties:
+      email:
+        example: 1234@gmail.com
+        type: string
       password:
         example: "123456"
         type: string
-      username:
-        example: alan
-        type: string
     required:
+    - email
     - password
-    - username
     type: object
   github_com_go-nunu_nunu-layout-advanced_api_v1.LoginResponse:
     properties:
@@ -49,13 +50,9 @@ definitions:
       password:
         example: "123456"
         type: string
-      username:
-        example: alan
-        type: string
     required:
     - email
     - password
-    - username
     type: object
   github_com_go-nunu_nunu-layout-advanced_api_v1.Response:
     properties:

+ 10 - 17
internal/handler/user.go

@@ -2,31 +2,24 @@ package handler
 
 import (
 	"github.com/gin-gonic/gin"
-	v1 "github.com/go-nunu/nunu-layout-advanced/api/v1"
+	"github.com/go-nunu/nunu-layout-advanced/api/v1"
 	"github.com/go-nunu/nunu-layout-advanced/internal/service"
 	"go.uber.org/zap"
 	"net/http"
 )
 
-type UserHandler interface {
-	Register(ctx *gin.Context)
-	Login(ctx *gin.Context)
-	GetProfile(ctx *gin.Context)
-	UpdateProfile(ctx *gin.Context)
+type UserHandler struct {
+	*Handler
+	userService service.UserService
 }
 
-func NewUserHandler(handler *Handler, userService service.UserService) UserHandler {
-	return &userHandler{
+func NewUserHandler(handler *Handler, userService service.UserService) *UserHandler {
+	return &UserHandler{
 		Handler:     handler,
 		userService: userService,
 	}
 }
 
-type userHandler struct {
-	*Handler
-	userService service.UserService
-}
-
 // Register godoc
 // @Summary 用户注册
 // @Schemes
@@ -37,7 +30,7 @@ type userHandler struct {
 // @Param request body v1.RegisterRequest true "params"
 // @Success 200 {object} v1.Response
 // @Router /register [post]
-func (h *userHandler) Register(ctx *gin.Context) {
+func (h *UserHandler) Register(ctx *gin.Context) {
 	req := new(v1.RegisterRequest)
 	if err := ctx.ShouldBindJSON(req); err != nil {
 		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, nil)
@@ -63,7 +56,7 @@ func (h *userHandler) Register(ctx *gin.Context) {
 // @Param request body v1.LoginRequest true "params"
 // @Success 200 {object} v1.LoginResponse
 // @Router /login [post]
-func (h *userHandler) Login(ctx *gin.Context) {
+func (h *UserHandler) Login(ctx *gin.Context) {
 	var req v1.LoginRequest
 	if err := ctx.ShouldBindJSON(&req); err != nil {
 		v1.HandleError(ctx, http.StatusBadRequest, v1.ErrBadRequest, nil)
@@ -90,7 +83,7 @@ func (h *userHandler) Login(ctx *gin.Context) {
 // @Security Bearer
 // @Success 200 {object} v1.GetProfileResponse
 // @Router /user [get]
-func (h *userHandler) GetProfile(ctx *gin.Context) {
+func (h *UserHandler) GetProfile(ctx *gin.Context) {
 	userId := GetUserIdFromCtx(ctx)
 	if userId == "" {
 		v1.HandleError(ctx, http.StatusUnauthorized, v1.ErrUnauthorized, nil)
@@ -106,7 +99,7 @@ func (h *userHandler) GetProfile(ctx *gin.Context) {
 	v1.HandleSuccess(ctx, user)
 }
 
-func (h *userHandler) UpdateProfile(ctx *gin.Context) {
+func (h *UserHandler) UpdateProfile(ctx *gin.Context) {
 	userId := GetUserIdFromCtx(ctx)
 
 	var req v1.UpdateProfileRequest

+ 3 - 2
internal/middleware/jwt.go

@@ -66,6 +66,7 @@ func NoStrictAuth(j *jwt.JWT, logger *log.Logger) gin.HandlerFunc {
 }
 
 func recoveryLoggerFunc(ctx *gin.Context, logger *log.Logger) {
-	userInfo := ctx.MustGet("claims").(*jwt.MyCustomClaims)
-	logger.NewContext(ctx, zap.String("UserId", userInfo.UserId))
+	if userInfo, ok := ctx.MustGet("claims").(*jwt.MyCustomClaims); ok {
+		logger.WithValue(ctx, zap.String("UserId", userInfo.UserId))
+	}
 }

+ 5 - 5
internal/middleware/log.go

@@ -15,14 +15,14 @@ func RequestLogMiddleware(logger *log.Logger) gin.HandlerFunc {
 	return func(ctx *gin.Context) {
 		// The configuration is initialized once per request
 		trace := md5.Md5(uuid.GenUUID())
-		logger.NewContext(ctx, zap.String("trace", trace))
-		logger.NewContext(ctx, zap.String("request_method", ctx.Request.Method))
-		logger.NewContext(ctx, zap.Any("request_headers", ctx.Request.Header))
-		logger.NewContext(ctx, zap.String("request_url", ctx.Request.URL.String()))
+		logger.WithValue(ctx, zap.String("trace", trace))
+		logger.WithValue(ctx, zap.String("request_method", ctx.Request.Method))
+		logger.WithValue(ctx, zap.Any("request_headers", ctx.Request.Header))
+		logger.WithValue(ctx, zap.String("request_url", ctx.Request.URL.String()))
 		if ctx.Request.Body != nil {
 			bodyBytes, _ := ctx.GetRawData()
 			ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 关键点
-			logger.NewContext(ctx, zap.String("request_params", string(bodyBytes)))
+			logger.WithValue(ctx, zap.String("request_params", string(bodyBytes)))
 		}
 		logger.WithContext(ctx).Info("Request")
 		ctx.Next()

+ 0 - 1
internal/model/user.go

@@ -8,7 +8,6 @@ import (
 type User struct {
 	Id        uint   `gorm:"primarykey"`
 	UserId    string `gorm:"unique;not null"`
-	Username  string `gorm:"unique;not null"`
 	Nickname  string `gorm:"not null"`
 	Password  string `gorm:"not null"`
 	Email     string `gorm:"not null"`

+ 29 - 0
internal/repository/repository.go

@@ -12,6 +12,8 @@ import (
 	"time"
 )
 
+const ctxTxKey = "TxKey"
+
 type Repository struct {
 	db     *gorm.DB
 	rdb    *redis.Client
@@ -26,6 +28,33 @@ func NewRepository(db *gorm.DB, rdb *redis.Client, logger *log.Logger) *Reposito
 	}
 }
 
+type Transaction interface {
+	Transaction(ctx context.Context, fn func(ctx context.Context) error) error
+}
+
+func NewTransaction(r *Repository) Transaction {
+	return r
+}
+
+// DB return tx
+// If you need to create a Transaction, you must call DB(ctx) and Transaction(ctx,fn)
+func (r *Repository) DB(ctx context.Context) *gorm.DB {
+	v := ctx.Value(ctxTxKey)
+	if v != nil {
+		if tx, ok := v.(*gorm.DB); ok {
+			return tx
+		}
+	}
+	return r.db.WithContext(ctx)
+}
+
+func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+		ctx = context.WithValue(ctx, ctxTxKey, tx)
+		return fn(ctx)
+	})
+}
+
 func NewDB(conf *viper.Viper, l *log.Logger) *gorm.DB {
 	logger := zapgorm2.New(l.Logger)
 	logger.SetAsDefault()

+ 6 - 9
internal/repository/user.go

@@ -12,7 +12,7 @@ type UserRepository interface {
 	Create(ctx context.Context, user *model.User) error
 	Update(ctx context.Context, user *model.User) error
 	GetByID(ctx context.Context, id string) (*model.User, error)
-	GetByUsername(ctx context.Context, username string) (*model.User, error)
+	GetByEmail(ctx context.Context, email string) (*model.User, error)
 }
 
 func NewUserRepository(r *Repository) UserRepository {
@@ -26,40 +26,37 @@ type userRepository struct {
 }
 
 func (r *userRepository) Create(ctx context.Context, user *model.User) error {
-	if err := r.db.Create(user).Error; err != nil {
+	if err := r.DB(ctx).Create(user).Error; err != nil {
 		return err
 	}
 	return nil
 }
 
 func (r *userRepository) Update(ctx context.Context, user *model.User) error {
-	if err := r.db.Save(user).Error; err != nil {
+	if err := r.DB(ctx).Save(user).Error; err != nil {
 		return err
 	}
-
 	return nil
 }
 
 func (r *userRepository) GetByID(ctx context.Context, userId string) (*model.User, error) {
 	var user model.User
-	if err := r.db.Where("user_id = ?", userId).First(&user).Error; err != nil {
+	if err := r.DB(ctx).Where("user_id = ?", userId).First(&user).Error; err != nil {
 		if errors.Is(err, gorm.ErrRecordNotFound) {
 			return nil, v1.ErrNotFound
 		}
 		return nil, err
 	}
-
 	return &user, nil
 }
 
-func (r *userRepository) GetByUsername(ctx context.Context, username string) (*model.User, error) {
+func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
 	var user model.User
-	if err := r.db.Where("username = ?", username).First(&user).Error; err != nil {
+	if err := r.DB(ctx).Where("email = ?", email).First(&user).Error; err != nil {
 		if errors.Is(err, gorm.ErrRecordNotFound) {
 			return nil, nil
 		}
 		return nil, err
 	}
-
 	return &user, nil
 }

+ 1 - 1
internal/server/http.go

@@ -18,7 +18,7 @@ func NewHTTPServer(
 	logger *log.Logger,
 	conf *viper.Viper,
 	jwt *jwt.JWT,
-	userHandler handler.UserHandler,
+	userHandler *handler.UserHandler,
 ) *http.Server {
 	gin.SetMode(gin.DebugMode)
 	s := http.NewServer(

+ 4 - 1
internal/service/service.go

@@ -1,6 +1,7 @@
 package service
 
 import (
+	"github.com/go-nunu/nunu-layout-advanced/internal/repository"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/helper/sid"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/jwt"
 	"github.com/go-nunu/nunu-layout-advanced/pkg/log"
@@ -10,12 +11,14 @@ type Service struct {
 	logger *log.Logger
 	sid    *sid.Sid
 	jwt    *jwt.JWT
+	tm     repository.Transaction
 }
 
-func NewService(logger *log.Logger, sid *sid.Sid, jwt *jwt.JWT) *Service {
+func NewService(tm repository.Transaction, logger *log.Logger, sid *sid.Sid, jwt *jwt.JWT) *Service {
 	return &Service{
 		logger: logger,
 		sid:    sid,
 		jwt:    jwt,
+		tm:     tm,
 	}
 }

+ 14 - 13
internal/service/user.go

@@ -30,8 +30,8 @@ type userService struct {
 
 func (s *userService) Register(ctx context.Context, req *v1.RegisterRequest) error {
 	// check username
-	if user, err := s.userRepo.GetByUsername(ctx, req.Username); err == nil && user != nil {
-		return v1.ErrUsernameAlreadyUse
+	if user, err := s.userRepo.GetByEmail(ctx, req.Email); err == nil && user != nil {
+		return v1.ErrEmailAlreadyUse
 	}
 
 	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
@@ -43,23 +43,25 @@ func (s *userService) Register(ctx context.Context, req *v1.RegisterRequest) err
 	if err != nil {
 		return err
 	}
-	// Create a user
 	user := &model.User{
 		UserId:   userId,
-		Username: req.Username,
-		Nickname: req.Username,
-		Password: string(hashedPassword),
 		Email:    req.Email,
+		Password: string(hashedPassword),
 	}
-	if err = s.userRepo.Create(ctx, user); err != nil {
-		return err
-	}
-
-	return nil
+	// Transaction demo
+	err = s.tm.Transaction(ctx, func(ctx context.Context) error {
+		// Create a user
+		if err = s.userRepo.Create(ctx, user); err != nil {
+			return err
+		}
+		// TODO: other repo
+		return nil
+	})
+	return err
 }
 
 func (s *userService) Login(ctx context.Context, req *v1.LoginRequest) (string, error) {
-	user, err := s.userRepo.GetByUsername(ctx, req.Username)
+	user, err := s.userRepo.GetByEmail(ctx, req.Email)
 	if err != nil || user == nil {
 		return "", v1.ErrUnauthorized
 	}
@@ -85,7 +87,6 @@ func (s *userService) GetProfile(ctx context.Context, userId string) (*v1.GetPro
 	return &v1.GetProfileResponseData{
 		UserId:   user.UserId,
 		Nickname: user.Nickname,
-		Username: user.Username,
 	}, nil
 }
 

+ 14 - 13
pkg/log/log.go

@@ -1,6 +1,7 @@
 package log
 
 import (
+	"context"
 	"github.com/gin-gonic/gin"
 	"github.com/spf13/viper"
 	"go.uber.org/zap"
@@ -10,17 +11,13 @@ import (
 	"time"
 )
 
-const LOGGER_KEY = "zapLogger"
+const ctxLoggerKey = "zapLogger"
 
 type Logger struct {
 	*zap.Logger
 }
 
 func NewLog(conf *viper.Viper) *Logger {
-	return initZap(conf)
-}
-
-func initZap(conf *viper.Viper) *Logger {
 	// log address "out.log" User-defined
 	lp := conf.GetString("log.log_file_name")
 	lv := conf.GetString("log.log_level")
@@ -86,7 +83,6 @@ func initZap(conf *viper.Viper) *Logger {
 		return &Logger{zap.New(core, zap.Development(), zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel))}
 	}
 	return &Logger{zap.New(core, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel))}
-
 }
 
 func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
@@ -94,17 +90,22 @@ func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
 	enc.AppendString(t.Format("2006-01-02 15:04:05.000000000"))
 }
 
-// NewContext Adds a field to the specified context
-func (l *Logger) NewContext(ctx *gin.Context, fields ...zapcore.Field) {
-	ctx.Set(LOGGER_KEY, l.WithContext(ctx).With(fields...))
+// WithValue Adds a field to the specified context
+func (l *Logger) WithValue(ctx context.Context, fields ...zapcore.Field) context.Context {
+	if c, ok := ctx.(*gin.Context); ok {
+		ctx = c.Request.Context()
+		c.Request = c.Request.WithContext(context.WithValue(ctx, ctxLoggerKey, l.WithContext(ctx).With(fields...)))
+		return c
+	}
+	return context.WithValue(ctx, ctxLoggerKey, l.WithContext(ctx).With(fields...))
 }
 
 // WithContext Returns a zap instance from the specified context
-func (l *Logger) WithContext(ctx *gin.Context) *Logger {
-	if ctx == nil {
-		return l
+func (l *Logger) WithContext(ctx context.Context) *Logger {
+	if c, ok := ctx.(*gin.Context); ok {
+		ctx = c.Request.Context()
 	}
-	zl, _ := ctx.Get(LOGGER_KEY)
+	zl := ctx.Value(ctxLoggerKey)
 	ctxLogger, ok := zl.(*zap.Logger)
 	if ok {
 		return &Logger{ctxLogger}

+ 49 - 0
test/mocks/repository/repository.go

@@ -0,0 +1,49 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: internal/repository/repository.go
+
+// Package mock_repository is a generated GoMock package.
+package mock_repository
+
+import (
+	context "context"
+	reflect "reflect"
+
+	gomock "github.com/golang/mock/gomock"
+)
+
+// MockTransaction is a mock of Transaction interface.
+type MockTransaction struct {
+	ctrl     *gomock.Controller
+	recorder *MockTransactionMockRecorder
+}
+
+// MockTransactionMockRecorder is the mock recorder for MockTransaction.
+type MockTransactionMockRecorder struct {
+	mock *MockTransaction
+}
+
+// NewMockTransaction creates a new mock instance.
+func NewMockTransaction(ctrl *gomock.Controller) *MockTransaction {
+	mock := &MockTransaction{ctrl: ctrl}
+	mock.recorder = &MockTransactionMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder {
+	return m.recorder
+}
+
+// Transaction mocks base method.
+func (m *MockTransaction) Transaction(ctx context.Context, fn func(context.Context) error) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "Transaction", ctx, fn)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// Transaction indicates an expected call of Transaction.
+func (mr *MockTransactionMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockTransaction)(nil).Transaction), ctx, fn)
+}

+ 12 - 12
test/mocks/repository/user.go

@@ -49,34 +49,34 @@ func (mr *MockUserRepositoryMockRecorder) Create(ctx, user interface{}) *gomock.
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockUserRepository)(nil).Create), ctx, user)
 }
 
-// GetByID mocks base method.
-func (m *MockUserRepository) GetByID(ctx context.Context, id string) (*model.User, error) {
+// GetByEmail mocks base method.
+func (m *MockUserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetByID", ctx, id)
+	ret := m.ctrl.Call(m, "GetByEmail", ctx, email)
 	ret0, _ := ret[0].(*model.User)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
-// GetByID indicates an expected call of GetByID.
-func (mr *MockUserRepositoryMockRecorder) GetByID(ctx, id interface{}) *gomock.Call {
+// GetByEmail indicates an expected call of GetByEmail.
+func (mr *MockUserRepositoryMockRecorder) GetByEmail(ctx, email interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockUserRepository)(nil).GetByID), ctx, id)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByEmail", reflect.TypeOf((*MockUserRepository)(nil).GetByEmail), ctx, email)
 }
 
-// GetByUsername mocks base method.
-func (m *MockUserRepository) GetByUsername(ctx context.Context, username string) (*model.User, error) {
+// GetByID mocks base method.
+func (m *MockUserRepository) GetByID(ctx context.Context, id string) (*model.User, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetByUsername", ctx, username)
+	ret := m.ctrl.Call(m, "GetByID", ctx, id)
 	ret0, _ := ret[0].(*model.User)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
-// GetByUsername indicates an expected call of GetByUsername.
-func (mr *MockUserRepositoryMockRecorder) GetByUsername(ctx, username interface{}) *gomock.Call {
+// GetByID indicates an expected call of GetByID.
+func (mr *MockUserRepositoryMockRecorder) GetByID(ctx, id interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByUsername", reflect.TypeOf((*MockUserRepository)(nil).GetByUsername), ctx, username)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockUserRepository)(nil).GetByID), ctx, id)
 }
 
 // Update mocks base method.

+ 1 - 4
test/server/handler/user_test.go

@@ -66,7 +66,6 @@ func TestUserHandler_Register(t *testing.T) {
 	defer ctrl.Finish()
 
 	params := v1.RegisterRequest{
-		Username: "xxx",
 		Password: "123456",
 		Email:    "xxx@gmail.com",
 	}
@@ -90,7 +89,7 @@ func TestUserHandler_Login(t *testing.T) {
 	defer ctrl.Finish()
 
 	params := v1.LoginRequest{
-		Username: "xxx",
+		Email:    "xxx@gmail.com",
 		Password: "123456",
 	}
 
@@ -114,7 +113,6 @@ func TestUserHandler_GetProfile(t *testing.T) {
 	mockUserService := mock_service.NewMockUserService(ctrl)
 	mockUserService.EXPECT().GetProfile(gomock.Any(), userId).Return(&v1.GetProfileResponseData{
 		UserId:   userId,
-		Username: "xxxxx",
 		Nickname: "xxxxx",
 	}, nil)
 
@@ -138,7 +136,6 @@ func TestUserHandler_UpdateProfile(t *testing.T) {
 	params := v1.UpdateProfileRequest{
 		Nickname: "alan",
 		Email:    "alan@gmail.com",
-		Avatar:   "xxx",
 	}
 
 	mockUserService := mock_service.NewMockUserService(ctrl)

+ 4 - 6
test/server/repository/user_test.go

@@ -43,7 +43,6 @@ func TestUserRepository_Create(t *testing.T) {
 	user := &model.User{
 		Id:        1,
 		UserId:    "123",
-		Username:  "test",
 		Nickname:  "Test",
 		Password:  "password",
 		Email:     "test@example.com",
@@ -53,7 +52,7 @@ func TestUserRepository_Create(t *testing.T) {
 
 	mock.ExpectBegin()
 	mock.ExpectExec("INSERT INTO `users`").
-		WithArgs(user.UserId, user.Username, user.Nickname, user.Password, user.Email, user.CreatedAt, user.UpdatedAt, user.DeletedAt, user.Id).
+		WithArgs(user.UserId, user.Nickname, user.Password, user.Email, user.CreatedAt, user.UpdatedAt, user.DeletedAt, user.Id).
 		WillReturnResult(sqlmock.NewResult(1, 1))
 	mock.ExpectCommit()
 
@@ -70,7 +69,6 @@ func TestUserRepository_Update(t *testing.T) {
 	user := &model.User{
 		Id:        1,
 		UserId:    "123",
-		Username:  "test",
 		Nickname:  "Test",
 		Password:  "password",
 		Email:     "test@example.com",
@@ -110,16 +108,16 @@ func TestUserRepository_GetByUsername(t *testing.T) {
 	userRepo, mock := setupRepository(t)
 
 	ctx := context.Background()
-	username := "test"
+	email := "test@example.com"
 
 	rows := sqlmock.NewRows([]string{"id", "user_id", "username", "nickname", "password", "email", "created_at", "updated_at"}).
 		AddRow(1, "123", "test", "Test", "password", "test@example.com", time.Now(), time.Now())
 	mock.ExpectQuery("SELECT \\* FROM `users`").WillReturnRows(rows)
 
-	user, err := userRepo.GetByUsername(ctx, username)
+	user, err := userRepo.GetByEmail(ctx, email)
 	assert.NoError(t, err)
 	assert.NotNil(t, user)
-	assert.Equal(t, "test", user.Username)
+	assert.Equal(t, "test@example.com", user.Email)
 
 	assert.NoError(t, mock.ExpectationsWereMet())
 }

+ 31 - 27
test/server/service/user_test.go

@@ -22,7 +22,9 @@ import (
 )
 
 var (
-	srv *service.Service
+	logger *log.Logger
+	j      *jwt.JWT
+	sf     *sid.Sid
 )
 
 func TestMain(m *testing.M) {
@@ -37,10 +39,9 @@ func TestMain(m *testing.M) {
 	flag.Parse()
 	conf := config.NewConfig(*envConf)
 
-	logger := log.NewLog(conf)
-	jwt := jwt.NewJwt(conf)
-	sf := sid.NewSid()
-	srv = service.NewService(logger, sf, jwt)
+	logger = log.NewLog(conf)
+	j = jwt.NewJwt(conf)
+	sf = sid.NewSid()
 
 	code := m.Run()
 	fmt.Println("test end")
@@ -53,18 +54,19 @@ func TestUserService_Register(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()
 	req := &v1.RegisterRequest{
-		Username: "testuser",
 		Password: "password",
 		Email:    "test@example.com",
 	}
 
-	mockUserRepo.EXPECT().GetByUsername(ctx, req.Username).Return(nil, nil)
-	mockUserRepo.EXPECT().Create(ctx, gomock.Any()).Return(nil)
+	mockUserRepo.EXPECT().GetByEmail(ctx, req.Email).Return(nil, nil)
+	mockTm.EXPECT().Transaction(ctx, gomock.Any()).Return(nil)
 
 	err := userService.Register(ctx, req)
 
@@ -76,17 +78,17 @@ func TestUserService_Register_UsernameExists(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
-
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()
 	req := &v1.RegisterRequest{
-		Username: "testuser",
 		Password: "password",
 		Email:    "test@example.com",
 	}
 
-	mockUserRepo.EXPECT().GetByUsername(ctx, req.Username).Return(&model.User{}, nil)
+	mockUserRepo.EXPECT().GetByEmail(ctx, req.Email).Return(&model.User{}, nil)
 
 	err := userService.Register(ctx, req)
 
@@ -98,12 +100,13 @@ func TestUserService_Login(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
-
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()
 	req := &v1.LoginRequest{
-		Username: "testuser",
+		Email:    "xxx@gmail.com",
 		Password: "password",
 	}
 	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
@@ -111,7 +114,7 @@ func TestUserService_Login(t *testing.T) {
 		t.Error("failed to hash password")
 	}
 
-	mockUserRepo.EXPECT().GetByUsername(ctx, req.Username).Return(&model.User{
+	mockUserRepo.EXPECT().GetByEmail(ctx, req.Email).Return(&model.User{
 		Password: string(hashedPassword),
 	}, nil)
 
@@ -126,16 +129,17 @@ func TestUserService_Login_UserNotFound(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
-
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()
 	req := &v1.LoginRequest{
-		Username: "testuser",
+		Email:    "xxx@gmail.com",
 		Password: "password",
 	}
 
-	mockUserRepo.EXPECT().GetByUsername(ctx, req.Username).Return(nil, errors.New("user not found"))
+	mockUserRepo.EXPECT().GetByEmail(ctx, req.Email).Return(nil, errors.New("user not found"))
 
 	_, err := userService.Login(ctx, req)
 
@@ -147,23 +151,22 @@ func TestUserService_GetProfile(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
-
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()
 	userId := "123"
 
 	mockUserRepo.EXPECT().GetByID(ctx, userId).Return(&model.User{
-		UserId:   userId,
-		Username: "testuser",
-		Email:    "test@example.com",
+		UserId: userId,
+		Email:  "test@example.com",
 	}, nil)
 
 	user, err := userService.GetProfile(ctx, userId)
 
 	assert.NoError(t, err)
 	assert.Equal(t, userId, user.UserId)
-	assert.Equal(t, "testuser", user.Username)
 }
 
 func TestUserService_UpdateProfile(t *testing.T) {
@@ -171,7 +174,8 @@ func TestUserService_UpdateProfile(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
-
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()
@@ -182,9 +186,8 @@ func TestUserService_UpdateProfile(t *testing.T) {
 	}
 
 	mockUserRepo.EXPECT().GetByID(ctx, userId).Return(&model.User{
-		UserId:   userId,
-		Username: "testuser",
-		Email:    "old@example.com",
+		UserId: userId,
+		Email:  "old@example.com",
 	}, nil)
 	mockUserRepo.EXPECT().Update(ctx, gomock.Any()).Return(nil)
 
@@ -198,7 +201,8 @@ func TestUserService_UpdateProfile_UserNotFound(t *testing.T) {
 	defer ctrl.Finish()
 
 	mockUserRepo := mock_repository.NewMockUserRepository(ctrl)
-
+	mockTm := mock_repository.NewMockTransaction(ctrl)
+	srv := service.NewService(mockTm, logger, sf, j)
 	userService := service.NewUserService(srv, mockUserRepo)
 
 	ctx := context.Background()