package billing import ( "context" "encoding/json" "errors" "fmt" "slices" "time" "bookra/apps/backend/internal/config" "bookra/apps/backend/internal/db" "bookra/apps/backend/internal/domain" "github.com/jackc/pgx/v5" "github.com/stripe/stripe-go/v83" "github.com/stripe/stripe-go/v83/checkout/session" "github.com/stripe/stripe-go/v83/customer" "github.com/stripe/stripe-go/v83/subscription" "github.com/stripe/stripe-go/v83/webhook" ) var ( ErrBillingMembership = errors.New("billing membership not found") ErrBillingPlanUnsupported = errors.New("billing plan is not configured") ErrStripeSignatureMissing = errors.New("stripe signature is missing") ) var allowedWebhookEvents = []string{ "checkout.session.completed", "customer.subscription.created", "customer.subscription.updated", "customer.subscription.deleted", "customer.subscription.paused", "customer.subscription.resumed", "invoice.paid", "invoice.payment_failed", "payment_intent.succeeded", "payment_intent.payment_failed", } type Service struct { cfg config.Config repo db.Repository } func NewService(cfg config.Config, repo db.Repository) *Service { return &Service{cfg: cfg, repo: repo} } func (s *Service) GetSubscription(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) { membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.SubscriptionSnapshot{}, ErrBillingMembership } return domain.SubscriptionSnapshot{}, err } record, err := s.repo.GetSubscriptionSnapshot(ctx, membership.Tenant.ID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{ TenantID: membership.Tenant.ID, StripeCustomerID: derefString(membership.Tenant.StripeCustomerID), Status: membership.Tenant.SubscriptionStatus, PlanCode: membership.Tenant.PlanCode, }, s.cfg), nil } return domain.SubscriptionSnapshot{}, err } return toSnapshot(membership.Tenant, record, s.cfg), nil } func (s *Service) CreateCheckoutSession(ctx context.Context, principal domain.Principal, planCode string) (domain.CheckoutSessionResponse, error) { membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.CheckoutSessionResponse{}, ErrBillingMembership } return domain.CheckoutSessionResponse{}, err } priceID := s.cfg.StripePriceIDs[planCode] if priceID == "" { return domain.CheckoutSessionResponse{}, ErrBillingPlanUnsupported } if s.cfg.StripeSecretKey == "" { mockURL := fmt.Sprintf("%s/dashboard?billing=mock-checkout&plan=%s", s.cfg.FrontendURL, planCode) return domain.CheckoutSessionResponse{URL: mockURL}, nil } stripe.Key = s.cfg.StripeSecretKey customerID := derefString(membership.Tenant.StripeCustomerID) if customerID == "" { params := &stripe.CustomerParams{ Name: stripe.String(membership.Tenant.Name), Email: stripe.String(principal.Email), Metadata: map[string]string{"tenant_id": membership.Tenant.ID, "tenant_slug": membership.Tenant.Slug}, } createdCustomer, err := customer.New(params) if err != nil { return domain.CheckoutSessionResponse{}, err } customerID = createdCustomer.ID if err := s.repo.UpdateTenantStripeCustomerID(ctx, membership.Tenant.ID, customerID); err != nil { return domain.CheckoutSessionResponse{}, err } } params := &stripe.CheckoutSessionParams{ Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), Customer: stripe.String(customerID), SuccessURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=success", s.cfg.FrontendURL)), CancelURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=cancelled", s.cfg.FrontendURL)), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(priceID), Quantity: stripe.Int64(1), }, }, Metadata: map[string]string{ "tenant_id": membership.Tenant.ID, "plan_code": planCode, }, } checkoutSession, err := session.New(params) if err != nil { return domain.CheckoutSessionResponse{}, err } return domain.CheckoutSessionResponse{URL: checkoutSession.URL}, nil } func (s *Service) Refresh(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) { membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.SubscriptionSnapshot{}, ErrBillingMembership } return domain.SubscriptionSnapshot{}, err } customerID := derefString(membership.Tenant.StripeCustomerID) if customerID == "" { return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{ TenantID: membership.Tenant.ID, StripeCustomerID: "", Status: "none", PlanCode: membership.Tenant.PlanCode, }, s.cfg), nil } record, err := s.syncStripeData(ctx, membership.Tenant, customerID) if err != nil { return domain.SubscriptionSnapshot{}, err } return toSnapshot(membership.Tenant, record, s.cfg), nil } func (s *Service) HandleWebhook(ctx context.Context, signature string, payload []byte) error { if s.cfg.StripeSecretKey == "" { return nil } if signature == "" { return ErrStripeSignatureMissing } event, err := webhook.ConstructEvent(payload, signature, s.cfg.StripeWebhookKey) if err != nil { return err } if !slices.Contains(allowedWebhookEvents, string(event.Type)) { return nil } customerID := extractCustomerID(event) if customerID == "" { return nil } tenant, err := s.repo.GetTenantByStripeCustomerID(ctx, customerID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil } return err } inserted, err := s.repo.RecordStripeEvent(ctx, tenant.ID, event.ID, string(event.Type), payload) if err != nil || !inserted { return err } _, err = s.syncStripeData(ctx, tenant, customerID) return err } func (s *Service) syncStripeData(ctx context.Context, tenant db.TenantRecord, customerID string) (db.BillingSnapshotRecord, error) { if s.cfg.StripeSecretKey == "" { now := time.Now().UTC() record := db.BillingSnapshotRecord{ TenantID: tenant.ID, StripeCustomerID: customerID, StripeSubscriptionID: "", Status: tenant.SubscriptionStatus, PlanCode: tenant.PlanCode, PriceID: s.cfg.StripePriceIDs[tenant.PlanCode], LastSyncedAt: &now, } if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil { return db.BillingSnapshotRecord{}, err } return record, nil } stripe.Key = s.cfg.StripeSecretKey params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)} params.Status = stripe.String("all") params.AddExpand("data.default_payment_method") params.AddExpand("data.items.data.price") iter := subscription.List(params) if iter.Err() != nil { return db.BillingSnapshotRecord{}, iter.Err() } now := time.Now().UTC() record := db.BillingSnapshotRecord{ TenantID: tenant.ID, StripeCustomerID: customerID, StripeSubscriptionID: "", Status: "none", PlanCode: tenant.PlanCode, PriceID: "", LastSyncedAt: &now, } if iter.Next() { subscriptionRecord := iter.Subscription() record.StripeSubscriptionID = subscriptionRecord.ID record.Status = string(subscriptionRecord.Status) record.CancelAtPeriodEnd = subscriptionRecord.CancelAtPeriodEnd if len(subscriptionRecord.Items.Data) > 0 { record.PriceID = subscriptionRecord.Items.Data[0].Price.ID record.PlanCode = s.planCodeForPrice(record.PriceID, tenant.PlanCode) record.CurrentPeriodStart = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodStart) record.CurrentPeriodEnd = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodEnd) } if subscriptionRecord.DefaultPaymentMethod != nil && subscriptionRecord.DefaultPaymentMethod.Card != nil { record.PaymentMethodBrand = string(subscriptionRecord.DefaultPaymentMethod.Card.Brand) record.PaymentMethodLast4 = subscriptionRecord.DefaultPaymentMethod.Card.Last4 } } if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil { return db.BillingSnapshotRecord{}, err } if err := s.repo.UpdateTenantBillingState(ctx, tenant.ID, record.PlanCode, record.Status, record.StripeSubscriptionID); err != nil { return db.BillingSnapshotRecord{}, err } return record, nil } func toSnapshot(tenant db.TenantRecord, record db.BillingSnapshotRecord, cfg config.Config) domain.SubscriptionSnapshot { if record.PlanCode == "" { record.PlanCode = tenant.PlanCode } if record.Status == "" { record.Status = tenant.SubscriptionStatus } return domain.SubscriptionSnapshot{ TenantID: tenant.ID, CustomerID: record.StripeCustomerID, SubscriptionID: record.StripeSubscriptionID, Status: record.Status, PlanCode: record.PlanCode, PriceID: record.PriceID, CancelAtPeriodEnd: record.CancelAtPeriodEnd, CurrentPeriodStart: record.CurrentPeriodStart, CurrentPeriodEnd: record.CurrentPeriodEnd, PaymentMethodBrand: record.PaymentMethodBrand, PaymentMethodLast4: record.PaymentMethodLast4, Entitlements: entitlementsForPlan(record.PlanCode), LastSyncedAt: record.LastSyncedAt, CheckoutURLAvailable: cfg.StripePriceIDs[record.PlanCode] != "", } } func entitlementsForPlan(planCode string) domain.PlanEntitlements { switch planCode { case "starter": return domain.PlanEntitlements{MaxLocations: 1, MaxStaff: 3, SMSAddonAvailable: false, AdvancedReporting: false} case "multi-location": return domain.PlanEntitlements{MaxLocations: 10, MaxStaff: 30, SMSAddonAvailable: true, AdvancedReporting: true} default: return domain.PlanEntitlements{MaxLocations: 3, MaxStaff: 10, SMSAddonAvailable: true, AdvancedReporting: true} } } func (s *Service) planCodeForPrice(priceID string, fallback string) string { for code, configuredPriceID := range s.cfg.StripePriceIDs { if configuredPriceID != "" && configuredPriceID == priceID { return code } } return fallback } func derefString(value *string) string { if value == nil { return "" } return *value } func toTimePtr(value int64) *time.Time { if value == 0 { return nil } timestamp := time.Unix(value, 0).UTC() return ×tamp } func extractCustomerID(event stripe.Event) string { var payload map[string]any if err := json.Unmarshal(event.Data.Raw, &payload); err != nil { return "" } value, ok := payload["customer"] if !ok { return "" } customerID, _ := value.(string) return customerID }