add state validation to oauth flow #8

Merged
grng merged 4 commits from state into dev 2025-07-31 10:30:40 +00:00
Showing only changes of commit 6a08afbf52 - Show all commits

View file

@ -1,26 +1,28 @@
/* /*
Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip) Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip)
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or the Free Software Foundation, either version 3 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details. GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package routes package routes
import ( import (
"errors" "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -31,30 +33,62 @@ import (
"stereo.cat/backend/internal/types" "stereo.cat/backend/internal/types"
) )
func generateState(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
api.GET("/auth/login", func(c *gin.Context) {
state, err := generateState(32)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.SetCookie("oauth_state", state, 300, "", cfg.Domain, true, true)
discordURL := fmt.Sprintf(
"https://discord.com/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=identify%%20email&state=%s",
cfg.Client.ClientId,
url.QueryEscape(cfg.Client.RedirectUri),
state,
)
c.Redirect(http.StatusTemporaryRedirect, discordURL)
})
api.GET("/auth/callback", func(c *gin.Context) { api.GET("/auth/callback", func(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
state := c.Query("state")
cookieState, err := c.Cookie("oauth_state")
if err != nil || state != cookieState {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid state"})
return
}
c.SetCookie("oauth_state", "", -1, "", cfg.Domain, true, true)
t, err := cfg.Client.ExchangeCode(code) t, err := cfg.Client.ExchangeCode(code)
if err != nil { if err != nil {
panic(err) panic(err)
} }
user, err := cfg.Client.GetUser(t) user, err := cfg.Client.GetUser(t)
if err != nil { if err != nil {
panic(err) panic(err)
} }
jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix())) jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix()))
if err != nil { if err != nil {
panic(err) panic(err)
} }
res := cfg.Database.FirstOrCreate(&user) res := cfg.Database.FirstOrCreate(&user)
if res.Error != nil { if res.Error != nil {
panic(res.Error) panic(res.Error)
} }
@ -66,7 +100,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
}) })
*/ */
c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true) c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true)
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"?jwt_set=true") c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"/dashboard?jwt_set=true")
}) })
api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) {
@ -80,7 +114,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
user, ok := claims["user"].(auth.User) user, ok := claims["user"].(auth.User)
if !ok { if !ok {
types.ErrorUserNotFound.Throw(c, errors.New(fmt.Sprintf("got data with type %T but wanted claims.User", claims["user"]))) types.ErrorUserNotFound.Throw(c, fmt.Errorf("got data with type %T but wanted claims.User", claims["user"]))
return return
} }