Files
Trackeep/backend/middleware/audit.go
T
Tomas Dvorak d27cf14110 first test
2026-02-08 14:14:55 +01:00

415 lines
11 KiB
Go

package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/trackeep/backend/config"
"github.com/trackeep/backend/models"
)
// AuditMiddleware creates audit logs for HTTP requests
func AuditMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
startTime := time.Now()
// Read request body for logging (only for POST/PUT/PATCH)
var requestBody []byte
if c.Request.Method != "GET" && c.Request.Body != nil {
requestBody, _ = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
// Process the request
c.Next()
// Skip audit logging for certain endpoints
if shouldSkipAudit(c.Request.URL.Path) {
return
}
// Create audit log entry with proper user data from session
userIDValue := GetUserIDFromSession(c)
userEmail := GetUserEmailFromSession(c)
// Ensure we have valid user data before creating audit log
if userIDValue == 0 && userEmail == "unknown" {
// Skip audit logging for unauthenticated requests
return
}
auditLog := &models.AuditLog{
UserID: userIDValue,
UserEmail: userEmail,
UserIP: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
Action: getActionFromMethodAndPath(c.Request.Method, c.Request.URL.Path),
Resource: getResourceFromPath(c.Request.URL.Path),
ResourceID: getResourceIDFromPath(c.Request.URL.Path),
Description: generateDescription(c, startTime),
Details: generateDetails(c, requestBody, startTime),
Success: c.Writer.Status() < 400,
SessionID: getSessionID(c),
Country: getCountryFromIP(c.ClientIP()),
Device: getDeviceFromUserAgent(c.Request.UserAgent()),
Platform: getPlatformFromUserAgent(c.Request.UserAgent()),
Browser: getBrowserFromUserAgent(c.Request.UserAgent()),
RiskLevel: assessRisk(c, startTime),
}
// Set failure reason if request failed
if !auditLog.Success {
auditLog.FailureReason = getFailureReason(c.Writer.Status())
}
// Save audit log asynchronously
go saveAuditLog(auditLog)
}
}
// LogSecurityEvent logs security-related events
func LogSecurityEvent(userID uint, userEmail, action, description, failureReason string, details map[string]interface{}) {
auditLog := &models.AuditLog{
UserID: userID,
UserEmail: userEmail,
Action: models.AuditAction(action),
Resource: models.AuditResourceSecurity,
Description: description,
Details: details,
Success: failureReason == "",
FailureReason: failureReason,
RiskLevel: assessSecurityRisk(action, failureReason),
Suspicious: isSuspiciousActivity(action, failureReason),
CreatedAt: time.Now(),
}
go saveAuditLog(auditLog)
}
// LogUserAction logs user-specific actions
func LogUserAction(user models.User, action models.AuditAction, resource models.AuditResource, resourceID *uint, description string, oldValues, newValues map[string]interface{}) {
auditLog := &models.AuditLog{
UserID: user.ID,
UserEmail: user.Email,
Action: action,
Resource: resource,
ResourceID: resourceID,
Description: description,
OldValues: oldValues,
NewValues: newValues,
Success: true,
RiskLevel: assessActionRisk(action, resource),
CreatedAt: time.Now(),
}
go saveAuditLog(auditLog)
}
// Helper functions
func shouldSkipAudit(path string) bool {
skipPaths := []string{
"/health",
"/metrics",
"/api/demo/status",
"/favicon.ico",
"/assets/",
}
for _, skipPath := range skipPaths {
if strings.HasPrefix(path, skipPath) {
return true
}
}
return false
}
func getUintFromInterface(value interface{}) uint {
if v, ok := value.(uint); ok {
return v
}
return 0
}
func getUserEmail(user interface{}) string {
if u, ok := user.(models.User); ok {
return u.Email
}
return "unknown"
}
func getActionFromMethodAndPath(method, path string) models.AuditAction {
switch method {
case "GET":
return models.AuditActionRead
case "POST":
if strings.Contains(path, "/login") {
return models.AuditActionLogin
} else if strings.Contains(path, "/logout") {
return models.AuditActionLogout
} else if strings.Contains(path, "/upload") {
return models.AuditActionUpload
} else if strings.Contains(path, "/export") {
return models.AuditActionExport
} else if strings.Contains(path, "/import") {
return models.AuditActionImport
}
return models.AuditActionCreate
case "PUT", "PATCH":
return models.AuditActionUpdate
case "DELETE":
return models.AuditActionDelete
default:
return models.AuditActionAccess
}
}
func getResourceFromPath(path string) models.AuditResource {
if strings.Contains(path, "/users") {
return models.AuditResourceUser
} else if strings.Contains(path, "/notes") {
return models.AuditResourceNote
} else if strings.Contains(path, "/files") {
return models.AuditResourceFile
} else if strings.Contains(path, "/bookmarks") {
return models.AuditResourceBookmark
} else if strings.Contains(path, "/tasks") {
return models.AuditResourceTask
} else if strings.Contains(path, "/time-entries") {
return models.AuditResourceTimeEntry
} else if strings.Contains(path, "/integrations") {
return models.AuditResourceIntegration
} else if strings.Contains(path, "/teams") {
return models.AuditResourceTeam
} else if strings.Contains(path, "/goals") || strings.Contains(path, "/habits") {
return models.AuditResourceGoal
} else if strings.Contains(path, "/calendar") {
return models.AuditResourceCalendar
} else if strings.Contains(path, "/search") {
return models.AuditResourceSearch
} else if strings.Contains(path, "/ai") {
return models.AuditResourceAI
} else if strings.Contains(path, "/analytics") {
return models.AuditResourceAnalytics
} else if strings.Contains(path, "/auth") {
return models.AuditResourceSecurity
}
return models.AuditResourceSystem
}
func getResourceIDFromPath(path string) *uint {
parts := strings.Split(path, "/")
for i, part := range parts {
if part == "" {
continue
}
// Check if this part looks like a numeric ID
if i > 0 && len(part) >= 1 && part[0] >= '0' && part[0] <= '9' {
var id uint
if _, err := fmt.Sscanf(part, "%d", &id); err == nil {
return &id
}
}
}
return nil
}
func generateDescription(c *gin.Context, startTime time.Time) string {
duration := time.Since(startTime)
method := c.Request.Method
path := c.Request.URL.Path
status := c.Writer.Status()
return fmt.Sprintf("%s %s - %d (%v)", method, path, status, duration.Round(time.Millisecond))
}
func generateDetails(c *gin.Context, requestBody []byte, startTime time.Time) map[string]interface{} {
details := make(map[string]interface{})
details["method"] = c.Request.Method
details["path"] = c.Request.URL.Path
details["query"] = c.Request.URL.RawQuery
details["status_code"] = c.Writer.Status()
details["duration_ms"] = time.Since(startTime).Milliseconds()
details["response_size"] = c.Writer.Size()
if len(requestBody) > 0 && len(requestBody) < 1024 { // Only log small request bodies
var jsonBody map[string]interface{}
if err := json.Unmarshal(requestBody, &jsonBody); err == nil {
// Remove sensitive fields
sanitizeJSON(jsonBody)
details["request_body"] = jsonBody
}
}
return details
}
func sanitizeJSON(data map[string]interface{}) {
sensitiveFields := []string{"password", "token", "secret", "key", "authorization"}
for key, value := range data {
keyLower := strings.ToLower(key)
for _, sensitive := range sensitiveFields {
if strings.Contains(keyLower, sensitive) {
data[key] = "[REDACTED]"
break
}
}
// Recursively sanitize nested objects
if nested, ok := value.(map[string]interface{}); ok {
sanitizeJSON(nested)
}
}
}
func getSessionID(c *gin.Context) string {
// Try to get session ID from various sources
if sessionID := c.GetHeader("X-Session-ID"); sessionID != "" {
return sessionID
}
// You could also get from JWT claims or cookie
return ""
}
func getCountryFromIP(ip string) string {
// This is a placeholder - in production, you'd use a GeoIP service
return "unknown"
}
func getDeviceFromUserAgent(userAgent string) string {
if strings.Contains(userAgent, "Mobile") {
return "mobile"
} else if strings.Contains(userAgent, "Tablet") {
return "tablet"
}
return "desktop"
}
func getPlatformFromUserAgent(userAgent string) string {
if strings.Contains(userAgent, "Windows") {
return "windows"
} else if strings.Contains(userAgent, "Mac") {
return "macos"
} else if strings.Contains(userAgent, "Linux") {
return "linux"
} else if strings.Contains(userAgent, "Android") {
return "android"
} else if strings.Contains(userAgent, "iOS") {
return "ios"
}
return "unknown"
}
func getBrowserFromUserAgent(userAgent string) string {
if strings.Contains(userAgent, "Chrome") {
return "chrome"
} else if strings.Contains(userAgent, "Firefox") {
return "firefox"
} else if strings.Contains(userAgent, "Safari") {
return "safari"
} else if strings.Contains(userAgent, "Edge") {
return "edge"
}
return "unknown"
}
func assessRisk(c *gin.Context, startTime time.Time) string {
path := c.Request.URL.Path
method := c.Request.Method
status := c.Writer.Status()
duration := time.Since(startTime)
// High risk indicators
if strings.Contains(path, "/admin") {
return "high"
}
if strings.Contains(path, "/auth") && status >= 400 {
return "medium"
}
if method == "DELETE" {
return "medium"
}
if duration > 5*time.Second {
return "medium"
}
return "low"
}
func assessSecurityRisk(action, failureReason string) string {
if failureReason != "" {
return "high"
}
switch action {
case "login_failed":
return "medium"
case "disable", "delete":
return "high"
default:
return "low"
}
}
func assessActionRisk(action models.AuditAction, resource models.AuditResource) string {
if action == models.AuditActionDelete {
return "medium"
}
if resource == models.AuditResourceSecurity {
return "medium"
}
return "low"
}
func isSuspiciousActivity(action, failureReason string) bool {
// Define suspicious activity patterns
if action == "login_failed" && failureReason == "too_many_attempts" {
return true
}
if strings.Contains(failureReason, "suspicious") {
return true
}
return false
}
func getFailureReason(statusCode int) string {
switch statusCode {
case 400:
return "bad_request"
case 401:
return "unauthorized"
case 403:
return "forbidden"
case 404:
return "not_found"
case 429:
return "rate_limited"
case 500:
return "server_error"
default:
return "unknown"
}
}
func saveAuditLog(auditLog *models.AuditLog) {
// Skip audit logging in demo mode
if os.Getenv("VITE_DEMO_MODE") == "true" {
return
}
db := config.GetDB()
if db != nil {
db.Create(auditLog)
}
}