first commit

This commit is contained in:
Tomas Dvorak
2026-04-10 12:04:09 +02:00
commit 3cb40adb23
203 changed files with 40226 additions and 0 deletions
+148
View File
@@ -0,0 +1,148 @@
package app
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
"go.uber.org/zap"
"productier/apps/backend/internal/authsession"
"productier/apps/backend/internal/filestorage"
"productier/apps/backend/internal/httpapi"
"productier/apps/backend/internal/mailruntime"
"productier/apps/backend/internal/store"
)
type App struct {
server *httpapi.Server
port string
shutdownTimeout time.Duration
stopMailRuntime context.CancelFunc
}
func New(logger *zap.Logger) (*App, error) {
runtimeConfig, err := loadRuntimeConfig()
if err != nil {
return nil, err
}
var dataStore store.Store
databaseURL := os.Getenv("DATABASE_URL")
if databaseURL != "" {
persistentStore, err := store.NewPostgresStore(databaseURL, runtimeConfig.mode)
if err != nil {
return nil, err
}
dataStore = persistentStore
} else {
if err := validateStoreRuntimeMode(runtimeConfig.mode, inMemoryStoreAllowed()); err != nil {
return nil, err
}
dataStore = store.NewSeededState(runtimeConfig.mode)
}
mailService, err := mailruntime.New(dataStore, logger, runtimeConfig.mailSecret)
if err != nil {
return nil, err
}
files, err := filestorage.NewFromEnv()
if err != nil {
return nil, err
}
probeCtx, cancelProbe := context.WithTimeout(context.Background(), 3*time.Second)
defer cancelProbe()
if err := files.Probe(probeCtx); err != nil {
return nil, fmt.Errorf("file storage startup probe failed: %w", err)
}
mailRuntimeCtx, stopMailRuntime := context.WithCancel(context.Background())
mailService.Start(mailRuntimeCtx)
return &App{
server: httpapi.NewServer(
dataStore,
authsession.NewClient(runtimeConfig.authServiceURL),
mailService,
files,
runtimeConfig.mode,
runtimeConfig.corsAllowOrigins,
runtimeConfig.metricsAuthToken,
logger,
),
port: runtimeConfig.apiPort,
shutdownTimeout: runtimeConfig.shutdownTimeout,
stopMailRuntime: stopMailRuntime,
}, nil
}
func (a *App) Run() error {
return a.RunContext(context.Background())
}
func (a *App) RunContext(ctx context.Context) error {
httpServer := &http.Server{
Addr: fmt.Sprintf(":%s", a.port),
Handler: a.server.Engine(),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
serverErr := make(chan error, 1)
go func() {
err := httpServer.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
serverErr <- err
}
close(serverErr)
}()
select {
case <-ctx.Done():
if a.stopMailRuntime != nil {
a.stopMailRuntime()
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.shutdownTimeout)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("shutdown api server: %w", err)
}
if err, ok := <-serverErr; ok && err != nil {
return fmt.Errorf("run api server: %w", err)
}
return nil
case err, ok := <-serverErr:
if a.stopMailRuntime != nil {
a.stopMailRuntime()
}
if !ok || err == nil {
return nil
}
return fmt.Errorf("run api server: %w", err)
}
}
func validateStoreRuntimeMode(mode string, allowInMemory bool) error {
if mode == "development" || allowInMemory {
return nil
}
return fmt.Errorf("DATABASE_URL is required when APP_ENV=%q (set ALLOW_INMEMORY_STORE=true only for temporary non-production testing)", mode)
}
func inMemoryStoreAllowed() bool {
raw := strings.TrimSpace(strings.ToLower(os.Getenv("ALLOW_INMEMORY_STORE")))
switch raw {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
+72
View File
@@ -0,0 +1,72 @@
package app
import "testing"
func TestValidateStoreRuntimeMode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mode string
allowInMemory bool
expectError bool
}{
{
name: "development allows in-memory store",
mode: "development",
allowInMemory: false,
expectError: false,
},
{
name: "production rejects in-memory store by default",
mode: "production",
allowInMemory: false,
expectError: true,
},
{
name: "non-development can be explicitly overridden",
mode: "staging",
allowInMemory: true,
expectError: false,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
err := validateStoreRuntimeMode(test.mode, test.allowInMemory)
if test.expectError && err == nil {
t.Fatalf("expected error for mode=%q allowInMemory=%v", test.mode, test.allowInMemory)
}
if !test.expectError && err != nil {
t.Fatalf("did not expect error for mode=%q allowInMemory=%v: %v", test.mode, test.allowInMemory, err)
}
})
}
}
func TestInMemoryStoreAllowed(t *testing.T) {
tests := []struct {
name string
value string
allowed bool
}{
{name: "empty", value: "", allowed: false},
{name: "true", value: "true", allowed: true},
{name: "uppercase true", value: "TRUE", allowed: true},
{name: "one", value: "1", allowed: true},
{name: "yes", value: "yes", allowed: true},
{name: "off", value: "off", allowed: false},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Setenv("ALLOW_INMEMORY_STORE", test.value)
if got := inMemoryStoreAllowed(); got != test.allowed {
t.Fatalf("inMemoryStoreAllowed() = %v, want %v for %q", got, test.allowed, test.value)
}
})
}
}
+256
View File
@@ -0,0 +1,256 @@
package app
import (
"errors"
"fmt"
"net/url"
"os"
"strconv"
"strings"
"time"
)
const (
defaultAppMode = "development"
defaultAPIPort = "8080"
defaultShutdownTimeout = 10 * time.Second
)
var (
defaultLocalCORSOrigins = []string{
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:3001",
"http://127.0.0.1:3001",
}
insecureSecretPlaceholders = map[string]struct{}{
"": {},
"replace-me-with-a-long-random-secret": {},
"replace-me-with-a-dedicated-mail-secret": {},
"productier-local-mail-key": {},
"changeme": {},
"change-me": {},
"replace-me": {},
}
)
type runtimeConfig struct {
mode string
apiPort string
authServiceURL string
shutdownTimeout time.Duration
corsAllowOrigins []string
mailSecret string
metricsAuthToken string
}
func loadRuntimeConfig() (runtimeConfig, error) {
mode, err := parseAppMode(os.Getenv("APP_ENV"))
if err != nil {
return runtimeConfig{}, err
}
apiPort, err := parsePort(os.Getenv("API_PORT"), defaultAPIPort, "API_PORT")
if err != nil {
return runtimeConfig{}, err
}
authServiceURL, err := parseAbsoluteHTTPURL(
valueOrDefault(strings.TrimSpace(os.Getenv("AUTH_SERVICE_URL")), "http://localhost:3001"),
"AUTH_SERVICE_URL",
)
if err != nil {
return runtimeConfig{}, err
}
shutdownTimeout, err := parseDuration(
os.Getenv("API_SHUTDOWN_TIMEOUT"),
defaultShutdownTimeout,
"API_SHUTDOWN_TIMEOUT",
)
if err != nil {
return runtimeConfig{}, err
}
corsAllowOrigins, err := parseCORSAllowOrigins(mode, os.Getenv("CORS_ALLOW_ORIGINS"))
if err != nil {
return runtimeConfig{}, err
}
mailSecret, err := resolveMailSecret(mode)
if err != nil {
return runtimeConfig{}, err
}
metricsAuthToken, err := resolveMetricsAuthToken(mode)
if err != nil {
return runtimeConfig{}, err
}
return runtimeConfig{
mode: mode,
apiPort: apiPort,
authServiceURL: authServiceURL,
shutdownTimeout: shutdownTimeout,
corsAllowOrigins: corsAllowOrigins,
mailSecret: mailSecret,
metricsAuthToken: metricsAuthToken,
}, nil
}
func parseAppMode(raw string) (string, error) {
mode := strings.TrimSpace(strings.ToLower(raw))
if mode == "" {
return defaultAppMode, nil
}
switch mode {
case "development", "test", "staging", "production":
return mode, nil
default:
return "", fmt.Errorf("unsupported APP_ENV %q (allowed: development, test, staging, production)", mode)
}
}
func parsePort(raw string, fallback string, envName string) (string, error) {
port := strings.TrimSpace(raw)
if port == "" {
port = fallback
}
numeric, err := strconv.Atoi(port)
if err != nil || numeric < 1 || numeric > 65535 {
return "", fmt.Errorf("%s must be a valid TCP port (1-65535)", envName)
}
return strconv.Itoa(numeric), nil
}
func parseDuration(raw string, fallback time.Duration, envName string) (time.Duration, error) {
value := strings.TrimSpace(raw)
if value == "" {
return fallback, nil
}
duration, err := time.ParseDuration(value)
if err != nil {
return 0, fmt.Errorf("%s must be a valid duration (example: 10s): %w", envName, err)
}
if duration <= 0 {
return 0, fmt.Errorf("%s must be greater than zero", envName)
}
return duration, nil
}
func parseAbsoluteHTTPURL(raw string, envName string) (string, error) {
value := strings.TrimSpace(raw)
if value == "" {
return "", fmt.Errorf("%s is required", envName)
}
parsed, err := url.Parse(value)
if err != nil {
return "", fmt.Errorf("%s must be a valid URL: %w", envName, err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", fmt.Errorf("%s must use http or https", envName)
}
if parsed.Host == "" {
return "", fmt.Errorf("%s must include a host", envName)
}
return strings.TrimRight(parsed.String(), "/"), nil
}
func parseCORSAllowOrigins(mode string, raw string) ([]string, error) {
value := strings.TrimSpace(raw)
if value == "" {
if mode == "staging" || mode == "production" {
return nil, errors.New("CORS_ALLOW_ORIGINS is required in staging/production (comma-separated origins)")
}
return append([]string(nil), defaultLocalCORSOrigins...), nil
}
parts := strings.Split(value, ",")
origins := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, part := range parts {
origin := strings.TrimSpace(part)
if origin == "" {
continue
}
if origin == "*" {
return nil, errors.New("CORS_ALLOW_ORIGINS cannot include '*' when credentials are enabled")
}
validated, err := parseOrigin(origin, "CORS_ALLOW_ORIGINS")
if err != nil {
return nil, err
}
if _, exists := seen[validated]; exists {
continue
}
seen[validated] = struct{}{}
origins = append(origins, validated)
}
if len(origins) == 0 {
return nil, errors.New("CORS_ALLOW_ORIGINS must include at least one valid origin")
}
return origins, nil
}
func parseOrigin(raw string, envName string) (string, error) {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return "", fmt.Errorf("%s must contain valid origins: %w", envName, err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", fmt.Errorf("%s origins must use http or https", envName)
}
if parsed.Host == "" {
return "", fmt.Errorf("%s origins must include a host", envName)
}
if parsed.Path != "" && parsed.Path != "/" {
return "", fmt.Errorf("%s origins cannot include URL paths", envName)
}
if parsed.RawQuery != "" || parsed.Fragment != "" {
return "", fmt.Errorf("%s origins cannot include query or fragment components", envName)
}
return parsed.Scheme + "://" + parsed.Host, nil
}
func resolveMailSecret(mode string) (string, error) {
mailSecret := strings.TrimSpace(os.Getenv("MAIL_ENCRYPTION_KEY"))
if mailSecret == "" {
mailSecret = strings.TrimSpace(os.Getenv("BETTER_AUTH_SECRET"))
}
if mode != "staging" && mode != "production" {
return mailSecret, nil
}
if isInsecureSecret(mailSecret) {
return "", errors.New("set a strong MAIL_ENCRYPTION_KEY (or BETTER_AUTH_SECRET fallback) for staging/production")
}
return mailSecret, nil
}
func resolveMetricsAuthToken(mode string) (string, error) {
token := strings.TrimSpace(os.Getenv("METRICS_AUTH_TOKEN"))
if token == "" {
return "", nil
}
if mode == "production" && isInsecureSecret(token) {
return "", errors.New("METRICS_AUTH_TOKEN must be a strong non-placeholder secret when set in production")
}
return token, nil
}
func isInsecureSecret(secret string) bool {
normalized := strings.TrimSpace(strings.ToLower(secret))
if _, exists := insecureSecretPlaceholders[normalized]; exists {
return true
}
return len(strings.TrimSpace(secret)) < 16
}
func valueOrDefault(value string, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return value
}
+146
View File
@@ -0,0 +1,146 @@
package app
import (
"testing"
"time"
)
func TestParseAppMode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value string
want string
expectErr bool
}{
{name: "default", value: "", want: "development"},
{name: "production", value: "production", want: "production"},
{name: "normalized", value: " StAgInG ", want: "staging"},
{name: "invalid", value: "prod", expectErr: true},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
got, err := parseAppMode(test.value)
if test.expectErr {
if err == nil {
t.Fatalf("expected error for %q", test.value)
}
return
}
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if got != test.want {
t.Fatalf("parseAppMode(%q) = %q, want %q", test.value, got, test.want)
}
})
}
}
func TestParseDuration(t *testing.T) {
t.Parallel()
got, err := parseDuration("", 10*time.Second, "API_SHUTDOWN_TIMEOUT")
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if got != 10*time.Second {
t.Fatalf("parseDuration default = %s, want 10s", got)
}
got, err = parseDuration("15s", 10*time.Second, "API_SHUTDOWN_TIMEOUT")
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if got != 15*time.Second {
t.Fatalf("parseDuration explicit = %s, want 15s", got)
}
if _, err := parseDuration("0s", 10*time.Second, "API_SHUTDOWN_TIMEOUT"); err == nil {
t.Fatal("expected zero duration to fail")
}
if _, err := parseDuration("nope", 10*time.Second, "API_SHUTDOWN_TIMEOUT"); err == nil {
t.Fatal("expected invalid duration to fail")
}
}
func TestParseCORSAllowOrigins(t *testing.T) {
t.Parallel()
devOrigins, err := parseCORSAllowOrigins("development", "")
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if len(devOrigins) == 0 {
t.Fatal("expected default development origins")
}
if _, err := parseCORSAllowOrigins("production", ""); err == nil {
t.Fatal("expected production with empty CORS_ALLOW_ORIGINS to fail")
}
origins, err := parseCORSAllowOrigins("production", "https://app.example.com, https://app.example.com ,https://admin.example.com")
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if len(origins) != 2 {
t.Fatalf("expected 2 deduplicated origins, got %d", len(origins))
}
if _, err := parseCORSAllowOrigins("production", "*"); err == nil {
t.Fatal("expected wildcard origin to fail")
}
if _, err := parseCORSAllowOrigins("production", "https://app.example.com/path"); err == nil {
t.Fatal("expected origin with path to fail")
}
}
func TestIsInsecureSecret(t *testing.T) {
t.Parallel()
if !isInsecureSecret("replace-me-with-a-long-random-secret") {
t.Fatal("expected placeholder secret to be insecure")
}
if !isInsecureSecret("short") {
t.Fatal("expected short secret to be insecure")
}
if isInsecureSecret("this-is-a-strong-enough-secret-12345") {
t.Fatal("expected long random secret to pass")
}
}
func TestResolveMetricsAuthToken(t *testing.T) {
t.Run("empty token is allowed", func(t *testing.T) {
t.Setenv("METRICS_AUTH_TOKEN", "")
token, err := resolveMetricsAuthToken("production")
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if token != "" {
t.Fatalf("token = %q, want empty", token)
}
})
t.Run("rejects weak token in production", func(t *testing.T) {
t.Setenv("METRICS_AUTH_TOKEN", "short")
if _, err := resolveMetricsAuthToken("production"); err == nil {
t.Fatal("expected weak production token to fail")
}
})
t.Run("allows token in production", func(t *testing.T) {
t.Setenv("METRICS_AUTH_TOKEN", "this-is-a-strong-enough-secret-98765")
token, err := resolveMetricsAuthToken("production")
if err != nil {
t.Fatalf("did not expect error: %v", err)
}
if token == "" {
t.Fatal("expected resolved token")
}
})
}