package auth

import (
	"encoding/json"
	"fmt"
	"net/http"
	"strings"

	"github.com/gin-gonic/gin"
	"github.com/golang-jwt/jwt/v5"
)

func GenerateJWT(key string, user User, expiryTimestamp uint64) (string, error) {
	claims := Claims{
		User: user,
		Exp:  expiryTimestamp,
	}

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	return token.SignedString([]byte(key))
}

func invalidAuth(c *gin.Context) {
	c.String(http.StatusUnauthorized, "Unauthorized.")
	c.Abort()
}

func JwtMiddleware(secret string) gin.HandlerFunc {
	return func(c *gin.Context) {
		jwtSplit := strings.Split(c.GetHeader("Authorization"), " ")

		if len(jwtSplit) < 2 || jwtSplit[0] != "Bearer" {
			invalidAuth(c)
			return
		}

		claims, err := ValidateJWT(jwtSplit[1], secret)
		if err != nil {
			invalidAuth(c)
			return
		}

		if userClaims, ok := claims["user"].(map[string]interface{}); ok {
			userJSON, err := json.Marshal(userClaims) // Convert map to JSON
			if err != nil {
				invalidAuth(c)
				return
			}

			var user User
			err = json.Unmarshal(userJSON, &user)
			if err != nil {
				invalidAuth(c)
				return
			}

			claims["user"] = user
		}

		c.Set("claims", claims)
		c.Next()
	}
}

func ValidateJWT(jwtString, key string) (jwt.MapClaims, error) {
	token, err := jwt.Parse(jwtString, func(token *jwt.Token) (any, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("Invalid signing method!")
		}

		return []byte(key), nil
	})

	if err != nil {
		return nil, err
	}

	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
		return claims, nil
	}

	return nil, fmt.Errorf("Invalid token!")
}