diff --git a/go.mod b/go.mod index 97c77aa..c5b1d5e 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/bytedance/sonic v1.13.2 // indirect github.com/bytedance/sonic/loader v0.2.4 // indirect github.com/cloudwego/base64x v0.1.5 // indirect + github.com/cristalhq/base64 v0.1.2 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/gabriel-vasile/mimetype v1.4.9 // indirect github.com/gin-contrib/sse v1.1.0 // indirect diff --git a/go.sum b/go.sum index e3ca063..59c9bb3 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCy github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/cristalhq/base64 v0.1.2 h1:edsefYyYDiac7Ytdh2xdaiiSSJzcI2f0yIkdGEf1qY0= +github.com/cristalhq/base64 v0.1.2/go.mod h1:sy4+2Hale2KbtSqkzpdMeYTP/IrB+HCvxVHWsh2VSYk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/api/routes/auth.go b/internal/api/routes/auth.go index fdc79bb..2c98b6c 100644 --- a/internal/api/routes/auth.go +++ b/internal/api/routes/auth.go @@ -18,11 +18,16 @@ package routes import ( + "errors" + "fmt" "net/http" "time" "github.com/gin-gonic/gin" - "stereo.cat/backend/internal/auth/token" + "github.com/golang-jwt/jwt/v5" + "stereo.cat/backend/internal/auth" + "stereo.cat/backend/internal/auth/session" + "stereo.cat/backend/internal/auth/ukey" "stereo.cat/backend/internal/types" ) @@ -42,7 +47,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { panic(err) } - jwt, err := token.GenerateJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix())) + jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix())) if err != nil { panic(err) @@ -64,8 +69,30 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"?jwt_set=true") }) - api.GET("/auth/me", token.JwtMiddleware(cfg.JWTSecret), func(c *gin.Context) { - claims, _ := c.Get("claims") + api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { + claims := c.MustGet("claims") c.JSON(http.StatusOK, claims) }) + + // Generate an API key (automatically revokes previous api key too since a user can only have one api key bound to their db entry at a given time) + api.GET("/auth/key", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { + claims := c.MustGet("claims").(jwt.MapClaims) + + user, ok := claims["user"].(auth.User) + if !ok { + types.ErrorUserNotFound.Throw(c, errors.New(fmt.Sprintf("got data with type %T but wanted claims.User", claims["user"]))) + return + } + + key := ukey.GenerateUploadKey(cfg, &user, c) + + if key == nil { + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "key": key, + }) + }) } diff --git a/internal/api/routes/files.go b/internal/api/routes/files.go index 98145ef..2aed6fe 100644 --- a/internal/api/routes/files.go +++ b/internal/api/routes/files.go @@ -29,7 +29,7 @@ import ( "github.com/h2non/filetype" "github.com/minio/minio-go/v7" "stereo.cat/backend/internal/auth" - "stereo.cat/backend/internal/auth/token" + "stereo.cat/backend/internal/auth/session" "stereo.cat/backend/internal/types" ) @@ -38,7 +38,7 @@ func intoReader(buf []byte) io.Reader { } func RegisterFileRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { - api.POST("/upload", token.JwtMiddleware(cfg.JWTSecret), func(c *gin.Context) { + api.POST("/upload", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { claims := c.MustGet("claims").(jwt.MapClaims) user := claims["user"].(auth.User) @@ -106,7 +106,7 @@ func RegisterFileRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { c.JSON(200, gin.H{"message": "file uploaded successfully", "id": fileMeta.ID.String()}) }) - api.DELETE("/:id", token.JwtMiddleware(cfg.JWTSecret), func(c *gin.Context) { + api.DELETE("/:id", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { claims := c.MustGet("claims").(jwt.MapClaims) user := claims["user"].(auth.User) @@ -175,7 +175,7 @@ func RegisterFileRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { c.DataFromReader(200, file.Size, file.Mime, object, nil) }) - api.GET("/list", token.JwtMiddleware(cfg.JWTSecret), func(c *gin.Context) { + api.GET("/list", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { claims := c.MustGet("claims").(jwt.MapClaims) user := claims["user"].(auth.User) diff --git a/internal/auth/token/jwt.go b/internal/auth/session/session.go similarity index 84% rename from internal/auth/token/jwt.go rename to internal/auth/session/session.go index 628cb92..25b65ca 100644 --- a/internal/auth/token/jwt.go +++ b/internal/auth/session/session.go @@ -15,7 +15,7 @@ along with this program. If not, see . */ -package token +package session import ( "encoding/json" @@ -28,8 +28,8 @@ import ( "stereo.cat/backend/internal/types" ) -func GenerateJWT(key string, user auth.User, expiryTimestamp uint64) (string, error) { - claims := auth.Claims{ +func GenerateSessionJWT(key string, user auth.User, expiryTimestamp uint64) (string, error) { + claims := auth.SessionClaims{ User: user, Exp: expiryTimestamp, } @@ -39,7 +39,7 @@ func GenerateJWT(key string, user auth.User, expiryTimestamp uint64) (string, er } -func JwtMiddleware(secret string) gin.HandlerFunc { +func SessionMiddleware(secret string) gin.HandlerFunc { return func(c *gin.Context) { jwt, err := c.Cookie("jwt") if err != nil { @@ -54,7 +54,7 @@ func JwtMiddleware(secret string) gin.HandlerFunc { jwt = jwtSplit[1] } - claims, err := ValidateJWT(jwt, secret) + claims, err := ValidateSession(jwt, secret) if err != nil { types.ErrorUnauthorized.Throw(c, err) return @@ -82,7 +82,7 @@ func JwtMiddleware(secret string) gin.HandlerFunc { } } -func ValidateJWT(jwtString, key string) (jwt.MapClaims, error) { +func ValidateSession(jwtString, key string) (jwt.MapClaims, error) { token, err := jwt.Parse(jwtString, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Invalid signing method!") diff --git a/internal/auth/types.go b/internal/auth/types.go index 8b42642..3f2051c 100644 --- a/internal/auth/types.go +++ b/internal/auth/types.go @@ -37,6 +37,7 @@ type User struct { Blacklisted bool `json:"blacklisted"` Email string `json:"email"` CreatedAt time.Time `json:"created_at"` + HashedApiKey string `json:"hashed_api_key"` } type AvatarDecorationData struct { @@ -50,7 +51,7 @@ type ExchangeCodeRequest struct { RedirectUri string `json:"redirect_uri"` } -type Claims struct { +type SessionClaims struct { User User `json:"user"` Exp uint64 `json:"exp"` jwt.RegisteredClaims diff --git a/internal/auth/ukey/ukey.go b/internal/auth/ukey/ukey.go new file mode 100644 index 0000000..f250ec7 --- /dev/null +++ b/internal/auth/ukey/ukey.go @@ -0,0 +1,52 @@ +package ukey + +import ( + "crypto/rand" + "math/big" + + "github.com/cristalhq/base64" + "github.com/gin-gonic/gin" + "golang.org/x/crypto/blake2b" + "stereo.cat/backend/internal/auth" + "stereo.cat/backend/internal/types" +) + +func GenerateUploadKey(cfg *types.StereoConfig, user *auth.User, c *gin.Context) []byte { + length := 32 + chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ123456789@#!&*%~?" + + key := make([]byte, length) + for i := range length { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars)))) + if err != nil { + types.ErrorInvalidParams.Throw(c, err) + return nil + } + + key[i] = chars[num.Int64()] + } + + hasher, err := blake2b.New512(nil) + if err != nil { + types.ErrorInvalidParams.Throw(c, err) + return nil + } + + _, err = hasher.Write(key) + if err != nil { + types.ErrorInvalidParams.Throw(c, err) + return nil + } + + hashed := base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) + + user.HashedApiKey = hashed + + err = cfg.Database.Updates(user).Error + if err != nil { + types.ErrorDatabase.Throw(c, err) + return nil + } + + return key +}