69 lines
1.3 KiB
Go
Raw Normal View History

package middleware
import (
"strings"
"github.com/gofiber/fiber/v2"
"rul.sh/vaulterm/db"
"rul.sh/vaulterm/models"
)
func Auth(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
var sessionId string
if authHeader != "" {
sessionId = strings.Split(authHeader, " ")[1]
}
if strings.HasPrefix(c.Path(), "/ws") && sessionId == "" {
sessionId = c.Query("sid")
}
session, _ := GetUserSession(sessionId)
2024-11-12 17:17:10 +00:00
if session != nil && session.ID != "" {
c.Locals("user", session)
c.Locals("sessionId", sessionId)
}
return c.Next()
}
2024-11-12 17:17:10 +00:00
type AuthUser struct {
models.User
SessionID string `json:"sessionId" gorm:"column:session_id"`
}
func GetUserSession(sessionId string) (*AuthUser, error) {
var session AuthUser
2024-11-12 19:15:13 +07:00
res := db.Get().
2024-11-12 17:17:10 +00:00
Model(&models.User{}).
Joins("JOIN user_sessions ON user_sessions.user_id = users.id").
Preload("Teams.Team").
Select("users.*, user_sessions.id AS session_id").
2024-11-12 19:15:13 +07:00
Where("user_sessions.id = ?", sessionId).
First(&session)
2024-11-12 17:17:10 +00:00
if res.Error != nil || session.User.ID == "" {
return nil, res.Error
}
return &session, nil
}
func Protected() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
user := c.Locals("user")
if user == nil {
return &fiber.Error{
Code: fiber.StatusUnauthorized,
Message: "Unauthorized",
}
}
return c.Next()
}
}