mirror of
https://github.com/Dvorinka/Trackeep.git
synced 2026-06-03 20:12:58 +00:00
275 lines
6.1 KiB
Go
275 lines
6.1 KiB
Go
package handlers
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/trackeep/backend/config"
|
|
"github.com/trackeep/backend/models"
|
|
)
|
|
|
|
type centralizedOAuthUser struct {
|
|
ID int `json:"id"`
|
|
GitHubID int `json:"github_id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
AvatarURL string `json:"avatar_url"`
|
|
}
|
|
|
|
func getOAuthServiceURL() string {
|
|
return config.ControlServiceURL
|
|
}
|
|
|
|
func headerValue(headers http.Header, key string) string {
|
|
raw := strings.TrimSpace(headers.Get(key))
|
|
if raw == "" {
|
|
return ""
|
|
}
|
|
|
|
for _, part := range strings.Split(raw, ",") {
|
|
candidate := strings.TrimSpace(part)
|
|
if candidate != "" {
|
|
return candidate
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func backendPublicBaseURL(r *http.Request) string {
|
|
if baseURL := strings.TrimSpace(os.Getenv("PUBLIC_API_URL")); baseURL != "" {
|
|
return strings.TrimRight(baseURL, "/")
|
|
}
|
|
if baseURL := strings.TrimSpace(os.Getenv("PUBLIC_BASE_URL")); baseURL != "" {
|
|
return strings.TrimRight(baseURL, "/")
|
|
}
|
|
|
|
scheme := "http"
|
|
if forwardedProto := headerValue(r.Header, "X-Forwarded-Proto"); forwardedProto != "" {
|
|
scheme = forwardedProto
|
|
} else if r.TLS != nil {
|
|
scheme = "https"
|
|
}
|
|
|
|
host := headerValue(r.Header, "X-Forwarded-Host")
|
|
if host == "" {
|
|
host = strings.TrimSpace(r.Host)
|
|
}
|
|
if host == "" {
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf("%s://%s", scheme, host)
|
|
}
|
|
|
|
func normalizeFrontendRedirectURL(raw string) string {
|
|
value := strings.TrimSpace(raw)
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
|
|
parsed, err := url.Parse(value)
|
|
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
|
return ""
|
|
}
|
|
|
|
if parsed.Path == "" || parsed.Path == "/" {
|
|
parsed.Path = "/auth/callback"
|
|
}
|
|
|
|
return parsed.String()
|
|
}
|
|
|
|
func resolveFrontendRedirectURL(r *http.Request) string {
|
|
if value := normalizeFrontendRedirectURL(r.URL.Query().Get("frontend_redirect")); value != "" {
|
|
return value
|
|
}
|
|
|
|
if value := normalizeFrontendRedirectURL(os.Getenv("FRONTEND_URL")); value != "" {
|
|
return value
|
|
}
|
|
|
|
if origin := normalizeFrontendRedirectURL(r.Header.Get("Origin")); origin != "" {
|
|
return origin
|
|
}
|
|
|
|
referer := strings.TrimSpace(r.Header.Get("Referer"))
|
|
if referer != "" {
|
|
if parsed, err := url.Parse(referer); err == nil && parsed.Scheme != "" && parsed.Host != "" {
|
|
return normalizeFrontendRedirectURL((&url.URL{
|
|
Scheme: parsed.Scheme,
|
|
Host: parsed.Host,
|
|
Path: "/auth/callback",
|
|
}).String())
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func buildGitHubUserCallbackURL(r *http.Request) string {
|
|
baseURL := backendPublicBaseURL(r)
|
|
if baseURL == "" {
|
|
return ""
|
|
}
|
|
|
|
callbackURL, err := url.Parse(baseURL + "/api/v1/auth/github/callback")
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
return callbackURL.String()
|
|
}
|
|
|
|
func buildFrontendCallbackRedirectURL(frontendRedirect, token string) string {
|
|
redirectTarget := normalizeFrontendRedirectURL(frontendRedirect)
|
|
if redirectTarget == "" {
|
|
redirectTarget = normalizeFrontendRedirectURL(os.Getenv("FRONTEND_URL"))
|
|
}
|
|
if redirectTarget == "" {
|
|
return ""
|
|
}
|
|
|
|
parsed, err := url.Parse(redirectTarget)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
query := parsed.Query()
|
|
query.Set("token", token)
|
|
parsed.RawQuery = query.Encode()
|
|
|
|
return parsed.String()
|
|
}
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
trimmed := strings.TrimSpace(value)
|
|
if trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func uniqueUsername(base string, db *gorm.DB, excludeUserID uint) string {
|
|
candidate := strings.TrimSpace(base)
|
|
if candidate == "" {
|
|
candidate = "user"
|
|
}
|
|
|
|
for suffix := 0; ; suffix++ {
|
|
username := candidate
|
|
if suffix > 0 {
|
|
username = fmt.Sprintf("%s-%d", candidate, suffix+1)
|
|
}
|
|
|
|
var existing models.User
|
|
err := db.Where("username = ?", username).First(&existing).Error
|
|
if err == nil {
|
|
if excludeUserID != 0 && existing.ID == excludeUserID {
|
|
return username
|
|
}
|
|
continue
|
|
}
|
|
if err == gorm.ErrRecordNotFound {
|
|
return username
|
|
}
|
|
return username
|
|
}
|
|
}
|
|
|
|
func upsertCentralizedOAuthUser(db *gorm.DB, controllerUser centralizedOAuthUser) (*models.User, error) {
|
|
var user models.User
|
|
var err error
|
|
|
|
normalizedEmail := strings.TrimSpace(controllerUser.Email)
|
|
normalizedUsername := firstNonEmpty(controllerUser.Username, strings.Split(normalizedEmail, "@")[0], "user")
|
|
fullName := firstNonEmpty(controllerUser.Name, controllerUser.Username, normalizedEmail)
|
|
provider := "email"
|
|
if controllerUser.GitHubID != 0 {
|
|
provider = "github"
|
|
err = db.Where("github_id = ?", controllerUser.GitHubID).First(&user).Error
|
|
} else {
|
|
err = gorm.ErrRecordNotFound
|
|
}
|
|
|
|
if err != nil && normalizedEmail != "" {
|
|
err = db.Where("email = ?", normalizedEmail).First(&user).Error
|
|
}
|
|
|
|
if err == nil {
|
|
updates := map[string]interface{}{
|
|
"email": normalizedEmail,
|
|
"username": uniqueUsername(normalizedUsername, db, user.ID),
|
|
"full_name": fullName,
|
|
"avatar_url": controllerUser.AvatarURL,
|
|
"provider": provider,
|
|
}
|
|
if controllerUser.GitHubID != 0 {
|
|
updates["github_id"] = controllerUser.GitHubID
|
|
}
|
|
|
|
now := time.Now()
|
|
updates["last_login_at"] = &now
|
|
|
|
if err := db.Model(&user).Updates(updates).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.First(&user, user.ID).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
if err != gorm.ErrRecordNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
var userCount int64
|
|
if err := db.Model(&models.User{}).Count(&userCount).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
randomPassword := generateRandomString(32)
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(randomPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
role := "user"
|
|
if userCount == 0 {
|
|
role = "admin"
|
|
}
|
|
|
|
now := time.Now()
|
|
user = models.User{
|
|
Email: normalizedEmail,
|
|
Username: uniqueUsername(normalizedUsername, db, 0),
|
|
Password: string(hashedPassword),
|
|
FullName: fullName,
|
|
Role: role,
|
|
Theme: "dark",
|
|
GitHubID: controllerUser.GitHubID,
|
|
AvatarURL: controllerUser.AvatarURL,
|
|
Provider: provider,
|
|
LastLoginAt: &now,
|
|
}
|
|
|
|
if err := db.Create(&user).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_ = ensureMessagingDefaults(db, user.ID)
|
|
|
|
return &user, nil
|
|
}
|