Merge pull request 'add state validation to oauth flow' (#8) from state into dev

Reviewed-on: #8
This commit is contained in:
g 2025-07-31 10:30:40 +00:00
commit 04284acc13
2 changed files with 115 additions and 34 deletions

View file

@ -1,26 +1,29 @@
/* /*
Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip) Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip)
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or the Free Software Foundation, either version 3 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details. GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package routes package routes
import ( import (
"errors" "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -31,42 +34,89 @@ import (
"stereo.cat/backend/internal/types" "stereo.cat/backend/internal/types"
) )
var states []string
var statesMutex sync.Mutex
func generateState(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) { func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
api.GET("/auth/login", func(c *gin.Context) {
state, err := generateState(32)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
statesMutex.Lock()
states = append(states, state)
statesMutex.Unlock()
discordURL := fmt.Sprintf(
"https://discord.com/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=identify%%20email&state=%s",
cfg.Client.ClientId,
url.QueryEscape(cfg.Client.RedirectUri),
state,
)
c.Redirect(http.StatusTemporaryRedirect, discordURL)
})
api.GET("/auth/logout", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) {
c.SetCookie("jwt", "", -1, "", cfg.Domain, true, true)
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri)
})
api.GET("/auth/callback", func(c *gin.Context) { api.GET("/auth/callback", func(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
state := c.Query("state")
statesMutex.Lock()
found := false
for i, s := range states {
if s == state {
states = append(states[:i], states[i+1:]...)
found = true
break
}
}
statesMutex.Unlock()
if !found {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid state"})
return
}
t, err := cfg.Client.ExchangeCode(code) t, err := cfg.Client.ExchangeCode(code)
if err != nil { if err != nil {
panic(err) panic(err)
} }
user, err := cfg.Client.GetUser(t) user, err := cfg.Client.GetUser(t)
if err != nil { if err != nil {
panic(err) panic(err)
} }
jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix())) jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix()))
if err != nil { if err != nil {
panic(err) panic(err)
} }
res := cfg.Database.FirstOrCreate(&user) res := cfg.Database.FirstOrCreate(&user)
if res.Error != nil { if res.Error != nil {
panic(res.Error) panic(res.Error)
} }
// TODO: redirect to dashboard
/*c.JSON(http.StatusOK, gin.H{
"jwt": jwt,
"known": res.RowsAffected == 0,
})
*/
c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true) c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true)
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"?jwt_set=true") c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"/dashboard")
}) })
api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) { api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) {
@ -80,7 +130,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
user, ok := claims["user"].(auth.User) user, ok := claims["user"].(auth.User)
if !ok { if !ok {
types.ErrorUserNotFound.Throw(c, errors.New(fmt.Sprintf("got data with type %T but wanted claims.User", claims["user"]))) types.ErrorUserNotFound.Throw(c, fmt.Errorf("got data with type %T but wanted claims.User", claims["user"]))
return return
} }

View file

@ -1,18 +1,18 @@
/* /*
Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip) Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip)
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or the Free Software Foundation, either version 3 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details. GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
package routes package routes
@ -20,6 +20,7 @@ package routes
import ( import (
"bytes" "bytes"
"io" "io"
"strconv"
"strings" "strings"
"time" "time"
@ -179,8 +180,38 @@ func RegisterFileRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
claims := c.MustGet("claims").(jwt.MapClaims) claims := c.MustGet("claims").(jwt.MapClaims)
user := claims["user"].(auth.User) user := claims["user"].(auth.User)
if c.Query("page") == "" || c.Query("size") == "" {
var files []types.File
if err := cfg.Database.Where("owner = ?", user.ID).Find(&files).Error; err != nil {
types.ErrorDatabase.Throw(c, err)
return
}
c.JSON(200, files)
return
}
page := c.Query("page")
size := c.Query("size")
pageNum, err := strconv.Atoi(page)
if err != nil || pageNum < 0 {
types.ErrorInvalidParams.Throw(c, err)
return
}
sizeNum, err := strconv.Atoi(size)
if err != nil || sizeNum <= 0 {
types.ErrorInvalidParams.Throw(c, err)
return
}
var files []types.File var files []types.File
if err := cfg.Database.Where("owner = ?", user.ID).Find(&files).Error; err != nil { offset := (pageNum - 1) * sizeNum
if offset < 0 {
offset = 0
}
if err := cfg.Database.Where("owner = ?", user.ID).Offset(offset).Limit(sizeNum).Find(&files).Error; err != nil {
types.ErrorDatabase.Throw(c, err) types.ErrorDatabase.Throw(c, err)
return return
} }