Files
Trackeep/backend/handlers/oauth_service.go
T
2026-04-10 12:06:01 +02:00

275 lines
6.1 KiB
Go

package handlers
import (
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"github.com/trackeep/backend/config"
"github.com/trackeep/backend/models"
)
type centralizedOAuthUser struct {
ID int `json:"id"`
GitHubID int `json:"github_id"`
Username string `json:"username"`
Email string `json:"email"`
Name string `json:"name"`
AvatarURL string `json:"avatar_url"`
}
func getOAuthServiceURL() string {
return config.ControlServiceURL
}
func headerValue(headers http.Header, key string) string {
raw := strings.TrimSpace(headers.Get(key))
if raw == "" {
return ""
}
for _, part := range strings.Split(raw, ",") {
candidate := strings.TrimSpace(part)
if candidate != "" {
return candidate
}
}
return ""
}
func backendPublicBaseURL(r *http.Request) string {
if baseURL := strings.TrimSpace(os.Getenv("PUBLIC_API_URL")); baseURL != "" {
return strings.TrimRight(baseURL, "/")
}
if baseURL := strings.TrimSpace(os.Getenv("PUBLIC_BASE_URL")); baseURL != "" {
return strings.TrimRight(baseURL, "/")
}
scheme := "http"
if forwardedProto := headerValue(r.Header, "X-Forwarded-Proto"); forwardedProto != "" {
scheme = forwardedProto
} else if r.TLS != nil {
scheme = "https"
}
host := headerValue(r.Header, "X-Forwarded-Host")
if host == "" {
host = strings.TrimSpace(r.Host)
}
if host == "" {
return ""
}
return fmt.Sprintf("%s://%s", scheme, host)
}
func normalizeFrontendRedirectURL(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
parsed, err := url.Parse(value)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return ""
}
if parsed.Path == "" || parsed.Path == "/" {
parsed.Path = "/auth/callback"
}
return parsed.String()
}
func resolveFrontendRedirectURL(r *http.Request) string {
if value := normalizeFrontendRedirectURL(r.URL.Query().Get("frontend_redirect")); value != "" {
return value
}
if value := normalizeFrontendRedirectURL(os.Getenv("FRONTEND_URL")); value != "" {
return value
}
if origin := normalizeFrontendRedirectURL(r.Header.Get("Origin")); origin != "" {
return origin
}
referer := strings.TrimSpace(r.Header.Get("Referer"))
if referer != "" {
if parsed, err := url.Parse(referer); err == nil && parsed.Scheme != "" && parsed.Host != "" {
return normalizeFrontendRedirectURL((&url.URL{
Scheme: parsed.Scheme,
Host: parsed.Host,
Path: "/auth/callback",
}).String())
}
}
return ""
}
func buildGitHubUserCallbackURL(r *http.Request) string {
baseURL := backendPublicBaseURL(r)
if baseURL == "" {
return ""
}
callbackURL, err := url.Parse(baseURL + "/api/v1/auth/github/callback")
if err != nil {
return ""
}
return callbackURL.String()
}
func buildFrontendCallbackRedirectURL(frontendRedirect, token string) string {
redirectTarget := normalizeFrontendRedirectURL(frontendRedirect)
if redirectTarget == "" {
redirectTarget = normalizeFrontendRedirectURL(os.Getenv("FRONTEND_URL"))
}
if redirectTarget == "" {
return ""
}
parsed, err := url.Parse(redirectTarget)
if err != nil {
return ""
}
query := parsed.Query()
query.Set("token", token)
parsed.RawQuery = query.Encode()
return parsed.String()
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed != "" {
return trimmed
}
}
return ""
}
func uniqueUsername(base string, db *gorm.DB, excludeUserID uint) string {
candidate := strings.TrimSpace(base)
if candidate == "" {
candidate = "user"
}
for suffix := 0; ; suffix++ {
username := candidate
if suffix > 0 {
username = fmt.Sprintf("%s-%d", candidate, suffix+1)
}
var existing models.User
err := db.Where("username = ?", username).First(&existing).Error
if err == nil {
if excludeUserID != 0 && existing.ID == excludeUserID {
return username
}
continue
}
if err == gorm.ErrRecordNotFound {
return username
}
return username
}
}
func upsertCentralizedOAuthUser(db *gorm.DB, controllerUser centralizedOAuthUser) (*models.User, error) {
var user models.User
var err error
normalizedEmail := strings.TrimSpace(controllerUser.Email)
normalizedUsername := firstNonEmpty(controllerUser.Username, strings.Split(normalizedEmail, "@")[0], "user")
fullName := firstNonEmpty(controllerUser.Name, controllerUser.Username, normalizedEmail)
provider := "email"
if controllerUser.GitHubID != 0 {
provider = "github"
err = db.Where("github_id = ?", controllerUser.GitHubID).First(&user).Error
} else {
err = gorm.ErrRecordNotFound
}
if err != nil && normalizedEmail != "" {
err = db.Where("email = ?", normalizedEmail).First(&user).Error
}
if err == nil {
updates := map[string]interface{}{
"email": normalizedEmail,
"username": uniqueUsername(normalizedUsername, db, user.ID),
"full_name": fullName,
"avatar_url": controllerUser.AvatarURL,
"provider": provider,
}
if controllerUser.GitHubID != 0 {
updates["github_id"] = controllerUser.GitHubID
}
now := time.Now()
updates["last_login_at"] = &now
if err := db.Model(&user).Updates(updates).Error; err != nil {
return nil, err
}
if err := db.First(&user, user.ID).Error; err != nil {
return nil, err
}
return &user, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
var userCount int64
if err := db.Model(&models.User{}).Count(&userCount).Error; err != nil {
return nil, err
}
randomPassword := generateRandomString(32)
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(randomPassword), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
role := "user"
if userCount == 0 {
role = "admin"
}
now := time.Now()
user = models.User{
Email: normalizedEmail,
Username: uniqueUsername(normalizedUsername, db, 0),
Password: string(hashedPassword),
FullName: fullName,
Role: role,
Theme: "dark",
GitHubID: controllerUser.GitHubID,
AvatarURL: controllerUser.AvatarURL,
Provider: provider,
LastLoginAt: &now,
}
if err := db.Create(&user).Error; err != nil {
return nil, err
}
_ = ensureMessagingDefaults(db, user.ID)
return &user, nil
}