initiall commit

This commit is contained in:
Tomas Dvorak
2026-04-10 12:03:31 +02:00
commit 7ddfb1f52b
276 changed files with 37629 additions and 0 deletions
@@ -0,0 +1,51 @@
package middleware
import (
"compress/gzip"
"io"
"strings"
"sync"
"github.com/gin-gonic/gin"
)
var gzipPool = sync.Pool{
New: func() any {
w, _ := gzip.NewWriterLevel(io.Discard, gzip.DefaultCompression)
return w
},
}
type gzipWriter struct {
gin.ResponseWriter
writer *gzip.Writer
}
func (g *gzipWriter) Write(data []byte) (int, error) {
return g.writer.Write(data)
}
func Compression() gin.HandlerFunc {
return func(c *gin.Context) {
if !strings.Contains(c.Request.Header.Get("Accept-Encoding"), "gzip") {
c.Next()
return
}
// Skip compression for small responses or streaming
if c.Request.Method == "HEAD" || c.Request.URL.Path == "/api/v1/health/liveness" {
c.Next()
return
}
gz := gzipPool.Get().(*gzip.Writer)
defer gzipPool.Put(gz)
gz.Reset(c.Writer)
defer gz.Close()
c.Header("Content-Encoding", "gzip")
c.Header("Vary", "Accept-Encoding")
c.Writer = &gzipWriter{ResponseWriter: c.Writer, writer: gz}
c.Next()
}
}
+264
View File
@@ -0,0 +1,264 @@
package middleware
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"github.com/tdvorak/primora/apps/backend/internal/auth"
db "github.com/tdvorak/primora/apps/backend/internal/database/db"
"github.com/tdvorak/primora/apps/backend/internal/models"
"github.com/tdvorak/primora/apps/backend/internal/repositories"
apperrors "github.com/tdvorak/primora/apps/backend/internal/response"
)
const (
requestIDKey = "request_id"
actorKey = "actor"
)
type AuthMiddleware struct {
Queries *repositories.CoreRepository
Logger *slog.Logger
Redis *redis.Client
Verifier *auth.Verifier
RateLimits RateLimitConfig
}
type RateLimitConfig struct {
APIKeyPerMinute int
UserPerMinute int
}
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.Request.Header.Get("X-Request-ID")
if requestID == "" {
requestID = uuid.NewString()
}
c.Set(requestIDKey, requestID)
c.Writer.Header().Set("X-Request-ID", requestID)
c.Next()
}
}
func Logger(logger *slog.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request_complete",
"method", c.Request.Method,
"path", c.Request.URL.Path,
"status", c.Writer.Status(),
"duration_ms", time.Since(start).Milliseconds(),
"request_id", RequestIDFromContext(c),
"client_ip", c.ClientIP(),
)
}
}
func (m AuthMiddleware) ResolveActor() gin.HandlerFunc {
return func(c *gin.Context) {
apiKey := strings.TrimSpace(c.GetHeader("X-API-Key"))
authz := strings.TrimSpace(c.GetHeader("Authorization"))
if apiKey == "" && strings.HasPrefix(strings.ToLower(authz), "bearer ") {
token := strings.TrimSpace(strings.TrimPrefix(authz, "Bearer"))
if strings.HasPrefix(authz, "Bearer ") {
token = strings.TrimSpace(strings.TrimPrefix(authz, "Bearer "))
}
if strings.HasPrefix(strings.ToLower(authz), "bearer ") {
apiKey = ""
actor, err := m.resolveJWTActor(c.Request.Context(), token)
if err != nil {
apperrors.Abort(c, http.StatusUnauthorized, "invalid_token", err.Error())
return
}
userIdentity := actor.AuthSubject
if actor.UserID != nil {
userIdentity = actor.UserID.String()
}
if !m.enforceRateLimit(c, "user", userIdentity, m.RateLimits.UserPerMinute, "User rate limit exceeded") {
return
}
c.Set(actorKey, actor)
c.Next()
return
}
}
if apiKey == "" {
apiKey = strings.TrimSpace(strings.TrimPrefix(authz, "ApiKey "))
}
if apiKey != "" {
actor, err := m.resolveAPIKeyActor(c.Request.Context(), apiKey)
if err != nil {
apperrors.Abort(c, http.StatusUnauthorized, "invalid_api_key", err.Error())
return
}
if !m.enforceRateLimit(c, "apikey", actor.APIKeyPrefix, m.RateLimits.APIKeyPerMinute, "API key rate limit exceeded") {
return
}
c.Set(actorKey, actor)
}
c.Next()
}
}
func (m AuthMiddleware) enforceRateLimit(c *gin.Context, scope, identity string, limit int, exceededMessage string) bool {
if m.Redis == nil || limit <= 0 || identity == "" {
return true
}
key := "ratelimit:" + scope + ":" + identity + ":" + time.Now().UTC().Format("200601021504")
count, err := m.Redis.Incr(c.Request.Context(), key).Result()
if err != nil {
if m.Logger != nil {
m.Logger.Warn("rate limit increment failed", "scope", scope, "error", err)
}
return true
}
if count == 1 {
if err := m.Redis.Expire(c.Request.Context(), key, time.Minute).Err(); err != nil && m.Logger != nil {
m.Logger.Warn("rate limit expiry update failed", "scope", scope, "error", err)
}
}
ttl, err := m.Redis.TTL(c.Request.Context(), key).Result()
if err != nil {
if m.Logger != nil {
m.Logger.Warn("rate limit ttl lookup failed", "scope", scope, "error", err)
}
ttl = time.Minute
}
if ttl <= 0 {
ttl = time.Minute
}
resetSeconds := int((ttl + time.Second - 1) / time.Second)
if resetSeconds < 1 {
resetSeconds = 1
}
remaining := limit - int(count)
if remaining < 0 {
remaining = 0
}
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.Itoa(resetSeconds))
if count > int64(limit) {
c.Header("Retry-After", strconv.Itoa(resetSeconds))
apperrors.Abort(c, http.StatusTooManyRequests, "rate_limited", exceededMessage)
return false
}
return true
}
func RequireActor(c *gin.Context) (*models.Actor, bool) {
actor, ok := ActorFromContext(c)
if !ok {
apperrors.Abort(c, http.StatusUnauthorized, "authentication_required", "authentication is required")
return nil, false
}
return actor, true
}
func ActorFromContext(c *gin.Context) (*models.Actor, bool) {
value, ok := c.Get(actorKey)
if !ok {
return nil, false
}
actor, ok := value.(*models.Actor)
return actor, ok
}
func RequestIDFromContext(c *gin.Context) string {
value, ok := c.Get(requestIDKey)
if !ok {
return ""
}
requestID, _ := value.(string)
return requestID
}
func (m AuthMiddleware) resolveJWTActor(ctx context.Context, token string) (*models.Actor, error) {
if m.Verifier == nil {
return nil, errors.New("jwt verifier not configured")
}
claims, err := m.Verifier.ParseToken(token)
if err != nil {
return nil, err
}
user, err := m.Queries.UpsertUser(ctx, db.UpsertUserParams{
AuthSubject: claims.Subject,
Email: strings.ToLower(claims.Email),
Name: claims.Name,
EmailVerified: claims.EmailVerified,
})
if err != nil {
return nil, fmt.Errorf("upsert user: %w", err)
}
return &models.Actor{
Type: models.ActorTypeUser,
UserID: &user.ID,
AuthSubject: claims.Subject,
Email: strings.ToLower(claims.Email),
EmailVerified: claims.EmailVerified,
Name: claims.Name,
SessionID: claims.SessionID,
}, nil
}
func (m AuthMiddleware) resolveAPIKeyActor(ctx context.Context, rawKey string) (*models.Actor, error) {
parts := strings.Split(rawKey, "_")
if len(parts) < 3 {
return nil, errors.New("malformed api key")
}
prefix := strings.Join(parts[:2], "_")
row, err := m.Queries.GetAPIKeyByPrefix(ctx, prefix)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, errors.New("unknown api key")
}
return nil, err
}
if row.RevokedAt.Valid {
return nil, errors.New("api key revoked")
}
sum := sha256.Sum256([]byte(rawKey))
if subtle.ConstantTimeCompare(row.SecretHash, sum[:]) != 1 {
return nil, errors.New("invalid api key secret")
}
if err := m.Queries.TouchAPIKey(ctx, row.ID); err != nil {
m.Logger.Warn("failed to touch api key", "error", err)
}
projectID := row.ProjectID
orgID := row.OrganizationID
apiKeyID := row.ID
return &models.Actor{
Type: models.ActorTypeAPIKey,
ProjectID: &projectID,
OrganizationID: &orgID,
APIKeyID: &apiKeyID,
APIKeyPrefix: prefix,
}, nil
}
func PgUUID(id uuid.UUID) pgtype.UUID {
return pgtype.UUID{Bytes: id, Valid: true}
}
func HexDigest(input string) string {
sum := sha256.Sum256([]byte(input))
return hex.EncodeToString(sum[:])
}
+58
View File
@@ -0,0 +1,58 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
MaxAge int
}
func CORS(config CORSConfig) gin.HandlerFunc {
if len(config.AllowedMethods) == 0 {
config.AllowedMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}
}
if len(config.AllowedHeaders) == 0 {
config.AllowedHeaders = []string{"Origin", "Content-Type", "Accept", "Authorization", "X-API-Key", "X-Request-ID"}
}
if config.MaxAge == 0 {
config.MaxAge = 86400
}
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// Check if origin is allowed
allowed := false
for _, allowedOrigin := range config.AllowedOrigins {
if allowedOrigin == "*" || allowedOrigin == origin {
allowed = true
break
}
}
if allowed {
if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
} else if len(config.AllowedOrigins) == 1 {
c.Writer.Header().Set("Access-Control-Allow-Origin", config.AllowedOrigins[0])
}
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
c.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
c.Writer.Header().Set("Access-Control-Max-Age", string(rune(config.MaxAge)))
}
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}
@@ -0,0 +1,58 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
// StructuredLogger returns a middleware that logs HTTP requests with structured logging
func StructuredLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
req := c.Request
// Create a logger with request context
logger := log.With().
Str("method", req.Method).
Str("uri", req.RequestURI).
Str("remote_ip", c.ClientIP()).
Str("user_agent", req.UserAgent()).
Str("request_id", req.Header.Get("X-Request-ID")).
Logger()
// Add logger to context
c.Set("logger", &logger)
// Process request
c.Next()
// Calculate latency
latency := time.Since(start)
// Log the request
logEvent := logger.Info()
if len(c.Errors) > 0 {
logEvent = logger.Error().Err(c.Errors.Last())
}
logEvent.
Int("status", c.Writer.Status()).
Int64("bytes_out", int64(c.Writer.Size())).
Dur("latency_ms", latency).
Msg("HTTP request")
}
}
// GetLogger retrieves the logger from the Gin context
func GetLogger(c *gin.Context) *zerolog.Logger {
if logger, exists := c.Get("logger"); exists {
if l, ok := logger.(*zerolog.Logger); ok {
return l
}
}
return &log.Logger
}
@@ -0,0 +1,22 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"github.com/tdvorak/primora/apps/backend/internal/observability"
)
func Metrics(metrics *observability.Metrics) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
metrics.IncrementActive()
defer metrics.DecrementActive()
c.Next()
duration := time.Since(start)
isError := c.Writer.Status() >= 400
metrics.RecordRequest(duration, isError)
}
}
@@ -0,0 +1,152 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter implements a simple token bucket rate limiter
type RateLimiter struct {
visitors map[string]*visitor
mu sync.RWMutex
rate int // requests per window
window time.Duration // time window
}
type visitor struct {
tokens int
lastSeen time.Time
mu sync.Mutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
visitors: make(map[string]*visitor),
rate: rate,
window: window,
}
// Cleanup old visitors every minute
go rl.cleanupVisitors()
return rl
}
// Middleware returns a Gin middleware function
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
if !rl.allow(ip) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Rate limit exceeded"})
return
}
c.Next()
}
}
// allow checks if a request from the given IP is allowed
func (rl *RateLimiter) allow(ip string) bool {
rl.mu.Lock()
v, exists := rl.visitors[ip]
if !exists {
v = &visitor{
tokens: rl.rate,
lastSeen: time.Now(),
}
rl.visitors[ip] = v
}
rl.mu.Unlock()
v.mu.Lock()
defer v.mu.Unlock()
// Refill tokens based on time passed
now := time.Now()
elapsed := now.Sub(v.lastSeen)
v.lastSeen = now
// Add tokens based on elapsed time
tokensToAdd := int(elapsed / rl.window * time.Duration(rl.rate))
v.tokens += tokensToAdd
if v.tokens > rl.rate {
v.tokens = rl.rate
}
// Check if we have tokens available
if v.tokens > 0 {
v.tokens--
return true
}
return false
}
// cleanupVisitors removes old visitors periodically
func (rl *RateLimiter) cleanupVisitors() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
for ip, v := range rl.visitors {
v.mu.Lock()
if time.Since(v.lastSeen) > rl.window*2 {
delete(rl.visitors, ip)
}
v.mu.Unlock()
}
rl.mu.Unlock()
}
}
// RateLimitByKey implements rate limiting by custom key (e.g., user ID, API key)
type KeyRateLimiter struct {
limiters map[string]*RateLimiter
mu sync.RWMutex
rate int
window time.Duration
}
// NewKeyRateLimiter creates a new key-based rate limiter
func NewKeyRateLimiter(rate int, window time.Duration) *KeyRateLimiter {
return &KeyRateLimiter{
limiters: make(map[string]*RateLimiter),
rate: rate,
window: window,
}
}
// Middleware returns a Gin middleware function that rate limits by a custom key
func (krl *KeyRateLimiter) Middleware(keyFunc func(*gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
key := keyFunc(c)
if key == "" {
c.Next()
return
}
krl.mu.RLock()
limiter, exists := krl.limiters[key]
krl.mu.RUnlock()
if !exists {
krl.mu.Lock()
limiter = NewRateLimiter(krl.rate, krl.window)
krl.limiters[key] = limiter
krl.mu.Unlock()
}
if !limiter.allow(key) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Rate limit exceeded for this resource"})
return
}
c.Next()
}
}
@@ -0,0 +1,38 @@
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
// Recovery returns a middleware that recovers from panics
func Recovery() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if r := recover(); r != nil {
err, ok := r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
// Log the panic with stack trace
log.Error().
Err(err).
Str("stack", string(debug.Stack())).
Str("method", c.Request.Method).
Str("uri", c.Request.RequestURI).
Str("remote_ip", c.ClientIP()).
Msg("Panic recovered")
// Return internal server error
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
}
}()
c.Next()
}
}