mirror of
https://github.com/TracksApp/tracks.git
synced 2025-12-18 08:10:13 +01:00
182 lines
4.2 KiB
Go
182 lines
4.2 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"fmt"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"github.com/TracksApp/tracks/internal/database"
|
||
|
|
"github.com/TracksApp/tracks/internal/models"
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
"github.com/golang-jwt/jwt/v5"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Claims represents the JWT claims
|
||
|
|
type Claims struct {
|
||
|
|
UserID uint `json:"user_id"`
|
||
|
|
Login string `json:"login"`
|
||
|
|
jwt.RegisteredClaims
|
||
|
|
}
|
||
|
|
|
||
|
|
// AuthMiddleware validates JWT tokens and sets the current user
|
||
|
|
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// Try to get token from Authorization header
|
||
|
|
authHeader := c.GetHeader("Authorization")
|
||
|
|
var tokenString string
|
||
|
|
|
||
|
|
if authHeader != "" {
|
||
|
|
// Bearer token
|
||
|
|
parts := strings.Split(authHeader, " ")
|
||
|
|
if len(parts) == 2 && parts[0] == "Bearer" {
|
||
|
|
tokenString = parts[1]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// If no Bearer token, try cookie
|
||
|
|
if tokenString == "" {
|
||
|
|
cookie, err := c.Cookie("tracks_token")
|
||
|
|
if err == nil {
|
||
|
|
tokenString = cookie
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// If still no token, try query parameter (for feed tokens)
|
||
|
|
if tokenString == "" {
|
||
|
|
tokenString = c.Query("token")
|
||
|
|
}
|
||
|
|
|
||
|
|
if tokenString == "" {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "No authentication token provided"})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse and validate token
|
||
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||
|
|
}
|
||
|
|
return []byte(jwtSecret), nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err != nil || !token.Valid {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
claims, ok := token.Claims.(*Claims)
|
||
|
|
if !ok {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Load user from database
|
||
|
|
var user models.User
|
||
|
|
if err := database.DB.First(&user, claims.UserID).Error; err != nil {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Set user in context
|
||
|
|
c.Set("user", &user)
|
||
|
|
c.Set("user_id", user.ID)
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// OptionalAuthMiddleware attempts to authenticate but doesn't fail if no token
|
||
|
|
func OptionalAuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
authHeader := c.GetHeader("Authorization")
|
||
|
|
var tokenString string
|
||
|
|
|
||
|
|
if authHeader != "" {
|
||
|
|
parts := strings.Split(authHeader, " ")
|
||
|
|
if len(parts) == 2 && parts[0] == "Bearer" {
|
||
|
|
tokenString = parts[1]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if tokenString == "" {
|
||
|
|
cookie, err := c.Cookie("tracks_token")
|
||
|
|
if err == nil {
|
||
|
|
tokenString = cookie
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if tokenString != "" {
|
||
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||
|
|
return []byte(jwtSecret), nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err == nil && token.Valid {
|
||
|
|
if claims, ok := token.Claims.(*Claims); ok {
|
||
|
|
var user models.User
|
||
|
|
if err := database.DB.First(&user, claims.UserID).Error; err == nil {
|
||
|
|
c.Set("user", &user)
|
||
|
|
c.Set("user_id", user.ID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// AdminMiddleware ensures the user is an admin
|
||
|
|
func AdminMiddleware() gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
userInterface, exists := c.Get("user")
|
||
|
|
if !exists {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
user, ok := userInterface.(*models.User)
|
||
|
|
if !ok || !user.IsAdmin {
|
||
|
|
c.JSON(http.StatusForbidden, gin.H{"error": "Admin access required"})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetCurrentUser retrieves the current user from the context
|
||
|
|
func GetCurrentUser(c *gin.Context) (*models.User, error) {
|
||
|
|
userInterface, exists := c.Get("user")
|
||
|
|
if !exists {
|
||
|
|
return nil, fmt.Errorf("user not found in context")
|
||
|
|
}
|
||
|
|
|
||
|
|
user, ok := userInterface.(*models.User)
|
||
|
|
if !ok {
|
||
|
|
return nil, fmt.Errorf("invalid user type in context")
|
||
|
|
}
|
||
|
|
|
||
|
|
return user, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetCurrentUserID retrieves the current user ID from the context
|
||
|
|
func GetCurrentUserID(c *gin.Context) (uint, error) {
|
||
|
|
userIDInterface, exists := c.Get("user_id")
|
||
|
|
if !exists {
|
||
|
|
return 0, fmt.Errorf("user ID not found in context")
|
||
|
|
}
|
||
|
|
|
||
|
|
userID, ok := userIDInterface.(uint)
|
||
|
|
if !ok {
|
||
|
|
return 0, fmt.Errorf("invalid user ID type in context")
|
||
|
|
}
|
||
|
|
|
||
|
|
return userID, nil
|
||
|
|
}
|