Files
Bookra/apps/auth-service/internal/billing/service.go
T
Tomas Dvorak 48c3e15a38 cleanup
2026-05-05 09:48:07 +02:00

465 lines
13 KiB
Go

package billing
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
"bookra/apps/auth-service/internal/config"
"bookra/apps/auth-service/internal/db"
"github.com/stripe/stripe-go/v83"
"github.com/stripe/stripe-go/v83/checkout/session"
"github.com/stripe/stripe-go/v83/customer"
"github.com/stripe/stripe-go/v83/subscription"
"github.com/stripe/stripe-go/v83/webhook"
)
var (
ErrStripeNotConfigured = errors.New("stripe is not configured")
ErrStripeWebhookMissing = errors.New("stripe webhook secret is not configured")
ErrStripeSignatureMissing = errors.New("stripe signature is missing")
ErrPlanNotConfigured = errors.New("stripe plan is not configured")
ErrCustomerMappingNotFound = errors.New("stripe customer mapping not found")
)
var allowedWebhookEvents = []string{
"checkout.session.completed",
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
"customer.subscription.paused",
"customer.subscription.resumed",
"invoice.paid",
"invoice.payment_failed",
"payment_intent.succeeded",
"payment_intent.payment_failed",
}
type Service struct {
cfg *config.Config
db *db.DB
}
type CheckoutSession struct {
URL string `json:"url"`
}
type SubscriptionSnapshot struct {
CustomerID string `json:"customerId,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
Status string `json:"status"`
PlanCode string `json:"planCode,omitempty"`
Currency string `json:"currency,omitempty"`
PriceID string `json:"priceId,omitempty"`
CancelAtPeriodEnd bool `json:"cancelAtPeriodEnd"`
CurrentPeriodStart *time.Time `json:"currentPeriodStart,omitempty"`
CurrentPeriodEnd *time.Time `json:"currentPeriodEnd,omitempty"`
PaymentMethod *PaymentMethod `json:"paymentMethod,omitempty"`
LastSyncedAt *time.Time `json:"lastSyncedAt,omitempty"`
CheckoutURLAvailable bool `json:"checkoutUrlAvailable"`
SyncAvailable bool `json:"syncAvailable"`
}
type PaymentMethod struct {
Brand string `json:"brand"`
Last4 string `json:"last4"`
}
type UserIdentity struct {
ID string
Email string
Name string
}
type userCustomerMapping struct {
CustomerID string `json:"customerId"`
UpdatedAt time.Time `json:"updatedAt"`
}
func NewService(cfg *config.Config, database *db.DB) *Service {
return &Service{cfg: cfg, db: database}
}
func (s *Service) GetSubscription(ctx context.Context, userID string) (SubscriptionSnapshot, error) {
mapping, ok, err := s.getCustomerMapping(ctx, userID)
if err != nil {
return SubscriptionSnapshot{}, err
}
if !ok {
return s.noneSnapshot(), nil
}
snapshot, ok, err := s.getCustomerSnapshot(ctx, mapping.CustomerID)
if err != nil {
return SubscriptionSnapshot{}, err
}
if !ok {
snapshot = SubscriptionSnapshot{
CustomerID: mapping.CustomerID,
Status: "none",
}
}
snapshot.CheckoutURLAvailable = s.checkoutAvailableForPlan(snapshot.PlanCode)
snapshot.SyncAvailable = s.cfg.StripeSecretConfigured()
return snapshot, nil
}
func (s *Service) CreateCheckoutSession(ctx context.Context, user UserIdentity, planCode string, currency string) (CheckoutSession, error) {
priceID, resolvedPlanCode, resolvedCurrency, err := s.priceForPlan(planCode, currency)
if err != nil {
return CheckoutSession{}, err
}
if s.cfg.StripeSecretKey == "" {
return CheckoutSession{}, ErrStripeNotConfigured
}
customerID, err := s.ensureCustomer(ctx, user)
if err != nil {
return CheckoutSession{}, err
}
stripe.Key = s.cfg.StripeSecretKey
params := &stripe.CheckoutSessionParams{
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
Customer: stripe.String(customerID),
ClientReferenceID: stripe.String(user.ID),
SuccessURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=success", strings.TrimRight(s.cfg.FrontendURL, "/"))),
CancelURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=cancelled", strings.TrimRight(s.cfg.FrontendURL, "/"))),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
Metadata: map[string]string{
"user_id": user.ID,
"plan_code": resolvedPlanCode,
"currency": resolvedCurrency,
},
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
TrialPeriodDays: stripe.Int64(30),
Metadata: map[string]string{
"user_id": user.ID,
"plan_code": resolvedPlanCode,
"currency": resolvedCurrency,
},
},
}
checkoutSession, err := session.New(params)
if err != nil {
return CheckoutSession{}, err
}
return CheckoutSession{URL: checkoutSession.URL}, nil
}
func (s *Service) Refresh(ctx context.Context, userID string) (SubscriptionSnapshot, error) {
mapping, ok, err := s.getCustomerMapping(ctx, userID)
if err != nil {
return SubscriptionSnapshot{}, err
}
if !ok {
return s.noneSnapshot(), nil
}
if s.cfg.StripeSecretKey == "" {
return SubscriptionSnapshot{}, ErrStripeNotConfigured
}
return s.syncStripeDataToKV(ctx, mapping.CustomerID)
}
func (s *Service) HandleWebhook(ctx context.Context, signature string, payload []byte) error {
if s.cfg.StripeSecretKey == "" {
return nil
}
if s.cfg.StripeWebhookSecret == "" {
return ErrStripeWebhookMissing
}
if signature == "" {
return ErrStripeSignatureMissing
}
event, err := webhook.ConstructEvent(payload, signature, s.cfg.StripeWebhookSecret)
if err != nil {
return err
}
if !slices.Contains(allowedWebhookEvents, string(event.Type)) {
return nil
}
customerID := extractCustomerID(event)
if customerID == "" {
return nil
}
_, err = s.syncStripeDataToKV(ctx, customerID)
return err
}
func (s *Service) ensureCustomer(ctx context.Context, user UserIdentity) (string, error) {
mapping, ok, err := s.getCustomerMapping(ctx, user.ID)
if err != nil {
return "", err
}
if ok && mapping.CustomerID != "" {
return mapping.CustomerID, nil
}
stripe.Key = s.cfg.StripeSecretKey
params := &stripe.CustomerParams{
Email: stripe.String(user.Email),
Metadata: map[string]string{
"user_id": user.ID,
},
}
if strings.TrimSpace(user.Name) != "" {
params.Name = stripe.String(strings.TrimSpace(user.Name))
}
createdCustomer, err := customer.New(params)
if err != nil {
return "", err
}
if err := s.storeCustomerMapping(ctx, user.ID, createdCustomer.ID); err != nil {
return "", err
}
return createdCustomer.ID, nil
}
func (s *Service) syncStripeDataToKV(ctx context.Context, customerID string) (SubscriptionSnapshot, error) {
stripe.Key = s.cfg.StripeSecretKey
params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)}
params.Status = stripe.String("all")
params.AddExpand("data.default_payment_method")
params.AddExpand("data.items.data.price")
iter := subscription.List(params)
selected := (*stripe.Subscription)(nil)
for iter.Next() {
current := iter.Subscription()
if selected == nil || subscriptionRank(current) > subscriptionRank(selected) {
selected = current
}
}
if iter.Err() != nil {
return SubscriptionSnapshot{}, iter.Err()
}
now := time.Now().UTC()
snapshot := SubscriptionSnapshot{
CustomerID: customerID,
Status: "none",
LastSyncedAt: &now,
CheckoutURLAvailable: s.cfg.StripeCheckoutReady(),
SyncAvailable: s.cfg.StripeSecretConfigured(),
}
if selected != nil {
snapshot.SubscriptionID = selected.ID
snapshot.Status = string(selected.Status)
snapshot.CancelAtPeriodEnd = selected.CancelAtPeriodEnd
if len(selected.Items.Data) > 0 {
item := selected.Items.Data[0]
if item.Price != nil {
snapshot.PriceID = item.Price.ID
snapshot.PlanCode = s.planCodeForPrice(snapshot.PriceID)
snapshot.Currency = normalizeCurrency(string(item.Price.Currency))
}
snapshot.CurrentPeriodStart = unixPtr(item.CurrentPeriodStart)
snapshot.CurrentPeriodEnd = unixPtr(item.CurrentPeriodEnd)
}
if selected.DefaultPaymentMethod != nil && selected.DefaultPaymentMethod.Card != nil {
snapshot.PaymentMethod = &PaymentMethod{
Brand: string(selected.DefaultPaymentMethod.Card.Brand),
Last4: selected.DefaultPaymentMethod.Card.Last4,
}
}
}
if err := s.db.PutKV(ctx, customerSnapshotKey(customerID), snapshot); err != nil {
return SubscriptionSnapshot{}, err
}
return snapshot, nil
}
func (s *Service) storeCustomerMapping(ctx context.Context, userID string, customerID string) error {
mapping := userCustomerMapping{
CustomerID: customerID,
UpdatedAt: time.Now().UTC(),
}
return s.db.PutKV(ctx, userCustomerKey(userID), mapping)
}
func (s *Service) getCustomerMapping(ctx context.Context, userID string) (userCustomerMapping, bool, error) {
var mapping userCustomerMapping
ok, err := s.db.GetKV(ctx, userCustomerKey(userID), &mapping)
if err != nil {
return userCustomerMapping{}, false, err
}
if !ok || mapping.CustomerID == "" {
return userCustomerMapping{}, false, nil
}
return mapping, true, nil
}
func (s *Service) getCustomerSnapshot(ctx context.Context, customerID string) (SubscriptionSnapshot, bool, error) {
var snapshot SubscriptionSnapshot
ok, err := s.db.GetKV(ctx, customerSnapshotKey(customerID), &snapshot)
if err != nil {
return SubscriptionSnapshot{}, false, err
}
return snapshot, ok, nil
}
func (s *Service) noneSnapshot() SubscriptionSnapshot {
return SubscriptionSnapshot{
Status: "none",
CheckoutURLAvailable: s.cfg.StripeCheckoutReady(),
SyncAvailable: s.cfg.StripeSecretConfigured(),
}
}
func (s *Service) priceForPlan(planCode string, currency string) (string, string, string, error) {
planCode = normalizePlanCode(strings.TrimSpace(planCode))
if planCode == "" {
planCode = s.defaultPlanCode()
}
if planCode == "" {
return "", "", "", ErrPlanNotConfigured
}
resolvedCurrency := normalizeCurrency(currency)
priceID := strings.TrimSpace(s.cfg.StripePriceIDs[planCode+":"+resolvedCurrency])
if priceID == "" && resolvedCurrency != "czk" {
priceID = strings.TrimSpace(s.cfg.StripePriceIDs[planCode+":czk"])
if priceID != "" {
resolvedCurrency = "czk"
}
}
if priceID == "" {
priceID = strings.TrimSpace(s.cfg.StripePriceIDs[planCode])
}
if priceID == "" {
switch planCode {
case "pro":
priceID = strings.TrimSpace(s.cfg.StripePriceIDs["growth"])
case "business":
priceID = strings.TrimSpace(s.cfg.StripePriceIDs["multi-location"])
}
}
if priceID == "" {
return "", "", "", ErrPlanNotConfigured
}
return priceID, planCode, resolvedCurrency, nil
}
func (s *Service) defaultPlanCode() string {
for _, planCode := range []string{"pro", "monthly", "growth", "starter", "business", "multi-location"} {
if strings.TrimSpace(s.cfg.StripePriceIDs[planCode]) != "" {
return normalizePlanCode(planCode)
}
if strings.TrimSpace(s.cfg.StripePriceIDs[normalizePlanCode(planCode)+":czk"]) != "" {
return normalizePlanCode(planCode)
}
}
return ""
}
func (s *Service) planCodeForPrice(priceID string) string {
for planCode, configuredPriceID := range s.cfg.StripePriceIDs {
if strings.TrimSpace(configuredPriceID) == priceID {
return normalizePlanCode(strings.Split(planCode, ":")[0])
}
}
return ""
}
func (s *Service) hasConfiguredPrices() bool {
return s.defaultPlanCode() != ""
}
func (s *Service) checkoutAvailableForPlan(planCode string) bool {
if !s.cfg.StripeSecretConfigured() {
return false
}
if strings.TrimSpace(planCode) == "" {
return s.hasConfiguredPrices()
}
_, _, _, err := s.priceForPlan(planCode, "czk")
return err == nil
}
func normalizePlanCode(planCode string) string {
switch planCode {
case "growth":
return "pro"
case "multi-location":
return "business"
default:
return planCode
}
}
func normalizeCurrency(currency string) string {
switch strings.ToLower(strings.TrimSpace(currency)) {
case "usd":
return "usd"
default:
return "czk"
}
}
func userCustomerKey(userID string) string {
return "stripe:user:" + userID
}
func customerSnapshotKey(customerID string) string {
return "stripe:customer:" + customerID
}
func unixPtr(value int64) *time.Time {
if value == 0 {
return nil
}
t := time.Unix(value, 0).UTC()
return &t
}
func subscriptionRank(subscription *stripe.Subscription) int {
switch subscription.Status {
case stripe.SubscriptionStatusActive:
return 100
case stripe.SubscriptionStatusTrialing:
return 90
case stripe.SubscriptionStatusPastDue:
return 80
case stripe.SubscriptionStatusUnpaid:
return 70
case stripe.SubscriptionStatusIncomplete:
return 60
case stripe.SubscriptionStatusPaused:
return 50
case stripe.SubscriptionStatusCanceled:
return 10
default:
return 0
}
}
func extractCustomerID(event stripe.Event) string {
var payload map[string]any
if err := json.Unmarshal(event.Data.Raw, &payload); err != nil {
return ""
}
value, ok := payload["customer"]
if !ok {
return ""
}
customerID, _ := value.(string)
return customerID
}