first commit

This commit is contained in:
Tomas Dvorak
2026-04-10 12:01:36 +02:00
commit 035ac8ddb5
61 changed files with 6600 additions and 0 deletions
+335
View File
@@ -0,0 +1,335 @@
package billing
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
"bookra/apps/backend/internal/config"
"bookra/apps/backend/internal/db"
"bookra/apps/backend/internal/domain"
"github.com/jackc/pgx/v5"
"github.com/stripe/stripe-go/v83"
"github.com/stripe/stripe-go/v83/checkout/session"
"github.com/stripe/stripe-go/v83/customer"
"github.com/stripe/stripe-go/v83/subscription"
"github.com/stripe/stripe-go/v83/webhook"
)
var (
ErrBillingMembership = errors.New("billing membership not found")
ErrBillingPlanUnsupported = errors.New("billing plan is not configured")
ErrStripeSignatureMissing = errors.New("stripe signature is missing")
)
var allowedWebhookEvents = []string{
"checkout.session.completed",
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
"customer.subscription.paused",
"customer.subscription.resumed",
"invoice.paid",
"invoice.payment_failed",
"payment_intent.succeeded",
"payment_intent.payment_failed",
}
type Service struct {
cfg config.Config
repo db.Repository
}
func NewService(cfg config.Config, repo db.Repository) *Service {
return &Service{cfg: cfg, repo: repo}
}
func (s *Service) GetSubscription(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) {
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return domain.SubscriptionSnapshot{}, ErrBillingMembership
}
return domain.SubscriptionSnapshot{}, err
}
record, err := s.repo.GetSubscriptionSnapshot(ctx, membership.Tenant.ID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{
TenantID: membership.Tenant.ID,
StripeCustomerID: derefString(membership.Tenant.StripeCustomerID),
Status: membership.Tenant.SubscriptionStatus,
PlanCode: membership.Tenant.PlanCode,
}, s.cfg), nil
}
return domain.SubscriptionSnapshot{}, err
}
return toSnapshot(membership.Tenant, record, s.cfg), nil
}
func (s *Service) CreateCheckoutSession(ctx context.Context, principal domain.Principal, planCode string) (domain.CheckoutSessionResponse, error) {
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return domain.CheckoutSessionResponse{}, ErrBillingMembership
}
return domain.CheckoutSessionResponse{}, err
}
priceID := s.cfg.StripePriceIDs[planCode]
if priceID == "" {
return domain.CheckoutSessionResponse{}, ErrBillingPlanUnsupported
}
if s.cfg.StripeSecretKey == "" {
mockURL := fmt.Sprintf("%s/dashboard?billing=mock-checkout&plan=%s", s.cfg.FrontendURL, planCode)
return domain.CheckoutSessionResponse{URL: mockURL}, nil
}
stripe.Key = s.cfg.StripeSecretKey
customerID := derefString(membership.Tenant.StripeCustomerID)
if customerID == "" {
params := &stripe.CustomerParams{
Name: stripe.String(membership.Tenant.Name),
Email: stripe.String(principal.Email),
Metadata: map[string]string{"tenant_id": membership.Tenant.ID, "tenant_slug": membership.Tenant.Slug},
}
createdCustomer, err := customer.New(params)
if err != nil {
return domain.CheckoutSessionResponse{}, err
}
customerID = createdCustomer.ID
if err := s.repo.UpdateTenantStripeCustomerID(ctx, membership.Tenant.ID, customerID); err != nil {
return domain.CheckoutSessionResponse{}, err
}
}
params := &stripe.CheckoutSessionParams{
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
Customer: stripe.String(customerID),
SuccessURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=success", s.cfg.FrontendURL)),
CancelURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=cancelled", s.cfg.FrontendURL)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
Metadata: map[string]string{
"tenant_id": membership.Tenant.ID,
"plan_code": planCode,
},
}
checkoutSession, err := session.New(params)
if err != nil {
return domain.CheckoutSessionResponse{}, err
}
return domain.CheckoutSessionResponse{URL: checkoutSession.URL}, nil
}
func (s *Service) Refresh(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) {
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return domain.SubscriptionSnapshot{}, ErrBillingMembership
}
return domain.SubscriptionSnapshot{}, err
}
customerID := derefString(membership.Tenant.StripeCustomerID)
if customerID == "" {
return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{
TenantID: membership.Tenant.ID,
StripeCustomerID: "",
Status: "none",
PlanCode: membership.Tenant.PlanCode,
}, s.cfg), nil
}
record, err := s.syncStripeData(ctx, membership.Tenant, customerID)
if err != nil {
return domain.SubscriptionSnapshot{}, err
}
return toSnapshot(membership.Tenant, record, s.cfg), nil
}
func (s *Service) HandleWebhook(ctx context.Context, signature string, payload []byte) error {
if s.cfg.StripeSecretKey == "" {
return nil
}
if signature == "" {
return ErrStripeSignatureMissing
}
event, err := webhook.ConstructEvent(payload, signature, s.cfg.StripeWebhookKey)
if err != nil {
return err
}
if !slices.Contains(allowedWebhookEvents, string(event.Type)) {
return nil
}
customerID := extractCustomerID(event)
if customerID == "" {
return nil
}
tenant, err := s.repo.GetTenantByStripeCustomerID(ctx, customerID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
return err
}
inserted, err := s.repo.RecordStripeEvent(ctx, tenant.ID, event.ID, string(event.Type), payload)
if err != nil || !inserted {
return err
}
_, err = s.syncStripeData(ctx, tenant, customerID)
return err
}
func (s *Service) syncStripeData(ctx context.Context, tenant db.TenantRecord, customerID string) (db.BillingSnapshotRecord, error) {
if s.cfg.StripeSecretKey == "" {
now := time.Now().UTC()
record := db.BillingSnapshotRecord{
TenantID: tenant.ID,
StripeCustomerID: customerID,
StripeSubscriptionID: "",
Status: tenant.SubscriptionStatus,
PlanCode: tenant.PlanCode,
PriceID: s.cfg.StripePriceIDs[tenant.PlanCode],
LastSyncedAt: &now,
}
if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil {
return db.BillingSnapshotRecord{}, err
}
return record, nil
}
stripe.Key = s.cfg.StripeSecretKey
params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)}
params.Status = stripe.String("all")
params.AddExpand("data.default_payment_method")
params.AddExpand("data.items.data.price")
iter := subscription.List(params)
if iter.Err() != nil {
return db.BillingSnapshotRecord{}, iter.Err()
}
now := time.Now().UTC()
record := db.BillingSnapshotRecord{
TenantID: tenant.ID,
StripeCustomerID: customerID,
StripeSubscriptionID: "",
Status: "none",
PlanCode: tenant.PlanCode,
PriceID: "",
LastSyncedAt: &now,
}
if iter.Next() {
subscriptionRecord := iter.Subscription()
record.StripeSubscriptionID = subscriptionRecord.ID
record.Status = string(subscriptionRecord.Status)
record.CancelAtPeriodEnd = subscriptionRecord.CancelAtPeriodEnd
if len(subscriptionRecord.Items.Data) > 0 {
record.PriceID = subscriptionRecord.Items.Data[0].Price.ID
record.PlanCode = s.planCodeForPrice(record.PriceID, tenant.PlanCode)
record.CurrentPeriodStart = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodStart)
record.CurrentPeriodEnd = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodEnd)
}
if subscriptionRecord.DefaultPaymentMethod != nil && subscriptionRecord.DefaultPaymentMethod.Card != nil {
record.PaymentMethodBrand = string(subscriptionRecord.DefaultPaymentMethod.Card.Brand)
record.PaymentMethodLast4 = subscriptionRecord.DefaultPaymentMethod.Card.Last4
}
}
if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil {
return db.BillingSnapshotRecord{}, err
}
if err := s.repo.UpdateTenantBillingState(ctx, tenant.ID, record.PlanCode, record.Status, record.StripeSubscriptionID); err != nil {
return db.BillingSnapshotRecord{}, err
}
return record, nil
}
func toSnapshot(tenant db.TenantRecord, record db.BillingSnapshotRecord, cfg config.Config) domain.SubscriptionSnapshot {
if record.PlanCode == "" {
record.PlanCode = tenant.PlanCode
}
if record.Status == "" {
record.Status = tenant.SubscriptionStatus
}
return domain.SubscriptionSnapshot{
TenantID: tenant.ID,
CustomerID: record.StripeCustomerID,
SubscriptionID: record.StripeSubscriptionID,
Status: record.Status,
PlanCode: record.PlanCode,
PriceID: record.PriceID,
CancelAtPeriodEnd: record.CancelAtPeriodEnd,
CurrentPeriodStart: record.CurrentPeriodStart,
CurrentPeriodEnd: record.CurrentPeriodEnd,
PaymentMethodBrand: record.PaymentMethodBrand,
PaymentMethodLast4: record.PaymentMethodLast4,
Entitlements: entitlementsForPlan(record.PlanCode),
LastSyncedAt: record.LastSyncedAt,
CheckoutURLAvailable: cfg.StripePriceIDs[record.PlanCode] != "",
}
}
func entitlementsForPlan(planCode string) domain.PlanEntitlements {
switch planCode {
case "starter":
return domain.PlanEntitlements{MaxLocations: 1, MaxStaff: 3, SMSAddonAvailable: false, AdvancedReporting: false}
case "multi-location":
return domain.PlanEntitlements{MaxLocations: 10, MaxStaff: 30, SMSAddonAvailable: true, AdvancedReporting: true}
default:
return domain.PlanEntitlements{MaxLocations: 3, MaxStaff: 10, SMSAddonAvailable: true, AdvancedReporting: true}
}
}
func (s *Service) planCodeForPrice(priceID string, fallback string) string {
for code, configuredPriceID := range s.cfg.StripePriceIDs {
if configuredPriceID != "" && configuredPriceID == priceID {
return code
}
}
return fallback
}
func derefString(value *string) string {
if value == nil {
return ""
}
return *value
}
func toTimePtr(value int64) *time.Time {
if value == 0 {
return nil
}
timestamp := time.Unix(value, 0).UTC()
return &timestamp
}
func extractCustomerID(event stripe.Event) string {
var payload map[string]any
if err := json.Unmarshal(event.Data.Raw, &payload); err != nil {
return ""
}
value, ok := payload["customer"]
if !ok {
return ""
}
customerID, _ := value.(string)
return customerID
}