Compare commits

...

5 commits

Author SHA1 Message Date
g
04284acc13 Merge pull request 'add state validation to oauth flow' (#8) from state into dev
Reviewed-on: #8
2025-07-31 10:30:40 +00:00
grngxd
8ca089ecfb use array instead of map 2025-07-31 11:07:46 +01:00
grngxd
b906736af8 logging out & fix state (?) 2025-07-31 10:58:25 +01:00
grngxd
96320c3cc4 clean up callback to dashboard url 😭 2025-07-30 11:17:23 +01:00
grngxd
6a08afbf52 add state validation to oauth flow 2025-07-30 11:12:22 +01:00
2 changed files with 115 additions and 34 deletions

View file

@ -1,26 +1,29 @@
/*
Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package routes
import (
"errors"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/gin-gonic/gin"
@ -31,42 +34,89 @@ import (
"stereo.cat/backend/internal/types"
)
var states []string
var statesMutex sync.Mutex
func generateState(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
api.GET("/auth/login", func(c *gin.Context) {
state, err := generateState(32)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
statesMutex.Lock()
states = append(states, state)
statesMutex.Unlock()
discordURL := fmt.Sprintf(
"https://discord.com/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=identify%%20email&state=%s",
cfg.Client.ClientId,
url.QueryEscape(cfg.Client.RedirectUri),
state,
)
c.Redirect(http.StatusTemporaryRedirect, discordURL)
})
api.GET("/auth/logout", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) {
c.SetCookie("jwt", "", -1, "", cfg.Domain, true, true)
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri)
})
api.GET("/auth/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
statesMutex.Lock()
found := false
for i, s := range states {
if s == state {
states = append(states[:i], states[i+1:]...)
found = true
break
}
}
statesMutex.Unlock()
if !found {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid state"})
return
}
t, err := cfg.Client.ExchangeCode(code)
if err != nil {
panic(err)
}
user, err := cfg.Client.GetUser(t)
if err != nil {
panic(err)
}
jwt, err := session.GenerateSessionJWT(cfg.JWTSecret, user, uint64(time.Now().Add(time.Second*time.Duration(t.ExpiresIn)).Unix()))
if err != nil {
panic(err)
}
res := cfg.Database.FirstOrCreate(&user)
if res.Error != nil {
panic(res.Error)
}
// TODO: redirect to dashboard
/*c.JSON(http.StatusOK, gin.H{
"jwt": jwt,
"known": res.RowsAffected == 0,
})
*/
c.SetCookie("jwt", jwt, int(t.ExpiresIn), "", cfg.Domain, true, true)
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"?jwt_set=true")
c.Redirect(http.StatusTemporaryRedirect, cfg.FrontendUri+"/dashboard")
})
api.GET("/auth/me", session.SessionMiddleware(cfg.JWTSecret), func(c *gin.Context) {
@ -80,7 +130,7 @@ func RegisterAuthRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
user, ok := claims["user"].(auth.User)
if !ok {
types.ErrorUserNotFound.Throw(c, errors.New(fmt.Sprintf("got data with type %T but wanted claims.User", claims["user"])))
types.ErrorUserNotFound.Throw(c, fmt.Errorf("got data with type %T but wanted claims.User", claims["user"]))
return
}

View file

@ -1,18 +1,18 @@
/*
Copyright (C) 2025 hexlocation (hex@iwakura.rip) & grngxd (grng@iwakura.rip)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package routes
@ -20,6 +20,7 @@ package routes
import (
"bytes"
"io"
"strconv"
"strings"
"time"
@ -179,8 +180,38 @@ func RegisterFileRoutes(cfg *types.StereoConfig, api *gin.RouterGroup) {
claims := c.MustGet("claims").(jwt.MapClaims)
user := claims["user"].(auth.User)
if c.Query("page") == "" || c.Query("size") == "" {
var files []types.File
if err := cfg.Database.Where("owner = ?", user.ID).Find(&files).Error; err != nil {
types.ErrorDatabase.Throw(c, err)
return
}
c.JSON(200, files)
return
}
page := c.Query("page")
size := c.Query("size")
pageNum, err := strconv.Atoi(page)
if err != nil || pageNum < 0 {
types.ErrorInvalidParams.Throw(c, err)
return
}
sizeNum, err := strconv.Atoi(size)
if err != nil || sizeNum <= 0 {
types.ErrorInvalidParams.Throw(c, err)
return
}
var files []types.File
if err := cfg.Database.Where("owner = ?", user.ID).Find(&files).Error; err != nil {
offset := (pageNum - 1) * sizeNum
if offset < 0 {
offset = 0
}
if err := cfg.Database.Where("owner = ?", user.ID).Offset(offset).Limit(sizeNum).Find(&files).Error; err != nil {
types.ErrorDatabase.Throw(c, err)
return
}