package billing import ( "context" "encoding/json" "errors" "fmt" "slices" "strings" "time" "bookra/apps/auth-service/internal/config" "bookra/apps/auth-service/internal/db" "github.com/stripe/stripe-go/v83" "github.com/stripe/stripe-go/v83/checkout/session" "github.com/stripe/stripe-go/v83/customer" "github.com/stripe/stripe-go/v83/subscription" "github.com/stripe/stripe-go/v83/webhook" ) var ( ErrStripeNotConfigured = errors.New("stripe is not configured") ErrStripeWebhookMissing = errors.New("stripe webhook secret is not configured") ErrStripeSignatureMissing = errors.New("stripe signature is missing") ErrPlanNotConfigured = errors.New("stripe plan is not configured") ErrCustomerMappingNotFound = errors.New("stripe customer mapping not found") ) var allowedWebhookEvents = []string{ "checkout.session.completed", "customer.subscription.created", "customer.subscription.updated", "customer.subscription.deleted", "customer.subscription.paused", "customer.subscription.resumed", "invoice.paid", "invoice.payment_failed", "payment_intent.succeeded", "payment_intent.payment_failed", } type Service struct { cfg *config.Config db *db.DB } type CheckoutSession struct { URL string `json:"url"` } type SubscriptionSnapshot struct { CustomerID string `json:"customerId,omitempty"` SubscriptionID string `json:"subscriptionId,omitempty"` Status string `json:"status"` PlanCode string `json:"planCode,omitempty"` Currency string `json:"currency,omitempty"` PriceID string `json:"priceId,omitempty"` CancelAtPeriodEnd bool `json:"cancelAtPeriodEnd"` CurrentPeriodStart *time.Time `json:"currentPeriodStart,omitempty"` CurrentPeriodEnd *time.Time `json:"currentPeriodEnd,omitempty"` PaymentMethod *PaymentMethod `json:"paymentMethod,omitempty"` LastSyncedAt *time.Time `json:"lastSyncedAt,omitempty"` CheckoutURLAvailable bool `json:"checkoutUrlAvailable"` SyncAvailable bool `json:"syncAvailable"` } type PaymentMethod struct { Brand string `json:"brand"` Last4 string `json:"last4"` } type UserIdentity struct { ID string Email string Name string } type userCustomerMapping struct { CustomerID string `json:"customerId"` UpdatedAt time.Time `json:"updatedAt"` } func NewService(cfg *config.Config, database *db.DB) *Service { return &Service{cfg: cfg, db: database} } func (s *Service) GetSubscription(ctx context.Context, userID string) (SubscriptionSnapshot, error) { mapping, ok, err := s.getCustomerMapping(ctx, userID) if err != nil { return SubscriptionSnapshot{}, err } if !ok { return s.noneSnapshot(), nil } snapshot, ok, err := s.getCustomerSnapshot(ctx, mapping.CustomerID) if err != nil { return SubscriptionSnapshot{}, err } if !ok { snapshot = SubscriptionSnapshot{ CustomerID: mapping.CustomerID, Status: "none", } } snapshot.CheckoutURLAvailable = s.checkoutAvailableForPlan(snapshot.PlanCode) snapshot.SyncAvailable = s.cfg.StripeSecretConfigured() return snapshot, nil } func (s *Service) CreateCheckoutSession(ctx context.Context, user UserIdentity, planCode string, currency string) (CheckoutSession, error) { priceID, resolvedPlanCode, resolvedCurrency, err := s.priceForPlan(planCode, currency) if err != nil { return CheckoutSession{}, err } if s.cfg.StripeSecretKey == "" { return CheckoutSession{}, ErrStripeNotConfigured } customerID, err := s.ensureCustomer(ctx, user) if err != nil { return CheckoutSession{}, err } stripe.Key = s.cfg.StripeSecretKey params := &stripe.CheckoutSessionParams{ Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), Customer: stripe.String(customerID), ClientReferenceID: stripe.String(user.ID), SuccessURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=success", strings.TrimRight(s.cfg.FrontendURL, "/"))), CancelURL: stripe.String(fmt.Sprintf("%s/dashboard?billing=cancelled", strings.TrimRight(s.cfg.FrontendURL, "/"))), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(priceID), Quantity: stripe.Int64(1), }, }, Metadata: map[string]string{ "user_id": user.ID, "plan_code": resolvedPlanCode, "currency": resolvedCurrency, }, SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{ TrialPeriodDays: stripe.Int64(30), Metadata: map[string]string{ "user_id": user.ID, "plan_code": resolvedPlanCode, "currency": resolvedCurrency, }, }, } checkoutSession, err := session.New(params) if err != nil { return CheckoutSession{}, err } return CheckoutSession{URL: checkoutSession.URL}, nil } func (s *Service) Refresh(ctx context.Context, userID string) (SubscriptionSnapshot, error) { mapping, ok, err := s.getCustomerMapping(ctx, userID) if err != nil { return SubscriptionSnapshot{}, err } if !ok { return s.noneSnapshot(), nil } if s.cfg.StripeSecretKey == "" { return SubscriptionSnapshot{}, ErrStripeNotConfigured } return s.syncStripeDataToKV(ctx, mapping.CustomerID) } func (s *Service) HandleWebhook(ctx context.Context, signature string, payload []byte) error { if s.cfg.StripeSecretKey == "" { return nil } if s.cfg.StripeWebhookSecret == "" { return ErrStripeWebhookMissing } if signature == "" { return ErrStripeSignatureMissing } event, err := webhook.ConstructEvent(payload, signature, s.cfg.StripeWebhookSecret) if err != nil { return err } if !slices.Contains(allowedWebhookEvents, string(event.Type)) { return nil } customerID := extractCustomerID(event) if customerID == "" { return nil } _, err = s.syncStripeDataToKV(ctx, customerID) return err } func (s *Service) ensureCustomer(ctx context.Context, user UserIdentity) (string, error) { mapping, ok, err := s.getCustomerMapping(ctx, user.ID) if err != nil { return "", err } if ok && mapping.CustomerID != "" { return mapping.CustomerID, nil } stripe.Key = s.cfg.StripeSecretKey params := &stripe.CustomerParams{ Email: stripe.String(user.Email), Metadata: map[string]string{ "user_id": user.ID, }, } if strings.TrimSpace(user.Name) != "" { params.Name = stripe.String(strings.TrimSpace(user.Name)) } createdCustomer, err := customer.New(params) if err != nil { return "", err } if err := s.storeCustomerMapping(ctx, user.ID, createdCustomer.ID); err != nil { return "", err } return createdCustomer.ID, nil } func (s *Service) syncStripeDataToKV(ctx context.Context, customerID string) (SubscriptionSnapshot, error) { stripe.Key = s.cfg.StripeSecretKey params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)} params.Status = stripe.String("all") params.AddExpand("data.default_payment_method") params.AddExpand("data.items.data.price") iter := subscription.List(params) selected := (*stripe.Subscription)(nil) for iter.Next() { current := iter.Subscription() if selected == nil || subscriptionRank(current) > subscriptionRank(selected) { selected = current } } if iter.Err() != nil { return SubscriptionSnapshot{}, iter.Err() } now := time.Now().UTC() snapshot := SubscriptionSnapshot{ CustomerID: customerID, Status: "none", LastSyncedAt: &now, CheckoutURLAvailable: s.cfg.StripeCheckoutReady(), SyncAvailable: s.cfg.StripeSecretConfigured(), } if selected != nil { snapshot.SubscriptionID = selected.ID snapshot.Status = string(selected.Status) snapshot.CancelAtPeriodEnd = selected.CancelAtPeriodEnd if len(selected.Items.Data) > 0 { item := selected.Items.Data[0] if item.Price != nil { snapshot.PriceID = item.Price.ID snapshot.PlanCode = s.planCodeForPrice(snapshot.PriceID) snapshot.Currency = normalizeCurrency(string(item.Price.Currency)) } snapshot.CurrentPeriodStart = unixPtr(item.CurrentPeriodStart) snapshot.CurrentPeriodEnd = unixPtr(item.CurrentPeriodEnd) } if selected.DefaultPaymentMethod != nil && selected.DefaultPaymentMethod.Card != nil { snapshot.PaymentMethod = &PaymentMethod{ Brand: string(selected.DefaultPaymentMethod.Card.Brand), Last4: selected.DefaultPaymentMethod.Card.Last4, } } } if err := s.db.PutKV(ctx, customerSnapshotKey(customerID), snapshot); err != nil { return SubscriptionSnapshot{}, err } return snapshot, nil } func (s *Service) storeCustomerMapping(ctx context.Context, userID string, customerID string) error { mapping := userCustomerMapping{ CustomerID: customerID, UpdatedAt: time.Now().UTC(), } return s.db.PutKV(ctx, userCustomerKey(userID), mapping) } func (s *Service) getCustomerMapping(ctx context.Context, userID string) (userCustomerMapping, bool, error) { var mapping userCustomerMapping ok, err := s.db.GetKV(ctx, userCustomerKey(userID), &mapping) if err != nil { return userCustomerMapping{}, false, err } if !ok || mapping.CustomerID == "" { return userCustomerMapping{}, false, nil } return mapping, true, nil } func (s *Service) getCustomerSnapshot(ctx context.Context, customerID string) (SubscriptionSnapshot, bool, error) { var snapshot SubscriptionSnapshot ok, err := s.db.GetKV(ctx, customerSnapshotKey(customerID), &snapshot) if err != nil { return SubscriptionSnapshot{}, false, err } return snapshot, ok, nil } func (s *Service) noneSnapshot() SubscriptionSnapshot { return SubscriptionSnapshot{ Status: "none", CheckoutURLAvailable: s.cfg.StripeCheckoutReady(), SyncAvailable: s.cfg.StripeSecretConfigured(), } } func (s *Service) priceForPlan(planCode string, currency string) (string, string, string, error) { planCode = normalizePlanCode(strings.TrimSpace(planCode)) if planCode == "" { planCode = s.defaultPlanCode() } if planCode == "" { return "", "", "", ErrPlanNotConfigured } resolvedCurrency := normalizeCurrency(currency) priceID := strings.TrimSpace(s.cfg.StripePriceIDs[planCode+":"+resolvedCurrency]) if priceID == "" && resolvedCurrency != "czk" { priceID = strings.TrimSpace(s.cfg.StripePriceIDs[planCode+":czk"]) if priceID != "" { resolvedCurrency = "czk" } } if priceID == "" { priceID = strings.TrimSpace(s.cfg.StripePriceIDs[planCode]) } if priceID == "" { switch planCode { case "pro": priceID = strings.TrimSpace(s.cfg.StripePriceIDs["growth"]) case "business": priceID = strings.TrimSpace(s.cfg.StripePriceIDs["multi-location"]) } } if priceID == "" { return "", "", "", ErrPlanNotConfigured } return priceID, planCode, resolvedCurrency, nil } func (s *Service) defaultPlanCode() string { for _, planCode := range []string{"pro", "monthly", "growth", "starter", "business", "multi-location"} { if strings.TrimSpace(s.cfg.StripePriceIDs[planCode]) != "" { return normalizePlanCode(planCode) } if strings.TrimSpace(s.cfg.StripePriceIDs[normalizePlanCode(planCode)+":czk"]) != "" { return normalizePlanCode(planCode) } } return "" } func (s *Service) planCodeForPrice(priceID string) string { for planCode, configuredPriceID := range s.cfg.StripePriceIDs { if strings.TrimSpace(configuredPriceID) == priceID { return normalizePlanCode(strings.Split(planCode, ":")[0]) } } return "" } func (s *Service) hasConfiguredPrices() bool { return s.defaultPlanCode() != "" } func (s *Service) checkoutAvailableForPlan(planCode string) bool { if !s.cfg.StripeSecretConfigured() { return false } if strings.TrimSpace(planCode) == "" { return s.hasConfiguredPrices() } _, _, _, err := s.priceForPlan(planCode, "czk") return err == nil } func normalizePlanCode(planCode string) string { switch planCode { case "growth": return "pro" case "multi-location": return "business" default: return planCode } } func normalizeCurrency(currency string) string { switch strings.ToLower(strings.TrimSpace(currency)) { case "usd": return "usd" default: return "czk" } } func userCustomerKey(userID string) string { return "stripe:user:" + userID } func customerSnapshotKey(customerID string) string { return "stripe:customer:" + customerID } func unixPtr(value int64) *time.Time { if value == 0 { return nil } t := time.Unix(value, 0).UTC() return &t } func subscriptionRank(subscription *stripe.Subscription) int { switch subscription.Status { case stripe.SubscriptionStatusActive: return 100 case stripe.SubscriptionStatusTrialing: return 90 case stripe.SubscriptionStatusPastDue: return 80 case stripe.SubscriptionStatusUnpaid: return 70 case stripe.SubscriptionStatusIncomplete: return 60 case stripe.SubscriptionStatusPaused: return 50 case stripe.SubscriptionStatusCanceled: return 10 default: return 0 } } func extractCustomerID(event stripe.Event) string { var payload map[string]any if err := json.Unmarshal(event.Data.Raw, &payload); err != nil { return "" } value, ok := payload["customer"] if !ok { return "" } customerID, _ := value.(string) return customerID }