mirror of
https://github.com/Dvorinka/Bookra.git
synced 2026-06-04 12:33:00 +00:00
336 lines
10 KiB
Go
336 lines
10 KiB
Go
package billing
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"bookra/apps/backend/internal/config"
|
|
"bookra/apps/backend/internal/db"
|
|
"bookra/apps/backend/internal/domain"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/stripe/stripe-go/v83"
|
|
"github.com/stripe/stripe-go/v83/checkout/session"
|
|
"github.com/stripe/stripe-go/v83/customer"
|
|
"github.com/stripe/stripe-go/v83/subscription"
|
|
"github.com/stripe/stripe-go/v83/webhook"
|
|
)
|
|
|
|
var (
|
|
ErrBillingMembership = errors.New("billing membership not found")
|
|
ErrBillingPlanUnsupported = errors.New("billing plan is not configured")
|
|
ErrStripeSignatureMissing = errors.New("stripe signature is missing")
|
|
)
|
|
|
|
var allowedWebhookEvents = []string{
|
|
"checkout.session.completed",
|
|
"customer.subscription.created",
|
|
"customer.subscription.updated",
|
|
"customer.subscription.deleted",
|
|
"customer.subscription.paused",
|
|
"customer.subscription.resumed",
|
|
"invoice.paid",
|
|
"invoice.payment_failed",
|
|
"payment_intent.succeeded",
|
|
"payment_intent.payment_failed",
|
|
}
|
|
|
|
type Service struct {
|
|
cfg config.Config
|
|
repo db.Repository
|
|
}
|
|
|
|
func NewService(cfg config.Config, repo db.Repository) *Service {
|
|
return &Service{cfg: cfg, repo: repo}
|
|
}
|
|
|
|
func (s *Service) GetSubscription(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) {
|
|
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return domain.SubscriptionSnapshot{}, ErrBillingMembership
|
|
}
|
|
return domain.SubscriptionSnapshot{}, err
|
|
}
|
|
record, err := s.repo.GetSubscriptionSnapshot(ctx, membership.Tenant.ID)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{
|
|
TenantID: membership.Tenant.ID,
|
|
StripeCustomerID: derefString(membership.Tenant.StripeCustomerID),
|
|
Status: membership.Tenant.SubscriptionStatus,
|
|
PlanCode: membership.Tenant.PlanCode,
|
|
}, s.cfg), nil
|
|
}
|
|
return domain.SubscriptionSnapshot{}, err
|
|
}
|
|
return toSnapshot(membership.Tenant, record, s.cfg), nil
|
|
}
|
|
|
|
func (s *Service) CreateCheckoutSession(ctx context.Context, principal domain.Principal, planCode string) (domain.CheckoutSessionResponse, error) {
|
|
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return domain.CheckoutSessionResponse{}, ErrBillingMembership
|
|
}
|
|
return domain.CheckoutSessionResponse{}, err
|
|
}
|
|
|
|
priceID := s.cfg.StripePriceIDs[planCode]
|
|
if priceID == "" {
|
|
return domain.CheckoutSessionResponse{}, ErrBillingPlanUnsupported
|
|
}
|
|
|
|
if s.cfg.StripeSecretKey == "" {
|
|
mockURL := fmt.Sprintf("%s/dashboard?billing=mock-checkout&plan=%s", s.cfg.FrontendURL, planCode)
|
|
return domain.CheckoutSessionResponse{URL: mockURL}, nil
|
|
}
|
|
|
|
stripe.Key = s.cfg.StripeSecretKey
|
|
customerID := derefString(membership.Tenant.StripeCustomerID)
|
|
if customerID == "" {
|
|
params := &stripe.CustomerParams{
|
|
Name: stripe.String(membership.Tenant.Name),
|
|
Email: stripe.String(principal.Email),
|
|
Metadata: map[string]string{"tenant_id": membership.Tenant.ID, "tenant_slug": membership.Tenant.Slug},
|
|
}
|
|
createdCustomer, err := customer.New(params)
|
|
if err != nil {
|
|
return domain.CheckoutSessionResponse{}, err
|
|
}
|
|
customerID = createdCustomer.ID
|
|
if err := s.repo.UpdateTenantStripeCustomerID(ctx, membership.Tenant.ID, customerID); err != nil {
|
|
return domain.CheckoutSessionResponse{}, err
|
|
}
|
|
}
|
|
|
|
params := &stripe.CheckoutSessionParams{
|
|
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
|
Customer: stripe.String(customerID),
|
|
SuccessURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=success", s.cfg.FrontendURL)),
|
|
CancelURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=cancelled", s.cfg.FrontendURL)),
|
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
|
{
|
|
Price: stripe.String(priceID),
|
|
Quantity: stripe.Int64(1),
|
|
},
|
|
},
|
|
Metadata: map[string]string{
|
|
"tenant_id": membership.Tenant.ID,
|
|
"plan_code": planCode,
|
|
},
|
|
}
|
|
|
|
checkoutSession, err := session.New(params)
|
|
if err != nil {
|
|
return domain.CheckoutSessionResponse{}, err
|
|
}
|
|
return domain.CheckoutSessionResponse{URL: checkoutSession.URL}, nil
|
|
}
|
|
|
|
func (s *Service) Refresh(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) {
|
|
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return domain.SubscriptionSnapshot{}, ErrBillingMembership
|
|
}
|
|
return domain.SubscriptionSnapshot{}, err
|
|
}
|
|
customerID := derefString(membership.Tenant.StripeCustomerID)
|
|
if customerID == "" {
|
|
return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{
|
|
TenantID: membership.Tenant.ID,
|
|
StripeCustomerID: "",
|
|
Status: "none",
|
|
PlanCode: membership.Tenant.PlanCode,
|
|
}, s.cfg), nil
|
|
}
|
|
record, err := s.syncStripeData(ctx, membership.Tenant, customerID)
|
|
if err != nil {
|
|
return domain.SubscriptionSnapshot{}, err
|
|
}
|
|
return toSnapshot(membership.Tenant, record, s.cfg), nil
|
|
}
|
|
|
|
func (s *Service) HandleWebhook(ctx context.Context, signature string, payload []byte) error {
|
|
if s.cfg.StripeSecretKey == "" {
|
|
return nil
|
|
}
|
|
if signature == "" {
|
|
return ErrStripeSignatureMissing
|
|
}
|
|
|
|
event, err := webhook.ConstructEvent(payload, signature, s.cfg.StripeWebhookKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !slices.Contains(allowedWebhookEvents, string(event.Type)) {
|
|
return nil
|
|
}
|
|
|
|
customerID := extractCustomerID(event)
|
|
if customerID == "" {
|
|
return nil
|
|
}
|
|
|
|
tenant, err := s.repo.GetTenantByStripeCustomerID(ctx, customerID)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
inserted, err := s.repo.RecordStripeEvent(ctx, tenant.ID, event.ID, string(event.Type), payload)
|
|
if err != nil || !inserted {
|
|
return err
|
|
}
|
|
|
|
_, err = s.syncStripeData(ctx, tenant, customerID)
|
|
return err
|
|
}
|
|
|
|
func (s *Service) syncStripeData(ctx context.Context, tenant db.TenantRecord, customerID string) (db.BillingSnapshotRecord, error) {
|
|
if s.cfg.StripeSecretKey == "" {
|
|
now := time.Now().UTC()
|
|
record := db.BillingSnapshotRecord{
|
|
TenantID: tenant.ID,
|
|
StripeCustomerID: customerID,
|
|
StripeSubscriptionID: "",
|
|
Status: tenant.SubscriptionStatus,
|
|
PlanCode: tenant.PlanCode,
|
|
PriceID: s.cfg.StripePriceIDs[tenant.PlanCode],
|
|
LastSyncedAt: &now,
|
|
}
|
|
if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil {
|
|
return db.BillingSnapshotRecord{}, err
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
stripe.Key = s.cfg.StripeSecretKey
|
|
params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)}
|
|
params.Status = stripe.String("all")
|
|
params.AddExpand("data.default_payment_method")
|
|
params.AddExpand("data.items.data.price")
|
|
|
|
iter := subscription.List(params)
|
|
if iter.Err() != nil {
|
|
return db.BillingSnapshotRecord{}, iter.Err()
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
record := db.BillingSnapshotRecord{
|
|
TenantID: tenant.ID,
|
|
StripeCustomerID: customerID,
|
|
StripeSubscriptionID: "",
|
|
Status: "none",
|
|
PlanCode: tenant.PlanCode,
|
|
PriceID: "",
|
|
LastSyncedAt: &now,
|
|
}
|
|
|
|
if iter.Next() {
|
|
subscriptionRecord := iter.Subscription()
|
|
record.StripeSubscriptionID = subscriptionRecord.ID
|
|
record.Status = string(subscriptionRecord.Status)
|
|
record.CancelAtPeriodEnd = subscriptionRecord.CancelAtPeriodEnd
|
|
if len(subscriptionRecord.Items.Data) > 0 {
|
|
record.PriceID = subscriptionRecord.Items.Data[0].Price.ID
|
|
record.PlanCode = s.planCodeForPrice(record.PriceID, tenant.PlanCode)
|
|
record.CurrentPeriodStart = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodStart)
|
|
record.CurrentPeriodEnd = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodEnd)
|
|
}
|
|
if subscriptionRecord.DefaultPaymentMethod != nil && subscriptionRecord.DefaultPaymentMethod.Card != nil {
|
|
record.PaymentMethodBrand = string(subscriptionRecord.DefaultPaymentMethod.Card.Brand)
|
|
record.PaymentMethodLast4 = subscriptionRecord.DefaultPaymentMethod.Card.Last4
|
|
}
|
|
}
|
|
|
|
if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil {
|
|
return db.BillingSnapshotRecord{}, err
|
|
}
|
|
if err := s.repo.UpdateTenantBillingState(ctx, tenant.ID, record.PlanCode, record.Status, record.StripeSubscriptionID); err != nil {
|
|
return db.BillingSnapshotRecord{}, err
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
func toSnapshot(tenant db.TenantRecord, record db.BillingSnapshotRecord, cfg config.Config) domain.SubscriptionSnapshot {
|
|
if record.PlanCode == "" {
|
|
record.PlanCode = tenant.PlanCode
|
|
}
|
|
if record.Status == "" {
|
|
record.Status = tenant.SubscriptionStatus
|
|
}
|
|
return domain.SubscriptionSnapshot{
|
|
TenantID: tenant.ID,
|
|
CustomerID: record.StripeCustomerID,
|
|
SubscriptionID: record.StripeSubscriptionID,
|
|
Status: record.Status,
|
|
PlanCode: record.PlanCode,
|
|
PriceID: record.PriceID,
|
|
CancelAtPeriodEnd: record.CancelAtPeriodEnd,
|
|
CurrentPeriodStart: record.CurrentPeriodStart,
|
|
CurrentPeriodEnd: record.CurrentPeriodEnd,
|
|
PaymentMethodBrand: record.PaymentMethodBrand,
|
|
PaymentMethodLast4: record.PaymentMethodLast4,
|
|
Entitlements: entitlementsForPlan(record.PlanCode),
|
|
LastSyncedAt: record.LastSyncedAt,
|
|
CheckoutURLAvailable: cfg.StripePriceIDs[record.PlanCode] != "",
|
|
}
|
|
}
|
|
|
|
func entitlementsForPlan(planCode string) domain.PlanEntitlements {
|
|
switch planCode {
|
|
case "starter":
|
|
return domain.PlanEntitlements{MaxLocations: 1, MaxStaff: 3, SMSAddonAvailable: false, AdvancedReporting: false}
|
|
case "multi-location":
|
|
return domain.PlanEntitlements{MaxLocations: 10, MaxStaff: 30, SMSAddonAvailable: true, AdvancedReporting: true}
|
|
default:
|
|
return domain.PlanEntitlements{MaxLocations: 3, MaxStaff: 10, SMSAddonAvailable: true, AdvancedReporting: true}
|
|
}
|
|
}
|
|
|
|
func (s *Service) planCodeForPrice(priceID string, fallback string) string {
|
|
for code, configuredPriceID := range s.cfg.StripePriceIDs {
|
|
if configuredPriceID != "" && configuredPriceID == priceID {
|
|
return code
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func derefString(value *string) string {
|
|
if value == nil {
|
|
return ""
|
|
}
|
|
return *value
|
|
}
|
|
|
|
func toTimePtr(value int64) *time.Time {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
timestamp := time.Unix(value, 0).UTC()
|
|
return ×tamp
|
|
}
|
|
|
|
func extractCustomerID(event stripe.Event) string {
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(event.Data.Raw, &payload); err != nil {
|
|
return ""
|
|
}
|
|
|
|
value, ok := payload["customer"]
|
|
if !ok {
|
|
return ""
|
|
}
|
|
customerID, _ := value.(string)
|
|
return customerID
|
|
}
|