mirror of
https://github.com/Dvorinka/Bookra.git
synced 2026-06-03 20:13:00 +00:00
first commit
This commit is contained in:
@@ -0,0 +1,335 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"bookra/apps/backend/internal/config"
|
||||
"bookra/apps/backend/internal/db"
|
||||
"bookra/apps/backend/internal/domain"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stripe/stripe-go/v83"
|
||||
"github.com/stripe/stripe-go/v83/checkout/session"
|
||||
"github.com/stripe/stripe-go/v83/customer"
|
||||
"github.com/stripe/stripe-go/v83/subscription"
|
||||
"github.com/stripe/stripe-go/v83/webhook"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBillingMembership = errors.New("billing membership not found")
|
||||
ErrBillingPlanUnsupported = errors.New("billing plan is not configured")
|
||||
ErrStripeSignatureMissing = errors.New("stripe signature is missing")
|
||||
)
|
||||
|
||||
var allowedWebhookEvents = []string{
|
||||
"checkout.session.completed",
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"customer.subscription.paused",
|
||||
"customer.subscription.resumed",
|
||||
"invoice.paid",
|
||||
"invoice.payment_failed",
|
||||
"payment_intent.succeeded",
|
||||
"payment_intent.payment_failed",
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
cfg config.Config
|
||||
repo db.Repository
|
||||
}
|
||||
|
||||
func NewService(cfg config.Config, repo db.Repository) *Service {
|
||||
return &Service{cfg: cfg, repo: repo}
|
||||
}
|
||||
|
||||
func (s *Service) GetSubscription(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) {
|
||||
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.SubscriptionSnapshot{}, ErrBillingMembership
|
||||
}
|
||||
return domain.SubscriptionSnapshot{}, err
|
||||
}
|
||||
record, err := s.repo.GetSubscriptionSnapshot(ctx, membership.Tenant.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{
|
||||
TenantID: membership.Tenant.ID,
|
||||
StripeCustomerID: derefString(membership.Tenant.StripeCustomerID),
|
||||
Status: membership.Tenant.SubscriptionStatus,
|
||||
PlanCode: membership.Tenant.PlanCode,
|
||||
}, s.cfg), nil
|
||||
}
|
||||
return domain.SubscriptionSnapshot{}, err
|
||||
}
|
||||
return toSnapshot(membership.Tenant, record, s.cfg), nil
|
||||
}
|
||||
|
||||
func (s *Service) CreateCheckoutSession(ctx context.Context, principal domain.Principal, planCode string) (domain.CheckoutSessionResponse, error) {
|
||||
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.CheckoutSessionResponse{}, ErrBillingMembership
|
||||
}
|
||||
return domain.CheckoutSessionResponse{}, err
|
||||
}
|
||||
|
||||
priceID := s.cfg.StripePriceIDs[planCode]
|
||||
if priceID == "" {
|
||||
return domain.CheckoutSessionResponse{}, ErrBillingPlanUnsupported
|
||||
}
|
||||
|
||||
if s.cfg.StripeSecretKey == "" {
|
||||
mockURL := fmt.Sprintf("%s/dashboard?billing=mock-checkout&plan=%s", s.cfg.FrontendURL, planCode)
|
||||
return domain.CheckoutSessionResponse{URL: mockURL}, nil
|
||||
}
|
||||
|
||||
stripe.Key = s.cfg.StripeSecretKey
|
||||
customerID := derefString(membership.Tenant.StripeCustomerID)
|
||||
if customerID == "" {
|
||||
params := &stripe.CustomerParams{
|
||||
Name: stripe.String(membership.Tenant.Name),
|
||||
Email: stripe.String(principal.Email),
|
||||
Metadata: map[string]string{"tenant_id": membership.Tenant.ID, "tenant_slug": membership.Tenant.Slug},
|
||||
}
|
||||
createdCustomer, err := customer.New(params)
|
||||
if err != nil {
|
||||
return domain.CheckoutSessionResponse{}, err
|
||||
}
|
||||
customerID = createdCustomer.ID
|
||||
if err := s.repo.UpdateTenantStripeCustomerID(ctx, membership.Tenant.ID, customerID); err != nil {
|
||||
return domain.CheckoutSessionResponse{}, err
|
||||
}
|
||||
}
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
Customer: stripe.String(customerID),
|
||||
SuccessURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=success", s.cfg.FrontendURL)),
|
||||
CancelURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=cancelled", s.cfg.FrontendURL)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(priceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
Metadata: map[string]string{
|
||||
"tenant_id": membership.Tenant.ID,
|
||||
"plan_code": planCode,
|
||||
},
|
||||
}
|
||||
|
||||
checkoutSession, err := session.New(params)
|
||||
if err != nil {
|
||||
return domain.CheckoutSessionResponse{}, err
|
||||
}
|
||||
return domain.CheckoutSessionResponse{URL: checkoutSession.URL}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Refresh(ctx context.Context, principal domain.Principal) (domain.SubscriptionSnapshot, error) {
|
||||
membership, err := s.repo.GetTenantMembershipByUserID(ctx, principal.Subject)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.SubscriptionSnapshot{}, ErrBillingMembership
|
||||
}
|
||||
return domain.SubscriptionSnapshot{}, err
|
||||
}
|
||||
customerID := derefString(membership.Tenant.StripeCustomerID)
|
||||
if customerID == "" {
|
||||
return toSnapshot(membership.Tenant, db.BillingSnapshotRecord{
|
||||
TenantID: membership.Tenant.ID,
|
||||
StripeCustomerID: "",
|
||||
Status: "none",
|
||||
PlanCode: membership.Tenant.PlanCode,
|
||||
}, s.cfg), nil
|
||||
}
|
||||
record, err := s.syncStripeData(ctx, membership.Tenant, customerID)
|
||||
if err != nil {
|
||||
return domain.SubscriptionSnapshot{}, err
|
||||
}
|
||||
return toSnapshot(membership.Tenant, record, s.cfg), nil
|
||||
}
|
||||
|
||||
func (s *Service) HandleWebhook(ctx context.Context, signature string, payload []byte) error {
|
||||
if s.cfg.StripeSecretKey == "" {
|
||||
return nil
|
||||
}
|
||||
if signature == "" {
|
||||
return ErrStripeSignatureMissing
|
||||
}
|
||||
|
||||
event, err := webhook.ConstructEvent(payload, signature, s.cfg.StripeWebhookKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains(allowedWebhookEvents, string(event.Type)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
customerID := extractCustomerID(event)
|
||||
if customerID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tenant, err := s.repo.GetTenantByStripeCustomerID(ctx, customerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
inserted, err := s.repo.RecordStripeEvent(ctx, tenant.ID, event.ID, string(event.Type), payload)
|
||||
if err != nil || !inserted {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.syncStripeData(ctx, tenant, customerID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) syncStripeData(ctx context.Context, tenant db.TenantRecord, customerID string) (db.BillingSnapshotRecord, error) {
|
||||
if s.cfg.StripeSecretKey == "" {
|
||||
now := time.Now().UTC()
|
||||
record := db.BillingSnapshotRecord{
|
||||
TenantID: tenant.ID,
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: "",
|
||||
Status: tenant.SubscriptionStatus,
|
||||
PlanCode: tenant.PlanCode,
|
||||
PriceID: s.cfg.StripePriceIDs[tenant.PlanCode],
|
||||
LastSyncedAt: &now,
|
||||
}
|
||||
if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil {
|
||||
return db.BillingSnapshotRecord{}, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
stripe.Key = s.cfg.StripeSecretKey
|
||||
params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)}
|
||||
params.Status = stripe.String("all")
|
||||
params.AddExpand("data.default_payment_method")
|
||||
params.AddExpand("data.items.data.price")
|
||||
|
||||
iter := subscription.List(params)
|
||||
if iter.Err() != nil {
|
||||
return db.BillingSnapshotRecord{}, iter.Err()
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := db.BillingSnapshotRecord{
|
||||
TenantID: tenant.ID,
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: "",
|
||||
Status: "none",
|
||||
PlanCode: tenant.PlanCode,
|
||||
PriceID: "",
|
||||
LastSyncedAt: &now,
|
||||
}
|
||||
|
||||
if iter.Next() {
|
||||
subscriptionRecord := iter.Subscription()
|
||||
record.StripeSubscriptionID = subscriptionRecord.ID
|
||||
record.Status = string(subscriptionRecord.Status)
|
||||
record.CancelAtPeriodEnd = subscriptionRecord.CancelAtPeriodEnd
|
||||
if len(subscriptionRecord.Items.Data) > 0 {
|
||||
record.PriceID = subscriptionRecord.Items.Data[0].Price.ID
|
||||
record.PlanCode = s.planCodeForPrice(record.PriceID, tenant.PlanCode)
|
||||
record.CurrentPeriodStart = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodStart)
|
||||
record.CurrentPeriodEnd = toTimePtr(subscriptionRecord.Items.Data[0].CurrentPeriodEnd)
|
||||
}
|
||||
if subscriptionRecord.DefaultPaymentMethod != nil && subscriptionRecord.DefaultPaymentMethod.Card != nil {
|
||||
record.PaymentMethodBrand = string(subscriptionRecord.DefaultPaymentMethod.Card.Brand)
|
||||
record.PaymentMethodLast4 = subscriptionRecord.DefaultPaymentMethod.Card.Last4
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.repo.UpsertSubscriptionSnapshot(ctx, record); err != nil {
|
||||
return db.BillingSnapshotRecord{}, err
|
||||
}
|
||||
if err := s.repo.UpdateTenantBillingState(ctx, tenant.ID, record.PlanCode, record.Status, record.StripeSubscriptionID); err != nil {
|
||||
return db.BillingSnapshotRecord{}, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func toSnapshot(tenant db.TenantRecord, record db.BillingSnapshotRecord, cfg config.Config) domain.SubscriptionSnapshot {
|
||||
if record.PlanCode == "" {
|
||||
record.PlanCode = tenant.PlanCode
|
||||
}
|
||||
if record.Status == "" {
|
||||
record.Status = tenant.SubscriptionStatus
|
||||
}
|
||||
return domain.SubscriptionSnapshot{
|
||||
TenantID: tenant.ID,
|
||||
CustomerID: record.StripeCustomerID,
|
||||
SubscriptionID: record.StripeSubscriptionID,
|
||||
Status: record.Status,
|
||||
PlanCode: record.PlanCode,
|
||||
PriceID: record.PriceID,
|
||||
CancelAtPeriodEnd: record.CancelAtPeriodEnd,
|
||||
CurrentPeriodStart: record.CurrentPeriodStart,
|
||||
CurrentPeriodEnd: record.CurrentPeriodEnd,
|
||||
PaymentMethodBrand: record.PaymentMethodBrand,
|
||||
PaymentMethodLast4: record.PaymentMethodLast4,
|
||||
Entitlements: entitlementsForPlan(record.PlanCode),
|
||||
LastSyncedAt: record.LastSyncedAt,
|
||||
CheckoutURLAvailable: cfg.StripePriceIDs[record.PlanCode] != "",
|
||||
}
|
||||
}
|
||||
|
||||
func entitlementsForPlan(planCode string) domain.PlanEntitlements {
|
||||
switch planCode {
|
||||
case "starter":
|
||||
return domain.PlanEntitlements{MaxLocations: 1, MaxStaff: 3, SMSAddonAvailable: false, AdvancedReporting: false}
|
||||
case "multi-location":
|
||||
return domain.PlanEntitlements{MaxLocations: 10, MaxStaff: 30, SMSAddonAvailable: true, AdvancedReporting: true}
|
||||
default:
|
||||
return domain.PlanEntitlements{MaxLocations: 3, MaxStaff: 10, SMSAddonAvailable: true, AdvancedReporting: true}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) planCodeForPrice(priceID string, fallback string) string {
|
||||
for code, configuredPriceID := range s.cfg.StripePriceIDs {
|
||||
if configuredPriceID != "" && configuredPriceID == priceID {
|
||||
return code
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func derefString(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func toTimePtr(value int64) *time.Time {
|
||||
if value == 0 {
|
||||
return nil
|
||||
}
|
||||
timestamp := time.Unix(value, 0).UTC()
|
||||
return ×tamp
|
||||
}
|
||||
|
||||
func extractCustomerID(event stripe.Event) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(event.Data.Raw, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
value, ok := payload["customer"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
customerID, _ := value.(string)
|
||||
return customerID
|
||||
}
|
||||
Reference in New Issue
Block a user