mirror of
https://github.com/Dvorinka/Containr.git
synced 2026-06-04 04:22:57 +00:00
small fix, don't worry about it
This commit is contained in:
@@ -0,0 +1,367 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"containr/internal/database"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Logger middleware
|
||||
func Logger() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n",
|
||||
param.ClientIP,
|
||||
param.TimeStamp.Format(time.RFC1123),
|
||||
param.Method,
|
||||
param.Path,
|
||||
param.Request.Proto,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
param.Request.UserAgent(),
|
||||
param.ErrorMessage,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// Recovery middleware
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.Recovery()
|
||||
}
|
||||
|
||||
// SecurityHeaders adds secure default HTTP response headers.
|
||||
func SecurityHeaders() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
if strings.EqualFold(c.Request.Header.Get("X-Forwarded-Proto"), "https") || c.Request.TLS != nil {
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestID middleware adds a unique request ID to each request
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Auth middleware for JWT authentication
|
||||
func Auth(jwtSecret string) gin.HandlerFunc {
|
||||
sessionVerifier := newBetterAuthSessionVerifier()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
tokenString, tokenErr, hasToken := extractJWTToken(c)
|
||||
if tokenString != "" {
|
||||
if claims, valid := validateJWTClaims(tokenString, jwtSecret); valid {
|
||||
userIDClaim, exists := claims["user_id"]
|
||||
if exists {
|
||||
userID := strings.TrimSpace(fmt.Sprint(userIDClaim))
|
||||
if _, err := uuid.Parse(userID); err == nil {
|
||||
email := ""
|
||||
if emailClaim, ok := claims["email"]; ok && emailClaim != nil {
|
||||
email = strings.TrimSpace(fmt.Sprint(emailClaim))
|
||||
}
|
||||
c.Set("user_id", userID)
|
||||
c.Set("email", email)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
tokenErr = "Invalid token claims"
|
||||
} else if tokenErr == "" {
|
||||
tokenErr = "Invalid token"
|
||||
}
|
||||
}
|
||||
|
||||
if sessionVerifier != nil {
|
||||
if userID, email, ok := sessionVerifier.resolveUser(c); ok {
|
||||
c.Set("user_id", userID)
|
||||
c.Set("email", email)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if tokenErr != "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": tokenErr})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if hasToken {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
type betterAuthSessionVerifier struct {
|
||||
internalURL string
|
||||
internalToken string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type betterAuthSessionUser struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
type betterAuthSessionResponse struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
User betterAuthSessionUser `json:"user"`
|
||||
}
|
||||
|
||||
func newBetterAuthSessionVerifier() *betterAuthSessionVerifier {
|
||||
internalURL := strings.TrimSpace(os.Getenv("BETTER_AUTH_INTERNAL_URL"))
|
||||
if internalURL == "" {
|
||||
internalURL = "http://127.0.0.1:3001/internal/session"
|
||||
}
|
||||
internalToken := strings.TrimSpace(os.Getenv("BETTER_AUTH_INTERNAL_TOKEN"))
|
||||
|
||||
if internalToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &betterAuthSessionVerifier{
|
||||
internalURL: internalURL,
|
||||
internalToken: internalToken,
|
||||
client: &http.Client{
|
||||
Timeout: 3 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *betterAuthSessionVerifier) resolveUser(c *gin.Context) (string, string, bool) {
|
||||
request, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, v.internalURL, nil)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
request.Header.Set("X-Containr-Auth-Internal", v.internalToken)
|
||||
copyHeaderIfPresent(c.Request, request, "Cookie")
|
||||
copyHeaderIfPresent(c.Request, request, "User-Agent")
|
||||
copyHeaderIfPresent(c.Request, request, "X-Forwarded-For")
|
||||
copyHeaderIfPresent(c.Request, request, "X-Forwarded-Proto")
|
||||
|
||||
response, err := v.client.Do(request)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
var payload betterAuthSessionResponse
|
||||
if err := json.NewDecoder(response.Body).Decode(&payload); err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if !payload.Authenticated {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(payload.User.Email))
|
||||
if email == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
localUserID, err := ensureLocalUserRecord(c, payload.User)
|
||||
if err != nil {
|
||||
log.Printf("Failed to map Better Auth user to local user: %v", err)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return localUserID, email, true
|
||||
}
|
||||
|
||||
func ensureLocalUserRecord(c *gin.Context, user betterAuthSessionUser) (string, error) {
|
||||
dbValue, ok := c.Get("db")
|
||||
if !ok || dbValue == nil {
|
||||
return "", fmt.Errorf("database context missing")
|
||||
}
|
||||
|
||||
db, ok := dbValue.(*database.DB)
|
||||
if !ok || db == nil {
|
||||
return "", fmt.Errorf("invalid database context")
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(user.Name)
|
||||
if name == "" {
|
||||
name = "Containr User"
|
||||
}
|
||||
|
||||
avatarURL := strings.TrimSpace(user.Image)
|
||||
|
||||
var localUserID string
|
||||
err := db.QueryRow(`SELECT id FROM users WHERE email = $1`, email).Scan(&localUserID)
|
||||
switch {
|
||||
case err == nil:
|
||||
_, _ = db.Exec(`
|
||||
UPDATE users
|
||||
SET name = $1,
|
||||
avatar_url = CASE WHEN $2 = '' THEN avatar_url ELSE $2 END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, name, avatarURL, localUserID)
|
||||
return localUserID, nil
|
||||
case !errors.Is(err, sql.ErrNoRows):
|
||||
return "", err
|
||||
}
|
||||
|
||||
hashedPassword, hashErr := bcrypt.GenerateFromPassword([]byte(uuid.NewString()), bcrypt.DefaultCost)
|
||||
if hashErr != nil {
|
||||
return "", hashErr
|
||||
}
|
||||
|
||||
if err := db.QueryRow(`
|
||||
INSERT INTO users (email, password_hash, name, avatar_url)
|
||||
VALUES ($1, $2, $3, NULLIF($4, ''))
|
||||
RETURNING id
|
||||
`, email, string(hashedPassword), name, avatarURL).Scan(&localUserID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return localUserID, nil
|
||||
}
|
||||
|
||||
func validateJWTClaims(tokenString, jwtSecret string) (jwt.MapClaims, bool) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return claims, true
|
||||
}
|
||||
|
||||
func extractJWTToken(c *gin.Context) (string, string, bool) {
|
||||
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if authHeader != "" {
|
||||
parts := strings.Fields(authHeader)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return "", "Invalid authorization header format", true
|
||||
}
|
||||
return strings.TrimSpace(parts[1]), "", true
|
||||
}
|
||||
|
||||
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
|
||||
token := strings.TrimSpace(c.Query("token"))
|
||||
if token == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return token, "", true
|
||||
}
|
||||
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func copyHeaderIfPresent(src *http.Request, dst *http.Request, key string) {
|
||||
value := strings.TrimSpace(src.Header.Get(key))
|
||||
if value != "" {
|
||||
dst.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorHandler middleware for consistent error handling
|
||||
func ErrorHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// Check if there are any errors
|
||||
if len(c.Errors) > 0 {
|
||||
err := c.Errors.Last()
|
||||
log.Printf("Request error: %v", err)
|
||||
|
||||
// Return JSON error response
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Internal server error",
|
||||
"code": "INTERNAL_ERROR",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CORSMiddleware for CORS handling
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestBodyLimit enforces a maximum HTTP request body size.
|
||||
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if maxBytes <= 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.ContentLength > maxBytes {
|
||||
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "Request body too large",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Body != nil {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user