package middleware import ( "bytes" "io" "net/http" "strings" "github.com/gin-gonic/gin" ) // RequestSizeLimit limits the size of request bodies func RequestSizeLimit(maxSize int64) gin.HandlerFunc { return func(c *gin.Context) { // Skip for upload endpoints (they have their own limits) if strings.Contains(c.Request.URL.Path, "/upload") { c.Next() return } c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) c.Next() } } // SanitizeHeaders removes potentially dangerous headers func SanitizeHeaders() gin.HandlerFunc { return func(c *gin.Context) { // Remove server information leakage c.Writer.Header().Del("Server") c.Writer.Header().Del("X-Powered-By") c.Next() } } // ValidateContentType ensures proper content type for POST/PUT/PATCH func ValidateContentType() gin.HandlerFunc { return func(c *gin.Context) { if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" { contentType := c.GetHeader("Content-Type") // Allow multipart for file uploads if strings.Contains(c.Request.URL.Path, "/upload") { c.Next() return } // Require JSON for API endpoints if !strings.Contains(contentType, "application/json") { c.JSON(http.StatusUnsupportedMediaType, gin.H{ "error": "Content-Type must be application/json", }) c.Abort() return } } c.Next() } } // RequestID adds a unique request ID for tracing func RequestID() gin.HandlerFunc { return func(c *gin.Context) { requestID := c.GetHeader("X-Request-ID") if requestID == "" { requestID = generateRequestID() } c.Set("request_id", requestID) c.Header("X-Request-ID", requestID) c.Next() } } func generateRequestID() string { // Simple request ID generation b := make([]byte, 16) _, _ = io.ReadFull(bytes.NewReader([]byte(strings.Repeat("0123456789abcdef", 2))), b) return string(b) } // SecurityAuditLog logs security-relevant events type SecurityEvent struct { Type string UserID uint IP string Path string Method string RequestID string Details map[string]interface{} } func LogSecurityEvent(c *gin.Context, eventType string, details map[string]interface{}) { event := SecurityEvent{ Type: eventType, IP: c.ClientIP(), Path: c.Request.URL.Path, Method: c.Request.Method, RequestID: c.GetString("request_id"), Details: details, } if userID, exists := c.Get("user_id"); exists { if uid, ok := userID.(uint); ok { event.UserID = uid } } // Log to your logger // logger.Warn("SECURITY_EVENT: type=%s user_id=%d ip=%s path=%s", // event.Type, event.UserID, event.IP, event.Path) }