diff --git a/README.md b/README.md index fb342cc..b58e450 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ # stereo.cat backend written in Go, uses Gin. + +## database shit + +Instead of using Discord oAuth as a database, we instead use it as a login source, only using it to source a username/id, avatar data and a secure login/registration flow. +We store these attributes alongside stereo.cat specific attributes in our own database. There is a trade-off however: this means that avatar & username data is not updated in real-time, only when the oauth flow is executed. diff --git a/internal/api/routes/auth.go b/internal/api/routes/auth.go index a6a634f..53082dc 100644 --- a/internal/api/routes/auth.go +++ b/internal/api/routes/auth.go @@ -31,11 +31,21 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { panic(err) } - c.String(http.StatusOK, jwt) + res := cfg.Database.FirstOrCreate(&user) + + if res.Error != nil { + panic(res.Error) + } + + // TODO: redirect to dashboard + c.JSON(http.StatusOK, gin.H{ + "jwt": jwt, + "known": res.RowsAffected == 0, + }) }) api.GET("/auth/me", auth.JwtMiddleware(cfg.JWTSecret), func(c *gin.Context) { - claims, _ := c.Get("claims") + claims, _ := c.Get("claims") c.JSON(http.StatusOK, claims) }) } diff --git a/internal/auth/client/client.go b/internal/auth/client/client.go index 9a415f5..3a922bc 100644 --- a/internal/auth/client/client.go +++ b/internal/auth/client/client.go @@ -8,6 +8,8 @@ import ( "net/http" "net/url" "strings" + "time" + "stereo.cat/backend/internal/auth" ) @@ -28,9 +30,10 @@ func New(redirectUri, clientId, clientSecret string) Client { } func (c Client) GetUser(t auth.TokenResponse) (auth.User, error) { - user := auth.User { - Blacklisted: false, - } + user := auth.User{ + Blacklisted: false, + CreatedAt: time.Now(), + } req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", api, "users/@me"), nil) diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 6a01712..9e7aec2 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -1,6 +1,7 @@ package auth import ( + "encoding/json" "fmt" "net/http" "strings" @@ -20,30 +21,45 @@ func GenerateJWT(key string, user User, expiryTimestamp uint64) (string, error) } func invalidAuth(c *gin.Context) { - c.String(http.StatusUnauthorized, "Unauthorized.") - c.Abort() + c.String(http.StatusUnauthorized, "Unauthorized.") + c.Abort() } func JwtMiddleware(secret string) gin.HandlerFunc { - return func(c *gin.Context) { - jwtSplit := strings.Split(c.GetHeader("Authorization"), " ") + return func(c *gin.Context) { + jwtSplit := strings.Split(c.GetHeader("Authorization"), " ") - if jwtSplit[0] != "Bearer" { - invalidAuth(c) - return - } + if len(jwtSplit) < 2 || jwtSplit[0] != "Bearer" { + invalidAuth(c) + return + } - claims, err := ValidateJWT(jwtSplit[1], secret) + claims, err := ValidateJWT(jwtSplit[1], secret) + if err != nil { + invalidAuth(c) + return + } - if err != nil { - invalidAuth(c) - return - } + if userClaims, ok := claims["user"].(map[string]interface{}); ok { + userJSON, err := json.Marshal(userClaims) // Convert map to JSON + if err != nil { + invalidAuth(c) + return + } - c.Set("claims", claims) + var user User + err = json.Unmarshal(userJSON, &user) + if err != nil { + invalidAuth(c) + return + } - c.Next() - } + claims["user"] = user + } + + c.Set("claims", claims) + c.Next() + } } func ValidateJWT(jwtString, key string) (jwt.MapClaims, error) {