first test

This commit is contained in:
Tomas Dvorak
2026-02-08 14:14:55 +01:00
parent 18aa702174
commit d27cf14110
372 changed files with 98089 additions and 2585 deletions
+414
View File
@@ -0,0 +1,414 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/trackeep/backend/config"
"github.com/trackeep/backend/models"
)
// AuditMiddleware creates audit logs for HTTP requests
func AuditMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
startTime := time.Now()
// Read request body for logging (only for POST/PUT/PATCH)
var requestBody []byte
if c.Request.Method != "GET" && c.Request.Body != nil {
requestBody, _ = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
// Process the request
c.Next()
// Skip audit logging for certain endpoints
if shouldSkipAudit(c.Request.URL.Path) {
return
}
// Create audit log entry with proper user data from session
userIDValue := GetUserIDFromSession(c)
userEmail := GetUserEmailFromSession(c)
// Ensure we have valid user data before creating audit log
if userIDValue == 0 && userEmail == "unknown" {
// Skip audit logging for unauthenticated requests
return
}
auditLog := &models.AuditLog{
UserID: userIDValue,
UserEmail: userEmail,
UserIP: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
Action: getActionFromMethodAndPath(c.Request.Method, c.Request.URL.Path),
Resource: getResourceFromPath(c.Request.URL.Path),
ResourceID: getResourceIDFromPath(c.Request.URL.Path),
Description: generateDescription(c, startTime),
Details: generateDetails(c, requestBody, startTime),
Success: c.Writer.Status() < 400,
SessionID: getSessionID(c),
Country: getCountryFromIP(c.ClientIP()),
Device: getDeviceFromUserAgent(c.Request.UserAgent()),
Platform: getPlatformFromUserAgent(c.Request.UserAgent()),
Browser: getBrowserFromUserAgent(c.Request.UserAgent()),
RiskLevel: assessRisk(c, startTime),
}
// Set failure reason if request failed
if !auditLog.Success {
auditLog.FailureReason = getFailureReason(c.Writer.Status())
}
// Save audit log asynchronously
go saveAuditLog(auditLog)
}
}
// LogSecurityEvent logs security-related events
func LogSecurityEvent(userID uint, userEmail, action, description, failureReason string, details map[string]interface{}) {
auditLog := &models.AuditLog{
UserID: userID,
UserEmail: userEmail,
Action: models.AuditAction(action),
Resource: models.AuditResourceSecurity,
Description: description,
Details: details,
Success: failureReason == "",
FailureReason: failureReason,
RiskLevel: assessSecurityRisk(action, failureReason),
Suspicious: isSuspiciousActivity(action, failureReason),
CreatedAt: time.Now(),
}
go saveAuditLog(auditLog)
}
// LogUserAction logs user-specific actions
func LogUserAction(user models.User, action models.AuditAction, resource models.AuditResource, resourceID *uint, description string, oldValues, newValues map[string]interface{}) {
auditLog := &models.AuditLog{
UserID: user.ID,
UserEmail: user.Email,
Action: action,
Resource: resource,
ResourceID: resourceID,
Description: description,
OldValues: oldValues,
NewValues: newValues,
Success: true,
RiskLevel: assessActionRisk(action, resource),
CreatedAt: time.Now(),
}
go saveAuditLog(auditLog)
}
// Helper functions
func shouldSkipAudit(path string) bool {
skipPaths := []string{
"/health",
"/metrics",
"/api/demo/status",
"/favicon.ico",
"/assets/",
}
for _, skipPath := range skipPaths {
if strings.HasPrefix(path, skipPath) {
return true
}
}
return false
}
func getUintFromInterface(value interface{}) uint {
if v, ok := value.(uint); ok {
return v
}
return 0
}
func getUserEmail(user interface{}) string {
if u, ok := user.(models.User); ok {
return u.Email
}
return "unknown"
}
func getActionFromMethodAndPath(method, path string) models.AuditAction {
switch method {
case "GET":
return models.AuditActionRead
case "POST":
if strings.Contains(path, "/login") {
return models.AuditActionLogin
} else if strings.Contains(path, "/logout") {
return models.AuditActionLogout
} else if strings.Contains(path, "/upload") {
return models.AuditActionUpload
} else if strings.Contains(path, "/export") {
return models.AuditActionExport
} else if strings.Contains(path, "/import") {
return models.AuditActionImport
}
return models.AuditActionCreate
case "PUT", "PATCH":
return models.AuditActionUpdate
case "DELETE":
return models.AuditActionDelete
default:
return models.AuditActionAccess
}
}
func getResourceFromPath(path string) models.AuditResource {
if strings.Contains(path, "/users") {
return models.AuditResourceUser
} else if strings.Contains(path, "/notes") {
return models.AuditResourceNote
} else if strings.Contains(path, "/files") {
return models.AuditResourceFile
} else if strings.Contains(path, "/bookmarks") {
return models.AuditResourceBookmark
} else if strings.Contains(path, "/tasks") {
return models.AuditResourceTask
} else if strings.Contains(path, "/time-entries") {
return models.AuditResourceTimeEntry
} else if strings.Contains(path, "/integrations") {
return models.AuditResourceIntegration
} else if strings.Contains(path, "/teams") {
return models.AuditResourceTeam
} else if strings.Contains(path, "/goals") || strings.Contains(path, "/habits") {
return models.AuditResourceGoal
} else if strings.Contains(path, "/calendar") {
return models.AuditResourceCalendar
} else if strings.Contains(path, "/search") {
return models.AuditResourceSearch
} else if strings.Contains(path, "/ai") {
return models.AuditResourceAI
} else if strings.Contains(path, "/analytics") {
return models.AuditResourceAnalytics
} else if strings.Contains(path, "/auth") {
return models.AuditResourceSecurity
}
return models.AuditResourceSystem
}
func getResourceIDFromPath(path string) *uint {
parts := strings.Split(path, "/")
for i, part := range parts {
if part == "" {
continue
}
// Check if this part looks like a numeric ID
if i > 0 && len(part) >= 1 && part[0] >= '0' && part[0] <= '9' {
var id uint
if _, err := fmt.Sscanf(part, "%d", &id); err == nil {
return &id
}
}
}
return nil
}
func generateDescription(c *gin.Context, startTime time.Time) string {
duration := time.Since(startTime)
method := c.Request.Method
path := c.Request.URL.Path
status := c.Writer.Status()
return fmt.Sprintf("%s %s - %d (%v)", method, path, status, duration.Round(time.Millisecond))
}
func generateDetails(c *gin.Context, requestBody []byte, startTime time.Time) map[string]interface{} {
details := make(map[string]interface{})
details["method"] = c.Request.Method
details["path"] = c.Request.URL.Path
details["query"] = c.Request.URL.RawQuery
details["status_code"] = c.Writer.Status()
details["duration_ms"] = time.Since(startTime).Milliseconds()
details["response_size"] = c.Writer.Size()
if len(requestBody) > 0 && len(requestBody) < 1024 { // Only log small request bodies
var jsonBody map[string]interface{}
if err := json.Unmarshal(requestBody, &jsonBody); err == nil {
// Remove sensitive fields
sanitizeJSON(jsonBody)
details["request_body"] = jsonBody
}
}
return details
}
func sanitizeJSON(data map[string]interface{}) {
sensitiveFields := []string{"password", "token", "secret", "key", "authorization"}
for key, value := range data {
keyLower := strings.ToLower(key)
for _, sensitive := range sensitiveFields {
if strings.Contains(keyLower, sensitive) {
data[key] = "[REDACTED]"
break
}
}
// Recursively sanitize nested objects
if nested, ok := value.(map[string]interface{}); ok {
sanitizeJSON(nested)
}
}
}
func getSessionID(c *gin.Context) string {
// Try to get session ID from various sources
if sessionID := c.GetHeader("X-Session-ID"); sessionID != "" {
return sessionID
}
// You could also get from JWT claims or cookie
return ""
}
func getCountryFromIP(ip string) string {
// This is a placeholder - in production, you'd use a GeoIP service
return "unknown"
}
func getDeviceFromUserAgent(userAgent string) string {
if strings.Contains(userAgent, "Mobile") {
return "mobile"
} else if strings.Contains(userAgent, "Tablet") {
return "tablet"
}
return "desktop"
}
func getPlatformFromUserAgent(userAgent string) string {
if strings.Contains(userAgent, "Windows") {
return "windows"
} else if strings.Contains(userAgent, "Mac") {
return "macos"
} else if strings.Contains(userAgent, "Linux") {
return "linux"
} else if strings.Contains(userAgent, "Android") {
return "android"
} else if strings.Contains(userAgent, "iOS") {
return "ios"
}
return "unknown"
}
func getBrowserFromUserAgent(userAgent string) string {
if strings.Contains(userAgent, "Chrome") {
return "chrome"
} else if strings.Contains(userAgent, "Firefox") {
return "firefox"
} else if strings.Contains(userAgent, "Safari") {
return "safari"
} else if strings.Contains(userAgent, "Edge") {
return "edge"
}
return "unknown"
}
func assessRisk(c *gin.Context, startTime time.Time) string {
path := c.Request.URL.Path
method := c.Request.Method
status := c.Writer.Status()
duration := time.Since(startTime)
// High risk indicators
if strings.Contains(path, "/admin") {
return "high"
}
if strings.Contains(path, "/auth") && status >= 400 {
return "medium"
}
if method == "DELETE" {
return "medium"
}
if duration > 5*time.Second {
return "medium"
}
return "low"
}
func assessSecurityRisk(action, failureReason string) string {
if failureReason != "" {
return "high"
}
switch action {
case "login_failed":
return "medium"
case "disable", "delete":
return "high"
default:
return "low"
}
}
func assessActionRisk(action models.AuditAction, resource models.AuditResource) string {
if action == models.AuditActionDelete {
return "medium"
}
if resource == models.AuditResourceSecurity {
return "medium"
}
return "low"
}
func isSuspiciousActivity(action, failureReason string) bool {
// Define suspicious activity patterns
if action == "login_failed" && failureReason == "too_many_attempts" {
return true
}
if strings.Contains(failureReason, "suspicious") {
return true
}
return false
}
func getFailureReason(statusCode int) string {
switch statusCode {
case 400:
return "bad_request"
case 401:
return "unauthorized"
case 403:
return "forbidden"
case 404:
return "not_found"
case 429:
return "rate_limited"
case 500:
return "server_error"
default:
return "unknown"
}
}
func saveAuditLog(auditLog *models.AuditLog) {
// Skip audit logging in demo mode
if os.Getenv("VITE_DEMO_MODE") == "true" {
return
}
db := config.GetDB()
if db != nil {
db.Create(auditLog)
}
}
+156
View File
@@ -0,0 +1,156 @@
package middleware
import (
"crypto/md5"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"golang.org/x/net/context"
)
// CacheConfig holds cache configuration
type CacheConfig struct {
Duration time.Duration
KeyPrefix string
Enabled bool
RedisClient *redis.Client
}
// DefaultCacheConfig returns default cache configuration
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
Duration: 5 * time.Minute,
KeyPrefix: "trackeep:",
Enabled: true,
}
}
// CacheMiddleware creates a cache middleware
func CacheMiddleware(config CacheConfig) gin.HandlerFunc {
if !config.Enabled || config.RedisClient == nil {
return func(c *gin.Context) {
c.Next()
}
}
return func(c *gin.Context) {
// Only cache GET requests
if c.Request.Method != http.MethodGet {
c.Next()
return
}
// Generate cache key
cacheKey := generateCacheKey(c, config.KeyPrefix)
// Try to get from cache
cached, err := config.RedisClient.Get(context.Background(), cacheKey).Result()
if err == nil && cached != "" {
// Cache hit
c.Header("X-Cache", "HIT")
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, cached)
c.Abort()
return
}
// Cache miss, continue with request
c.Header("X-Cache", "MISS")
// Capture response
writer := &cachedResponseWriter{
ResponseWriter: c.Writer,
buffer: make([]byte, 0),
}
c.Writer = writer
c.Next()
// Cache the response if successful
if c.Writer.Status() == http.StatusOK && len(writer.buffer) > 0 {
config.RedisClient.Set(
context.Background(),
cacheKey,
string(writer.buffer),
config.Duration,
)
}
}
}
// generateCacheKey creates a unique cache key for the request
func generateCacheKey(c *gin.Context, prefix string) string {
// Include path, query params, and user ID if available
keyParts := []string{
prefix,
c.Request.URL.Path,
c.Request.URL.RawQuery,
}
// Add user ID for personalized caching
if userID := c.GetString("userID"); userID != "" {
keyParts = append(keyParts, "user:"+userID)
}
// Create hash of the key to avoid long keys
key := strings.Join(keyParts, ":")
hash := md5.Sum([]byte(key))
return fmt.Sprintf("%x", hash)
}
// cachedResponseWriter captures response data for caching
type cachedResponseWriter struct {
gin.ResponseWriter
buffer []byte
}
func (w *cachedResponseWriter) Write(data []byte) (int, error) {
w.buffer = append(w.buffer, data...)
return w.ResponseWriter.Write(data)
}
// InvalidateCache invalidates cache entries matching a pattern
func InvalidateCache(redisClient *redis.Client, pattern string) error {
if redisClient == nil {
return nil
}
keys, err := redisClient.Keys(context.Background(), pattern).Result()
if err != nil {
return err
}
if len(keys) > 0 {
return redisClient.Del(context.Background(), keys...).Err()
}
return nil
}
// CacheInvalidationMiddleware invalidates cache on write operations
func CacheInvalidationMiddleware(redisClient *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// Invalidate cache on successful write operations
if c.Writer.Status() >= 200 && c.Writer.Status() < 300 &&
(c.Request.Method == http.MethodPost ||
c.Request.Method == http.MethodPut ||
c.Request.Method == http.MethodDelete) {
// Invalidate user-specific cache
if userID := c.GetString("userID"); userID != "" {
pattern := fmt.Sprintf("trackeep:*user:%s*", userID)
InvalidateCache(redisClient, pattern)
}
// Invalidate general cache for the affected resource
resourcePattern := fmt.Sprintf("trackeep:*%s*", c.Request.URL.Path)
InvalidateCache(redisClient, resourcePattern)
}
}
}
+57
View File
@@ -0,0 +1,57 @@
package middleware
import (
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/trackeep/backend/models"
)
// DemoModeMiddleware prevents write operations when in demo mode
func DemoModeMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Check if demo mode is enabled
if os.Getenv("VITE_DEMO_MODE") == "true" {
// Allow GET requests (read operations)
if c.Request.Method == "GET" || c.Request.Method == "OPTIONS" {
c.Next()
return
}
// Allow specific write operations in demo mode
path := c.Request.URL.Path
if (strings.Contains(path, "/learning-paths") && c.Request.Method == "POST") ||
(strings.Contains(path, "/bookmarks/content") && c.Request.Method == "POST") ||
(strings.Contains(path, "/bookmarks/metadata") && c.Request.Method == "POST") {
// Set demo user for these operations
c.Set("user", models.User{
ID: 1,
Username: "demo",
Email: "demo@trackeep.com",
})
c.Set("user_id", uint(1))
c.Set("userID", uint(1)) // Add this for compatibility with handlers
c.Next()
return
}
// Block other write operations (POST, PUT, DELETE, PATCH)
c.JSON(http.StatusForbidden, gin.H{
"error": "Write operations are disabled in demo mode",
"message": "This is a demo instance. Database modifications are not allowed.",
})
c.Abort()
return
}
// If not in demo mode, allow all operations
c.Next()
}
}
// IsDemoMode returns true if demo mode is enabled
func IsDemoMode() bool {
return os.Getenv("VITE_DEMO_MODE") == "true"
}
+379
View File
@@ -0,0 +1,379 @@
package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
"unicode"
"github.com/gin-gonic/gin"
)
// InputValidationMiddleware provides comprehensive input validation
func InputValidationMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Validate query parameters
for key, values := range c.Request.URL.Query() {
for i, value := range values {
if containsMaliciousContent(value) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid input detected",
"message": "Query parameter contains potentially malicious content",
"parameter": key,
})
c.Abort()
return
}
// Sanitize query parameters
values[i] = sanitizeInput(value)
}
}
// For POST/PUT requests, we'll validate the body in the handler
// since we need to know the expected structure
c.Next()
}
}
// ValidateRequestBody validates JSON request bodies against common attack patterns
func ValidateRequestBody() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method == "GET" || c.Request.Method == "DELETE" {
c.Next()
return
}
// Read and validate the body
bodyBytes, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
c.Abort()
return
}
bodyString := string(bodyBytes)
// Check for common injection patterns
if containsMaliciousContent(bodyString) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid input detected",
"message": "Request body contains potentially malicious content",
})
c.Abort()
return
}
// Restore the body for subsequent handlers
c.Request.Body = &requestBody{body: bodyBytes}
c.Next()
}
}
// requestBody is a custom type to restore the request body
type requestBody struct {
body []byte
pos int
}
func (rb *requestBody) Read(p []byte) (int, error) {
if rb.pos >= len(rb.body) {
return 0, nil
}
n := copy(p, rb.body[rb.pos:])
rb.pos += n
return n, nil
}
func (rb *requestBody) Close() error {
return nil
}
// containsMaliciousContent checks for common attack patterns
func containsMaliciousContent(input string) bool {
// Convert to lowercase for case-insensitive matching
lowerInput := strings.ToLower(input)
// SQL injection patterns
sqlPatterns := []string{
"union select",
"drop table",
"insert into",
"delete from",
"update set",
"exec(",
"execute(",
"sp_executesql",
"xp_cmdshell",
"'--",
"/*",
"*/",
"char(",
"ascii(",
"concat(",
"substring(",
"waitfor delay",
"benchmark(",
"sleep(",
"pg_sleep(",
}
for _, pattern := range sqlPatterns {
if strings.Contains(lowerInput, pattern) {
return true
}
}
// XSS patterns
xssPatterns := []string{
"<script",
"</script>",
"javascript:",
"vbscript:",
"onload=",
"onerror=",
"onclick=",
"onmouseover=",
"onfocus=",
"onblur=",
"onchange=",
"onsubmit=",
"<iframe",
"<object",
"<embed",
"<form",
"<input",
"<link",
"<meta",
"<style",
"eval(",
"alert(",
"confirm(",
"prompt(",
"document.cookie",
"document.write",
"window.location",
}
for _, pattern := range xssPatterns {
if strings.Contains(lowerInput, pattern) {
return true
}
}
// Command injection patterns
cmdPatterns := []string{
"; rm",
"; cat",
"; ls",
"; ps",
"; kill",
"; chmod",
"; chown",
"; wget",
"; curl",
"; nc",
"; netcat",
"| rm",
"| cat",
"| ls",
"| ps",
"& rm",
"& cat",
"& ls",
"&& rm",
"&& cat",
"&& ls",
"`rm",
"`cat",
"`ls",
"$(rm",
"$(cat",
"$(ls",
}
for _, pattern := range cmdPatterns {
if strings.Contains(lowerInput, pattern) {
return true
}
}
// Path traversal patterns
pathPatterns := []string{
"../",
"..\\",
"%2e%2e%2f",
"%2e%2e\\",
"..%2f",
"..%5c",
"%2e%2e/",
"%2e%2e\\",
"/etc/passwd",
"/etc/shadow",
"/etc/hosts",
"windows/system32",
"\\windows\\system32",
}
for _, pattern := range pathPatterns {
if strings.Contains(lowerInput, pattern) {
return true
}
}
// LDAP injection patterns
ldapPatterns := []string{
"*)(",
"*}",
"*)",
"*(|",
"*(|",
"*)(",
}
for _, pattern := range ldapPatterns {
if strings.Contains(lowerInput, pattern) {
return true
}
}
// NoSQL injection patterns
nosqlPatterns := []string{
"$where",
"$ne",
"$gt",
"$lt",
"$regex",
"$expr",
"$json",
"$or",
"$and",
"$not",
}
for _, pattern := range nosqlPatterns {
if strings.Contains(lowerInput, pattern) {
return true
}
}
return false
}
// sanitizeInput cleans input by removing potentially dangerous characters
func sanitizeInput(input string) string {
// Remove null bytes
input = strings.ReplaceAll(input, "\x00", "")
// Remove control characters except newline, tab, and carriage return
var result []rune
for _, r := range input {
if unicode.IsControl(r) && r != '\n' && r != '\t' && r != '\r' {
continue
}
result = append(result, r)
}
// Trim whitespace
input = strings.TrimSpace(string(result))
return input
}
// ValidateEmail validates email format
func ValidateEmail(email string) bool {
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
return emailRegex.MatchString(email)
}
// ValidatePassword validates password strength
func ValidatePassword(password string) error {
if len(password) < 8 {
return gin.Error{
Err: http.ErrBodyNotAllowed,
Type: gin.ErrorTypeBind,
Meta: "Password must be at least 8 characters long",
}
}
var hasUpper, hasLower, hasDigit, hasSpecial bool
for _, char := range password {
switch {
case unicode.IsUpper(char):
hasUpper = true
case unicode.IsLower(char):
hasLower = true
case unicode.IsDigit(char):
hasDigit = true
case unicode.IsPunct(char) || unicode.IsSymbol(char):
hasSpecial = true
}
}
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
return gin.Error{
Err: http.ErrBodyNotAllowed,
Type: gin.ErrorTypeBind,
Meta: "Password must contain at least one uppercase letter, one lowercase letter, one digit, and one special character",
}
}
return nil
}
// ValidateUsername validates username format
func ValidateUsername(username string) error {
if len(username) < 3 || len(username) > 30 {
return gin.Error{
Err: http.ErrBodyNotAllowed,
Type: gin.ErrorTypeBind,
Meta: "Username must be between 3 and 30 characters long",
}
}
usernameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
if !usernameRegex.MatchString(username) {
return gin.Error{
Err: http.ErrBodyNotAllowed,
Type: gin.ErrorTypeBind,
Meta: "Username can only contain letters, numbers, underscores, and hyphens",
}
}
return nil
}
// ValidateID validates that ID is a positive integer
func ValidateID(id string) error {
idRegex := regexp.MustCompile(`^[1-9]\d*$`)
if !idRegex.MatchString(id) {
return gin.Error{
Err: http.ErrBodyNotAllowed,
Type: gin.ErrorTypeBind,
Meta: "Invalid ID format",
}
}
return nil
}
// ValidatePagination validates pagination parameters
func ValidatePagination(page, limit string) (int, int, error) {
pageInt := 1
limitInt := 20
if page != "" {
if p, err := strconv.Atoi(page); err == nil && p > 0 {
pageInt = p
}
}
if limit != "" {
if l, err := strconv.Atoi(limit); err == nil && l > 0 && l <= 100 {
limitInt = l
}
}
return pageInt, limitInt, nil
}
-6
View File
@@ -205,12 +205,6 @@ func logJSON(data map[string]interface{}) {
log.Println(string(jsonData))
}
// generateRequestID generates a unique request ID
func generateRequestID() string {
return time.Now().Format("20060102150405") + "-" +
string(rune(time.Now().UnixNano()%1000))
}
// SecurityLogger logs security-related events
func SecurityLogger() gin.HandlerFunc {
return func(c *gin.Context) {
+223
View File
@@ -0,0 +1,223 @@
package middleware
import (
"crypto/md5"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// MemoryCacheItem represents an item in memory cache
type MemoryCacheItem struct {
Data []byte
ExpiresAt time.Time
}
// MemoryCache represents an in-memory cache
type MemoryCache struct {
items map[string]MemoryCacheItem
mutex sync.RWMutex
}
// NewMemoryCache creates a new memory cache
func NewMemoryCache() *MemoryCache {
cache := &MemoryCache{
items: make(map[string]MemoryCacheItem),
}
// Start cleanup goroutine
go cache.cleanup()
return cache
}
// Get retrieves an item from cache
func (c *MemoryCache) Get(key string) ([]byte, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, exists := c.items[key]
if !exists || time.Now().After(item.ExpiresAt) {
return nil, false
}
return item.Data, true
}
// Set stores an item in cache
func (c *MemoryCache) Set(key string, data []byte, duration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.items[key] = MemoryCacheItem{
Data: data,
ExpiresAt: time.Now().Add(duration),
}
}
// Delete removes an item from cache
func (c *MemoryCache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// DeletePattern removes items matching a pattern
func (c *MemoryCache) DeletePattern(pattern string) {
c.mutex.Lock()
defer c.mutex.Unlock()
for key := range c.items {
if strings.Contains(key, pattern) {
delete(c.items, key)
}
}
}
// cleanup removes expired items
func (c *MemoryCache) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mutex.Lock()
now := time.Now()
for key, item := range c.items {
if now.After(item.ExpiresAt) {
delete(c.items, key)
}
}
c.mutex.Unlock()
}
}
// Global memory cache instance
var globalMemoryCache = NewMemoryCache()
// MemoryCacheConfig holds memory cache configuration
type MemoryCacheConfig struct {
Duration time.Duration
KeyPrefix string
Enabled bool
}
// DefaultMemoryCacheConfig returns default memory cache configuration
func DefaultMemoryCacheConfig() MemoryCacheConfig {
return MemoryCacheConfig{
Duration: 5 * time.Minute,
KeyPrefix: "trackeep:",
Enabled: true,
}
}
// MemoryCacheMiddleware creates a memory cache middleware
func MemoryCacheMiddleware(config MemoryCacheConfig) gin.HandlerFunc {
if !config.Enabled {
return func(c *gin.Context) {
c.Next()
}
}
return func(c *gin.Context) {
// Only cache GET requests
if c.Request.Method != http.MethodGet {
c.Next()
return
}
// Generate cache key
cacheKey := generateMemoryCacheKey(c, config.KeyPrefix)
// Try to get from cache
if data, hit := globalMemoryCache.Get(cacheKey); hit {
// Cache hit
c.Header("X-Cache", "HIT")
c.Header("Content-Type", "application/json")
c.Data(http.StatusOK, "application/json", data)
c.Abort()
return
}
// Cache miss, continue with request
c.Header("X-Cache", "MISS")
// Capture response
writer := &memoryCachedResponseWriter{
ResponseWriter: c.Writer,
buffer: make([]byte, 0),
}
c.Writer = writer
c.Next()
// Cache the response if successful
if c.Writer.Status() == http.StatusOK && len(writer.buffer) > 0 {
globalMemoryCache.Set(cacheKey, writer.buffer, config.Duration)
}
}
}
// generateMemoryCacheKey creates a unique cache key for the request
func generateMemoryCacheKey(c *gin.Context, prefix string) string {
// Include path, query params, and user ID if available
keyParts := []string{
prefix,
c.Request.URL.Path,
c.Request.URL.RawQuery,
}
// Add user ID for personalized caching
if userID := c.GetString("userID"); userID != "" {
keyParts = append(keyParts, "user:"+userID)
}
// Create hash of the key to avoid long keys
key := strings.Join(keyParts, ":")
hash := md5.Sum([]byte(key))
return fmt.Sprintf("%x", hash)
}
// memoryCachedResponseWriter captures response data for caching
type memoryCachedResponseWriter struct {
gin.ResponseWriter
buffer []byte
}
func (w *memoryCachedResponseWriter) Write(data []byte) (int, error) {
w.buffer = append(w.buffer, data...)
return w.ResponseWriter.Write(data)
}
// InvalidateMemoryCache invalidates cache entries matching a pattern
func InvalidateMemoryCache(pattern string) {
globalMemoryCache.DeletePattern(pattern)
}
// MemoryCacheInvalidationMiddleware invalidates cache on write operations
func MemoryCacheInvalidationMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// Invalidate cache on successful write operations
if c.Writer.Status() >= 200 && c.Writer.Status() < 300 &&
(c.Request.Method == http.MethodPost ||
c.Request.Method == http.MethodPut ||
c.Request.Method == http.MethodDelete) {
// Invalidate user-specific cache
if userID := c.GetString("userID"); userID != "" {
pattern := fmt.Sprintf("user:%s", userID)
InvalidateMemoryCache(pattern)
}
// Invalidate general cache for the affected resource
resourcePattern := fmt.Sprintf("%s", c.Request.URL.Path)
InvalidateMemoryCache(resourcePattern)
}
}
}
+58
View File
@@ -0,0 +1,58 @@
package middleware
import (
"fmt"
"log"
"time"
"github.com/gin-gonic/gin"
)
// PerformanceMiddleware adds performance monitoring to requests
func PerformanceMiddleware() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// Custom log format with performance metrics
return fmt.Sprintf(
"[%s] %s %s %d %s %s \"%s\" %s\n",
param.TimeStamp.Format("02/Jan/2006:15:04:05 -0700"),
param.Method,
param.Path,
param.StatusCode,
param.Latency,
param.Request.UserAgent(),
param.ErrorMessage,
param.Request.Header.Get("X-Request-ID"),
)
})
}
// RequestIDMiddleware adds a unique request ID to each request
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
c.Set("RequestID", requestID)
c.Header("X-Request-ID", requestID)
c.Next()
}
}
// generateRequestID generates a unique request ID
func generateRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
// SlowQueryMiddleware logs slow database queries
func SlowQueryMiddleware(threshold time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
duration := time.Since(start)
if duration > threshold {
log.Printf("SLOW REQUEST: %s %s took %v", c.Request.Method, c.Request.URL.Path, duration)
}
}
}
+139
View File
@@ -0,0 +1,139 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter implements a simple in-memory rate limiter
type RateLimiter struct {
clients map[string]*ClientInfo
mutex sync.RWMutex
limit int
window time.Duration
}
type ClientInfo struct {
requests int
resetTime time.Time
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
limiter := &RateLimiter{
clients: make(map[string]*ClientInfo),
limit: limit,
window: window,
}
// Start cleanup goroutine
go limiter.cleanup()
return limiter
}
// Middleware returns the Gin middleware function
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
clientIP := c.ClientIP()
rl.mutex.Lock()
client, exists := rl.clients[clientIP]
if !exists {
client = &ClientInfo{
requests: 0,
resetTime: time.Now().Add(rl.window),
}
rl.clients[clientIP] = client
}
// Reset window if expired
if time.Now().After(client.resetTime) {
client.requests = 0
client.resetTime = time.Now().Add(rl.window)
}
client.requests++
// Check if limit exceeded
if client.requests > rl.limit {
rl.mutex.Unlock()
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", rl.limit))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", client.resetTime.Unix()))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"message": fmt.Sprintf("Too many requests. Limit is %d per %v", rl.limit, rl.window),
"retry_after": time.Until(client.resetTime).Seconds(),
})
c.Abort()
return
}
remaining := rl.limit - client.requests
rl.mutex.Unlock()
// Set rate limit headers
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", rl.limit))
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", client.resetTime.Unix()))
c.Next()
}
}
// cleanup removes expired client entries
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mutex.Lock()
now := time.Now()
for ip, client := range rl.clients {
if now.After(client.resetTime.Add(rl.window)) {
delete(rl.clients, ip)
}
}
rl.mutex.Unlock()
}
}
// RateLimitConfig holds configuration for different endpoint types
type RateLimitConfig struct {
AuthRequests int // requests per window for auth endpoints
GeneralRequests int // requests per window for general endpoints
Window time.Duration // time window for rate limiting
}
// DefaultRateLimitConfig returns sensible defaults
func DefaultRateLimitConfig() RateLimitConfig {
return RateLimitConfig{
AuthRequests: 5, // 5 login attempts per minute
GeneralRequests: 100, // 100 requests per minute
Window: time.Minute,
}
}
// RateLimit creates rate limiters for different endpoint types
func RateLimit(config RateLimitConfig) map[string]*RateLimiter {
return map[string]*RateLimiter{
"auth": NewRateLimiter(config.AuthRequests, config.Window),
"general": NewRateLimiter(config.GeneralRequests, config.Window),
}
}
// AuthRateLimit applies stricter rate limiting to authentication endpoints
func AuthRateLimit(limiter *RateLimiter) gin.HandlerFunc {
return limiter.Middleware()
}
// GeneralRateLimit applies standard rate limiting to general endpoints
func GeneralRateLimit(limiter *RateLimiter) gin.HandlerFunc {
return limiter.Middleware()
}
+235
View File
@@ -0,0 +1,235 @@
package middleware
import (
"fmt"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/trackeep/backend/models"
)
// SessionData represents the structure of session data stored in Redis
type SessionData struct {
UserID uint `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Role string `json:"role"`
SessionID string `json:"session_id"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
CreatedAt time.Time `json:"created_at"`
LastActive time.Time `json:"last_active"`
}
// SessionStore interface for session storage
type SessionStore interface {
CreateSession(sessionData *SessionData) error
GetSession(sessionID string) (*SessionData, error)
UpdateSession(sessionID string, sessionData *SessionData) error
DeleteSession(sessionID string) error
CleanupExpiredSessions() error
}
// RedisSessionStore implements SessionStore using Redis (or fallback to memory)
type RedisSessionStore struct {
sessions map[string]*SessionData // Fallback in-memory store
}
// NewSessionStore creates a new session store
func NewSessionStore() SessionStore {
return &RedisSessionStore{
sessions: make(map[string]*SessionData),
}
}
// CreateSession creates a new session
func (r *RedisSessionStore) CreateSession(sessionData *SessionData) error {
sessionData.CreatedAt = time.Now()
sessionData.LastActive = time.Now()
r.sessions[sessionData.SessionID] = sessionData
return nil
}
// GetSession retrieves a session by ID
func (r *RedisSessionStore) GetSession(sessionID string) (*SessionData, error) {
if session, exists := r.sessions[sessionID]; exists {
// Update last active time
session.LastActive = time.Now()
return session, nil
}
return nil, fmt.Errorf("session not found")
}
// UpdateSession updates an existing session
func (r *RedisSessionStore) UpdateSession(sessionID string, sessionData *SessionData) error {
if _, exists := r.sessions[sessionID]; exists {
sessionData.LastActive = time.Now()
r.sessions[sessionID] = sessionData
return nil
}
return fmt.Errorf("session not found")
}
// DeleteSession removes a session
func (r *RedisSessionStore) DeleteSession(sessionID string) error {
delete(r.sessions, sessionID)
return nil
}
// CleanupExpiredSessions removes sessions older than 24 hours
func (r *RedisSessionStore) CleanupExpiredSessions() error {
now := time.Now()
for sessionID, session := range r.sessions {
if now.Sub(session.LastActive) > 24*time.Hour {
delete(r.sessions, sessionID)
}
}
return nil
}
// Global session store instance
var sessionStore SessionStore
// InitSessionStore initializes the session store
func InitSessionStore() {
sessionStore = NewSessionStore()
// Start cleanup goroutine
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
if sessionStore != nil {
sessionStore.CleanupExpiredSessions()
}
}
}()
}
// SessionMiddleware creates and manages user sessions
func SessionMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Skip session management for health checks and static assets
path := c.Request.URL.Path
if path == "/health" || path == "/metrics" || strings.HasPrefix(path, "/static") {
c.Next()
return
}
// Get session ID from header or create new one
sessionID := c.GetHeader("X-Session-ID")
if sessionID == "" {
sessionID = generateSessionID()
c.Header("X-Session-ID", sessionID)
}
// Try to get existing session
session, err := sessionStore.GetSession(sessionID)
if err != nil {
// No existing session, check if user is authenticated via JWT
if user, exists := c.Get("user"); exists {
// Create session from authenticated user
if userModel, ok := user.(models.User); ok {
session = &SessionData{
SessionID: sessionID,
UserID: userModel.ID,
Email: userModel.Email,
Username: userModel.Username,
Role: userModel.Role,
IPAddress: c.ClientIP(),
UserAgent: c.GetHeader("User-Agent"),
}
sessionStore.CreateSession(session)
}
}
}
// Set session data in context
if session != nil {
c.Set("session_id", session.SessionID)
c.Set("session_user_id", session.UserID)
c.Set("session_email", session.Email)
c.Set("session_username", session.Username)
c.Set("session_role", session.Role)
}
c.Next()
}
}
// GetSessionFromContext retrieves session data from Gin context
func GetSessionFromContext(c *gin.Context) (*SessionData, error) {
if sessionID, exists := c.Get("session_id"); exists {
return sessionStore.GetSession(sessionID.(string))
}
return nil, fmt.Errorf("no session in context")
}
// GetUserIDFromSession safely gets user ID from session or context
func GetUserIDFromSession(c *gin.Context) uint {
// First try session
if session, err := GetSessionFromContext(c); err == nil {
return session.UserID
}
// Fallback to context (for demo mode or JWT)
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(uint); ok {
return id
}
}
if userID, exists := c.Get("userID"); exists {
if id, ok := userID.(uint); ok {
return id
}
}
// Final fallback for demo mode
if os.Getenv("VITE_DEMO_MODE") == "true" {
return 1
}
return 0
}
// GetUserEmailFromSession safely gets user email from session or context
func GetUserEmailFromSession(c *gin.Context) string {
// First try session
if session, err := GetSessionFromContext(c); err == nil {
return session.Email
}
// Fallback to context
if email, exists := c.Get("user_email"); exists {
if e, ok := email.(string); ok {
return e
}
}
// Fallback for demo mode
if os.Getenv("VITE_DEMO_MODE") == "true" {
return "demo@trackeep.com"
}
return "unknown"
}
// generateSessionID generates a unique session ID
func generateSessionID() string {
return fmt.Sprintf("sess_%d_%s", time.Now().UnixNano(), "trackeep")
}
// GetSessionStore returns the global session store instance
func GetSessionStore() SessionStore {
return sessionStore
}
// CleanupSessionsOnShutdown gracefully cleans up sessions
func CleanupSessionsOnShutdown() {
if sessionStore != nil {
sessionStore.CleanupExpiredSessions()
}
}