mirror of
https://github.com/Dvorinka/Containr.git
synced 2026-06-03 20:12:58 +00:00
368 lines
9.0 KiB
Go
368 lines
9.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"containr/internal/database"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// Logger middleware
|
|
func Logger() gin.HandlerFunc {
|
|
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
|
return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n",
|
|
param.ClientIP,
|
|
param.TimeStamp.Format(time.RFC1123),
|
|
param.Method,
|
|
param.Path,
|
|
param.Request.Proto,
|
|
param.StatusCode,
|
|
param.Latency,
|
|
param.Request.UserAgent(),
|
|
param.ErrorMessage,
|
|
)
|
|
})
|
|
}
|
|
|
|
// Recovery middleware
|
|
func Recovery() gin.HandlerFunc {
|
|
return gin.Recovery()
|
|
}
|
|
|
|
// SecurityHeaders adds secure default HTTP response headers.
|
|
func SecurityHeaders() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
c.Header("X-Content-Type-Options", "nosniff")
|
|
c.Header("X-Frame-Options", "DENY")
|
|
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
c.Header("X-XSS-Protection", "1; mode=block")
|
|
|
|
if strings.EqualFold(c.Request.Header.Get("X-Forwarded-Proto"), "https") || c.Request.TLS != nil {
|
|
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequestID middleware adds a unique request ID to each request
|
|
func RequestID() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
requestID := c.GetHeader("X-Request-ID")
|
|
if requestID == "" {
|
|
requestID = uuid.New().String()
|
|
}
|
|
c.Set("request_id", requestID)
|
|
c.Header("X-Request-ID", requestID)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// Auth middleware for JWT authentication
|
|
func Auth(jwtSecret string) gin.HandlerFunc {
|
|
sessionVerifier := newBetterAuthSessionVerifier()
|
|
|
|
return func(c *gin.Context) {
|
|
tokenString, tokenErr, hasToken := extractJWTToken(c)
|
|
if tokenString != "" {
|
|
if claims, valid := validateJWTClaims(tokenString, jwtSecret); valid {
|
|
userIDClaim, exists := claims["user_id"]
|
|
if exists {
|
|
userID := strings.TrimSpace(fmt.Sprint(userIDClaim))
|
|
if _, err := uuid.Parse(userID); err == nil {
|
|
email := ""
|
|
if emailClaim, ok := claims["email"]; ok && emailClaim != nil {
|
|
email = strings.TrimSpace(fmt.Sprint(emailClaim))
|
|
}
|
|
c.Set("user_id", userID)
|
|
c.Set("email", email)
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
tokenErr = "Invalid token claims"
|
|
} else if tokenErr == "" {
|
|
tokenErr = "Invalid token"
|
|
}
|
|
}
|
|
|
|
if sessionVerifier != nil {
|
|
if userID, email, ok := sessionVerifier.resolveUser(c); ok {
|
|
c.Set("user_id", userID)
|
|
c.Set("email", email)
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
if tokenErr != "" {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": tokenErr})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
if hasToken {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
|
} else {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
|
}
|
|
c.Abort()
|
|
}
|
|
}
|
|
|
|
type betterAuthSessionVerifier struct {
|
|
internalURL string
|
|
internalToken string
|
|
client *http.Client
|
|
}
|
|
|
|
type betterAuthSessionUser struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Image string `json:"image"`
|
|
}
|
|
|
|
type betterAuthSessionResponse struct {
|
|
Authenticated bool `json:"authenticated"`
|
|
User betterAuthSessionUser `json:"user"`
|
|
}
|
|
|
|
func newBetterAuthSessionVerifier() *betterAuthSessionVerifier {
|
|
internalURL := strings.TrimSpace(os.Getenv("BETTER_AUTH_INTERNAL_URL"))
|
|
if internalURL == "" {
|
|
internalURL = "http://127.0.0.1:3001/internal/session"
|
|
}
|
|
internalToken := strings.TrimSpace(os.Getenv("BETTER_AUTH_INTERNAL_TOKEN"))
|
|
|
|
if internalToken == "" {
|
|
return nil
|
|
}
|
|
|
|
return &betterAuthSessionVerifier{
|
|
internalURL: internalURL,
|
|
internalToken: internalToken,
|
|
client: &http.Client{
|
|
Timeout: 3 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (v *betterAuthSessionVerifier) resolveUser(c *gin.Context) (string, string, bool) {
|
|
request, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, v.internalURL, nil)
|
|
if err != nil {
|
|
return "", "", false
|
|
}
|
|
|
|
request.Header.Set("X-Containr-Auth-Internal", v.internalToken)
|
|
copyHeaderIfPresent(c.Request, request, "Cookie")
|
|
copyHeaderIfPresent(c.Request, request, "User-Agent")
|
|
copyHeaderIfPresent(c.Request, request, "X-Forwarded-For")
|
|
copyHeaderIfPresent(c.Request, request, "X-Forwarded-Proto")
|
|
|
|
response, err := v.client.Do(request)
|
|
if err != nil {
|
|
return "", "", false
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
return "", "", false
|
|
}
|
|
|
|
var payload betterAuthSessionResponse
|
|
if err := json.NewDecoder(response.Body).Decode(&payload); err != nil {
|
|
return "", "", false
|
|
}
|
|
|
|
if !payload.Authenticated {
|
|
return "", "", false
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(payload.User.Email))
|
|
if email == "" {
|
|
return "", "", false
|
|
}
|
|
|
|
localUserID, err := ensureLocalUserRecord(c, payload.User)
|
|
if err != nil {
|
|
log.Printf("Failed to map Better Auth user to local user: %v", err)
|
|
return "", "", false
|
|
}
|
|
|
|
return localUserID, email, true
|
|
}
|
|
|
|
func ensureLocalUserRecord(c *gin.Context, user betterAuthSessionUser) (string, error) {
|
|
dbValue, ok := c.Get("db")
|
|
if !ok || dbValue == nil {
|
|
return "", fmt.Errorf("database context missing")
|
|
}
|
|
|
|
db, ok := dbValue.(*database.DB)
|
|
if !ok || db == nil {
|
|
return "", fmt.Errorf("invalid database context")
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(user.Email))
|
|
if email == "" {
|
|
return "", fmt.Errorf("email is required")
|
|
}
|
|
|
|
name := strings.TrimSpace(user.Name)
|
|
if name == "" {
|
|
name = "Containr User"
|
|
}
|
|
|
|
avatarURL := strings.TrimSpace(user.Image)
|
|
|
|
var localUserID string
|
|
err := db.QueryRow(`SELECT id FROM users WHERE email = $1`, email).Scan(&localUserID)
|
|
switch {
|
|
case err == nil:
|
|
_, _ = db.Exec(`
|
|
UPDATE users
|
|
SET name = $1,
|
|
avatar_url = CASE WHEN $2 = '' THEN avatar_url ELSE $2 END,
|
|
updated_at = NOW()
|
|
WHERE id = $3
|
|
`, name, avatarURL, localUserID)
|
|
return localUserID, nil
|
|
case !errors.Is(err, sql.ErrNoRows):
|
|
return "", err
|
|
}
|
|
|
|
hashedPassword, hashErr := bcrypt.GenerateFromPassword([]byte(uuid.NewString()), bcrypt.DefaultCost)
|
|
if hashErr != nil {
|
|
return "", hashErr
|
|
}
|
|
|
|
if err := db.QueryRow(`
|
|
INSERT INTO users (email, password_hash, name, avatar_url)
|
|
VALUES ($1, $2, $3, NULLIF($4, ''))
|
|
RETURNING id
|
|
`, email, string(hashedPassword), name, avatarURL).Scan(&localUserID); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return localUserID, nil
|
|
}
|
|
|
|
func validateJWTClaims(tokenString, jwtSecret string) (jwt.MapClaims, bool) {
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(jwtSecret), nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok || !token.Valid {
|
|
return nil, false
|
|
}
|
|
|
|
return claims, true
|
|
}
|
|
|
|
func extractJWTToken(c *gin.Context) (string, string, bool) {
|
|
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
|
|
if authHeader != "" {
|
|
parts := strings.Fields(authHeader)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
|
return "", "Invalid authorization header format", true
|
|
}
|
|
return strings.TrimSpace(parts[1]), "", true
|
|
}
|
|
|
|
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
|
|
token := strings.TrimSpace(c.Query("token"))
|
|
if token == "" {
|
|
return "", "", false
|
|
}
|
|
return token, "", true
|
|
}
|
|
|
|
return "", "", false
|
|
}
|
|
|
|
func copyHeaderIfPresent(src *http.Request, dst *http.Request, key string) {
|
|
value := strings.TrimSpace(src.Header.Get(key))
|
|
if value != "" {
|
|
dst.Header.Set(key, value)
|
|
}
|
|
}
|
|
|
|
// ErrorHandler middleware for consistent error handling
|
|
func ErrorHandler() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
c.Next()
|
|
|
|
// Check if there are any errors
|
|
if len(c.Errors) > 0 {
|
|
err := c.Errors.Last()
|
|
log.Printf("Request error: %v", err)
|
|
|
|
// Return JSON error response
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": "Internal server error",
|
|
"code": "INTERNAL_ERROR",
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// CORSMiddleware for CORS handling
|
|
func CORSMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Allow-Credentials", "true")
|
|
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
|
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
|
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.AbortWithStatus(204)
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequestBodyLimit enforces a maximum HTTP request body size.
|
|
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if maxBytes <= 0 {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
if c.Request.ContentLength > maxBytes {
|
|
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
|
|
"error": "Request body too large",
|
|
})
|
|
return
|
|
}
|
|
|
|
if c.Request.Body != nil {
|
|
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|