add state validation to oauth flow

This commit is contained in:
grngxd 2025-07-30 11:12:22 +01:00
parent 470f8fa047
commit 6a08afbf52

View file

@ -1,26 +1,28 @@
/* /*
Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip) Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip)
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or the Free Software Foundation, either version 3 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details. GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package routes package routes
import ( import (
"errors" "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -31,30 +33,62 @@ import (
"stereo.cat/backend/internal/types" "stereo.cat/backend/internal/types"
) )
func generateState(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
api.GET("/auth/login", func(c *gin.Context) {
state, err := generateState(32)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.SetCookie("oauth_state", state, 300, "", cfg.Domain, true, true)
discordURL := fmt.Sprintf(
"https://discord.com/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=identify%%20email&state=%s",
cfg.Client.ClientId,
url.QueryEscape(cfg.Client.RedirectUri),
state,
)
c.Redirect(http.StatusTemporaryRedirect, discordURL)
})
api.GET("/auth/callback", func(c *gin.Context) { api.GET("/auth/callback", func(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
state := c.Query("state")
cookieState, err := c.Cookie("oauth_state")
if err != nil || state != cookieState {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid state"})
return
}
c.SetCookie("oauth_state", "", -1, "", cfg.Domain, true, true)
t, err := cfg.Client.ExchangeCode(code) t, err := cfg.Client.ExchangeCode(code)
if err != nil { if err != nil {
panic(err) panic(err)
} }
user, err := cfg.Client.GetUser(t) user, err := cfg.Client.GetUser(t)
if err != nil { if err != nil {
panic(err) panic(err)
} }
jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix())) jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix()))
if err != nil { if err != nil {
panic(err) panic(err)
} }
res := cfg.Database.FirstOrCreate(&user) res := cfg.Database.FirstOrCreate(&user)
if res.Error != nil { if res.Error != nil {
panic(res.Error) panic(res.Error)
} }
@ -66,7 +100,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
}) })
*/ */
c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true) c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true)
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"?jwt_set=true") c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"/dashboard?jwt_set=true")
}) })
api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) {
@ -80,7 +114,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
user, ok := claims["user"].(auth.User) user, ok := claims["user"].(auth.User)
if !ok { if !ok {
types.ErrorUserNotFound.Throw(c, errors.New(fmt.Sprintf("got data with type %T but wanted claims.User", claims["user"]))) types.ErrorUserNotFound.Throw(c, fmt.Errorf("got data with type %T but wanted claims.User", claims["user"]))
return return
} }