jwt.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package middleware
  2. import (
  3. "net/http"
  4. "regexp"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/go-nunu/nunu-layout-advanced/pkg/helper/resp"
  8. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  9. "github.com/golang-jwt/jwt/v5"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. )
  13. type JWT struct {
  14. key []byte
  15. }
  16. type MyCustomClaims struct {
  17. UserId string
  18. jwt.RegisteredClaims
  19. }
  20. func NewJwt(conf *viper.Viper) *JWT {
  21. return &JWT{key: []byte(conf.GetString("security.jwt.key"))}
  22. }
  23. func (j *JWT) GenToken(userId string, expiresAt time.Time) (string, error) {
  24. token := jwt.NewWithClaims(jwt.SigningMethodHS256, MyCustomClaims{
  25. UserId: userId,
  26. RegisteredClaims: jwt.RegisteredClaims{
  27. ExpiresAt: jwt.NewNumericDate(expiresAt),
  28. IssuedAt: jwt.NewNumericDate(time.Now()),
  29. NotBefore: jwt.NewNumericDate(time.Now()),
  30. Issuer: "",
  31. Subject: "",
  32. ID: "",
  33. Audience: []string{},
  34. },
  35. })
  36. // Sign and get the complete encoded token as a string using the key
  37. tokenString, err := token.SignedString(j.key)
  38. if err != nil {
  39. return "", err
  40. }
  41. return tokenString, nil
  42. }
  43. func (j *JWT) ParseToken(tokenString string) (*MyCustomClaims, error) {
  44. re := regexp.MustCompile(`(?i)Bearer `)
  45. tokenString = re.ReplaceAllString(tokenString, "")
  46. token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
  47. return j.key, nil
  48. })
  49. if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
  50. return claims, nil
  51. } else {
  52. return nil, err
  53. }
  54. }
  55. func StrictAuth(j *JWT, logger *log.Logger) gin.HandlerFunc {
  56. return func(ctx *gin.Context) {
  57. tokenString := ctx.Request.Header.Get("Authorization")
  58. if tokenString == "" {
  59. logger.WithContext(ctx).Warn("请求未携带token,无权限访问", zap.Any("data", map[string]interface{}{
  60. "url": ctx.Request.URL,
  61. "params": ctx.Params,
  62. }))
  63. resp.HandleError(ctx, http.StatusUnauthorized, 1, "no token", nil)
  64. ctx.Abort()
  65. return
  66. }
  67. claims, err := j.ParseToken(tokenString)
  68. if err != nil {
  69. logger.WithContext(ctx).Error("token error", zap.Any("data", map[string]interface{}{
  70. "url": ctx.Request.URL,
  71. "params": ctx.Params,
  72. }))
  73. resp.HandleError(ctx, http.StatusUnauthorized, 1, err.Error(), nil)
  74. ctx.Abort()
  75. return
  76. }
  77. ctx.Set("claims", claims)
  78. recoveryLoggerFunc(ctx, logger)
  79. ctx.Next()
  80. }
  81. }
  82. func NoStrictAuth(j *JWT, logger *log.Logger) gin.HandlerFunc {
  83. return func(ctx *gin.Context) {
  84. tokenString := ctx.Request.Header.Get("Authorization")
  85. if tokenString == "" {
  86. tokenString, _ = ctx.Cookie("accessToken")
  87. }
  88. if tokenString == "" {
  89. tokenString = ctx.Query("accessToken")
  90. }
  91. if tokenString == "" {
  92. ctx.Next()
  93. return
  94. }
  95. claims, err := j.ParseToken(tokenString)
  96. if err != nil {
  97. ctx.Next()
  98. return
  99. }
  100. ctx.Set("claims", claims)
  101. recoveryLoggerFunc(ctx, logger)
  102. ctx.Next()
  103. }
  104. }
  105. func recoveryLoggerFunc(ctx *gin.Context, logger *log.Logger) {
  106. userInfo := ctx.MustGet("claims").(*MyCustomClaims)
  107. logger.NewContext(ctx, zap.String("UserId", userInfo.UserId))
  108. }