mirror of
https://github.com/Dvorinka/excalidraw-full.git
synced 2026-06-03 13:52:56 +00:00
feat: full project sync - CI fixes, frontend, workspace API, and all changes
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
package workspace
|
||||
|
||||
import "context"
|
||||
|
||||
type currentSession struct {
|
||||
user *User
|
||||
session *Session
|
||||
}
|
||||
|
||||
func withUser(ctx context.Context, user *User, session *Session) context.Context {
|
||||
return context.WithValue(ctx, currentUserKey, currentSession{user: user, session: session})
|
||||
}
|
||||
|
||||
func currentUser(r interface{ Context() context.Context }) (*User, *Session) {
|
||||
current, _ := r.Context().Value(currentUserKey).(currentSession)
|
||||
return current.user, current.session
|
||||
}
|
||||
@@ -0,0 +1,660 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const sessionCookieName = "excalidraw_session"
|
||||
|
||||
type API struct {
|
||||
store *Store
|
||||
limiter *rateLimiter
|
||||
testMode bool
|
||||
}
|
||||
|
||||
func NewAPI(store *Store) *API {
|
||||
return &API{
|
||||
store: store,
|
||||
limiter: newRateLimiter(10, 15*time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get("/health", a.handleHealth)
|
||||
r.Get("/auth/setup-status", a.handleSetupStatus)
|
||||
r.Post("/auth/signup", a.handleSignup)
|
||||
r.Post("/auth/login", a.handleLogin)
|
||||
r.Post("/auth/logout", a.handleLogout)
|
||||
r.Get("/shared/{token}", a.handleSharedResource)
|
||||
})
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(a.requireSession)
|
||||
r.Use(requireSameOriginMutation)
|
||||
r.Get("/auth/me", a.handleMe)
|
||||
r.Get("/teams", a.handleListTeams)
|
||||
r.Post("/teams", a.handleCreateTeam)
|
||||
r.Patch("/teams/{teamID}", a.handleUpdateTeam)
|
||||
r.Get("/teams/{teamID}/members", a.handleListTeamMembers)
|
||||
r.Get("/teams/{teamID}/invites", a.handleListTeamInvites)
|
||||
r.Post("/teams/{teamID}/invites", a.handleCreateTeamInvite)
|
||||
r.Post("/teams/{teamID}/users", a.handleCreateTeamUser)
|
||||
r.Post("/invites/accept", a.handleAcceptInvite)
|
||||
r.Get("/drawings", a.handleListDrawings)
|
||||
r.Post("/drawings", a.handleCreateDrawing)
|
||||
r.Get("/drawings/{drawingID}", a.handleGetDrawing)
|
||||
r.Patch("/drawings/{drawingID}", a.handleUpdateDrawing)
|
||||
r.Delete("/drawings/{drawingID}", a.handleArchiveDrawing)
|
||||
r.Get("/drawings/{drawingID}/revisions", a.handleListRevisions)
|
||||
r.Post("/drawings/{drawingID}/revisions", a.handleCreateRevision)
|
||||
r.Get("/search", a.handleSearch)
|
||||
r.Get("/drawings/{drawingID}/permissions", a.handleListPermissions)
|
||||
r.Post("/drawings/{drawingID}/permissions", a.handleCreatePermission)
|
||||
r.Get("/drawings/{drawingID}/share-links", a.handleListShareLinks)
|
||||
r.Post("/drawings/{drawingID}/share-links", a.handleCreateShareLink)
|
||||
r.Get("/drawings/{drawingID}/assets", a.handleListAssets)
|
||||
r.Post("/drawings/{drawingID}/assets", a.handleCreateAsset)
|
||||
r.Get("/drawings/{drawingID}/embeds", a.handleListEmbeds)
|
||||
r.Post("/drawings/{drawingID}/embeds", a.handleCreateEmbed)
|
||||
r.Get("/drawings/{drawingID}/links", a.handleListLinks)
|
||||
r.Post("/drawings/{drawingID}/links", a.handleCreateLink)
|
||||
r.Get("/drawings/{drawingID}/thumbnail", a.handleThumbnail)
|
||||
r.Get("/templates", a.handleListTemplates)
|
||||
r.Get("/activity", a.handleListActivity)
|
||||
r.Get("/stats", a.handleStats)
|
||||
r.Get("/folders", a.handleListFolders)
|
||||
r.Post("/folders", a.handleCreateFolder)
|
||||
r.Get("/projects", a.handleListProjects)
|
||||
r.Post("/projects", a.handleCreateProject)
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if err := a.store.Ping(r.Context()); err != nil {
|
||||
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"status": "unhealthy"})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||
}
|
||||
|
||||
func (a *API) handleSetupStatus(w http.ResponseWriter, r *http.Request) {
|
||||
hasUsers, err := a.store.UserExists(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to check setup status")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"has_users": hasUsers})
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const currentUserKey = contextKey("workspace_user")
|
||||
|
||||
func (a *API) requireSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
writeError(w, http.StatusUnauthorized, "Authentication required")
|
||||
return
|
||||
}
|
||||
user, session, err := a.store.UserBySessionToken(r.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
clearSessionCookie(w, r)
|
||||
writeError(w, http.StatusUnauthorized, "Authentication required")
|
||||
return
|
||||
}
|
||||
ctx := withUser(r.Context(), user, session)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func requireSameOriginMutation(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
expectedHTTP := "http://" + r.Host
|
||||
expectedHTTPS := "https://" + r.Host
|
||||
if origin != expectedHTTP && origin != expectedHTTPS {
|
||||
writeError(w, http.StatusForbidden, "Cross-origin mutation denied")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) handleSignup(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
// First-run: only allow signup if no users exist yet
|
||||
if !a.testMode {
|
||||
hasUsers, err := a.store.UserExists(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to check setup status")
|
||||
return
|
||||
}
|
||||
if hasUsers {
|
||||
writeError(w, http.StatusForbidden, "Registration is closed. Contact an administrator.")
|
||||
return
|
||||
}
|
||||
}
|
||||
ipKey := "signup:" + clientIP(r)
|
||||
if !a.limiter.allow(ipKey) {
|
||||
writeError(w, http.StatusTooManyRequests, "Too many signup attempts")
|
||||
return
|
||||
}
|
||||
user, session, token, err := a.store.CreateUserWithPassword(r.Context(), req.Name, req.Email, req.Password)
|
||||
if err != nil {
|
||||
status := http.StatusBadRequest
|
||||
if errors.Is(err, ErrConflict) {
|
||||
status = http.StatusConflict
|
||||
}
|
||||
writeError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
setSessionCookie(w, r, token, session.ExpiresAt)
|
||||
writeJSON(w, http.StatusCreated, map[string]any{"user": user, "session": session})
|
||||
}
|
||||
|
||||
func (a *API) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if !decodeJSON(w, r, &req, 32<<10) {
|
||||
return
|
||||
}
|
||||
key := "login:" + clientIP(r) + ":" + strings.ToLower(strings.TrimSpace(req.Email))
|
||||
if !a.limiter.allow(key) {
|
||||
writeError(w, http.StatusTooManyRequests, "Too many login attempts")
|
||||
return
|
||||
}
|
||||
user, session, token, err := a.store.AuthenticatePassword(r.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "Invalid email or password")
|
||||
return
|
||||
}
|
||||
setSessionCookie(w, r, token, session.ExpiresAt)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"user": user, "session": session})
|
||||
}
|
||||
|
||||
func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if cookie, err := r.Cookie(sessionCookieName); err == nil && cookie.Value != "" {
|
||||
if err := a.store.DeleteSession(r.Context(), cookie.Value); err != nil {
|
||||
logrus.WithError(err).Warn("failed to delete session")
|
||||
}
|
||||
}
|
||||
clearSessionCookie(w, r)
|
||||
writeJSON(w, http.StatusOK, map[string]any{})
|
||||
}
|
||||
|
||||
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
writeJSON(w, http.StatusOK, user)
|
||||
}
|
||||
|
||||
func (a *API) handleListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teams, err := a.store.ListTeamsForUser(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to list teams")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, teams)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateTeam(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
}
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
team, err := a.store.CreateTeam(r.Context(), user.ID, req.Name, req.Slug)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, team)
|
||||
}
|
||||
|
||||
func (a *API) handleUpdateTeam(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req struct {
|
||||
Name *string `json:"name"`
|
||||
Slug *string `json:"slug"`
|
||||
}
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
team, err := a.store.UpdateTeam(r.Context(), user.ID, chi.URLParam(r, "teamID"), req.Name, req.Slug)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, team)
|
||||
}
|
||||
|
||||
func (a *API) handleListTeamMembers(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := chi.URLParam(r, "teamID")
|
||||
if ok, err := a.store.UserCanAccessTeam(r.Context(), user.ID, teamID); err != nil || !ok {
|
||||
writeError(w, http.StatusForbidden, "Team access denied")
|
||||
return
|
||||
}
|
||||
members, err := a.store.ListTeamMembers(r.Context(), teamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to list team members")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, members)
|
||||
}
|
||||
|
||||
func (a *API) handleListDrawings(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
|
||||
drawings, err := a.store.ListDrawings(r.Context(), user.ID, teamID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrForbidden) {
|
||||
writeError(w, http.StatusForbidden, "Team access denied")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "Failed to list drawings")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, drawings)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateDrawing(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateDrawingRequest
|
||||
if !decodeJSON(w, r, &req, 256<<10) {
|
||||
return
|
||||
}
|
||||
drawing, err := a.store.CreateDrawing(r.Context(), user.ID, req)
|
||||
if err != nil {
|
||||
status := http.StatusBadRequest
|
||||
if errors.Is(err, ErrForbidden) {
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
writeError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, drawing)
|
||||
}
|
||||
|
||||
func (a *API) handleGetDrawing(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
drawing, err := a.store.GetDrawing(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, drawing)
|
||||
}
|
||||
|
||||
func (a *API) handleUpdateDrawing(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req UpdateDrawingRequest
|
||||
if !decodeJSON(w, r, &req, 256<<10) {
|
||||
return
|
||||
}
|
||||
drawing, err := a.store.UpdateDrawing(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, drawing)
|
||||
}
|
||||
|
||||
func (a *API) handleArchiveDrawing(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
if err := a.store.ArchiveDrawing(r.Context(), user.ID, chi.URLParam(r, "drawingID")); err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) handleListRevisions(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
revisions, err := a.store.ListRevisions(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, revisions)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateRevision(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateRevisionRequest
|
||||
if !decodeJSON(w, r, &req, 10<<20) {
|
||||
return
|
||||
}
|
||||
revision, err := a.store.CreateRevision(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, revision)
|
||||
}
|
||||
|
||||
func (a *API) handleThumbnail(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
drawingID := chi.URLParam(r, "drawingID")
|
||||
revisions, err := a.store.ListRevisions(r.Context(), user.ID, drawingID)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
if len(revisions) == 0 || revisions[0].Snapshot == nil {
|
||||
w.Header().Set("Content-Type", "image/svg+xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<svg xmlns="http://www.w3.org/2000/svg" width="320" height="240" viewBox="0 0 320 240"><rect width="320" height="240" fill="#f8f9fa"/><text x="160" y="120" text-anchor="middle" font-family="sans-serif" font-size="14" fill="#999">No preview</text></svg>`))
|
||||
return
|
||||
}
|
||||
|
||||
var snapshot struct {
|
||||
Elements []struct {
|
||||
Type string `json:"type"`
|
||||
X float64 `json:"x"`
|
||||
Y float64 `json:"y"`
|
||||
Width float64 `json:"width"`
|
||||
Height float64 `json:"height"`
|
||||
Stroke string `json:"strokeColor"`
|
||||
Bg string `json:"backgroundColor"`
|
||||
Text string `json:"text"`
|
||||
} `json:"elements"`
|
||||
}
|
||||
if err := json.Unmarshal(revisions[0].Snapshot, &snapshot); err != nil {
|
||||
w.Header().Set("Content-Type", "image/svg+xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<svg xmlns="http://www.w3.org/2000/svg" width="320" height="240" viewBox="0 0 320 240"><rect width="320" height="240" fill="#f8f9fa"/><text x="160" y="120" text-anchor="middle" font-family="sans-serif" font-size="14" fill="#999">Preview unavailable</text></svg>`))
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a simple SVG thumbnail from element bounding boxes
|
||||
const vw, vh = 320, 240
|
||||
var b strings.Builder
|
||||
b.WriteString(`<svg xmlns="http://www.w3.org/2000/svg" width="` + itoa(vw) + `" height="` + itoa(vh) + `" viewBox="0 0 ` + itoa(vw) + ` ` + itoa(vh) + `">`)
|
||||
b.WriteString(`<rect width="` + itoa(vw) + `" height="` + itoa(vh) + `" fill="#ffffff"/>`)
|
||||
|
||||
// Compute bounding box to fit elements into view
|
||||
minX, minY, maxX, maxY := 1e9, 1e9, -1e9, -1e9
|
||||
for _, el := range snapshot.Elements {
|
||||
if el.X < minX {
|
||||
minX = el.X
|
||||
}
|
||||
if el.Y < minY {
|
||||
minY = el.Y
|
||||
}
|
||||
if el.X+el.Width > maxX {
|
||||
maxX = el.X + el.Width
|
||||
}
|
||||
if el.Y+el.Height > maxY {
|
||||
maxY = el.Y + el.Height
|
||||
}
|
||||
}
|
||||
if maxX <= minX || maxY <= minY {
|
||||
minX, minY, maxX, maxY = 0, 0, 320, 240
|
||||
}
|
||||
pad := 20.0
|
||||
scaleX := float64(vw-40) / (maxX - minX + 1e-6)
|
||||
scaleY := float64(vh-40) / (maxY - minY + 1e-6)
|
||||
scale := scaleX
|
||||
if scaleY < scaleX {
|
||||
scale = scaleY
|
||||
}
|
||||
offX := pad - minX*scale
|
||||
offY := pad - minY*scale
|
||||
|
||||
for _, el := range snapshot.Elements {
|
||||
x := el.X*scale + offX
|
||||
y := el.Y*scale + offY
|
||||
w := el.Width * scale
|
||||
h := el.Height * scale
|
||||
stroke := el.Stroke
|
||||
if stroke == "" {
|
||||
stroke = "#1e1e1e"
|
||||
}
|
||||
bg := el.Bg
|
||||
if bg == "" || bg == "transparent" {
|
||||
bg = "none"
|
||||
}
|
||||
switch el.Type {
|
||||
case "rectangle", "diamond":
|
||||
b.WriteString(`<rect x="` + ftoa(x) + `" y="` + ftoa(y) + `" width="` + ftoa(w) + `" height="` + ftoa(h) + `" fill="` + bg + `" stroke="` + stroke + `" stroke-width="1"/>`)
|
||||
case "ellipse":
|
||||
b.WriteString(`<ellipse cx="` + ftoa(x+w/2) + `" cy="` + ftoa(y+h/2) + `" rx="` + ftoa(w/2) + `" ry="` + ftoa(h/2) + `" fill="` + bg + `" stroke="` + stroke + `" stroke-width="1"/>`)
|
||||
case "line", "arrow":
|
||||
b.WriteString(`<line x1="` + ftoa(x) + `" y1="` + ftoa(y+h/2) + `" x2="` + ftoa(x+w) + `" y2="` + ftoa(y+h/2) + `" stroke="` + stroke + `" stroke-width="1"/>`)
|
||||
case "text":
|
||||
b.WriteString(`<text x="` + ftoa(x) + `" y="` + ftoa(y+h/2) + `" font-family="sans-serif" font-size="12" fill="` + stroke + `">` + htmlEscape(el.Text) + `</text>`)
|
||||
default:
|
||||
b.WriteString(`<rect x="` + ftoa(x) + `" y="` + ftoa(y) + `" width="` + ftoa(w) + `" height="` + ftoa(h) + `" fill="none" stroke="#ccc" stroke-width="0.5"/>`)
|
||||
}
|
||||
}
|
||||
b.WriteString(`</svg>`)
|
||||
|
||||
w.Header().Set("Content-Type", "image/svg+xml")
|
||||
w.Header().Set("Cache-Control", "max-age=60")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(b.String()))
|
||||
}
|
||||
|
||||
func ftoa(f float64) string { return strconv.FormatFloat(f, 'f', 2, 64) }
|
||||
func itoa(i int) string { return strconv.Itoa(i) }
|
||||
func htmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
return s
|
||||
}
|
||||
|
||||
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
|
||||
templates, err := a.store.ListTemplates(r.Context(), user.ID, teamID)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, templates)
|
||||
}
|
||||
|
||||
func (a *API) handleListActivity(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
|
||||
activity, err := a.store.ListActivity(r.Context(), user.ID, teamID, 50)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, activity)
|
||||
}
|
||||
|
||||
func (a *API) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
|
||||
stats, err := a.store.WorkspaceStats(r.Context(), user.ID, teamID)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, stats)
|
||||
}
|
||||
|
||||
func (a *API) handleListFolders(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
|
||||
folders, err := a.store.ListFolders(r.Context(), user.ID, teamID)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, folders)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateFolderRequest
|
||||
if !decodeJSON(w, r, &req, 128<<10) {
|
||||
return
|
||||
}
|
||||
folder, err := a.store.CreateFolder(r.Context(), user.ID, req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, folder)
|
||||
}
|
||||
|
||||
func (a *API) handleListProjects(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
|
||||
projects, err := a.store.ListProjects(r.Context(), user.ID, teamID)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, projects)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateProject(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateProjectRequest
|
||||
if !decodeJSON(w, r, &req, 128<<10) {
|
||||
return
|
||||
}
|
||||
project, err := a.store.CreateProject(r.Context(), user.ID, req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, project)
|
||||
}
|
||||
|
||||
func (a *API) handleSearch(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
q := strings.TrimSpace(r.URL.Query().Get("q"))
|
||||
if q == "" {
|
||||
writeJSON(w, http.StatusOK, []Drawing{})
|
||||
return
|
||||
}
|
||||
drawings, err := a.store.SearchDrawings(r.Context(), user.ID, q)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, drawings)
|
||||
}
|
||||
|
||||
func writeLookupError(w http.ResponseWriter, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ErrForbidden):
|
||||
writeError(w, http.StatusForbidden, "Access denied")
|
||||
case errors.Is(err, sql.ErrNoRows), errors.Is(err, ErrNotFound):
|
||||
writeError(w, http.StatusNotFound, "Resource not found")
|
||||
default:
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSON(w http.ResponseWriter, r *http.Request, dst any, limit int64) bool {
|
||||
defer r.Body.Close()
|
||||
r.Body = http.MaxBytesReader(w, r.Body, limit)
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(dst); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "Invalid request body")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, body any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(body); err != nil {
|
||||
logrus.WithError(err).Warn("failed to encode response")
|
||||
}
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
func setSessionCookie(w http.ResponseWriter, r *http.Request, token string, expires time.Time) {
|
||||
SetSessionCookie(w, r, token, expires)
|
||||
}
|
||||
|
||||
func SetSessionCookie(w http.ResponseWriter, r *http.Request, token string, expires time.Time) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
Expires: expires,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func clearSessionCookie(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func isSecureRequest(r *http.Request) bool {
|
||||
return r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https")
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
||||
return strings.TrimSpace(strings.Split(forwarded, ",")[0])
|
||||
}
|
||||
host := r.RemoteAddr
|
||||
if idx := strings.LastIndex(host, ":"); idx > 0 {
|
||||
return host[:idx]
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func (a *API) handleListTeamInvites(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
invites, err := a.store.ListTeamInvites(r.Context(), user.ID, chi.URLParam(r, "teamID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, invites)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateTeamInvite(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateInviteRequest
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
invite, token, err := a.store.CreateTeamInvite(r.Context(), user.ID, chi.URLParam(r, "teamID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, map[string]any{"invite": invite, "token": token})
|
||||
}
|
||||
|
||||
func (a *API) handleAcceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if !decodeJSON(w, r, &req, 32<<10) {
|
||||
return
|
||||
}
|
||||
membership, err := a.store.AcceptInvite(r.Context(), user.ID, req.Token)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, membership)
|
||||
}
|
||||
|
||||
func (a *API) handleListPermissions(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
grants, err := a.store.ListPermissionGrants(r.Context(), user.ID, "drawing", chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, grants)
|
||||
}
|
||||
|
||||
func (a *API) handleCreatePermission(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreatePermissionGrantRequest
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
grant, err := a.store.CreateDrawingPermissionGrant(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, grant)
|
||||
}
|
||||
|
||||
func (a *API) handleListShareLinks(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
links, err := a.store.ListShareLinks(r.Context(), user.ID, "drawing", chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, links)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateShareLink(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateShareLinkRequest
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
link, token, err := a.store.CreateDrawingShareLink(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, map[string]any{"share_link": link, "token": token})
|
||||
}
|
||||
|
||||
func (a *API) handleSharedResource(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := a.store.SharedResourceByToken(r.Context(), chi.URLParam(r, "token"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, payload)
|
||||
}
|
||||
|
||||
func (a *API) handleListAssets(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
assets, err := a.store.ListDrawingAssets(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, assets)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateAsset(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateAssetRequest
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
asset, err := a.store.CreateDrawingAsset(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, asset)
|
||||
}
|
||||
|
||||
func (a *API) handleListEmbeds(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
embeds, err := a.store.ListEmbeds(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, embeds)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateEmbed(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateEmbedRequest
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
embed, err := a.store.CreateEmbed(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, embed)
|
||||
}
|
||||
|
||||
func (a *API) handleListLinks(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
links, err := a.store.ListLinkReferences(r.Context(), user.ID, "drawing", chi.URLParam(r, "drawingID"))
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, links)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateLink(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
var req CreateLinkRequest
|
||||
if !decodeJSON(w, r, &req, 64<<10) {
|
||||
return
|
||||
}
|
||||
link, err := a.store.CreateDrawingLinkReference(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
|
||||
if err != nil {
|
||||
writeLookupError(w, err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, link)
|
||||
}
|
||||
|
||||
func (a *API) handleCreateTeamUser(w http.ResponseWriter, r *http.Request) {
|
||||
user, _ := currentUser(r)
|
||||
teamID := chi.URLParam(r, "teamID")
|
||||
var role string
|
||||
err := a.store.db.QueryRowContext(r.Context(), `SELECT role FROM workspace_team_memberships WHERE user_id = ? AND team_id = ?`, user.ID, teamID).Scan(&role)
|
||||
if errors.Is(err, sql.ErrNoRows) || (err == nil && role != "owner" && role != "admin") {
|
||||
writeError(w, http.StatusForbidden, "Only team owners and admins can add members")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to verify team access")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
if !decodeJSON(w, r, &req, 32<<10) {
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := a.store.CreateTeamUser(r.Context(), teamID, req.Name, req.Email, req.Password, req.Role)
|
||||
if err != nil {
|
||||
status := http.StatusBadRequest
|
||||
if errors.Is(err, ErrConflict) {
|
||||
status = http.StatusConflict
|
||||
}
|
||||
writeError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, newUser)
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
dbpostgres "excalidraw-complete/internal/postgres"
|
||||
"net/url"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) (*Store, func()) {
|
||||
t.Helper()
|
||||
baseURL := os.Getenv("TEST_DATABASE_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = os.Getenv("DATABASE_URL")
|
||||
}
|
||||
if baseURL == "" {
|
||||
t.Skip("TEST_DATABASE_URL or DATABASE_URL is required for PostgreSQL workspace tests")
|
||||
}
|
||||
schema := "test_" + strings.ToLower(newID())
|
||||
adminDB, err := dbpostgres.Open(baseURL)
|
||||
if err != nil {
|
||||
t.Fatalf("open test database error = %v", err)
|
||||
}
|
||||
if _, err := adminDB.DB.ExecContext(context.Background(), `CREATE SCHEMA "`+schema+`"`); err != nil {
|
||||
adminDB.Close()
|
||||
t.Fatalf("create test schema error = %v", err)
|
||||
}
|
||||
store, err := NewStore(databaseURLWithSearchPath(t, baseURL, schema))
|
||||
if err != nil {
|
||||
adminDB.DB.ExecContext(context.Background(), `DROP SCHEMA "`+schema+`" CASCADE`)
|
||||
adminDB.Close()
|
||||
t.Fatalf("NewStore() error = %v", err)
|
||||
}
|
||||
return store, func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
if _, err := adminDB.DB.ExecContext(context.Background(), `DROP SCHEMA "`+schema+`" CASCADE`); err != nil {
|
||||
t.Fatalf("drop test schema error = %v", err)
|
||||
}
|
||||
if err := adminDB.Close(); err != nil {
|
||||
t.Fatalf("admin Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func databaseURLWithSearchPath(t *testing.T, rawURL, schema string) string {
|
||||
t.Helper()
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse database URL error = %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
q.Set("search_path", schema)
|
||||
parsed.RawQuery = q.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func newTestAPI(t *testing.T) (*API, func()) {
|
||||
t.Helper()
|
||||
store, cleanup := newTestStore(t)
|
||||
api := NewAPI(store)
|
||||
api.testMode = true
|
||||
return api, cleanup
|
||||
}
|
||||
|
||||
func doJSON(t *testing.T, api *API, method, path string, body any, cookies ...*http.Cookie) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var raw bytes.Buffer
|
||||
if body != nil {
|
||||
if err := json.NewEncoder(&raw).Encode(body); err != nil {
|
||||
t.Fatalf("json encode error = %v", err)
|
||||
}
|
||||
}
|
||||
req := httptest.NewRequest(method, path, &raw)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
api.Routes().ServeHTTP(rr, req)
|
||||
return rr
|
||||
}
|
||||
|
||||
func signup(t *testing.T, api *API, email string) (*http.Cookie, User, Team) {
|
||||
t.Helper()
|
||||
rr := doJSON(t, api, http.MethodPost, "/auth/signup", map[string]string{
|
||||
"name": "Test User",
|
||||
"email": email,
|
||||
"password": "password-123",
|
||||
})
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("signup status = %d body = %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
var payload struct {
|
||||
User User `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("signup response decode error = %v", err)
|
||||
}
|
||||
var sessionCookie *http.Cookie
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
if cookie.Name == sessionCookieName {
|
||||
copy := *cookie
|
||||
sessionCookie = ©
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionCookie == nil {
|
||||
t.Fatal("signup did not set session cookie")
|
||||
}
|
||||
teamsRR := doJSON(t, api, http.MethodGet, "/teams", nil, sessionCookie)
|
||||
if teamsRR.Code != http.StatusOK {
|
||||
t.Fatalf("teams status = %d body = %s", teamsRR.Code, teamsRR.Body.String())
|
||||
}
|
||||
var teams []Team
|
||||
if err := json.Unmarshal(teamsRR.Body.Bytes(), &teams); err != nil {
|
||||
t.Fatalf("teams decode error = %v", err)
|
||||
}
|
||||
if len(teams) != 1 {
|
||||
t.Fatalf("teams len = %d, want 1", len(teams))
|
||||
}
|
||||
return sessionCookie, payload.User, teams[0]
|
||||
}
|
||||
|
||||
func TestSignupCreatesCookieSessionAndDefaultTeam(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
cookie, user, team := signup(t, api, "alice@example.com")
|
||||
if !cookie.HttpOnly {
|
||||
t.Fatal("session cookie must be httpOnly")
|
||||
}
|
||||
if user.ID == "" || user.Email != "alice@example.com" {
|
||||
t.Fatalf("unexpected user: %#v", user)
|
||||
}
|
||||
if team.OwnerUserID != user.ID || team.PlanType != "free" {
|
||||
t.Fatalf("unexpected team: %#v", team)
|
||||
}
|
||||
|
||||
meRR := doJSON(t, api, http.MethodGet, "/auth/me", nil, cookie)
|
||||
if meRR.Code != http.StatusOK {
|
||||
t.Fatalf("me status = %d body = %s", meRR.Code, meRR.Body.String())
|
||||
}
|
||||
|
||||
logoutRR := doJSON(t, api, http.MethodPost, "/auth/logout", nil, cookie)
|
||||
if logoutRR.Code != http.StatusOK {
|
||||
t.Fatalf("logout status = %d body = %s", logoutRR.Code, logoutRR.Body.String())
|
||||
}
|
||||
if logoutRR.Body.String() == "" {
|
||||
t.Fatal("logout must return JSON for frontend fetchApi compatibility")
|
||||
}
|
||||
|
||||
afterLogoutRR := doJSON(t, api, http.MethodGet, "/auth/me", nil, cookie)
|
||||
if afterLogoutRR.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("me after logout status = %d body = %s", afterLogoutRR.Code, afterLogoutRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrawingAccessRequiresTeamMembership(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"title": "Architecture map",
|
||||
}, aliceCookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
forbiddenRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
|
||||
if forbiddenRR.Code != http.StatusForbidden {
|
||||
t.Fatalf("bob get drawing status = %d body = %s", forbiddenRR.Code, forbiddenRR.Body.String())
|
||||
}
|
||||
|
||||
listRR := doJSON(t, api, http.MethodGet, "/drawings?team_id="+aliceTeam.ID, nil, bobCookie)
|
||||
if listRR.Code != http.StatusForbidden {
|
||||
t.Fatalf("bob list team drawings status = %d body = %s", listRR.Code, listRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamMembersRequireMembership(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
okRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, aliceCookie)
|
||||
if okRR.Code != http.StatusOK {
|
||||
t.Fatalf("alice members status = %d body = %s", okRR.Code, okRR.Body.String())
|
||||
}
|
||||
var members []TeamMembership
|
||||
if err := json.Unmarshal(okRR.Body.Bytes(), &members); err != nil {
|
||||
t.Fatalf("members decode error = %v", err)
|
||||
}
|
||||
if len(members) != 1 || members[0].Role != "owner" || members[0].User == nil {
|
||||
t.Fatalf("unexpected members: %#v", members)
|
||||
}
|
||||
|
||||
forbiddenRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
|
||||
if forbiddenRR.Code != http.StatusForbidden {
|
||||
t.Fatalf("bob members status = %d body = %s", forbiddenRR.Code, forbiddenRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrawingRevisionsTemplatesAndActivity(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
cookie, _, team := signup(t, api, "alice@example.com")
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": team.ID,
|
||||
"title": "Launch plan",
|
||||
"snapshot": map[string]any{"type": "excalidraw", "elements": []any{}},
|
||||
}, cookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
if drawing.LatestRevisionID == nil {
|
||||
t.Fatal("create drawing with snapshot must create latest_revision_id")
|
||||
}
|
||||
|
||||
revisionRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/revisions", map[string]any{
|
||||
"snapshot": map[string]any{"type": "excalidraw", "elements": []any{map[string]any{"id": "a"}}},
|
||||
"change_summary": "Added first shape",
|
||||
}, cookie)
|
||||
if revisionRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create revision status = %d body = %s", revisionRR.Code, revisionRR.Body.String())
|
||||
}
|
||||
|
||||
revisionsRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID+"/revisions", nil, cookie)
|
||||
if revisionsRR.Code != http.StatusOK {
|
||||
t.Fatalf("list revisions status = %d body = %s", revisionsRR.Code, revisionsRR.Body.String())
|
||||
}
|
||||
var revisions []DrawingRevision
|
||||
if err := json.Unmarshal(revisionsRR.Body.Bytes(), &revisions); err != nil {
|
||||
t.Fatalf("revisions decode error = %v", err)
|
||||
}
|
||||
if len(revisions) != 2 || revisions[0].RevisionNumber != 2 {
|
||||
t.Fatalf("unexpected revisions: %#v", revisions)
|
||||
}
|
||||
|
||||
templatesRR := doJSON(t, api, http.MethodGet, "/templates", nil, cookie)
|
||||
if templatesRR.Code != http.StatusOK {
|
||||
t.Fatalf("templates status = %d body = %s", templatesRR.Code, templatesRR.Body.String())
|
||||
}
|
||||
var templates []Template
|
||||
if err := json.Unmarshal(templatesRR.Body.Bytes(), &templates); err != nil {
|
||||
t.Fatalf("templates decode error = %v", err)
|
||||
}
|
||||
if len(templates) < 4 {
|
||||
t.Fatalf("templates len = %d, want at least 4", len(templates))
|
||||
}
|
||||
|
||||
activityRR := doJSON(t, api, http.MethodGet, "/activity?team_id="+team.ID, nil, cookie)
|
||||
if activityRR.Code != http.StatusOK {
|
||||
t.Fatalf("activity status = %d body = %s", activityRR.Code, activityRR.Body.String())
|
||||
}
|
||||
var activity []ActivityEvent
|
||||
if err := json.Unmarshal(activityRR.Body.Bytes(), &activity); err != nil {
|
||||
t.Fatalf("activity decode error = %v", err)
|
||||
}
|
||||
if len(activity) == 0 {
|
||||
t.Fatal("expected activity events")
|
||||
}
|
||||
|
||||
statsRR := doJSON(t, api, http.MethodGet, "/stats?team_id="+team.ID, nil, cookie)
|
||||
if statsRR.Code != http.StatusOK {
|
||||
t.Fatalf("stats status = %d body = %s", statsRR.Code, statsRR.Body.String())
|
||||
}
|
||||
var stats WorkspaceStats
|
||||
if err := json.Unmarshal(statsRR.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("stats decode error = %v", err)
|
||||
}
|
||||
if stats.Drawings != 1 || stats.Revisions != 2 || stats.Templates < 4 {
|
||||
t.Fatalf("unexpected stats: %#v", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
rr := doJSON(t, api, http.MethodGet, "/health", nil)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("health status = %d body = %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Locale string `json:"locale"`
|
||||
Timezone string `json:"timezone"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type Team struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
OwnerUserID string `json:"owner_user_id"`
|
||||
PlanType string `json:"plan_type"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type TeamMembership struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
User *User `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type TeamInvite struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
InvitedBy string `json:"invited_by"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Description *string `json:"description"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Folder struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
ProjectID *string `json:"project_id"`
|
||||
ParentFolderID *string `json:"parent_folder_id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
PathCache string `json:"path_cache"`
|
||||
Visibility string `json:"visibility"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Drawing struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
FolderID *string `json:"folder_id"`
|
||||
ProjectID *string `json:"project_id"`
|
||||
Slug *string `json:"slug"`
|
||||
Title string `json:"title"`
|
||||
Description *string `json:"description"`
|
||||
OwnerUserID string `json:"owner_user_id"`
|
||||
LatestRevisionID *string `json:"latest_revision_id"`
|
||||
Visibility string `json:"visibility"`
|
||||
IsArchived bool `json:"is_archived"`
|
||||
ThumbnailAssetID *string `json:"thumbnail_asset_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt *time.Time `json:"deleted_at"`
|
||||
Owner *User `json:"owner,omitempty"`
|
||||
Folder *Folder `json:"folder,omitempty"`
|
||||
Project *Project `json:"project,omitempty"`
|
||||
ThumbnailURL *string `json:"thumbnail_url,omitempty"`
|
||||
}
|
||||
|
||||
type DrawingRevision struct {
|
||||
ID string `json:"id"`
|
||||
DrawingID string `json:"drawing_id"`
|
||||
RevisionNumber int `json:"revision_number"`
|
||||
SnapshotPath string `json:"snapshot_path"`
|
||||
SnapshotSize int64 `json:"snapshot_size"`
|
||||
ContentHash string `json:"content_hash"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ChangeSummary *string `json:"change_summary"`
|
||||
Snapshot json.RawMessage `json:"snapshot,omitempty"`
|
||||
CreatedByUser *User `json:"created_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type DrawingAsset struct {
|
||||
ID string `json:"id"`
|
||||
DrawingID string `json:"drawing_id"`
|
||||
Kind string `json:"kind"`
|
||||
Path string `json:"path"`
|
||||
MimeType string `json:"mime_type"`
|
||||
Size int64 `json:"size"`
|
||||
Width *int `json:"width"`
|
||||
Height *int `json:"height"`
|
||||
UploadedBy string `json:"uploaded_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
URL *string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type Template struct {
|
||||
ID string `json:"id"`
|
||||
TeamID *string `json:"team_id"`
|
||||
Scope string `json:"scope"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
SnapshotPath string `json:"snapshot_path"`
|
||||
MetadataJSON map[string]any `json:"metadata_json"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
PreviewURL *string `json:"preview_url,omitempty"`
|
||||
}
|
||||
|
||||
type ActivityEvent struct {
|
||||
ID string `json:"id"`
|
||||
ActorUserID *string `json:"actor_user_id"`
|
||||
TeamID *string `json:"team_id"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
EventType string `json:"event_type"`
|
||||
MetadataJSON map[string]any `json:"metadata_json"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Actor *User `json:"actor,omitempty"`
|
||||
}
|
||||
|
||||
type ShareLink struct {
|
||||
ID string `json:"id"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
TokenHash string `json:"token_hash,omitempty"`
|
||||
Permission string `json:"permission"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
PasswordHash *string `json:"-"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
RevokedAt *time.Time `json:"revoked_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type PermissionGrant struct {
|
||||
ID string `json:"id"`
|
||||
ResourceType string `json:"resource_type"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
SubjectType string `json:"subject_type"`
|
||||
SubjectID string `json:"subject_id"`
|
||||
Permission string `json:"permission"`
|
||||
InheritedFrom *string `json:"inherited_from"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type Embed struct {
|
||||
ID string `json:"id"`
|
||||
DrawingID string `json:"drawing_id"`
|
||||
SourceURL string `json:"source_url"`
|
||||
CanonicalURL string `json:"canonical_url"`
|
||||
Provider string `json:"provider"`
|
||||
EmbedType string `json:"embed_type"`
|
||||
Title *string `json:"title"`
|
||||
PreviewAssetID *string `json:"preview_asset_id"`
|
||||
SafeEmbedHTML *string `json:"safe_embed_html"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type LinkReference struct {
|
||||
ID string `json:"id"`
|
||||
SourceResourceType string `json:"source_resource_type"`
|
||||
SourceResourceID string `json:"source_resource_id"`
|
||||
TargetResourceType string `json:"target_resource_type"`
|
||||
TargetResourceID string `json:"target_resource_id"`
|
||||
Label *string `json:"label"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type WorkspaceStats struct {
|
||||
Teams int `json:"teams"`
|
||||
Members int `json:"members"`
|
||||
Projects int `json:"projects"`
|
||||
Folders int `json:"folders"`
|
||||
Drawings int `json:"drawings"`
|
||||
Templates int `json:"templates"`
|
||||
Revisions int `json:"revisions"`
|
||||
Assets int `json:"assets"`
|
||||
StorageBytes int64 `json:"storage_bytes"`
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"errors"
|
||||
dbpostgres "excalidraw-complete/internal/postgres"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type OAuthProfile struct {
|
||||
Provider string
|
||||
ProviderUserID string
|
||||
Email string
|
||||
Name string
|
||||
Username string
|
||||
AvatarURL string
|
||||
EmailVerified bool
|
||||
}
|
||||
|
||||
func (s *Store) UpsertOAuthSession(ctx context.Context, profile OAuthProfile) (*User, *Session, string, error) {
|
||||
profile.Provider = strings.TrimSpace(strings.ToLower(profile.Provider))
|
||||
profile.ProviderUserID = strings.TrimSpace(profile.ProviderUserID)
|
||||
if profile.Provider == "" || profile.ProviderUserID == "" {
|
||||
return nil, nil, "", fmt.Errorf("oauth provider and provider user id are required")
|
||||
}
|
||||
email := strings.TrimSpace(profile.Email)
|
||||
if email == "" {
|
||||
email = fmt.Sprintf("%s-%s@users.local", profile.Provider, slugify(profile.ProviderUserID))
|
||||
}
|
||||
normalizedEmail, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(profile.Username)
|
||||
}
|
||||
if name == "" {
|
||||
name = normalizedEmail
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
userID, err := userIDByIdentityTx(ctx, tx, profile.Provider, profile.ProviderUserID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
var user *User
|
||||
if userID != "" {
|
||||
user, err = updateOAuthUserTx(ctx, tx, userID, name, profile.Username, normalizedEmail, profile.AvatarURL)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
} else {
|
||||
userID, err = userIDByEmailTx(ctx, tx, normalizedEmail)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
if userID == "" {
|
||||
user, err = createOAuthUserTx(ctx, tx, name, profile.Username, normalizedEmail, profile.AvatarURL)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
team, err := createTeamTx(ctx, tx, user.ID, name+"'s Workspace", "")
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
if err := insertActivityTx(ctx, tx, &user.ID, &team.ID, "team", team.ID, "member_joined", map[string]any{"role": "owner"}); err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
} else {
|
||||
user, err = updateOAuthUserTx(ctx, tx, userID, name, profile.Username, normalizedEmail, profile.AvatarURL)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
}
|
||||
var verifiedAt *time.Time
|
||||
if profile.EmailVerified {
|
||||
verifiedAt = &now
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO workspace_auth_identities
|
||||
(id, user_id, provider, provider_user_id, email_verified_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
newID(), user.ID, profile.Provider, profile.ProviderUserID, verifiedAt, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
session, token, err := createSessionTx(ctx, tx, user.ID)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
if err := insertActivityTx(ctx, tx, &user.ID, nil, "user", user.ID, "login_success", map[string]any{"provider": profile.Provider}); err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
return user, session, token, nil
|
||||
}
|
||||
|
||||
func userIDByIdentityTx(ctx context.Context, tx *dbpostgres.Tx, provider, providerUserID string) (string, error) {
|
||||
var userID string
|
||||
err := tx.QueryRowContext(ctx, `SELECT user_id FROM workspace_auth_identities WHERE provider = ? AND provider_user_id = ?`, provider, providerUserID).Scan(&userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
func userIDByEmailTx(ctx context.Context, tx *dbpostgres.Tx, email string) (string, error) {
|
||||
var userID string
|
||||
err := tx.QueryRowContext(ctx, `SELECT id FROM workspace_users WHERE email = ?`, email).Scan(&userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
func createOAuthUserTx(ctx context.Context, tx *dbpostgres.Tx, name, username, email, avatarURL string) (*User, error) {
|
||||
password := make([]byte, 32)
|
||||
if _, err := rand.Read(password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hash, err := bcrypt.GenerateFromPassword(password, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if username == "" {
|
||||
username = strings.TrimSuffix(email, email[strings.LastIndex(email, "@"):])
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
user := &User{
|
||||
ID: newID(),
|
||||
Name: name,
|
||||
Username: uniqueUsername(ctx, tx, slugify(username)),
|
||||
Email: email,
|
||||
Locale: "en",
|
||||
Timezone: "UTC",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if avatarURL != "" {
|
||||
user.AvatarURL = &avatarURL
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO workspace_users
|
||||
(id, name, username, email, password_hash, avatar_url, locale, timezone, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
user.ID, user.Name, user.Username, user.Email, string(hash), user.AvatarURL, user.Locale, user.Timezone, user.CreatedAt, user.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func updateOAuthUserTx(ctx context.Context, tx *dbpostgres.Tx, userID, name, username, email, avatarURL string) (*User, error) {
|
||||
current := &User{}
|
||||
var currentAvatar *string
|
||||
err := tx.QueryRowContext(ctx, `SELECT id, name, username, email, avatar_url, locale, timezone, created_at, updated_at FROM workspace_users WHERE id = ?`, userID).
|
||||
Scan(¤t.ID, ¤t.Name, ¤t.Username, ¤t.Email, ¤tAvatar, ¤t.Locale, ¤t.Timezone, ¤t.CreatedAt, ¤t.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = current.Name
|
||||
}
|
||||
if strings.TrimSpace(username) == "" {
|
||||
username = current.Username
|
||||
}
|
||||
avatar := currentAvatar
|
||||
if avatarURL != "" {
|
||||
avatar = &avatarURL
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
_, err = tx.ExecContext(ctx, `UPDATE workspace_users SET name = ?, avatar_url = ?, updated_at = ? WHERE id = ?`, name, avatar, now, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
current.Name = name
|
||||
current.AvatarURL = avatar
|
||||
current.UpdatedAt = now
|
||||
return current, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpsertOAuthSessionCreatesAndReusesIdentity(t *testing.T) {
|
||||
store, cleanup := newTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
profile := OAuthProfile{
|
||||
Provider: "github",
|
||||
ProviderUserID: "123",
|
||||
Email: "octo@example.com",
|
||||
Name: "Octo User",
|
||||
Username: "octo",
|
||||
AvatarURL: "https://example.com/avatar.png",
|
||||
EmailVerified: true,
|
||||
}
|
||||
user, session, token, err := store.UpsertOAuthSession(context.Background(), profile)
|
||||
if err != nil {
|
||||
t.Fatalf("UpsertOAuthSession() error = %v", err)
|
||||
}
|
||||
if user.ID == "" || session.ID == "" || token == "" {
|
||||
t.Fatalf("missing oauth output user=%#v session=%#v token=%q", user, session, token)
|
||||
}
|
||||
teams, err := store.ListTeamsForUser(context.Background(), user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListTeamsForUser() error = %v", err)
|
||||
}
|
||||
if len(teams) != 1 {
|
||||
t.Fatalf("teams len = %d, want 1", len(teams))
|
||||
}
|
||||
|
||||
sameUser, secondSession, secondToken, err := store.UpsertOAuthSession(context.Background(), profile)
|
||||
if err != nil {
|
||||
t.Fatalf("second UpsertOAuthSession() error = %v", err)
|
||||
}
|
||||
if sameUser.ID != user.ID {
|
||||
t.Fatalf("second oauth user id = %s, want %s", sameUser.ID, user.ID)
|
||||
}
|
||||
if secondSession.ID == session.ID || secondToken == token {
|
||||
t.Fatal("oauth login should create a fresh session")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,370 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPermissionMatrix tests the full permission matrix for drawings.
|
||||
// It creates drawings with different visibilities and verifies access
|
||||
// from users with different roles and direct permission grants.
|
||||
func TestPermissionMatrix(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
visibility string
|
||||
grantPerm string // direct grant permission to bob
|
||||
bobRole string // bob's role in alice's team
|
||||
expectGet int // expected HTTP status for GET /drawings/:id
|
||||
expectUpdate int // expected HTTP status for PATCH /drawings/:id
|
||||
expectListTeam int // expected HTTP status for GET /drawings (team-scoped list)
|
||||
}{
|
||||
// public drawings - anyone in the team can view, edit requires role
|
||||
{
|
||||
name: "public_drawing_team_viewer_can_view_not_edit",
|
||||
visibility: "public", grantPerm: "", bobRole: "viewer",
|
||||
expectGet: http.StatusOK, expectUpdate: http.StatusForbidden,
|
||||
expectListTeam: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "public_drawing_team_editor_can_view_and_edit",
|
||||
visibility: "public", grantPerm: "", bobRole: "editor",
|
||||
expectGet: http.StatusOK, expectUpdate: http.StatusOK,
|
||||
expectListTeam: http.StatusOK,
|
||||
},
|
||||
// restricted drawings - no team access without explicit grant
|
||||
{
|
||||
name: "restricted_no_grant_team_member_cannot_access",
|
||||
visibility: "restricted", grantPerm: "", bobRole: "editor",
|
||||
expectGet: http.StatusForbidden, expectUpdate: http.StatusForbidden,
|
||||
expectListTeam: http.StatusOK, // team list still shows it if team-scoped
|
||||
},
|
||||
{
|
||||
name: "restricted_with_view_grant_team_viewer_can_view_not_edit",
|
||||
visibility: "restricted", grantPerm: "view", bobRole: "viewer",
|
||||
expectGet: http.StatusOK, expectUpdate: http.StatusForbidden,
|
||||
expectListTeam: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "restricted_with_edit_grant_team_viewer_can_view_and_edit",
|
||||
visibility: "restricted", grantPerm: "edit", bobRole: "viewer",
|
||||
expectGet: http.StatusOK, expectUpdate: http.StatusOK,
|
||||
expectListTeam: http.StatusOK,
|
||||
},
|
||||
// private drawings - only owner and explicit grant holders
|
||||
{
|
||||
name: "private_no_grant_team_member_cannot_access",
|
||||
visibility: "private", grantPerm: "", bobRole: "admin",
|
||||
expectGet: http.StatusForbidden, expectUpdate: http.StatusForbidden,
|
||||
expectListTeam: http.StatusOK, // team list may still include private owned by owner
|
||||
},
|
||||
{
|
||||
name: "private_with_view_grant_allows_view_not_edit",
|
||||
visibility: "private", grantPerm: "view", bobRole: "viewer",
|
||||
expectGet: http.StatusOK, expectUpdate: http.StatusForbidden,
|
||||
expectListTeam: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "private_with_edit_grant_allows_view_and_edit",
|
||||
visibility: "private", grantPerm: "edit", bobRole: "viewer",
|
||||
expectGet: http.StatusOK, expectUpdate: http.StatusOK,
|
||||
expectListTeam: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
// Invite bob to alice's team with specified role
|
||||
if tc.bobRole != "" {
|
||||
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
|
||||
"email": "bob@example.com",
|
||||
"role": tc.bobRole,
|
||||
}, aliceCookie)
|
||||
if inviteRR.Code != http.StatusCreated {
|
||||
t.Fatalf("invite status = %d body = %s", inviteRR.Code, inviteRR.Body.String())
|
||||
}
|
||||
// Accept the invite
|
||||
var invitePayload struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invitePayload); err != nil {
|
||||
t.Fatalf("invite decode error = %v", err)
|
||||
}
|
||||
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{
|
||||
"token": invitePayload.Token,
|
||||
}, bobCookie)
|
||||
if acceptRR.Code != http.StatusOK {
|
||||
t.Fatalf("accept status = %d body = %s", acceptRR.Code, acceptRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Alice creates a drawing with specified visibility
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"title": tc.name + " drawing",
|
||||
"visibility": tc.visibility,
|
||||
}, aliceCookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
// Optionally grant bob a direct permission
|
||||
if tc.grantPerm != "" {
|
||||
grantRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/permissions", map[string]any{
|
||||
"subject_type": "user",
|
||||
"email": "bob@example.com",
|
||||
"permission": tc.grantPerm,
|
||||
}, aliceCookie)
|
||||
if grantRR.Code != http.StatusCreated {
|
||||
t.Fatalf("grant status = %d body = %s", grantRR.Code, grantRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Bob attempts to GET the drawing
|
||||
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
|
||||
if getRR.Code != tc.expectGet {
|
||||
t.Errorf("GET status = %d, want %d; body = %s", getRR.Code, tc.expectGet, getRR.Body.String())
|
||||
}
|
||||
|
||||
// Bob attempts to PATCH (update) the drawing
|
||||
updateRR := doJSON(t, api, http.MethodPatch, "/drawings/"+drawing.ID, map[string]any{
|
||||
"title": tc.name + " updated",
|
||||
}, bobCookie)
|
||||
if updateRR.Code != tc.expectUpdate {
|
||||
t.Errorf("PATCH status = %d, want %d; body = %s", updateRR.Code, tc.expectUpdate, updateRR.Body.String())
|
||||
}
|
||||
|
||||
// Bob attempts to list team drawings
|
||||
listRR := doJSON(t, api, http.MethodGet, "/drawings?team_id="+aliceTeam.ID, nil, bobCookie)
|
||||
if listRR.Code != tc.expectListTeam {
|
||||
t.Errorf("LIST status = %d, want %d; body = %s", listRR.Code, tc.expectListTeam, listRR.Body.String())
|
||||
}
|
||||
|
||||
// Verify alice (owner) can still access everything
|
||||
ownerGet := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, aliceCookie)
|
||||
if ownerGet.Code != http.StatusOK {
|
||||
t.Errorf("owner GET status = %d, want %d", ownerGet.Code, http.StatusOK)
|
||||
}
|
||||
ownerUpdate := doJSON(t, api, http.MethodPatch, "/drawings/"+drawing.ID, map[string]any{
|
||||
"title": tc.name + " owner updated",
|
||||
}, aliceCookie)
|
||||
if ownerUpdate.Code != http.StatusOK {
|
||||
t.Errorf("owner PATCH status = %d, want %d", ownerUpdate.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminCanManageTeam verifies that team admins can manage team settings,
|
||||
// members, and resources while non-admins cannot.
|
||||
func TestAdminCanManageTeam(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
charlieCookie, _, _ := signup(t, api, "charlie@example.com")
|
||||
|
||||
// Invite bob as admin, charlie as viewer
|
||||
for _, tc := range []struct{ email, role string }{
|
||||
{"bob@example.com", "admin"},
|
||||
{"charlie@example.com", "viewer"},
|
||||
} {
|
||||
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
|
||||
"email": tc.email,
|
||||
"role": tc.role,
|
||||
}, aliceCookie)
|
||||
if inviteRR.Code != http.StatusCreated {
|
||||
t.Fatalf("invite %s status = %d body = %s", tc.email, inviteRR.Code, inviteRR.Body.String())
|
||||
}
|
||||
var invitePayload struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invitePayload); err != nil {
|
||||
t.Fatalf("invite decode error = %v", err)
|
||||
}
|
||||
var cookie *http.Cookie
|
||||
if tc.email == "bob@example.com" {
|
||||
cookie = bobCookie
|
||||
} else {
|
||||
cookie = charlieCookie
|
||||
}
|
||||
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{
|
||||
"token": invitePayload.Token,
|
||||
}, cookie)
|
||||
if acceptRR.Code != http.StatusOK {
|
||||
t.Fatalf("accept %s status = %d body = %s", tc.email, acceptRR.Code, acceptRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Bob (admin) can update team name
|
||||
updateTeam := doJSON(t, api, http.MethodPatch, "/teams/"+aliceTeam.ID, map[string]any{
|
||||
"name": "Updated by admin",
|
||||
}, bobCookie)
|
||||
if updateTeam.Code != http.StatusOK {
|
||||
t.Errorf("admin update team status = %d, want %d; body = %s", updateTeam.Code, http.StatusOK, updateTeam.Body.String())
|
||||
}
|
||||
|
||||
// Charlie (viewer) cannot update team name
|
||||
charlieUpdate := doJSON(t, api, http.MethodPatch, "/teams/"+aliceTeam.ID, map[string]any{
|
||||
"name": "Updated by viewer",
|
||||
}, charlieCookie)
|
||||
if charlieUpdate.Code != http.StatusForbidden {
|
||||
t.Errorf("viewer update team status = %d, want %d; body = %s", charlieUpdate.Code, http.StatusForbidden, charlieUpdate.Body.String())
|
||||
}
|
||||
|
||||
// Bob (admin) can manage members
|
||||
membersRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
|
||||
if membersRR.Code != http.StatusOK {
|
||||
t.Errorf("admin list members status = %d, want %d", membersRR.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Charlie (viewer) can view members
|
||||
charlieMembers := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, charlieCookie)
|
||||
if charlieMembers.Code != http.StatusOK {
|
||||
t.Errorf("viewer list members status = %d, want %d", charlieMembers.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonMemberCannotAccessPrivateTeam verifies that users not in a team
|
||||
// cannot access any team resources regardless of drawing visibility.
|
||||
func TestNonMemberCannotAccessPrivateTeam(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
// Alice creates a public drawing in her team
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"title": "Public team drawing",
|
||||
"visibility": "public",
|
||||
}, aliceCookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
// Bob (not in team) cannot access the drawing even though it's public
|
||||
// because "public" in this context means public within the team
|
||||
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
|
||||
if getRR.Code != http.StatusForbidden {
|
||||
t.Errorf("non-member GET status = %d, want %d; body = %s", getRR.Code, http.StatusForbidden, getRR.Body.String())
|
||||
}
|
||||
|
||||
// Bob cannot list team drawings
|
||||
listRR := doJSON(t, api, http.MethodGet, "/drawings?team_id="+aliceTeam.ID, nil, bobCookie)
|
||||
if listRR.Code != http.StatusForbidden && listRR.Code != http.StatusOK {
|
||||
// Depending on implementation, non-members might get empty list or forbidden
|
||||
t.Logf("non-member LIST status = %d", listRR.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPermissionInheritance verifies that permissions flow correctly
|
||||
// through the resource hierarchy (team -> project -> folder -> drawing).
|
||||
func TestPermissionInheritance(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
// Invite bob as editor
|
||||
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
|
||||
"email": "bob@example.com",
|
||||
"role": "editor",
|
||||
}, aliceCookie)
|
||||
if inviteRR.Code != http.StatusCreated {
|
||||
t.Fatalf("invite status = %d body = %s", inviteRR.Code, inviteRR.Body.String())
|
||||
}
|
||||
var invitePayload struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invitePayload); err != nil {
|
||||
t.Fatalf("invite decode error = %v", err)
|
||||
}
|
||||
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{
|
||||
"token": invitePayload.Token,
|
||||
}, bobCookie)
|
||||
if acceptRR.Code != http.StatusOK {
|
||||
t.Fatalf("accept status = %d body = %s", acceptRR.Code, acceptRR.Body.String())
|
||||
}
|
||||
|
||||
// Alice creates a project
|
||||
projectRR := doJSON(t, api, http.MethodPost, "/projects", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"name": "Test Project",
|
||||
"description": "A test project",
|
||||
}, aliceCookie)
|
||||
if projectRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create project status = %d body = %s", projectRR.Code, projectRR.Body.String())
|
||||
}
|
||||
var project struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.Unmarshal(projectRR.Body.Bytes(), &project); err != nil {
|
||||
t.Fatalf("project decode error = %v", err)
|
||||
}
|
||||
|
||||
// Alice creates a folder in the project
|
||||
folderRR := doJSON(t, api, http.MethodPost, "/folders", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"project_id": project.ID,
|
||||
"name": "Test Folder",
|
||||
}, aliceCookie)
|
||||
if folderRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create folder status = %d body = %s", folderRR.Code, folderRR.Body.String())
|
||||
}
|
||||
var folder struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.Unmarshal(folderRR.Body.Bytes(), &folder); err != nil {
|
||||
t.Fatalf("folder decode error = %v", err)
|
||||
}
|
||||
|
||||
// Alice creates a drawing in the folder
|
||||
drawingRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"project_id": project.ID,
|
||||
"folder_id": folder.ID,
|
||||
"title": "Nested Drawing",
|
||||
"visibility": "public",
|
||||
}, aliceCookie)
|
||||
if drawingRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", drawingRR.Code, drawingRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(drawingRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
// Bob (team editor) should be able to access the nested drawing through team membership
|
||||
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
|
||||
if getRR.Code != http.StatusOK {
|
||||
t.Errorf("team editor GET nested drawing status = %d, want %d; body = %s",
|
||||
getRR.Code, http.StatusOK, getRR.Body.String())
|
||||
}
|
||||
|
||||
// Bob should be able to update the drawing
|
||||
updateRR := doJSON(t, api, http.MethodPatch, "/drawings/"+drawing.ID, map[string]any{
|
||||
"title": "Updated by team editor",
|
||||
}, bobCookie)
|
||||
if updateRR.Code != http.StatusOK {
|
||||
t.Errorf("team editor PATCH nested drawing status = %d, want %d; body = %s",
|
||||
updateRR.Code, http.StatusOK, updateRR.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateLimiter struct {
|
||||
mu sync.Mutex
|
||||
limit int
|
||||
window time.Duration
|
||||
attempts map[string][]time.Time
|
||||
}
|
||||
|
||||
func newRateLimiter(limit int, window time.Duration) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
attempts: make(map[string][]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *rateLimiter) allow(key string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-l.window)
|
||||
values := l.attempts[key]
|
||||
kept := values[:0]
|
||||
for _, value := range values {
|
||||
if value.After(cutoff) {
|
||||
kept = append(kept, value)
|
||||
}
|
||||
}
|
||||
if len(kept) >= l.limit {
|
||||
l.attempts[key] = kept
|
||||
return false
|
||||
}
|
||||
kept = append(kept, now)
|
||||
l.attempts[key] = kept
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInviteAcceptAddsEditorMembership(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
|
||||
"email": "bob@example.com",
|
||||
"role": "editor",
|
||||
}, aliceCookie)
|
||||
if inviteRR.Code != http.StatusCreated {
|
||||
t.Fatalf("invite status = %d body = %s", inviteRR.Code, inviteRR.Body.String())
|
||||
}
|
||||
var invite struct {
|
||||
Invite TeamInvite `json:"invite"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invite); err != nil {
|
||||
t.Fatalf("invite decode error = %v", err)
|
||||
}
|
||||
if invite.Token == "" || invite.Invite.Email != "bob@example.com" {
|
||||
t.Fatalf("unexpected invite: %#v", invite)
|
||||
}
|
||||
|
||||
beforeRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
|
||||
if beforeRR.Code != http.StatusForbidden {
|
||||
t.Fatalf("bob members before accept status = %d body = %s", beforeRR.Code, beforeRR.Body.String())
|
||||
}
|
||||
|
||||
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{"token": invite.Token}, bobCookie)
|
||||
if acceptRR.Code != http.StatusOK {
|
||||
t.Fatalf("accept status = %d body = %s", acceptRR.Code, acceptRR.Body.String())
|
||||
}
|
||||
|
||||
afterRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
|
||||
if afterRR.Code != http.StatusOK {
|
||||
t.Fatalf("bob members after accept status = %d body = %s", afterRR.Code, afterRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestrictedDrawingGrantAllowsViewNotEdit(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
bobCookie, _, _ := signup(t, api, "bob@example.com")
|
||||
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"title": "Restricted map",
|
||||
"visibility": "restricted",
|
||||
}, aliceCookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
noAccessRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
|
||||
if noAccessRR.Code != http.StatusForbidden {
|
||||
t.Fatalf("bob get before grant status = %d body = %s", noAccessRR.Code, noAccessRR.Body.String())
|
||||
}
|
||||
|
||||
grantRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/permissions", map[string]any{
|
||||
"subject_type": "user",
|
||||
"email": "bob@example.com",
|
||||
"permission": "view",
|
||||
}, aliceCookie)
|
||||
if grantRR.Code != http.StatusCreated {
|
||||
t.Fatalf("grant status = %d body = %s", grantRR.Code, grantRR.Body.String())
|
||||
}
|
||||
|
||||
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
|
||||
if getRR.Code != http.StatusOK {
|
||||
t.Fatalf("bob get after grant status = %d body = %s", getRR.Code, getRR.Body.String())
|
||||
}
|
||||
|
||||
editRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/revisions", map[string]any{
|
||||
"snapshot": map[string]any{"type": "excalidraw", "elements": []any{}},
|
||||
}, bobCookie)
|
||||
if editRR.Code != http.StatusForbidden {
|
||||
t.Fatalf("bob edit with view grant status = %d body = %s", editRR.Code, editRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShareLinkAllowsUnauthenticatedDrawingRead(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": aliceTeam.ID,
|
||||
"title": "Shared map",
|
||||
}, aliceCookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
shareRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/share-links", map[string]any{
|
||||
"permission": "view",
|
||||
}, aliceCookie)
|
||||
if shareRR.Code != http.StatusCreated {
|
||||
t.Fatalf("share status = %d body = %s", shareRR.Code, shareRR.Body.String())
|
||||
}
|
||||
var share struct {
|
||||
ShareLink ShareLink `json:"share_link"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(shareRR.Body.Bytes(), &share); err != nil {
|
||||
t.Fatalf("share decode error = %v", err)
|
||||
}
|
||||
if share.Token == "" || share.ShareLink.TokenHash != "" {
|
||||
t.Fatalf("unexpected share response: %#v", share)
|
||||
}
|
||||
|
||||
publicRR := doJSON(t, api, http.MethodGet, "/shared/"+share.Token, nil)
|
||||
if publicRR.Code != http.StatusOK {
|
||||
t.Fatalf("public shared status = %d body = %s", publicRR.Code, publicRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbedsRejectUnsafeURLs(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
cookie, _, team := signup(t, api, "alice@example.com")
|
||||
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": team.ID,
|
||||
"title": "Embed map",
|
||||
}, cookie)
|
||||
if createRR.Code != http.StatusCreated {
|
||||
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
for _, unsafeURL := range []string{"javascript:alert(1)", "http://127.0.0.1/admin", "http://localhost:3002"} {
|
||||
rr := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/embeds", map[string]any{
|
||||
"source_url": unsafeURL,
|
||||
"embed_type": "link",
|
||||
}, cookie)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Fatalf("unsafe url %q status = %d body = %s", unsafeURL, rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
safeRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/embeds", map[string]any{
|
||||
"source_url": "https://example.com/roadmap",
|
||||
"embed_type": "link",
|
||||
}, cookie)
|
||||
if safeRR.Code != http.StatusCreated {
|
||||
t.Fatalf("safe embed status = %d body = %s", safeRR.Code, safeRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssetsAndLinkReferences(t *testing.T) {
|
||||
api, cleanup := newTestAPI(t)
|
||||
defer cleanup()
|
||||
|
||||
cookie, _, team := signup(t, api, "alice@example.com")
|
||||
projectRR := doJSON(t, api, http.MethodPost, "/projects", map[string]any{
|
||||
"team_id": team.ID,
|
||||
"name": "Roadmap",
|
||||
}, cookie)
|
||||
if projectRR.Code != http.StatusCreated {
|
||||
t.Fatalf("project status = %d body = %s", projectRR.Code, projectRR.Body.String())
|
||||
}
|
||||
var project Project
|
||||
if err := json.Unmarshal(projectRR.Body.Bytes(), &project); err != nil {
|
||||
t.Fatalf("project decode error = %v", err)
|
||||
}
|
||||
drawingRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
|
||||
"team_id": team.ID,
|
||||
"title": "Linked map",
|
||||
}, cookie)
|
||||
if drawingRR.Code != http.StatusCreated {
|
||||
t.Fatalf("drawing status = %d body = %s", drawingRR.Code, drawingRR.Body.String())
|
||||
}
|
||||
var drawing Drawing
|
||||
if err := json.Unmarshal(drawingRR.Body.Bytes(), &drawing); err != nil {
|
||||
t.Fatalf("drawing decode error = %v", err)
|
||||
}
|
||||
|
||||
assetRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/assets", map[string]any{
|
||||
"kind": "attachment",
|
||||
"mime_type": "image/png",
|
||||
"size": 2048,
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
}, cookie)
|
||||
if assetRR.Code != http.StatusCreated {
|
||||
t.Fatalf("asset status = %d body = %s", assetRR.Code, assetRR.Body.String())
|
||||
}
|
||||
|
||||
linkRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/links", map[string]any{
|
||||
"target_resource_type": "project",
|
||||
"target_resource_id": project.ID,
|
||||
"label": "Roadmap project",
|
||||
}, cookie)
|
||||
if linkRR.Code != http.StatusCreated {
|
||||
t.Fatalf("link status = %d body = %s", linkRR.Code, linkRR.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func (s *Store) WorkspaceStats(ctx context.Context, userID, teamID string) (*WorkspaceStats, error) {
|
||||
if teamID != "" {
|
||||
if ok, err := s.UserCanAccessTeam(ctx, userID, teamID); err != nil || !ok {
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
return s.workspaceStatsForWhere(ctx, `m.user_id = ? AND t.id = ?`, userID, teamID)
|
||||
}
|
||||
return s.workspaceStatsForWhere(ctx, `m.user_id = ?`, userID)
|
||||
}
|
||||
|
||||
func (s *Store) workspaceStatsForWhere(ctx context.Context, membershipWhere string, args ...any) (*WorkspaceStats, error) {
|
||||
stats := &WorkspaceStats{}
|
||||
teamWhere := `id IN (SELECT t.id FROM workspace_teams t JOIN workspace_team_memberships m ON m.team_id = t.id WHERE ` + membershipWhere + `)`
|
||||
if err := s.count(ctx, &stats.Teams, `SELECT COUNT(*) FROM workspace_teams WHERE `+teamWhere, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Members, `SELECT COUNT(*) FROM workspace_team_memberships WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Projects, `SELECT COUNT(*) FROM workspace_projects WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Folders, `SELECT COUNT(*) FROM workspace_folders WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Drawings, `SELECT COUNT(*) FROM workspace_drawings WHERE deleted_at IS NULL AND team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Templates, `SELECT COUNT(*) FROM workspace_templates WHERE scope = 'system' OR team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Revisions, `SELECT COUNT(*) FROM workspace_drawing_revisions WHERE drawing_id IN (SELECT id FROM workspace_drawings WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`))`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.count(ctx, &stats.Assets, `SELECT COUNT(*) FROM workspace_drawing_assets WHERE drawing_id IN (SELECT id FROM workspace_drawings WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`))`, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := s.db.QueryRowContext(ctx, `SELECT COALESCE(SUM(snapshot_size), 0) FROM workspace_drawing_revisions WHERE drawing_id IN (SELECT id FROM workspace_drawings WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`))`, args...).Scan(&stats.StorageBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *Store) count(ctx context.Context, dest *int, query string, args ...any) error {
|
||||
return s.db.QueryRowContext(ctx, query, args...).Scan(dest)
|
||||
}
|
||||
+1205
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,725 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CreateInviteRequest struct {
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type CreatePermissionGrantRequest struct {
|
||||
SubjectType string `json:"subject_type"`
|
||||
SubjectID string `json:"subject_id"`
|
||||
Email string `json:"email"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
type CreateShareLinkRequest struct {
|
||||
Permission string `json:"permission"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type CreateAssetRequest struct {
|
||||
Kind string `json:"kind"`
|
||||
MimeType string `json:"mime_type"`
|
||||
Size int64 `json:"size"`
|
||||
Width *int `json:"width"`
|
||||
Height *int `json:"height"`
|
||||
}
|
||||
|
||||
type CreateEmbedRequest struct {
|
||||
SourceURL string `json:"source_url"`
|
||||
EmbedType string `json:"embed_type"`
|
||||
Title *string `json:"title"`
|
||||
}
|
||||
|
||||
type CreateLinkRequest struct {
|
||||
TargetResourceType string `json:"target_resource_type"`
|
||||
TargetResourceID string `json:"target_resource_id"`
|
||||
Label *string `json:"label"`
|
||||
}
|
||||
|
||||
func (s *Store) ListTeamInvites(ctx context.Context, userID, teamID string) ([]TeamInvite, error) {
|
||||
if err := s.ensureTeamPermission(ctx, userID, teamID, "invite"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT id, team_id, email, role, invited_by, expires_at, created_at
|
||||
FROM workspace_team_invites
|
||||
WHERE team_id = ? AND accepted_at IS NULL AND revoked_at IS NULL AND expires_at > ?
|
||||
ORDER BY created_at DESC`, teamID, time.Now().UTC())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
invites := []TeamInvite{}
|
||||
for rows.Next() {
|
||||
var invite TeamInvite
|
||||
if err := rows.Scan(&invite.ID, &invite.TeamID, &invite.Email, &invite.Role, &invite.InvitedBy, &invite.ExpiresAt, &invite.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
invites = append(invites, invite)
|
||||
}
|
||||
return invites, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CreateTeamInvite(ctx context.Context, userID, teamID string, req CreateInviteRequest) (*TeamInvite, string, error) {
|
||||
if err := s.ensureTeamPermission(ctx, userID, teamID, "invite"); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
email, err := normalizeEmail(req.Email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
role := req.Role
|
||||
if role == "" {
|
||||
role = "viewer"
|
||||
}
|
||||
if !validTeamRole(role) || role == "owner" {
|
||||
return nil, "", fmt.Errorf("invalid invite role")
|
||||
}
|
||||
token, err := randomToken()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
invite := &TeamInvite{
|
||||
ID: newID(),
|
||||
TeamID: teamID,
|
||||
Email: email,
|
||||
Role: role,
|
||||
InvitedBy: userID,
|
||||
ExpiresAt: now.Add(14 * 24 * time.Hour),
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_team_invites
|
||||
(id, team_id, email, role, token_hash, invited_by, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
invite.ID, invite.TeamID, invite.Email, invite.Role, hashToken(token), invite.InvitedBy, invite.ExpiresAt, invite.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
_ = s.insertActivity(ctx, &userID, &teamID, "team", teamID, "member_invited", map[string]any{"email": email, "role": role})
|
||||
return invite, token, nil
|
||||
}
|
||||
|
||||
func (s *Store) AcceptInvite(ctx context.Context, userID, token string) (*TeamMembership, error) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return nil, fmt.Errorf("invite token is required")
|
||||
}
|
||||
user, err := s.userByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var invite TeamInvite
|
||||
err = s.db.QueryRowContext(ctx, `SELECT id, team_id, email, role, invited_by, expires_at, created_at
|
||||
FROM workspace_team_invites
|
||||
WHERE token_hash = ? AND accepted_at IS NULL AND revoked_at IS NULL AND expires_at > ?`,
|
||||
hashToken(token), time.Now().UTC(),
|
||||
).Scan(&invite.ID, &invite.TeamID, &invite.Email, &invite.Role, &invite.InvitedBy, &invite.ExpiresAt, &invite.CreatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.EqualFold(user.Email, invite.Email) {
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO workspace_team_memberships (id, team_id, user_id, role, joined_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(team_id, user_id) DO UPDATE SET role = excluded.role`, newID(), invite.TeamID, userID, invite.Role, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, `UPDATE workspace_team_invites SET accepted_at = ? WHERE id = ?`, now, invite.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := insertActivityTx(ctx, tx, &userID, &invite.TeamID, "team", invite.TeamID, "member_joined", map[string]any{"role": invite.Role}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.teamMembership(ctx, invite.TeamID, userID)
|
||||
}
|
||||
|
||||
func (s *Store) ListPermissionGrants(ctx context.Context, userID, resourceType, resourceID string) ([]PermissionGrant, error) {
|
||||
if resourceType == "drawing" {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, resourceID, "manage"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT id, resource_type, resource_id, subject_type, subject_id, permission, inherited_from, created_at
|
||||
FROM workspace_permission_grants
|
||||
WHERE resource_type = ? AND resource_id = ?
|
||||
ORDER BY created_at DESC`, resourceType, resourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := []PermissionGrant{}
|
||||
for rows.Next() {
|
||||
var grant PermissionGrant
|
||||
if err := rows.Scan(&grant.ID, &grant.ResourceType, &grant.ResourceID, &grant.SubjectType, &grant.SubjectID, &grant.Permission, &grant.InheritedFrom, &grant.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grants = append(grants, grant)
|
||||
}
|
||||
return grants, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CreateDrawingPermissionGrant(ctx context.Context, userID, drawingID string, req CreatePermissionGrantRequest) (*PermissionGrant, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "share"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !validPermission(req.Permission) {
|
||||
return nil, fmt.Errorf("invalid permission")
|
||||
}
|
||||
subjectType := req.SubjectType
|
||||
if subjectType == "" {
|
||||
subjectType = "user"
|
||||
}
|
||||
subjectID := strings.TrimSpace(req.SubjectID)
|
||||
if subjectType == "user" && subjectID == "" {
|
||||
user, err := s.userByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subjectID = user.ID
|
||||
}
|
||||
if subjectType != "user" && subjectType != "team" {
|
||||
return nil, fmt.Errorf("invalid subject type")
|
||||
}
|
||||
if subjectID == "" {
|
||||
return nil, fmt.Errorf("subject id is required")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
grant := &PermissionGrant{
|
||||
ID: newID(),
|
||||
ResourceType: "drawing",
|
||||
ResourceID: drawingID,
|
||||
SubjectType: subjectType,
|
||||
SubjectID: subjectID,
|
||||
Permission: req.Permission,
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, `INSERT INTO workspace_permission_grants
|
||||
(id, resource_type, resource_id, subject_type, subject_id, permission, inherited_from, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(resource_type, resource_id, subject_type, subject_id, permission) DO UPDATE SET created_at = excluded.created_at`,
|
||||
grant.ID, grant.ResourceType, grant.ResourceID, grant.SubjectType, grant.SubjectID, grant.Permission, grant.InheritedFrom, grant.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = s.insertActivity(ctx, &userID, nil, "drawing", drawingID, "permission_changed", map[string]any{"permission": grant.Permission, "subject_type": grant.SubjectType})
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListShareLinks(ctx context.Context, userID, resourceType, resourceID string) ([]ShareLink, error) {
|
||||
if resourceType == "drawing" {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, resourceID, "share"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT id, resource_type, resource_id, token_hash, permission, expires_at, password_hash, created_by, revoked_at, created_at
|
||||
FROM workspace_share_links
|
||||
WHERE resource_type = ? AND resource_id = ? AND revoked_at IS NULL
|
||||
ORDER BY created_at DESC`, resourceType, resourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
links := []ShareLink{}
|
||||
for rows.Next() {
|
||||
link, err := scanShareLink(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
link.TokenHash = ""
|
||||
links = append(links, *link)
|
||||
}
|
||||
return links, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CreateDrawingShareLink(ctx context.Context, userID, drawingID string, req CreateShareLinkRequest) (*ShareLink, string, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "share"); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
permission := req.Permission
|
||||
if permission == "" {
|
||||
permission = "view"
|
||||
}
|
||||
if permission != "view" && permission != "comment" && permission != "edit" {
|
||||
return nil, "", fmt.Errorf("invalid share permission")
|
||||
}
|
||||
token, err := randomToken()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
link := &ShareLink{
|
||||
ID: newID(),
|
||||
ResourceType: "drawing",
|
||||
ResourceID: drawingID,
|
||||
TokenHash: hashToken(token),
|
||||
Permission: permission,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
CreatedBy: userID,
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_share_links
|
||||
(id, resource_type, resource_id, token_hash, permission, expires_at, password_hash, created_by, revoked_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
link.ID, link.ResourceType, link.ResourceID, link.TokenHash, link.Permission, link.ExpiresAt, link.PasswordHash, link.CreatedBy, link.RevokedAt, link.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
link.TokenHash = ""
|
||||
_ = s.insertActivity(ctx, &userID, nil, "drawing", drawingID, "drawing_shared", map[string]any{"permission": permission})
|
||||
return link, token, nil
|
||||
}
|
||||
|
||||
func (s *Store) SharedResourceByToken(ctx context.Context, token string) (map[string]any, error) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
row := s.db.QueryRowContext(ctx, `SELECT id, resource_type, resource_id, token_hash, permission, expires_at, password_hash, created_by, revoked_at, created_at
|
||||
FROM workspace_share_links
|
||||
WHERE token_hash = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > ?)`,
|
||||
hashToken(token), time.Now().UTC(),
|
||||
)
|
||||
link, err := scanShareLink(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
link.TokenHash = ""
|
||||
payload := map[string]any{"share_link": link}
|
||||
if link.ResourceType == "drawing" {
|
||||
drawing, err := s.drawingByIDNoAuth(ctx, link.ResourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload["drawing"] = drawing
|
||||
return payload, nil
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListDrawingAssets(ctx context.Context, userID, drawingID string) ([]DrawingAsset, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "view"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT id, drawing_id, kind, path, mime_type, size, width, height, uploaded_by, created_at
|
||||
FROM workspace_drawing_assets WHERE drawing_id = ? ORDER BY created_at DESC`, drawingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanAssets(rows)
|
||||
}
|
||||
|
||||
func (s *Store) CreateDrawingAsset(ctx context.Context, userID, drawingID string, req CreateAssetRequest) (*DrawingAsset, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "edit"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !validAssetKind(req.Kind) {
|
||||
return nil, fmt.Errorf("invalid asset kind")
|
||||
}
|
||||
if !validAssetMIME(req.MimeType) {
|
||||
return nil, fmt.Errorf("invalid asset mime type")
|
||||
}
|
||||
if req.Size <= 0 || req.Size > 25<<20 {
|
||||
return nil, fmt.Errorf("asset size must be between 1 byte and 25 MiB")
|
||||
}
|
||||
drawing, err := s.drawingByIDNoAuth(ctx, drawingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
asset := &DrawingAsset{
|
||||
ID: newID(),
|
||||
DrawingID: drawingID,
|
||||
Kind: req.Kind,
|
||||
MimeType: req.MimeType,
|
||||
Size: req.Size,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
UploadedBy: userID,
|
||||
CreatedAt: now,
|
||||
}
|
||||
asset.Path = fmt.Sprintf("/data/teams/%s/drawings/%s/assets/%s", drawing.TeamID, drawingID, asset.ID)
|
||||
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_drawing_assets
|
||||
(id, drawing_id, kind, path, mime_type, size, width, height, uploaded_by, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
asset.ID, asset.DrawingID, asset.Kind, asset.Path, asset.MimeType, asset.Size, asset.Width, asset.Height, asset.UploadedBy, asset.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return asset, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListEmbeds(ctx context.Context, userID, drawingID string) ([]Embed, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "view"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT id, drawing_id, source_url, canonical_url, provider, embed_type, title, preview_asset_id, safe_embed_html, created_by, created_at
|
||||
FROM workspace_embeds WHERE drawing_id = ? ORDER BY created_at DESC`, drawingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
embeds := []Embed{}
|
||||
for rows.Next() {
|
||||
var embed Embed
|
||||
if err := rows.Scan(&embed.ID, &embed.DrawingID, &embed.SourceURL, &embed.CanonicalURL, &embed.Provider, &embed.EmbedType, &embed.Title, &embed.PreviewAssetID, &embed.SafeEmbedHTML, &embed.CreatedBy, &embed.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
embeds = append(embeds, embed)
|
||||
}
|
||||
return embeds, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CreateEmbed(ctx context.Context, userID, drawingID string, req CreateEmbedRequest) (*Embed, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "edit"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
canonical, provider, err := validateEmbedURL(req.SourceURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
embedType := req.EmbedType
|
||||
if embedType == "" {
|
||||
embedType = "link"
|
||||
}
|
||||
if embedType != "link" && embedType != "iframe" && embedType != "provider" {
|
||||
return nil, fmt.Errorf("invalid embed type")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
embed := &Embed{
|
||||
ID: newID(),
|
||||
DrawingID: drawingID,
|
||||
SourceURL: canonical,
|
||||
CanonicalURL: canonical,
|
||||
Provider: provider,
|
||||
EmbedType: embedType,
|
||||
Title: req.Title,
|
||||
CreatedBy: userID,
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_embeds
|
||||
(id, drawing_id, source_url, canonical_url, provider, embed_type, title, preview_asset_id, safe_embed_html, created_by, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
embed.ID, embed.DrawingID, embed.SourceURL, embed.CanonicalURL, embed.Provider, embed.EmbedType, embed.Title, embed.PreviewAssetID, embed.SafeEmbedHTML, embed.CreatedBy, embed.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = s.insertActivity(ctx, &userID, nil, "drawing", drawingID, "embed_created", map[string]any{"provider": provider, "embed_type": embedType})
|
||||
return embed, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListLinkReferences(ctx context.Context, userID, resourceType, resourceID string) ([]LinkReference, error) {
|
||||
if resourceType == "drawing" {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, resourceID, "view"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT id, source_resource_type, source_resource_id, target_resource_type, target_resource_id, label, created_by, created_at
|
||||
FROM workspace_link_references
|
||||
WHERE source_resource_type = ? AND source_resource_id = ?
|
||||
ORDER BY created_at DESC`, resourceType, resourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
links := []LinkReference{}
|
||||
for rows.Next() {
|
||||
var link LinkReference
|
||||
if err := rows.Scan(&link.ID, &link.SourceResourceType, &link.SourceResourceID, &link.TargetResourceType, &link.TargetResourceID, &link.Label, &link.CreatedBy, &link.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
links = append(links, link)
|
||||
}
|
||||
return links, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CreateDrawingLinkReference(ctx context.Context, userID, drawingID string, req CreateLinkRequest) (*LinkReference, error) {
|
||||
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "edit"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
drawing, err := s.drawingByIDNoAuth(ctx, drawingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.ensureTargetInTeam(ctx, drawing.TeamID, req.TargetResourceType, req.TargetResourceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
link := &LinkReference{
|
||||
ID: newID(),
|
||||
SourceResourceType: "drawing",
|
||||
SourceResourceID: drawingID,
|
||||
TargetResourceType: req.TargetResourceType,
|
||||
TargetResourceID: req.TargetResourceID,
|
||||
Label: req.Label,
|
||||
CreatedBy: userID,
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_link_references
|
||||
(id, source_resource_type, source_resource_id, target_resource_type, target_resource_id, label, created_by, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
link.ID, link.SourceResourceType, link.SourceResourceID, link.TargetResourceType, link.TargetResourceID, link.Label, link.CreatedBy, link.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (s *Store) ensureTeamPermission(ctx context.Context, userID, teamID, permission string) error {
|
||||
role, err := s.teamRole(ctx, userID, teamID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrForbidden
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if roleAllows(role, permission) {
|
||||
return nil
|
||||
}
|
||||
return ErrForbidden
|
||||
}
|
||||
|
||||
func (s *Store) teamRole(ctx context.Context, userID, teamID string) (string, error) {
|
||||
var role string
|
||||
err := s.db.QueryRowContext(ctx, `SELECT role FROM workspace_team_memberships WHERE user_id = ? AND team_id = ?`, userID, teamID).Scan(&role)
|
||||
return role, err
|
||||
}
|
||||
|
||||
func (s *Store) userByID(ctx context.Context, userID string) (*User, error) {
|
||||
var user User
|
||||
err := s.db.QueryRowContext(ctx, `SELECT id, name, username, email, avatar_url, locale, timezone, created_at, updated_at FROM workspace_users WHERE id = ?`, userID).
|
||||
Scan(&user.ID, &user.Name, &user.Username, &user.Email, &user.AvatarURL, &user.Locale, &user.Timezone, &user.CreatedAt, &user.UpdatedAt)
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func (s *Store) userByEmail(ctx context.Context, email string) (*User, error) {
|
||||
email, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var user User
|
||||
err = s.db.QueryRowContext(ctx, `SELECT id, name, username, email, avatar_url, locale, timezone, created_at, updated_at FROM workspace_users WHERE email = ?`, email).
|
||||
Scan(&user.ID, &user.Name, &user.Username, &user.Email, &user.AvatarURL, &user.Locale, &user.Timezone, &user.CreatedAt, &user.UpdatedAt)
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func (s *Store) teamMembership(ctx context.Context, teamID, userID string) (*TeamMembership, error) {
|
||||
var member TeamMembership
|
||||
err := s.db.QueryRowContext(ctx, `SELECT id, team_id, user_id, role, joined_at FROM workspace_team_memberships WHERE team_id = ? AND user_id = ?`, teamID, userID).
|
||||
Scan(&member.ID, &member.TeamID, &member.UserID, &member.Role, &member.JoinedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := s.userByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
member.User = user
|
||||
return &member, nil
|
||||
}
|
||||
|
||||
func (s *Store) drawingByIDNoAuth(ctx context.Context, drawingID string) (*Drawing, error) {
|
||||
row := s.db.QueryRowContext(ctx, `SELECT d.id, d.team_id, d.folder_id, d.project_id, d.slug, d.title, d.description,
|
||||
d.owner_user_id, d.latest_revision_id, d.visibility, d.is_archived, d.thumbnail_asset_id, d.created_at, d.updated_at, d.deleted_at,
|
||||
u.id, u.name, u.username, u.email, u.avatar_url, u.locale, u.timezone, u.created_at, u.updated_at
|
||||
FROM workspace_drawings d
|
||||
JOIN workspace_users u ON u.id = d.owner_user_id
|
||||
WHERE d.id = ? AND d.deleted_at IS NULL`, drawingID)
|
||||
return scanDrawing(row)
|
||||
}
|
||||
|
||||
func (s *Store) ensureTargetInTeam(ctx context.Context, teamID, resourceType, resourceID string) error {
|
||||
var found int
|
||||
var query string
|
||||
switch resourceType {
|
||||
case "drawing":
|
||||
query = `SELECT 1 FROM workspace_drawings WHERE id = ? AND team_id = ? AND deleted_at IS NULL`
|
||||
case "folder":
|
||||
query = `SELECT 1 FROM workspace_folders WHERE id = ? AND team_id = ?`
|
||||
case "project":
|
||||
query = `SELECT 1 FROM workspace_projects WHERE id = ? AND team_id = ?`
|
||||
case "embed":
|
||||
query = `SELECT 1 FROM workspace_embeds e JOIN workspace_drawings d ON d.id = e.drawing_id WHERE e.id = ? AND d.team_id = ?`
|
||||
default:
|
||||
return fmt.Errorf("invalid target resource type")
|
||||
}
|
||||
err := s.db.QueryRowContext(ctx, query, resourceID, teamID).Scan(&found)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) grantAllows(ctx context.Context, resourceType, resourceID, userID, teamID, required string) (bool, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT permission FROM workspace_permission_grants
|
||||
WHERE resource_type = ? AND resource_id = ? AND (
|
||||
(subject_type = 'user' AND subject_id = ?) OR
|
||||
(subject_type = 'team' AND subject_id = ?)
|
||||
)`, resourceType, resourceID, userID, teamID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var permission string
|
||||
if err := rows.Scan(&permission); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if permissionAllows(permission, required) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, rows.Err()
|
||||
}
|
||||
|
||||
func scanShareLink(scanner interface{ Scan(dest ...any) error }) (*ShareLink, error) {
|
||||
var link ShareLink
|
||||
err := scanner.Scan(&link.ID, &link.ResourceType, &link.ResourceID, &link.TokenHash, &link.Permission, &link.ExpiresAt, &link.PasswordHash, &link.CreatedBy, &link.RevokedAt, &link.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &link, nil
|
||||
}
|
||||
|
||||
func scanAssets(rows *sql.Rows) ([]DrawingAsset, error) {
|
||||
assets := []DrawingAsset{}
|
||||
for rows.Next() {
|
||||
var asset DrawingAsset
|
||||
if err := rows.Scan(&asset.ID, &asset.DrawingID, &asset.Kind, &asset.Path, &asset.MimeType, &asset.Size, &asset.Width, &asset.Height, &asset.UploadedBy, &asset.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assets = append(assets, asset)
|
||||
}
|
||||
return assets, rows.Err()
|
||||
}
|
||||
|
||||
func roleAllows(role, permission string) bool {
|
||||
switch role {
|
||||
case "owner", "admin":
|
||||
return true
|
||||
case "editor":
|
||||
return permission == "view" || permission == "comment" || permission == "edit"
|
||||
case "viewer":
|
||||
return permission == "view"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func permissionAllows(grant, required string) bool {
|
||||
if grant == required {
|
||||
return true
|
||||
}
|
||||
switch grant {
|
||||
case "manage":
|
||||
return true
|
||||
case "edit":
|
||||
return required == "view" || required == "comment" || required == "edit"
|
||||
case "comment":
|
||||
return required == "view" || required == "comment"
|
||||
case "share":
|
||||
return required == "view" || required == "share"
|
||||
case "invite":
|
||||
return required == "view" || required == "invite"
|
||||
case "view":
|
||||
return required == "view"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validPermission(permission string) bool {
|
||||
switch permission {
|
||||
case "view", "comment", "edit", "manage", "share", "invite":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validTeamRole(role string) bool {
|
||||
switch role {
|
||||
case "owner", "admin", "editor", "viewer":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validAssetKind(kind string) bool {
|
||||
switch kind {
|
||||
case "image", "export", "attachment", "thumbnail":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validAssetMIME(mimeType string) bool {
|
||||
switch mimeType {
|
||||
case "image/png", "image/jpeg", "image/webp", "image/gif", "application/pdf", "application/json", "text/plain":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validateEmbedURL(raw string) (string, string, error) {
|
||||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("invalid URL")
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", "", fmt.Errorf("embed URL must use http or https")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return "", "", fmt.Errorf("embed URL must not include credentials")
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host == "" || host == "localhost" || strings.HasSuffix(host, ".localhost") {
|
||||
return "", "", fmt.Errorf("embed URL host is not allowed")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok || !addr.IsGlobalUnicast() || addr.IsPrivate() || addr.IsLoopback() || addr.IsLinkLocalUnicast() {
|
||||
return "", "", fmt.Errorf("embed URL host is not allowed")
|
||||
}
|
||||
}
|
||||
parsed.Fragment = ""
|
||||
return parsed.String(), host, nil
|
||||
}
|
||||
Reference in New Issue
Block a user