jwt.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package middleware
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/go-nunu/nunu-layout-advanced/pkg/helper/resp"
  5. "github.com/go-nunu/nunu-layout-advanced/pkg/log"
  6. "github.com/golang-jwt/jwt/v5"
  7. "github.com/spf13/viper"
  8. "go.uber.org/zap"
  9. "net/http"
  10. "regexp"
  11. "time"
  12. )
  13. type JWT struct {
  14. key []byte
  15. }
  16. type MyCustomClaims struct {
  17. UserId string
  18. jwt.RegisteredClaims
  19. }
  20. // NewJwt https://pkg.go.dev/github.com/golang-jwt/jwt/v5
  21. func NewJwt(conf *viper.Viper) *JWT {
  22. return &JWT{key: []byte(conf.GetString("security.jwt.key"))}
  23. }
  24. func (j *JWT) GenToken(userId string, expiresAt time.Time) string {
  25. token := jwt.NewWithClaims(jwt.SigningMethodHS256, MyCustomClaims{
  26. UserId: userId,
  27. RegisteredClaims: jwt.RegisteredClaims{
  28. ExpiresAt: jwt.NewNumericDate(expiresAt),
  29. IssuedAt: jwt.NewNumericDate(time.Now()),
  30. NotBefore: jwt.NewNumericDate(time.Now()),
  31. Issuer: "",
  32. Subject: "",
  33. ID: "",
  34. Audience: []string{},
  35. },
  36. })
  37. // Sign and get the complete encoded token as a string using the key
  38. tokenString, err := token.SignedString(j.key)
  39. if err != nil {
  40. return ""
  41. }
  42. return tokenString
  43. }
  44. func (j *JWT) ParseToken(tokenString string) (*MyCustomClaims, error) {
  45. re, _ := regexp.Compile(`(?i)Bearer `)
  46. tokenString = re.ReplaceAllString(tokenString, "")
  47. token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
  48. return j.key, nil
  49. })
  50. if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
  51. return claims, nil
  52. } else {
  53. return nil, err
  54. }
  55. }
  56. // StrictAuth 严格权限
  57. func StrictAuth(j *JWT, logger *log.Logger) gin.HandlerFunc {
  58. return func(ctx *gin.Context) {
  59. tokenString := ctx.Request.Header.Get("Authorization")
  60. if tokenString == "" {
  61. logger.WithContext(ctx).Warn("请求未携带token,无权限访问", zap.Any("data", map[string]interface{}{
  62. "url": ctx.Request.URL,
  63. "params": ctx.Params,
  64. }))
  65. resp.HandleError(ctx, http.StatusUnauthorized, 1, "no token", nil)
  66. ctx.Abort()
  67. return
  68. }
  69. // parseToken 解析token包含的信息
  70. claims, err := j.ParseToken(tokenString)
  71. if err != nil {
  72. logger.WithContext(ctx).Error("token error", zap.Any("data", map[string]interface{}{
  73. "url": ctx.Request.URL,
  74. "params": ctx.Params,
  75. }))
  76. resp.HandleError(ctx, http.StatusUnauthorized, 1, err.Error(), nil)
  77. ctx.Abort()
  78. return
  79. }
  80. // 继续交由下一个路由处理,并将解析出的信息传递下去
  81. ctx.Set("claims", claims)
  82. recoveryLoggerFunc(ctx, logger)
  83. ctx.Next()
  84. }
  85. }
  86. func NoStrictAuth(j *JWT, logger *log.Logger) gin.HandlerFunc {
  87. return func(ctx *gin.Context) {
  88. tokenString := ctx.Request.Header.Get("Authorization")
  89. if tokenString == "" {
  90. tokenString, _ = ctx.Cookie("accessToken")
  91. }
  92. if tokenString == "" {
  93. tokenString = ctx.Query("accessToken")
  94. }
  95. if tokenString == "" {
  96. ctx.Next()
  97. return
  98. }
  99. // parseToken 解析token包含的信息
  100. claims, err := j.ParseToken(tokenString)
  101. if err != nil {
  102. ctx.Next()
  103. return
  104. }
  105. // 继续交由下一个路由处理,并将解析出的信息传递下去
  106. ctx.Set("claims", claims)
  107. recoveryLoggerFunc(ctx, logger)
  108. ctx.Next()
  109. }
  110. }
  111. func recoveryLoggerFunc(ctx *gin.Context, logger *log.Logger) {
  112. userInfo := ctx.MustGet("claims").(*MyCustomClaims)
  113. logger.NewContext(ctx, zap.String("UserId", userInfo.UserId))
  114. }