mirror of
https://github.com/Dvorinka/Trackeep.git
synced 2026-06-03 20:12:58 +00:00
first test
This commit is contained in:
@@ -0,0 +1,414 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/trackeep/backend/config"
|
||||
"github.com/trackeep/backend/models"
|
||||
)
|
||||
|
||||
// AuditMiddleware creates audit logs for HTTP requests
|
||||
func AuditMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Read request body for logging (only for POST/PUT/PATCH)
|
||||
var requestBody []byte
|
||||
if c.Request.Method != "GET" && c.Request.Body != nil {
|
||||
requestBody, _ = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
}
|
||||
|
||||
// Process the request
|
||||
c.Next()
|
||||
|
||||
// Skip audit logging for certain endpoints
|
||||
if shouldSkipAudit(c.Request.URL.Path) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create audit log entry with proper user data from session
|
||||
userIDValue := GetUserIDFromSession(c)
|
||||
userEmail := GetUserEmailFromSession(c)
|
||||
|
||||
// Ensure we have valid user data before creating audit log
|
||||
if userIDValue == 0 && userEmail == "unknown" {
|
||||
// Skip audit logging for unauthenticated requests
|
||||
return
|
||||
}
|
||||
|
||||
auditLog := &models.AuditLog{
|
||||
UserID: userIDValue,
|
||||
UserEmail: userEmail,
|
||||
UserIP: c.ClientIP(),
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
Action: getActionFromMethodAndPath(c.Request.Method, c.Request.URL.Path),
|
||||
Resource: getResourceFromPath(c.Request.URL.Path),
|
||||
ResourceID: getResourceIDFromPath(c.Request.URL.Path),
|
||||
Description: generateDescription(c, startTime),
|
||||
Details: generateDetails(c, requestBody, startTime),
|
||||
Success: c.Writer.Status() < 400,
|
||||
SessionID: getSessionID(c),
|
||||
Country: getCountryFromIP(c.ClientIP()),
|
||||
Device: getDeviceFromUserAgent(c.Request.UserAgent()),
|
||||
Platform: getPlatformFromUserAgent(c.Request.UserAgent()),
|
||||
Browser: getBrowserFromUserAgent(c.Request.UserAgent()),
|
||||
RiskLevel: assessRisk(c, startTime),
|
||||
}
|
||||
|
||||
// Set failure reason if request failed
|
||||
if !auditLog.Success {
|
||||
auditLog.FailureReason = getFailureReason(c.Writer.Status())
|
||||
}
|
||||
|
||||
// Save audit log asynchronously
|
||||
go saveAuditLog(auditLog)
|
||||
}
|
||||
}
|
||||
|
||||
// LogSecurityEvent logs security-related events
|
||||
func LogSecurityEvent(userID uint, userEmail, action, description, failureReason string, details map[string]interface{}) {
|
||||
auditLog := &models.AuditLog{
|
||||
UserID: userID,
|
||||
UserEmail: userEmail,
|
||||
Action: models.AuditAction(action),
|
||||
Resource: models.AuditResourceSecurity,
|
||||
Description: description,
|
||||
Details: details,
|
||||
Success: failureReason == "",
|
||||
FailureReason: failureReason,
|
||||
RiskLevel: assessSecurityRisk(action, failureReason),
|
||||
Suspicious: isSuspiciousActivity(action, failureReason),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
go saveAuditLog(auditLog)
|
||||
}
|
||||
|
||||
// LogUserAction logs user-specific actions
|
||||
func LogUserAction(user models.User, action models.AuditAction, resource models.AuditResource, resourceID *uint, description string, oldValues, newValues map[string]interface{}) {
|
||||
auditLog := &models.AuditLog{
|
||||
UserID: user.ID,
|
||||
UserEmail: user.Email,
|
||||
Action: action,
|
||||
Resource: resource,
|
||||
ResourceID: resourceID,
|
||||
Description: description,
|
||||
OldValues: oldValues,
|
||||
NewValues: newValues,
|
||||
Success: true,
|
||||
RiskLevel: assessActionRisk(action, resource),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
go saveAuditLog(auditLog)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func shouldSkipAudit(path string) bool {
|
||||
skipPaths := []string{
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/demo/status",
|
||||
"/favicon.ico",
|
||||
"/assets/",
|
||||
}
|
||||
|
||||
for _, skipPath := range skipPaths {
|
||||
if strings.HasPrefix(path, skipPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getUintFromInterface(value interface{}) uint {
|
||||
if v, ok := value.(uint); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getUserEmail(user interface{}) string {
|
||||
if u, ok := user.(models.User); ok {
|
||||
return u.Email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func getActionFromMethodAndPath(method, path string) models.AuditAction {
|
||||
switch method {
|
||||
case "GET":
|
||||
return models.AuditActionRead
|
||||
case "POST":
|
||||
if strings.Contains(path, "/login") {
|
||||
return models.AuditActionLogin
|
||||
} else if strings.Contains(path, "/logout") {
|
||||
return models.AuditActionLogout
|
||||
} else if strings.Contains(path, "/upload") {
|
||||
return models.AuditActionUpload
|
||||
} else if strings.Contains(path, "/export") {
|
||||
return models.AuditActionExport
|
||||
} else if strings.Contains(path, "/import") {
|
||||
return models.AuditActionImport
|
||||
}
|
||||
return models.AuditActionCreate
|
||||
case "PUT", "PATCH":
|
||||
return models.AuditActionUpdate
|
||||
case "DELETE":
|
||||
return models.AuditActionDelete
|
||||
default:
|
||||
return models.AuditActionAccess
|
||||
}
|
||||
}
|
||||
|
||||
func getResourceFromPath(path string) models.AuditResource {
|
||||
if strings.Contains(path, "/users") {
|
||||
return models.AuditResourceUser
|
||||
} else if strings.Contains(path, "/notes") {
|
||||
return models.AuditResourceNote
|
||||
} else if strings.Contains(path, "/files") {
|
||||
return models.AuditResourceFile
|
||||
} else if strings.Contains(path, "/bookmarks") {
|
||||
return models.AuditResourceBookmark
|
||||
} else if strings.Contains(path, "/tasks") {
|
||||
return models.AuditResourceTask
|
||||
} else if strings.Contains(path, "/time-entries") {
|
||||
return models.AuditResourceTimeEntry
|
||||
} else if strings.Contains(path, "/integrations") {
|
||||
return models.AuditResourceIntegration
|
||||
} else if strings.Contains(path, "/teams") {
|
||||
return models.AuditResourceTeam
|
||||
} else if strings.Contains(path, "/goals") || strings.Contains(path, "/habits") {
|
||||
return models.AuditResourceGoal
|
||||
} else if strings.Contains(path, "/calendar") {
|
||||
return models.AuditResourceCalendar
|
||||
} else if strings.Contains(path, "/search") {
|
||||
return models.AuditResourceSearch
|
||||
} else if strings.Contains(path, "/ai") {
|
||||
return models.AuditResourceAI
|
||||
} else if strings.Contains(path, "/analytics") {
|
||||
return models.AuditResourceAnalytics
|
||||
} else if strings.Contains(path, "/auth") {
|
||||
return models.AuditResourceSecurity
|
||||
}
|
||||
return models.AuditResourceSystem
|
||||
}
|
||||
|
||||
func getResourceIDFromPath(path string) *uint {
|
||||
parts := strings.Split(path, "/")
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
// Check if this part looks like a numeric ID
|
||||
if i > 0 && len(part) >= 1 && part[0] >= '0' && part[0] <= '9' {
|
||||
var id uint
|
||||
if _, err := fmt.Sscanf(part, "%d", &id); err == nil {
|
||||
return &id
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateDescription(c *gin.Context, startTime time.Time) string {
|
||||
duration := time.Since(startTime)
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
status := c.Writer.Status()
|
||||
|
||||
return fmt.Sprintf("%s %s - %d (%v)", method, path, status, duration.Round(time.Millisecond))
|
||||
}
|
||||
|
||||
func generateDetails(c *gin.Context, requestBody []byte, startTime time.Time) map[string]interface{} {
|
||||
details := make(map[string]interface{})
|
||||
|
||||
details["method"] = c.Request.Method
|
||||
details["path"] = c.Request.URL.Path
|
||||
details["query"] = c.Request.URL.RawQuery
|
||||
details["status_code"] = c.Writer.Status()
|
||||
details["duration_ms"] = time.Since(startTime).Milliseconds()
|
||||
details["response_size"] = c.Writer.Size()
|
||||
|
||||
if len(requestBody) > 0 && len(requestBody) < 1024 { // Only log small request bodies
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(requestBody, &jsonBody); err == nil {
|
||||
// Remove sensitive fields
|
||||
sanitizeJSON(jsonBody)
|
||||
details["request_body"] = jsonBody
|
||||
}
|
||||
}
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
func sanitizeJSON(data map[string]interface{}) {
|
||||
sensitiveFields := []string{"password", "token", "secret", "key", "authorization"}
|
||||
|
||||
for key, value := range data {
|
||||
keyLower := strings.ToLower(key)
|
||||
for _, sensitive := range sensitiveFields {
|
||||
if strings.Contains(keyLower, sensitive) {
|
||||
data[key] = "[REDACTED]"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively sanitize nested objects
|
||||
if nested, ok := value.(map[string]interface{}); ok {
|
||||
sanitizeJSON(nested)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getSessionID(c *gin.Context) string {
|
||||
// Try to get session ID from various sources
|
||||
if sessionID := c.GetHeader("X-Session-ID"); sessionID != "" {
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// You could also get from JWT claims or cookie
|
||||
return ""
|
||||
}
|
||||
|
||||
func getCountryFromIP(ip string) string {
|
||||
// This is a placeholder - in production, you'd use a GeoIP service
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func getDeviceFromUserAgent(userAgent string) string {
|
||||
if strings.Contains(userAgent, "Mobile") {
|
||||
return "mobile"
|
||||
} else if strings.Contains(userAgent, "Tablet") {
|
||||
return "tablet"
|
||||
}
|
||||
return "desktop"
|
||||
}
|
||||
|
||||
func getPlatformFromUserAgent(userAgent string) string {
|
||||
if strings.Contains(userAgent, "Windows") {
|
||||
return "windows"
|
||||
} else if strings.Contains(userAgent, "Mac") {
|
||||
return "macos"
|
||||
} else if strings.Contains(userAgent, "Linux") {
|
||||
return "linux"
|
||||
} else if strings.Contains(userAgent, "Android") {
|
||||
return "android"
|
||||
} else if strings.Contains(userAgent, "iOS") {
|
||||
return "ios"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func getBrowserFromUserAgent(userAgent string) string {
|
||||
if strings.Contains(userAgent, "Chrome") {
|
||||
return "chrome"
|
||||
} else if strings.Contains(userAgent, "Firefox") {
|
||||
return "firefox"
|
||||
} else if strings.Contains(userAgent, "Safari") {
|
||||
return "safari"
|
||||
} else if strings.Contains(userAgent, "Edge") {
|
||||
return "edge"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func assessRisk(c *gin.Context, startTime time.Time) string {
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
status := c.Writer.Status()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// High risk indicators
|
||||
if strings.Contains(path, "/admin") {
|
||||
return "high"
|
||||
}
|
||||
if strings.Contains(path, "/auth") && status >= 400 {
|
||||
return "medium"
|
||||
}
|
||||
if method == "DELETE" {
|
||||
return "medium"
|
||||
}
|
||||
if duration > 5*time.Second {
|
||||
return "medium"
|
||||
}
|
||||
|
||||
return "low"
|
||||
}
|
||||
|
||||
func assessSecurityRisk(action, failureReason string) string {
|
||||
if failureReason != "" {
|
||||
return "high"
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "login_failed":
|
||||
return "medium"
|
||||
case "disable", "delete":
|
||||
return "high"
|
||||
default:
|
||||
return "low"
|
||||
}
|
||||
}
|
||||
|
||||
func assessActionRisk(action models.AuditAction, resource models.AuditResource) string {
|
||||
if action == models.AuditActionDelete {
|
||||
return "medium"
|
||||
}
|
||||
if resource == models.AuditResourceSecurity {
|
||||
return "medium"
|
||||
}
|
||||
return "low"
|
||||
}
|
||||
|
||||
func isSuspiciousActivity(action, failureReason string) bool {
|
||||
// Define suspicious activity patterns
|
||||
if action == "login_failed" && failureReason == "too_many_attempts" {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(failureReason, "suspicious") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getFailureReason(statusCode int) string {
|
||||
switch statusCode {
|
||||
case 400:
|
||||
return "bad_request"
|
||||
case 401:
|
||||
return "unauthorized"
|
||||
case 403:
|
||||
return "forbidden"
|
||||
case 404:
|
||||
return "not_found"
|
||||
case 429:
|
||||
return "rate_limited"
|
||||
case 500:
|
||||
return "server_error"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func saveAuditLog(auditLog *models.AuditLog) {
|
||||
// Skip audit logging in demo mode
|
||||
if os.Getenv("VITE_DEMO_MODE") == "true" {
|
||||
return
|
||||
}
|
||||
|
||||
db := config.GetDB()
|
||||
if db != nil {
|
||||
db.Create(auditLog)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// CacheConfig holds cache configuration
|
||||
type CacheConfig struct {
|
||||
Duration time.Duration
|
||||
KeyPrefix string
|
||||
Enabled bool
|
||||
RedisClient *redis.Client
|
||||
}
|
||||
|
||||
// DefaultCacheConfig returns default cache configuration
|
||||
func DefaultCacheConfig() CacheConfig {
|
||||
return CacheConfig{
|
||||
Duration: 5 * time.Minute,
|
||||
KeyPrefix: "trackeep:",
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// CacheMiddleware creates a cache middleware
|
||||
func CacheMiddleware(config CacheConfig) gin.HandlerFunc {
|
||||
if !config.Enabled || config.RedisClient == nil {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Only cache GET requests
|
||||
if c.Request.Method != http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
cacheKey := generateCacheKey(c, config.KeyPrefix)
|
||||
|
||||
// Try to get from cache
|
||||
cached, err := config.RedisClient.Get(context.Background(), cacheKey).Result()
|
||||
if err == nil && cached != "" {
|
||||
// Cache hit
|
||||
c.Header("X-Cache", "HIT")
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.String(http.StatusOK, cached)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Cache miss, continue with request
|
||||
c.Header("X-Cache", "MISS")
|
||||
|
||||
// Capture response
|
||||
writer := &cachedResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
buffer: make([]byte, 0),
|
||||
}
|
||||
c.Writer = writer
|
||||
|
||||
c.Next()
|
||||
|
||||
// Cache the response if successful
|
||||
if c.Writer.Status() == http.StatusOK && len(writer.buffer) > 0 {
|
||||
config.RedisClient.Set(
|
||||
context.Background(),
|
||||
cacheKey,
|
||||
string(writer.buffer),
|
||||
config.Duration,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateCacheKey creates a unique cache key for the request
|
||||
func generateCacheKey(c *gin.Context, prefix string) string {
|
||||
// Include path, query params, and user ID if available
|
||||
keyParts := []string{
|
||||
prefix,
|
||||
c.Request.URL.Path,
|
||||
c.Request.URL.RawQuery,
|
||||
}
|
||||
|
||||
// Add user ID for personalized caching
|
||||
if userID := c.GetString("userID"); userID != "" {
|
||||
keyParts = append(keyParts, "user:"+userID)
|
||||
}
|
||||
|
||||
// Create hash of the key to avoid long keys
|
||||
key := strings.Join(keyParts, ":")
|
||||
hash := md5.Sum([]byte(key))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// cachedResponseWriter captures response data for caching
|
||||
type cachedResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (w *cachedResponseWriter) Write(data []byte) (int, error) {
|
||||
w.buffer = append(w.buffer, data...)
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates cache entries matching a pattern
|
||||
func InvalidateCache(redisClient *redis.Client, pattern string) error {
|
||||
if redisClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys, err := redisClient.Keys(context.Background(), pattern).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
return redisClient.Del(context.Background(), keys...).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CacheInvalidationMiddleware invalidates cache on write operations
|
||||
func CacheInvalidationMiddleware(redisClient *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// Invalidate cache on successful write operations
|
||||
if c.Writer.Status() >= 200 && c.Writer.Status() < 300 &&
|
||||
(c.Request.Method == http.MethodPost ||
|
||||
c.Request.Method == http.MethodPut ||
|
||||
c.Request.Method == http.MethodDelete) {
|
||||
|
||||
// Invalidate user-specific cache
|
||||
if userID := c.GetString("userID"); userID != "" {
|
||||
pattern := fmt.Sprintf("trackeep:*user:%s*", userID)
|
||||
InvalidateCache(redisClient, pattern)
|
||||
}
|
||||
|
||||
// Invalidate general cache for the affected resource
|
||||
resourcePattern := fmt.Sprintf("trackeep:*%s*", c.Request.URL.Path)
|
||||
InvalidateCache(redisClient, resourcePattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/trackeep/backend/models"
|
||||
)
|
||||
|
||||
// DemoModeMiddleware prevents write operations when in demo mode
|
||||
func DemoModeMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Check if demo mode is enabled
|
||||
if os.Getenv("VITE_DEMO_MODE") == "true" {
|
||||
// Allow GET requests (read operations)
|
||||
if c.Request.Method == "GET" || c.Request.Method == "OPTIONS" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Allow specific write operations in demo mode
|
||||
path := c.Request.URL.Path
|
||||
if (strings.Contains(path, "/learning-paths") && c.Request.Method == "POST") ||
|
||||
(strings.Contains(path, "/bookmarks/content") && c.Request.Method == "POST") ||
|
||||
(strings.Contains(path, "/bookmarks/metadata") && c.Request.Method == "POST") {
|
||||
// Set demo user for these operations
|
||||
c.Set("user", models.User{
|
||||
ID: 1,
|
||||
Username: "demo",
|
||||
Email: "demo@trackeep.com",
|
||||
})
|
||||
c.Set("user_id", uint(1))
|
||||
c.Set("userID", uint(1)) // Add this for compatibility with handlers
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Block other write operations (POST, PUT, DELETE, PATCH)
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "Write operations are disabled in demo mode",
|
||||
"message": "This is a demo instance. Database modifications are not allowed.",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// If not in demo mode, allow all operations
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsDemoMode returns true if demo mode is enabled
|
||||
func IsDemoMode() bool {
|
||||
return os.Getenv("VITE_DEMO_MODE") == "true"
|
||||
}
|
||||
@@ -0,0 +1,379 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// InputValidationMiddleware provides comprehensive input validation
|
||||
func InputValidationMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Validate query parameters
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
for i, value := range values {
|
||||
if containsMaliciousContent(value) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Invalid input detected",
|
||||
"message": "Query parameter contains potentially malicious content",
|
||||
"parameter": key,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// Sanitize query parameters
|
||||
values[i] = sanitizeInput(value)
|
||||
}
|
||||
}
|
||||
|
||||
// For POST/PUT requests, we'll validate the body in the handler
|
||||
// since we need to know the expected structure
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestBody validates JSON request bodies against common attack patterns
|
||||
func ValidateRequestBody() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method == "GET" || c.Request.Method == "DELETE" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Read and validate the body
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
bodyString := string(bodyBytes)
|
||||
|
||||
// Check for common injection patterns
|
||||
if containsMaliciousContent(bodyString) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Invalid input detected",
|
||||
"message": "Request body contains potentially malicious content",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Restore the body for subsequent handlers
|
||||
c.Request.Body = &requestBody{body: bodyBytes}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// requestBody is a custom type to restore the request body
|
||||
type requestBody struct {
|
||||
body []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (rb *requestBody) Read(p []byte) (int, error) {
|
||||
if rb.pos >= len(rb.body) {
|
||||
return 0, nil
|
||||
}
|
||||
n := copy(p, rb.body[rb.pos:])
|
||||
rb.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (rb *requestBody) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// containsMaliciousContent checks for common attack patterns
|
||||
func containsMaliciousContent(input string) bool {
|
||||
// Convert to lowercase for case-insensitive matching
|
||||
lowerInput := strings.ToLower(input)
|
||||
|
||||
// SQL injection patterns
|
||||
sqlPatterns := []string{
|
||||
"union select",
|
||||
"drop table",
|
||||
"insert into",
|
||||
"delete from",
|
||||
"update set",
|
||||
"exec(",
|
||||
"execute(",
|
||||
"sp_executesql",
|
||||
"xp_cmdshell",
|
||||
"'--",
|
||||
"/*",
|
||||
"*/",
|
||||
"char(",
|
||||
"ascii(",
|
||||
"concat(",
|
||||
"substring(",
|
||||
"waitfor delay",
|
||||
"benchmark(",
|
||||
"sleep(",
|
||||
"pg_sleep(",
|
||||
}
|
||||
|
||||
for _, pattern := range sqlPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// XSS patterns
|
||||
xssPatterns := []string{
|
||||
"<script",
|
||||
"</script>",
|
||||
"javascript:",
|
||||
"vbscript:",
|
||||
"onload=",
|
||||
"onerror=",
|
||||
"onclick=",
|
||||
"onmouseover=",
|
||||
"onfocus=",
|
||||
"onblur=",
|
||||
"onchange=",
|
||||
"onsubmit=",
|
||||
"<iframe",
|
||||
"<object",
|
||||
"<embed",
|
||||
"<form",
|
||||
"<input",
|
||||
"<link",
|
||||
"<meta",
|
||||
"<style",
|
||||
"eval(",
|
||||
"alert(",
|
||||
"confirm(",
|
||||
"prompt(",
|
||||
"document.cookie",
|
||||
"document.write",
|
||||
"window.location",
|
||||
}
|
||||
|
||||
for _, pattern := range xssPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Command injection patterns
|
||||
cmdPatterns := []string{
|
||||
"; rm",
|
||||
"; cat",
|
||||
"; ls",
|
||||
"; ps",
|
||||
"; kill",
|
||||
"; chmod",
|
||||
"; chown",
|
||||
"; wget",
|
||||
"; curl",
|
||||
"; nc",
|
||||
"; netcat",
|
||||
"| rm",
|
||||
"| cat",
|
||||
"| ls",
|
||||
"| ps",
|
||||
"& rm",
|
||||
"& cat",
|
||||
"& ls",
|
||||
"&& rm",
|
||||
"&& cat",
|
||||
"&& ls",
|
||||
"`rm",
|
||||
"`cat",
|
||||
"`ls",
|
||||
"$(rm",
|
||||
"$(cat",
|
||||
"$(ls",
|
||||
}
|
||||
|
||||
for _, pattern := range cmdPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Path traversal patterns
|
||||
pathPatterns := []string{
|
||||
"../",
|
||||
"..\\",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e\\",
|
||||
"..%2f",
|
||||
"..%5c",
|
||||
"%2e%2e/",
|
||||
"%2e%2e\\",
|
||||
"/etc/passwd",
|
||||
"/etc/shadow",
|
||||
"/etc/hosts",
|
||||
"windows/system32",
|
||||
"\\windows\\system32",
|
||||
}
|
||||
|
||||
for _, pattern := range pathPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// LDAP injection patterns
|
||||
ldapPatterns := []string{
|
||||
"*)(",
|
||||
"*}",
|
||||
"*)",
|
||||
"*(|",
|
||||
"*(|",
|
||||
"*)(",
|
||||
}
|
||||
|
||||
for _, pattern := range ldapPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// NoSQL injection patterns
|
||||
nosqlPatterns := []string{
|
||||
"$where",
|
||||
"$ne",
|
||||
"$gt",
|
||||
"$lt",
|
||||
"$regex",
|
||||
"$expr",
|
||||
"$json",
|
||||
"$or",
|
||||
"$and",
|
||||
"$not",
|
||||
}
|
||||
|
||||
for _, pattern := range nosqlPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// sanitizeInput cleans input by removing potentially dangerous characters
|
||||
func sanitizeInput(input string) string {
|
||||
// Remove null bytes
|
||||
input = strings.ReplaceAll(input, "\x00", "")
|
||||
|
||||
// Remove control characters except newline, tab, and carriage return
|
||||
var result []rune
|
||||
for _, r := range input {
|
||||
if unicode.IsControl(r) && r != '\n' && r != '\t' && r != '\r' {
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
|
||||
// Trim whitespace
|
||||
input = strings.TrimSpace(string(result))
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// ValidateEmail validates email format
|
||||
func ValidateEmail(email string) bool {
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
return emailRegex.MatchString(email)
|
||||
}
|
||||
|
||||
// ValidatePassword validates password strength
|
||||
func ValidatePassword(password string) error {
|
||||
if len(password) < 8 {
|
||||
return gin.Error{
|
||||
Err: http.ErrBodyNotAllowed,
|
||||
Type: gin.ErrorTypeBind,
|
||||
Meta: "Password must be at least 8 characters long",
|
||||
}
|
||||
}
|
||||
|
||||
var hasUpper, hasLower, hasDigit, hasSpecial bool
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(char):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(char):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(char):
|
||||
hasDigit = true
|
||||
case unicode.IsPunct(char) || unicode.IsSymbol(char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
||||
return gin.Error{
|
||||
Err: http.ErrBodyNotAllowed,
|
||||
Type: gin.ErrorTypeBind,
|
||||
Meta: "Password must contain at least one uppercase letter, one lowercase letter, one digit, and one special character",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUsername validates username format
|
||||
func ValidateUsername(username string) error {
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return gin.Error{
|
||||
Err: http.ErrBodyNotAllowed,
|
||||
Type: gin.ErrorTypeBind,
|
||||
Meta: "Username must be between 3 and 30 characters long",
|
||||
}
|
||||
}
|
||||
|
||||
usernameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
if !usernameRegex.MatchString(username) {
|
||||
return gin.Error{
|
||||
Err: http.ErrBodyNotAllowed,
|
||||
Type: gin.ErrorTypeBind,
|
||||
Meta: "Username can only contain letters, numbers, underscores, and hyphens",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateID validates that ID is a positive integer
|
||||
func ValidateID(id string) error {
|
||||
idRegex := regexp.MustCompile(`^[1-9]\d*$`)
|
||||
if !idRegex.MatchString(id) {
|
||||
return gin.Error{
|
||||
Err: http.ErrBodyNotAllowed,
|
||||
Type: gin.ErrorTypeBind,
|
||||
Meta: "Invalid ID format",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePagination validates pagination parameters
|
||||
func ValidatePagination(page, limit string) (int, int, error) {
|
||||
pageInt := 1
|
||||
limitInt := 20
|
||||
|
||||
if page != "" {
|
||||
if p, err := strconv.Atoi(page); err == nil && p > 0 {
|
||||
pageInt = p
|
||||
}
|
||||
}
|
||||
|
||||
if limit != "" {
|
||||
if l, err := strconv.Atoi(limit); err == nil && l > 0 && l <= 100 {
|
||||
limitInt = l
|
||||
}
|
||||
}
|
||||
|
||||
return pageInt, limitInt, nil
|
||||
}
|
||||
@@ -205,12 +205,6 @@ func logJSON(data map[string]interface{}) {
|
||||
log.Println(string(jsonData))
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID
|
||||
func generateRequestID() string {
|
||||
return time.Now().Format("20060102150405") + "-" +
|
||||
string(rune(time.Now().UnixNano()%1000))
|
||||
}
|
||||
|
||||
// SecurityLogger logs security-related events
|
||||
func SecurityLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MemoryCacheItem represents an item in memory cache
|
||||
type MemoryCacheItem struct {
|
||||
Data []byte
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// MemoryCache represents an in-memory cache
|
||||
type MemoryCache struct {
|
||||
items map[string]MemoryCacheItem
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMemoryCache creates a new memory cache
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
cache := &MemoryCache{
|
||||
items: make(map[string]MemoryCacheItem),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go cache.cleanup()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get retrieves an item from cache
|
||||
func (c *MemoryCache) Get(key string) ([]byte, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists || time.Now().After(item.ExpiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return item.Data, true
|
||||
}
|
||||
|
||||
// Set stores an item in cache
|
||||
func (c *MemoryCache) Set(key string, data []byte, duration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.items[key] = MemoryCacheItem{
|
||||
Data: data,
|
||||
ExpiresAt: time.Now().Add(duration),
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes an item from cache
|
||||
func (c *MemoryCache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// DeletePattern removes items matching a pattern
|
||||
func (c *MemoryCache) DeletePattern(pattern string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
for key := range c.items {
|
||||
if strings.Contains(key, pattern) {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes expired items
|
||||
func (c *MemoryCache) cleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.mutex.Lock()
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Global memory cache instance
|
||||
var globalMemoryCache = NewMemoryCache()
|
||||
|
||||
// MemoryCacheConfig holds memory cache configuration
|
||||
type MemoryCacheConfig struct {
|
||||
Duration time.Duration
|
||||
KeyPrefix string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// DefaultMemoryCacheConfig returns default memory cache configuration
|
||||
func DefaultMemoryCacheConfig() MemoryCacheConfig {
|
||||
return MemoryCacheConfig{
|
||||
Duration: 5 * time.Minute,
|
||||
KeyPrefix: "trackeep:",
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryCacheMiddleware creates a memory cache middleware
|
||||
func MemoryCacheMiddleware(config MemoryCacheConfig) gin.HandlerFunc {
|
||||
if !config.Enabled {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Only cache GET requests
|
||||
if c.Request.Method != http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
cacheKey := generateMemoryCacheKey(c, config.KeyPrefix)
|
||||
|
||||
// Try to get from cache
|
||||
if data, hit := globalMemoryCache.Get(cacheKey); hit {
|
||||
// Cache hit
|
||||
c.Header("X-Cache", "HIT")
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.Data(http.StatusOK, "application/json", data)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Cache miss, continue with request
|
||||
c.Header("X-Cache", "MISS")
|
||||
|
||||
// Capture response
|
||||
writer := &memoryCachedResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
buffer: make([]byte, 0),
|
||||
}
|
||||
c.Writer = writer
|
||||
|
||||
c.Next()
|
||||
|
||||
// Cache the response if successful
|
||||
if c.Writer.Status() == http.StatusOK && len(writer.buffer) > 0 {
|
||||
globalMemoryCache.Set(cacheKey, writer.buffer, config.Duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateMemoryCacheKey creates a unique cache key for the request
|
||||
func generateMemoryCacheKey(c *gin.Context, prefix string) string {
|
||||
// Include path, query params, and user ID if available
|
||||
keyParts := []string{
|
||||
prefix,
|
||||
c.Request.URL.Path,
|
||||
c.Request.URL.RawQuery,
|
||||
}
|
||||
|
||||
// Add user ID for personalized caching
|
||||
if userID := c.GetString("userID"); userID != "" {
|
||||
keyParts = append(keyParts, "user:"+userID)
|
||||
}
|
||||
|
||||
// Create hash of the key to avoid long keys
|
||||
key := strings.Join(keyParts, ":")
|
||||
hash := md5.Sum([]byte(key))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// memoryCachedResponseWriter captures response data for caching
|
||||
type memoryCachedResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (w *memoryCachedResponseWriter) Write(data []byte) (int, error) {
|
||||
w.buffer = append(w.buffer, data...)
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
// InvalidateMemoryCache invalidates cache entries matching a pattern
|
||||
func InvalidateMemoryCache(pattern string) {
|
||||
globalMemoryCache.DeletePattern(pattern)
|
||||
}
|
||||
|
||||
// MemoryCacheInvalidationMiddleware invalidates cache on write operations
|
||||
func MemoryCacheInvalidationMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// Invalidate cache on successful write operations
|
||||
if c.Writer.Status() >= 200 && c.Writer.Status() < 300 &&
|
||||
(c.Request.Method == http.MethodPost ||
|
||||
c.Request.Method == http.MethodPut ||
|
||||
c.Request.Method == http.MethodDelete) {
|
||||
|
||||
// Invalidate user-specific cache
|
||||
if userID := c.GetString("userID"); userID != "" {
|
||||
pattern := fmt.Sprintf("user:%s", userID)
|
||||
InvalidateMemoryCache(pattern)
|
||||
}
|
||||
|
||||
// Invalidate general cache for the affected resource
|
||||
resourcePattern := fmt.Sprintf("%s", c.Request.URL.Path)
|
||||
InvalidateMemoryCache(resourcePattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PerformanceMiddleware adds performance monitoring to requests
|
||||
func PerformanceMiddleware() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// Custom log format with performance metrics
|
||||
return fmt.Sprintf(
|
||||
"[%s] %s %s %d %s %s \"%s\" %s\n",
|
||||
param.TimeStamp.Format("02/Jan/2006:15:04:05 -0700"),
|
||||
param.Method,
|
||||
param.Path,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
param.Request.UserAgent(),
|
||||
param.ErrorMessage,
|
||||
param.Request.Header.Get("X-Request-ID"),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// RequestIDMiddleware adds a unique request ID to each request
|
||||
func RequestIDMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
c.Set("RequestID", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID
|
||||
func generateRequestID() string {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// SlowQueryMiddleware logs slow database queries
|
||||
func SlowQueryMiddleware(threshold time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
duration := time.Since(start)
|
||||
|
||||
if duration > threshold {
|
||||
log.Printf("SLOW REQUEST: %s %s took %v", c.Request.Method, c.Request.URL.Path, duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter implements a simple in-memory rate limiter
|
||||
type RateLimiter struct {
|
||||
clients map[string]*ClientInfo
|
||||
mutex sync.RWMutex
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
type ClientInfo struct {
|
||||
requests int
|
||||
resetTime time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
limiter := &RateLimiter{
|
||||
clients: make(map[string]*ClientInfo),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go limiter.cleanup()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Middleware returns the Gin middleware function
|
||||
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
rl.mutex.Lock()
|
||||
client, exists := rl.clients[clientIP]
|
||||
|
||||
if !exists {
|
||||
client = &ClientInfo{
|
||||
requests: 0,
|
||||
resetTime: time.Now().Add(rl.window),
|
||||
}
|
||||
rl.clients[clientIP] = client
|
||||
}
|
||||
|
||||
// Reset window if expired
|
||||
if time.Now().After(client.resetTime) {
|
||||
client.requests = 0
|
||||
client.resetTime = time.Now().Add(rl.window)
|
||||
}
|
||||
|
||||
client.requests++
|
||||
|
||||
// Check if limit exceeded
|
||||
if client.requests > rl.limit {
|
||||
rl.mutex.Unlock()
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", rl.limit))
|
||||
c.Header("X-RateLimit-Remaining", "0")
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", client.resetTime.Unix()))
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"message": fmt.Sprintf("Too many requests. Limit is %d per %v", rl.limit, rl.window),
|
||||
"retry_after": time.Until(client.resetTime).Seconds(),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
remaining := rl.limit - client.requests
|
||||
rl.mutex.Unlock()
|
||||
|
||||
// Set rate limit headers
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", rl.limit))
|
||||
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", client.resetTime.Unix()))
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes expired client entries
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mutex.Lock()
|
||||
now := time.Now()
|
||||
for ip, client := range rl.clients {
|
||||
if now.After(client.resetTime.Add(rl.window)) {
|
||||
delete(rl.clients, ip)
|
||||
}
|
||||
}
|
||||
rl.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitConfig holds configuration for different endpoint types
|
||||
type RateLimitConfig struct {
|
||||
AuthRequests int // requests per window for auth endpoints
|
||||
GeneralRequests int // requests per window for general endpoints
|
||||
Window time.Duration // time window for rate limiting
|
||||
}
|
||||
|
||||
// DefaultRateLimitConfig returns sensible defaults
|
||||
func DefaultRateLimitConfig() RateLimitConfig {
|
||||
return RateLimitConfig{
|
||||
AuthRequests: 5, // 5 login attempts per minute
|
||||
GeneralRequests: 100, // 100 requests per minute
|
||||
Window: time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit creates rate limiters for different endpoint types
|
||||
func RateLimit(config RateLimitConfig) map[string]*RateLimiter {
|
||||
return map[string]*RateLimiter{
|
||||
"auth": NewRateLimiter(config.AuthRequests, config.Window),
|
||||
"general": NewRateLimiter(config.GeneralRequests, config.Window),
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit applies stricter rate limiting to authentication endpoints
|
||||
func AuthRateLimit(limiter *RateLimiter) gin.HandlerFunc {
|
||||
return limiter.Middleware()
|
||||
}
|
||||
|
||||
// GeneralRateLimit applies standard rate limiting to general endpoints
|
||||
func GeneralRateLimit(limiter *RateLimiter) gin.HandlerFunc {
|
||||
return limiter.Middleware()
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/trackeep/backend/models"
|
||||
)
|
||||
|
||||
// SessionData represents the structure of session data stored in Redis
|
||||
type SessionData struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
SessionID string `json:"session_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastActive time.Time `json:"last_active"`
|
||||
}
|
||||
|
||||
// SessionStore interface for session storage
|
||||
type SessionStore interface {
|
||||
CreateSession(sessionData *SessionData) error
|
||||
GetSession(sessionID string) (*SessionData, error)
|
||||
UpdateSession(sessionID string, sessionData *SessionData) error
|
||||
DeleteSession(sessionID string) error
|
||||
CleanupExpiredSessions() error
|
||||
}
|
||||
|
||||
// RedisSessionStore implements SessionStore using Redis (or fallback to memory)
|
||||
type RedisSessionStore struct {
|
||||
sessions map[string]*SessionData // Fallback in-memory store
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() SessionStore {
|
||||
return &RedisSessionStore{
|
||||
sessions: make(map[string]*SessionData),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new session
|
||||
func (r *RedisSessionStore) CreateSession(sessionData *SessionData) error {
|
||||
sessionData.CreatedAt = time.Now()
|
||||
sessionData.LastActive = time.Now()
|
||||
r.sessions[sessionData.SessionID] = sessionData
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID
|
||||
func (r *RedisSessionStore) GetSession(sessionID string) (*SessionData, error) {
|
||||
if session, exists := r.sessions[sessionID]; exists {
|
||||
// Update last active time
|
||||
session.LastActive = time.Now()
|
||||
return session, nil
|
||||
}
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
// UpdateSession updates an existing session
|
||||
func (r *RedisSessionStore) UpdateSession(sessionID string, sessionData *SessionData) error {
|
||||
if _, exists := r.sessions[sessionID]; exists {
|
||||
sessionData.LastActive = time.Now()
|
||||
r.sessions[sessionID] = sessionData
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
// DeleteSession removes a session
|
||||
func (r *RedisSessionStore) DeleteSession(sessionID string) error {
|
||||
delete(r.sessions, sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions removes sessions older than 24 hours
|
||||
func (r *RedisSessionStore) CleanupExpiredSessions() error {
|
||||
now := time.Now()
|
||||
for sessionID, session := range r.sessions {
|
||||
if now.Sub(session.LastActive) > 24*time.Hour {
|
||||
delete(r.sessions, sessionID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global session store instance
|
||||
var sessionStore SessionStore
|
||||
|
||||
// InitSessionStore initializes the session store
|
||||
func InitSessionStore() {
|
||||
sessionStore = NewSessionStore()
|
||||
|
||||
// Start cleanup goroutine
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if sessionStore != nil {
|
||||
sessionStore.CleanupExpiredSessions()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SessionMiddleware creates and manages user sessions
|
||||
func SessionMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip session management for health checks and static assets
|
||||
path := c.Request.URL.Path
|
||||
if path == "/health" || path == "/metrics" || strings.HasPrefix(path, "/static") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Get session ID from header or create new one
|
||||
sessionID := c.GetHeader("X-Session-ID")
|
||||
if sessionID == "" {
|
||||
sessionID = generateSessionID()
|
||||
c.Header("X-Session-ID", sessionID)
|
||||
}
|
||||
|
||||
// Try to get existing session
|
||||
session, err := sessionStore.GetSession(sessionID)
|
||||
if err != nil {
|
||||
// No existing session, check if user is authenticated via JWT
|
||||
if user, exists := c.Get("user"); exists {
|
||||
// Create session from authenticated user
|
||||
if userModel, ok := user.(models.User); ok {
|
||||
session = &SessionData{
|
||||
SessionID: sessionID,
|
||||
UserID: userModel.ID,
|
||||
Email: userModel.Email,
|
||||
Username: userModel.Username,
|
||||
Role: userModel.Role,
|
||||
IPAddress: c.ClientIP(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
sessionStore.CreateSession(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set session data in context
|
||||
if session != nil {
|
||||
c.Set("session_id", session.SessionID)
|
||||
c.Set("session_user_id", session.UserID)
|
||||
c.Set("session_email", session.Email)
|
||||
c.Set("session_username", session.Username)
|
||||
c.Set("session_role", session.Role)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionFromContext retrieves session data from Gin context
|
||||
func GetSessionFromContext(c *gin.Context) (*SessionData, error) {
|
||||
if sessionID, exists := c.Get("session_id"); exists {
|
||||
return sessionStore.GetSession(sessionID.(string))
|
||||
}
|
||||
return nil, fmt.Errorf("no session in context")
|
||||
}
|
||||
|
||||
// GetUserIDFromSession safely gets user ID from session or context
|
||||
func GetUserIDFromSession(c *gin.Context) uint {
|
||||
// First try session
|
||||
if session, err := GetSessionFromContext(c); err == nil {
|
||||
return session.UserID
|
||||
}
|
||||
|
||||
// Fallback to context (for demo mode or JWT)
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(uint); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
if userID, exists := c.Get("userID"); exists {
|
||||
if id, ok := userID.(uint); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
// Final fallback for demo mode
|
||||
if os.Getenv("VITE_DEMO_MODE") == "true" {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetUserEmailFromSession safely gets user email from session or context
|
||||
func GetUserEmailFromSession(c *gin.Context) string {
|
||||
// First try session
|
||||
if session, err := GetSessionFromContext(c); err == nil {
|
||||
return session.Email
|
||||
}
|
||||
|
||||
// Fallback to context
|
||||
if email, exists := c.Get("user_email"); exists {
|
||||
if e, ok := email.(string); ok {
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for demo mode
|
||||
if os.Getenv("VITE_DEMO_MODE") == "true" {
|
||||
return "demo@trackeep.com"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// generateSessionID generates a unique session ID
|
||||
func generateSessionID() string {
|
||||
return fmt.Sprintf("sess_%d_%s", time.Now().UnixNano(), "trackeep")
|
||||
}
|
||||
|
||||
// GetSessionStore returns the global session store instance
|
||||
func GetSessionStore() SessionStore {
|
||||
return sessionStore
|
||||
}
|
||||
|
||||
// CleanupSessionsOnShutdown gracefully cleans up sessions
|
||||
func CleanupSessionsOnShutdown() {
|
||||
if sessionStore != nil {
|
||||
sessionStore.CleanupExpiredSessions()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user