mirror of
https://github.com/Dvorinka/Trackeep.git
synced 2026-06-03 20:12:58 +00:00
f3a835caa2
- Add Redis client initialization with DragonflyDB connection - Update session middleware to use DragonflyDB with fallback to memory - Update cache middleware to use DragonflyDB for persistent caching - Add proper error handling and connection timeouts - Implement session storage in DragonflyDB with 24-hour expiration - Add cache invalidation middleware for DragonflyDB - Maintain backward compatibility with in-memory fallbacks
318 lines
8.3 KiB
Go
318 lines
8.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/trackeep/backend/models"
|
|
)
|
|
|
|
// SessionData represents the structure of session data stored in Redis
|
|
type SessionData struct {
|
|
UserID uint `json:"user_id"`
|
|
Email string `json:"email"`
|
|
Username string `json:"username"`
|
|
Role string `json:"role"`
|
|
SessionID string `json:"session_id"`
|
|
IPAddress string `json:"ip_address"`
|
|
UserAgent string `json:"user_agent"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
LastActive time.Time `json:"last_active"`
|
|
}
|
|
|
|
// SessionStore interface for session storage
|
|
type SessionStore interface {
|
|
CreateSession(sessionData *SessionData) error
|
|
GetSession(sessionID string) (*SessionData, error)
|
|
UpdateSession(sessionID string, sessionData *SessionData) error
|
|
DeleteSession(sessionID string) error
|
|
CleanupExpiredSessions() error
|
|
}
|
|
|
|
// RedisSessionStore implements SessionStore using Redis (or fallback to memory)
|
|
type RedisSessionStore struct {
|
|
redisClient *redis.Client
|
|
sessions map[string]*SessionData // Fallback in-memory store
|
|
}
|
|
|
|
// NewSessionStore creates a new session store
|
|
func NewSessionStore(redisClient *redis.Client) SessionStore {
|
|
return &RedisSessionStore{
|
|
redisClient: redisClient,
|
|
sessions: make(map[string]*SessionData),
|
|
}
|
|
}
|
|
|
|
// CreateSession creates a new session
|
|
func (r *RedisSessionStore) CreateSession(sessionData *SessionData) error {
|
|
sessionData.CreatedAt = time.Now()
|
|
sessionData.LastActive = time.Now()
|
|
|
|
// Try Redis first
|
|
if r.redisClient != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
sessionJSON, err := json.Marshal(sessionData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal session data: %w", err)
|
|
}
|
|
|
|
// Store in Redis with 24 hour expiration
|
|
err = r.redisClient.Set(ctx, "session:"+sessionData.SessionID, sessionJSON, 24*time.Hour).Err()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// Fall back to memory if Redis fails
|
|
}
|
|
|
|
// Fallback to in-memory storage
|
|
r.sessions[sessionData.SessionID] = sessionData
|
|
return nil
|
|
}
|
|
|
|
// GetSession retrieves a session by ID
|
|
func (r *RedisSessionStore) GetSession(sessionID string) (*SessionData, error) {
|
|
// Try Redis first
|
|
if r.redisClient != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
sessionJSON, err := r.redisClient.Get(ctx, "session:"+sessionID).Result()
|
|
if err == nil {
|
|
var sessionData SessionData
|
|
if err := json.Unmarshal([]byte(sessionJSON), &sessionData); err == nil {
|
|
// Update last active time
|
|
sessionData.LastActive = time.Now()
|
|
// Update in Redis
|
|
updatedJSON, _ := json.Marshal(sessionData)
|
|
r.redisClient.Set(ctx, "session:"+sessionID, updatedJSON, 24*time.Hour)
|
|
return &sessionData, nil
|
|
}
|
|
}
|
|
// Fall back to memory if Redis fails
|
|
}
|
|
|
|
// Fallback to in-memory storage
|
|
if session, exists := r.sessions[sessionID]; exists {
|
|
// Update last active time
|
|
session.LastActive = time.Now()
|
|
return session, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("session not found")
|
|
}
|
|
|
|
// UpdateSession updates an existing session
|
|
func (r *RedisSessionStore) UpdateSession(sessionID string, sessionData *SessionData) error {
|
|
sessionData.LastActive = time.Now()
|
|
|
|
// Try Redis first
|
|
if r.redisClient != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
sessionJSON, err := json.Marshal(sessionData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal session data: %w", err)
|
|
}
|
|
|
|
err = r.redisClient.Set(ctx, "session:"+sessionID, sessionJSON, 24*time.Hour).Err()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// Fall back to memory if Redis fails
|
|
}
|
|
|
|
// Fallback to in-memory storage
|
|
if _, exists := r.sessions[sessionID]; exists {
|
|
r.sessions[sessionID] = sessionData
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("session not found")
|
|
}
|
|
|
|
// DeleteSession removes a session
|
|
func (r *RedisSessionStore) DeleteSession(sessionID string) error {
|
|
// Try Redis first
|
|
if r.redisClient != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
err := r.redisClient.Del(ctx, "session:"+sessionID).Err()
|
|
if err == nil {
|
|
// Also remove from memory fallback
|
|
delete(r.sessions, sessionID)
|
|
return nil
|
|
}
|
|
// Fall back to memory if Redis fails
|
|
}
|
|
|
|
// Fallback to in-memory storage
|
|
delete(r.sessions, sessionID)
|
|
return nil
|
|
}
|
|
|
|
// CleanupExpiredSessions removes sessions older than 24 hours
|
|
func (r *RedisSessionStore) CleanupExpiredSessions() error {
|
|
now := time.Now()
|
|
for sessionID, session := range r.sessions {
|
|
if now.Sub(session.LastActive) > 24*time.Hour {
|
|
delete(r.sessions, sessionID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Global session store instance
|
|
var sessionStore SessionStore
|
|
|
|
// InitSessionStore initializes the session store
|
|
func InitSessionStore(redisClient *redis.Client) {
|
|
sessionStore = NewSessionStore(redisClient)
|
|
|
|
// Start cleanup goroutine
|
|
go func() {
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
if sessionStore != nil {
|
|
sessionStore.CleanupExpiredSessions()
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// SessionMiddleware creates and manages user sessions
|
|
func SessionMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Skip session management for health checks and static assets
|
|
path := c.Request.URL.Path
|
|
if path == "/health" || path == "/metrics" || strings.HasPrefix(path, "/static") {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Get session ID from header or create new one
|
|
sessionID := c.GetHeader("X-Session-ID")
|
|
if sessionID == "" {
|
|
sessionID = generateSessionID()
|
|
c.Header("X-Session-ID", sessionID)
|
|
}
|
|
|
|
// Try to get existing session
|
|
session, err := sessionStore.GetSession(sessionID)
|
|
if err != nil {
|
|
// No existing session, check if user is authenticated via JWT
|
|
if user, exists := c.Get("user"); exists {
|
|
// Create session from authenticated user
|
|
if userModel, ok := user.(models.User); ok {
|
|
session = &SessionData{
|
|
SessionID: sessionID,
|
|
UserID: userModel.ID,
|
|
Email: userModel.Email,
|
|
Username: userModel.Username,
|
|
Role: userModel.Role,
|
|
IPAddress: c.ClientIP(),
|
|
UserAgent: c.GetHeader("User-Agent"),
|
|
}
|
|
sessionStore.CreateSession(session)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set session data in context
|
|
if session != nil {
|
|
c.Set("session_id", session.SessionID)
|
|
c.Set("session_user_id", session.UserID)
|
|
c.Set("session_email", session.Email)
|
|
c.Set("session_username", session.Username)
|
|
c.Set("session_role", session.Role)
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// GetSessionFromContext retrieves session data from Gin context
|
|
func GetSessionFromContext(c *gin.Context) (*SessionData, error) {
|
|
if sessionID, exists := c.Get("session_id"); exists {
|
|
return sessionStore.GetSession(sessionID.(string))
|
|
}
|
|
return nil, fmt.Errorf("no session in context")
|
|
}
|
|
|
|
// GetUserIDFromSession safely gets user ID from session or context
|
|
func GetUserIDFromSession(c *gin.Context) uint {
|
|
// First try session
|
|
if session, err := GetSessionFromContext(c); err == nil {
|
|
return session.UserID
|
|
}
|
|
|
|
// Fallback to context (for demo mode or JWT)
|
|
if userID, exists := c.Get("user_id"); exists {
|
|
if id, ok := userID.(uint); ok {
|
|
return id
|
|
}
|
|
}
|
|
if userID, exists := c.Get("userID"); exists {
|
|
if id, ok := userID.(uint); ok {
|
|
return id
|
|
}
|
|
}
|
|
|
|
// Final fallback for demo mode
|
|
if os.Getenv("VITE_DEMO_MODE") == "true" {
|
|
return 1
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// GetUserEmailFromSession safely gets user email from session or context
|
|
func GetUserEmailFromSession(c *gin.Context) string {
|
|
// First try session
|
|
if session, err := GetSessionFromContext(c); err == nil {
|
|
return session.Email
|
|
}
|
|
|
|
// Fallback to context
|
|
if email, exists := c.Get("user_email"); exists {
|
|
if e, ok := email.(string); ok {
|
|
return e
|
|
}
|
|
}
|
|
|
|
// Fallback for demo mode
|
|
if os.Getenv("VITE_DEMO_MODE") == "true" {
|
|
return "demo@trackeep.com"
|
|
}
|
|
|
|
return "unknown"
|
|
}
|
|
|
|
// generateSessionID generates a unique session ID
|
|
func generateSessionID() string {
|
|
return fmt.Sprintf("sess_%d_%s", time.Now().UnixNano(), "trackeep")
|
|
}
|
|
|
|
// GetSessionStore returns the global session store instance
|
|
func GetSessionStore() SessionStore {
|
|
return sessionStore
|
|
}
|
|
|
|
// CleanupSessionsOnShutdown gracefully cleans up sessions
|
|
func CleanupSessionsOnShutdown() {
|
|
if sessionStore != nil {
|
|
sessionStore.CleanupExpiredSessions()
|
|
}
|
|
}
|