feat: full project sync - CI fixes, frontend, workspace API, and all changes

This commit is contained in:
Tomas Dvorak
2026-04-27 09:08:07 +02:00
parent a07fca997e
commit 89b9390c14
109 changed files with 21120 additions and 545 deletions
+17
View File
@@ -0,0 +1,17 @@
package workspace
import "context"
type currentSession struct {
user *User
session *Session
}
func withUser(ctx context.Context, user *User, session *Session) context.Context {
return context.WithValue(ctx, currentUserKey, currentSession{user: user, session: session})
}
func currentUser(r interface{ Context() context.Context }) (*User, *Session) {
current, _ := r.Context().Value(currentUserKey).(currentSession)
return current.user, current.session
}
+660
View File
@@ -0,0 +1,660 @@
package workspace
import (
"database/sql"
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/sirupsen/logrus"
)
const sessionCookieName = "excalidraw_session"
type API struct {
store *Store
limiter *rateLimiter
testMode bool
}
func NewAPI(store *Store) *API {
return &API{
store: store,
limiter: newRateLimiter(10, 15*time.Minute),
}
}
func (a *API) Routes() chi.Router {
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Get("/health", a.handleHealth)
r.Get("/auth/setup-status", a.handleSetupStatus)
r.Post("/auth/signup", a.handleSignup)
r.Post("/auth/login", a.handleLogin)
r.Post("/auth/logout", a.handleLogout)
r.Get("/shared/{token}", a.handleSharedResource)
})
r.Group(func(r chi.Router) {
r.Use(a.requireSession)
r.Use(requireSameOriginMutation)
r.Get("/auth/me", a.handleMe)
r.Get("/teams", a.handleListTeams)
r.Post("/teams", a.handleCreateTeam)
r.Patch("/teams/{teamID}", a.handleUpdateTeam)
r.Get("/teams/{teamID}/members", a.handleListTeamMembers)
r.Get("/teams/{teamID}/invites", a.handleListTeamInvites)
r.Post("/teams/{teamID}/invites", a.handleCreateTeamInvite)
r.Post("/teams/{teamID}/users", a.handleCreateTeamUser)
r.Post("/invites/accept", a.handleAcceptInvite)
r.Get("/drawings", a.handleListDrawings)
r.Post("/drawings", a.handleCreateDrawing)
r.Get("/drawings/{drawingID}", a.handleGetDrawing)
r.Patch("/drawings/{drawingID}", a.handleUpdateDrawing)
r.Delete("/drawings/{drawingID}", a.handleArchiveDrawing)
r.Get("/drawings/{drawingID}/revisions", a.handleListRevisions)
r.Post("/drawings/{drawingID}/revisions", a.handleCreateRevision)
r.Get("/search", a.handleSearch)
r.Get("/drawings/{drawingID}/permissions", a.handleListPermissions)
r.Post("/drawings/{drawingID}/permissions", a.handleCreatePermission)
r.Get("/drawings/{drawingID}/share-links", a.handleListShareLinks)
r.Post("/drawings/{drawingID}/share-links", a.handleCreateShareLink)
r.Get("/drawings/{drawingID}/assets", a.handleListAssets)
r.Post("/drawings/{drawingID}/assets", a.handleCreateAsset)
r.Get("/drawings/{drawingID}/embeds", a.handleListEmbeds)
r.Post("/drawings/{drawingID}/embeds", a.handleCreateEmbed)
r.Get("/drawings/{drawingID}/links", a.handleListLinks)
r.Post("/drawings/{drawingID}/links", a.handleCreateLink)
r.Get("/drawings/{drawingID}/thumbnail", a.handleThumbnail)
r.Get("/templates", a.handleListTemplates)
r.Get("/activity", a.handleListActivity)
r.Get("/stats", a.handleStats)
r.Get("/folders", a.handleListFolders)
r.Post("/folders", a.handleCreateFolder)
r.Get("/projects", a.handleListProjects)
r.Post("/projects", a.handleCreateProject)
})
return r
}
func (a *API) handleHealth(w http.ResponseWriter, r *http.Request) {
if err := a.store.Ping(r.Context()); err != nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]any{"status": "unhealthy"})
return
}
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
}
func (a *API) handleSetupStatus(w http.ResponseWriter, r *http.Request) {
hasUsers, err := a.store.UserExists(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to check setup status")
return
}
writeJSON(w, http.StatusOK, map[string]any{"has_users": hasUsers})
}
type contextKey string
const currentUserKey = contextKey("workspace_user")
func (a *API) requireSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil || cookie.Value == "" {
writeError(w, http.StatusUnauthorized, "Authentication required")
return
}
user, session, err := a.store.UserBySessionToken(r.Context(), cookie.Value)
if err != nil {
clearSessionCookie(w, r)
writeError(w, http.StatusUnauthorized, "Authentication required")
return
}
ctx := withUser(r.Context(), user, session)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func requireSameOriginMutation(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, r)
return
}
origin := r.Header.Get("Origin")
if origin == "" {
next.ServeHTTP(w, r)
return
}
expectedHTTP := "http://" + r.Host
expectedHTTPS := "https://" + r.Host
if origin != expectedHTTP && origin != expectedHTTPS {
writeError(w, http.StatusForbidden, "Cross-origin mutation denied")
return
}
next.ServeHTTP(w, r)
})
}
func (a *API) handleSignup(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
}
if !decodeJSON(w, r, &req, 64<<10) {
return
}
// First-run: only allow signup if no users exist yet
if !a.testMode {
hasUsers, err := a.store.UserExists(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to check setup status")
return
}
if hasUsers {
writeError(w, http.StatusForbidden, "Registration is closed. Contact an administrator.")
return
}
}
ipKey := "signup:" + clientIP(r)
if !a.limiter.allow(ipKey) {
writeError(w, http.StatusTooManyRequests, "Too many signup attempts")
return
}
user, session, token, err := a.store.CreateUserWithPassword(r.Context(), req.Name, req.Email, req.Password)
if err != nil {
status := http.StatusBadRequest
if errors.Is(err, ErrConflict) {
status = http.StatusConflict
}
writeError(w, status, err.Error())
return
}
setSessionCookie(w, r, token, session.ExpiresAt)
writeJSON(w, http.StatusCreated, map[string]any{"user": user, "session": session})
}
func (a *API) handleLogin(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Password string `json:"password"`
}
if !decodeJSON(w, r, &req, 32<<10) {
return
}
key := "login:" + clientIP(r) + ":" + strings.ToLower(strings.TrimSpace(req.Email))
if !a.limiter.allow(key) {
writeError(w, http.StatusTooManyRequests, "Too many login attempts")
return
}
user, session, token, err := a.store.AuthenticatePassword(r.Context(), req.Email, req.Password)
if err != nil {
writeError(w, http.StatusUnauthorized, "Invalid email or password")
return
}
setSessionCookie(w, r, token, session.ExpiresAt)
writeJSON(w, http.StatusOK, map[string]any{"user": user, "session": session})
}
func (a *API) handleLogout(w http.ResponseWriter, r *http.Request) {
if cookie, err := r.Cookie(sessionCookieName); err == nil && cookie.Value != "" {
if err := a.store.DeleteSession(r.Context(), cookie.Value); err != nil {
logrus.WithError(err).Warn("failed to delete session")
}
}
clearSessionCookie(w, r)
writeJSON(w, http.StatusOK, map[string]any{})
}
func (a *API) handleMe(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
writeJSON(w, http.StatusOK, user)
}
func (a *API) handleListTeams(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teams, err := a.store.ListTeamsForUser(r.Context(), user.ID)
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to list teams")
return
}
writeJSON(w, http.StatusOK, teams)
}
func (a *API) handleCreateTeam(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req struct {
Name string `json:"name"`
Slug string `json:"slug"`
}
if !decodeJSON(w, r, &req, 64<<10) {
return
}
team, err := a.store.CreateTeam(r.Context(), user.ID, req.Name, req.Slug)
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusCreated, team)
}
func (a *API) handleUpdateTeam(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req struct {
Name *string `json:"name"`
Slug *string `json:"slug"`
}
if !decodeJSON(w, r, &req, 64<<10) {
return
}
team, err := a.store.UpdateTeam(r.Context(), user.ID, chi.URLParam(r, "teamID"), req.Name, req.Slug)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, team)
}
func (a *API) handleListTeamMembers(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := chi.URLParam(r, "teamID")
if ok, err := a.store.UserCanAccessTeam(r.Context(), user.ID, teamID); err != nil || !ok {
writeError(w, http.StatusForbidden, "Team access denied")
return
}
members, err := a.store.ListTeamMembers(r.Context(), teamID)
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to list team members")
return
}
writeJSON(w, http.StatusOK, members)
}
func (a *API) handleListDrawings(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
drawings, err := a.store.ListDrawings(r.Context(), user.ID, teamID)
if err != nil {
if errors.Is(err, ErrForbidden) {
writeError(w, http.StatusForbidden, "Team access denied")
return
}
writeError(w, http.StatusInternalServerError, "Failed to list drawings")
return
}
writeJSON(w, http.StatusOK, drawings)
}
func (a *API) handleCreateDrawing(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateDrawingRequest
if !decodeJSON(w, r, &req, 256<<10) {
return
}
drawing, err := a.store.CreateDrawing(r.Context(), user.ID, req)
if err != nil {
status := http.StatusBadRequest
if errors.Is(err, ErrForbidden) {
status = http.StatusForbidden
}
writeError(w, status, err.Error())
return
}
writeJSON(w, http.StatusCreated, drawing)
}
func (a *API) handleGetDrawing(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
drawing, err := a.store.GetDrawing(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, drawing)
}
func (a *API) handleUpdateDrawing(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req UpdateDrawingRequest
if !decodeJSON(w, r, &req, 256<<10) {
return
}
drawing, err := a.store.UpdateDrawing(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, drawing)
}
func (a *API) handleArchiveDrawing(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
if err := a.store.ArchiveDrawing(r.Context(), user.ID, chi.URLParam(r, "drawingID")); err != nil {
writeLookupError(w, err)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (a *API) handleListRevisions(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
revisions, err := a.store.ListRevisions(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, revisions)
}
func (a *API) handleCreateRevision(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateRevisionRequest
if !decodeJSON(w, r, &req, 10<<20) {
return
}
revision, err := a.store.CreateRevision(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, revision)
}
func (a *API) handleThumbnail(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
drawingID := chi.URLParam(r, "drawingID")
revisions, err := a.store.ListRevisions(r.Context(), user.ID, drawingID)
if err != nil {
writeLookupError(w, err)
return
}
if len(revisions) == 0 || revisions[0].Snapshot == nil {
w.Header().Set("Content-Type", "image/svg+xml")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<svg xmlns="http://www.w3.org/2000/svg" width="320" height="240" viewBox="0 0 320 240"><rect width="320" height="240" fill="#f8f9fa"/><text x="160" y="120" text-anchor="middle" font-family="sans-serif" font-size="14" fill="#999">No preview</text></svg>`))
return
}
var snapshot struct {
Elements []struct {
Type string `json:"type"`
X float64 `json:"x"`
Y float64 `json:"y"`
Width float64 `json:"width"`
Height float64 `json:"height"`
Stroke string `json:"strokeColor"`
Bg string `json:"backgroundColor"`
Text string `json:"text"`
} `json:"elements"`
}
if err := json.Unmarshal(revisions[0].Snapshot, &snapshot); err != nil {
w.Header().Set("Content-Type", "image/svg+xml")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<svg xmlns="http://www.w3.org/2000/svg" width="320" height="240" viewBox="0 0 320 240"><rect width="320" height="240" fill="#f8f9fa"/><text x="160" y="120" text-anchor="middle" font-family="sans-serif" font-size="14" fill="#999">Preview unavailable</text></svg>`))
return
}
// Generate a simple SVG thumbnail from element bounding boxes
const vw, vh = 320, 240
var b strings.Builder
b.WriteString(`<svg xmlns="http://www.w3.org/2000/svg" width="` + itoa(vw) + `" height="` + itoa(vh) + `" viewBox="0 0 ` + itoa(vw) + ` ` + itoa(vh) + `">`)
b.WriteString(`<rect width="` + itoa(vw) + `" height="` + itoa(vh) + `" fill="#ffffff"/>`)
// Compute bounding box to fit elements into view
minX, minY, maxX, maxY := 1e9, 1e9, -1e9, -1e9
for _, el := range snapshot.Elements {
if el.X < minX {
minX = el.X
}
if el.Y < minY {
minY = el.Y
}
if el.X+el.Width > maxX {
maxX = el.X + el.Width
}
if el.Y+el.Height > maxY {
maxY = el.Y + el.Height
}
}
if maxX <= minX || maxY <= minY {
minX, minY, maxX, maxY = 0, 0, 320, 240
}
pad := 20.0
scaleX := float64(vw-40) / (maxX - minX + 1e-6)
scaleY := float64(vh-40) / (maxY - minY + 1e-6)
scale := scaleX
if scaleY < scaleX {
scale = scaleY
}
offX := pad - minX*scale
offY := pad - minY*scale
for _, el := range snapshot.Elements {
x := el.X*scale + offX
y := el.Y*scale + offY
w := el.Width * scale
h := el.Height * scale
stroke := el.Stroke
if stroke == "" {
stroke = "#1e1e1e"
}
bg := el.Bg
if bg == "" || bg == "transparent" {
bg = "none"
}
switch el.Type {
case "rectangle", "diamond":
b.WriteString(`<rect x="` + ftoa(x) + `" y="` + ftoa(y) + `" width="` + ftoa(w) + `" height="` + ftoa(h) + `" fill="` + bg + `" stroke="` + stroke + `" stroke-width="1"/>`)
case "ellipse":
b.WriteString(`<ellipse cx="` + ftoa(x+w/2) + `" cy="` + ftoa(y+h/2) + `" rx="` + ftoa(w/2) + `" ry="` + ftoa(h/2) + `" fill="` + bg + `" stroke="` + stroke + `" stroke-width="1"/>`)
case "line", "arrow":
b.WriteString(`<line x1="` + ftoa(x) + `" y1="` + ftoa(y+h/2) + `" x2="` + ftoa(x+w) + `" y2="` + ftoa(y+h/2) + `" stroke="` + stroke + `" stroke-width="1"/>`)
case "text":
b.WriteString(`<text x="` + ftoa(x) + `" y="` + ftoa(y+h/2) + `" font-family="sans-serif" font-size="12" fill="` + stroke + `">` + htmlEscape(el.Text) + `</text>`)
default:
b.WriteString(`<rect x="` + ftoa(x) + `" y="` + ftoa(y) + `" width="` + ftoa(w) + `" height="` + ftoa(h) + `" fill="none" stroke="#ccc" stroke-width="0.5"/>`)
}
}
b.WriteString(`</svg>`)
w.Header().Set("Content-Type", "image/svg+xml")
w.Header().Set("Cache-Control", "max-age=60")
w.WriteHeader(http.StatusOK)
w.Write([]byte(b.String()))
}
func ftoa(f float64) string { return strconv.FormatFloat(f, 'f', 2, 64) }
func itoa(i int) string { return strconv.Itoa(i) }
func htmlEscape(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
s = strings.ReplaceAll(s, `"`, "&quot;")
return s
}
func (a *API) handleListTemplates(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
templates, err := a.store.ListTemplates(r.Context(), user.ID, teamID)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, templates)
}
func (a *API) handleListActivity(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
activity, err := a.store.ListActivity(r.Context(), user.ID, teamID, 50)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, activity)
}
func (a *API) handleStats(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
stats, err := a.store.WorkspaceStats(r.Context(), user.ID, teamID)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, stats)
}
func (a *API) handleListFolders(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
folders, err := a.store.ListFolders(r.Context(), user.ID, teamID)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, folders)
}
func (a *API) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateFolderRequest
if !decodeJSON(w, r, &req, 128<<10) {
return
}
folder, err := a.store.CreateFolder(r.Context(), user.ID, req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, folder)
}
func (a *API) handleListProjects(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := strings.TrimSpace(r.URL.Query().Get("team_id"))
projects, err := a.store.ListProjects(r.Context(), user.ID, teamID)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, projects)
}
func (a *API) handleCreateProject(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateProjectRequest
if !decodeJSON(w, r, &req, 128<<10) {
return
}
project, err := a.store.CreateProject(r.Context(), user.ID, req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, project)
}
func (a *API) handleSearch(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
q := strings.TrimSpace(r.URL.Query().Get("q"))
if q == "" {
writeJSON(w, http.StatusOK, []Drawing{})
return
}
drawings, err := a.store.SearchDrawings(r.Context(), user.ID, q)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, drawings)
}
func writeLookupError(w http.ResponseWriter, err error) {
switch {
case errors.Is(err, ErrForbidden):
writeError(w, http.StatusForbidden, "Access denied")
case errors.Is(err, sql.ErrNoRows), errors.Is(err, ErrNotFound):
writeError(w, http.StatusNotFound, "Resource not found")
default:
writeError(w, http.StatusBadRequest, err.Error())
}
}
func decodeJSON(w http.ResponseWriter, r *http.Request, dst any, limit int64) bool {
defer r.Body.Close()
r.Body = http.MaxBytesReader(w, r.Body, limit)
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
if err := decoder.Decode(dst); err != nil {
writeError(w, http.StatusBadRequest, "Invalid request body")
return false
}
return true
}
func writeJSON(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(body); err != nil {
logrus.WithError(err).Warn("failed to encode response")
}
}
func writeError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]string{"error": message})
}
func setSessionCookie(w http.ResponseWriter, r *http.Request, token string, expires time.Time) {
SetSessionCookie(w, r, token, expires)
}
func SetSessionCookie(w http.ResponseWriter, r *http.Request, token string, expires time.Time) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: token,
Path: "/",
Expires: expires,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
}
func clearSessionCookie(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
}
func isSecureRequest(r *http.Request) bool {
return r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https")
}
func clientIP(r *http.Request) string {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
return strings.TrimSpace(strings.Split(forwarded, ",")[0])
}
host := r.RemoteAddr
if idx := strings.LastIndex(host, ":"); idx > 0 {
return host[:idx]
}
return host
}
+214
View File
@@ -0,0 +1,214 @@
package workspace
import (
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
)
func (a *API) handleListTeamInvites(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
invites, err := a.store.ListTeamInvites(r.Context(), user.ID, chi.URLParam(r, "teamID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, invites)
}
func (a *API) handleCreateTeamInvite(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateInviteRequest
if !decodeJSON(w, r, &req, 64<<10) {
return
}
invite, token, err := a.store.CreateTeamInvite(r.Context(), user.ID, chi.URLParam(r, "teamID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, map[string]any{"invite": invite, "token": token})
}
func (a *API) handleAcceptInvite(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req struct {
Token string `json:"token"`
}
if !decodeJSON(w, r, &req, 32<<10) {
return
}
membership, err := a.store.AcceptInvite(r.Context(), user.ID, req.Token)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, membership)
}
func (a *API) handleListPermissions(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
grants, err := a.store.ListPermissionGrants(r.Context(), user.ID, "drawing", chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, grants)
}
func (a *API) handleCreatePermission(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreatePermissionGrantRequest
if !decodeJSON(w, r, &req, 64<<10) {
return
}
grant, err := a.store.CreateDrawingPermissionGrant(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, grant)
}
func (a *API) handleListShareLinks(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
links, err := a.store.ListShareLinks(r.Context(), user.ID, "drawing", chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, links)
}
func (a *API) handleCreateShareLink(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateShareLinkRequest
if !decodeJSON(w, r, &req, 64<<10) {
return
}
link, token, err := a.store.CreateDrawingShareLink(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, map[string]any{"share_link": link, "token": token})
}
func (a *API) handleSharedResource(w http.ResponseWriter, r *http.Request) {
payload, err := a.store.SharedResourceByToken(r.Context(), chi.URLParam(r, "token"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, payload)
}
func (a *API) handleListAssets(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
assets, err := a.store.ListDrawingAssets(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, assets)
}
func (a *API) handleCreateAsset(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateAssetRequest
if !decodeJSON(w, r, &req, 64<<10) {
return
}
asset, err := a.store.CreateDrawingAsset(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, asset)
}
func (a *API) handleListEmbeds(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
embeds, err := a.store.ListEmbeds(r.Context(), user.ID, chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, embeds)
}
func (a *API) handleCreateEmbed(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateEmbedRequest
if !decodeJSON(w, r, &req, 64<<10) {
return
}
embed, err := a.store.CreateEmbed(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, embed)
}
func (a *API) handleListLinks(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
links, err := a.store.ListLinkReferences(r.Context(), user.ID, "drawing", chi.URLParam(r, "drawingID"))
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusOK, links)
}
func (a *API) handleCreateLink(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
var req CreateLinkRequest
if !decodeJSON(w, r, &req, 64<<10) {
return
}
link, err := a.store.CreateDrawingLinkReference(r.Context(), user.ID, chi.URLParam(r, "drawingID"), req)
if err != nil {
writeLookupError(w, err)
return
}
writeJSON(w, http.StatusCreated, link)
}
func (a *API) handleCreateTeamUser(w http.ResponseWriter, r *http.Request) {
user, _ := currentUser(r)
teamID := chi.URLParam(r, "teamID")
var role string
err := a.store.db.QueryRowContext(r.Context(), `SELECT role FROM workspace_team_memberships WHERE user_id = ? AND team_id = ?`, user.ID, teamID).Scan(&role)
if errors.Is(err, sql.ErrNoRows) || (err == nil && role != "owner" && role != "admin") {
writeError(w, http.StatusForbidden, "Only team owners and admins can add members")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to verify team access")
return
}
var req struct {
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
Role string `json:"role"`
}
if !decodeJSON(w, r, &req, 32<<10) {
return
}
newUser, err := a.store.CreateTeamUser(r.Context(), teamID, req.Name, req.Email, req.Password, req.Role)
if err != nil {
status := http.StatusBadRequest
if errors.Is(err, ErrConflict) {
status = http.StatusConflict
}
writeError(w, status, err.Error())
return
}
writeJSON(w, http.StatusCreated, newUser)
}
+307
View File
@@ -0,0 +1,307 @@
package workspace
import (
"bytes"
"context"
"encoding/json"
dbpostgres "excalidraw-complete/internal/postgres"
"net/url"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
func newTestStore(t *testing.T) (*Store, func()) {
t.Helper()
baseURL := os.Getenv("TEST_DATABASE_URL")
if baseURL == "" {
baseURL = os.Getenv("DATABASE_URL")
}
if baseURL == "" {
t.Skip("TEST_DATABASE_URL or DATABASE_URL is required for PostgreSQL workspace tests")
}
schema := "test_" + strings.ToLower(newID())
adminDB, err := dbpostgres.Open(baseURL)
if err != nil {
t.Fatalf("open test database error = %v", err)
}
if _, err := adminDB.DB.ExecContext(context.Background(), `CREATE SCHEMA "`+schema+`"`); err != nil {
adminDB.Close()
t.Fatalf("create test schema error = %v", err)
}
store, err := NewStore(databaseURLWithSearchPath(t, baseURL, schema))
if err != nil {
adminDB.DB.ExecContext(context.Background(), `DROP SCHEMA "`+schema+`" CASCADE`)
adminDB.Close()
t.Fatalf("NewStore() error = %v", err)
}
return store, func() {
if err := store.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
if _, err := adminDB.DB.ExecContext(context.Background(), `DROP SCHEMA "`+schema+`" CASCADE`); err != nil {
t.Fatalf("drop test schema error = %v", err)
}
if err := adminDB.Close(); err != nil {
t.Fatalf("admin Close() error = %v", err)
}
}
}
func databaseURLWithSearchPath(t *testing.T, rawURL, schema string) string {
t.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("parse database URL error = %v", err)
}
q := parsed.Query()
q.Set("search_path", schema)
parsed.RawQuery = q.Encode()
return parsed.String()
}
func newTestAPI(t *testing.T) (*API, func()) {
t.Helper()
store, cleanup := newTestStore(t)
api := NewAPI(store)
api.testMode = true
return api, cleanup
}
func doJSON(t *testing.T, api *API, method, path string, body any, cookies ...*http.Cookie) *httptest.ResponseRecorder {
t.Helper()
var raw bytes.Buffer
if body != nil {
if err := json.NewEncoder(&raw).Encode(body); err != nil {
t.Fatalf("json encode error = %v", err)
}
}
req := httptest.NewRequest(method, path, &raw)
req.Header.Set("Content-Type", "application/json")
for _, cookie := range cookies {
req.AddCookie(cookie)
}
rr := httptest.NewRecorder()
api.Routes().ServeHTTP(rr, req)
return rr
}
func signup(t *testing.T, api *API, email string) (*http.Cookie, User, Team) {
t.Helper()
rr := doJSON(t, api, http.MethodPost, "/auth/signup", map[string]string{
"name": "Test User",
"email": email,
"password": "password-123",
})
if rr.Code != http.StatusCreated {
t.Fatalf("signup status = %d body = %s", rr.Code, rr.Body.String())
}
var payload struct {
User User `json:"user"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
t.Fatalf("signup response decode error = %v", err)
}
var sessionCookie *http.Cookie
for _, cookie := range rr.Result().Cookies() {
if cookie.Name == sessionCookieName {
copy := *cookie
sessionCookie = &copy
break
}
}
if sessionCookie == nil {
t.Fatal("signup did not set session cookie")
}
teamsRR := doJSON(t, api, http.MethodGet, "/teams", nil, sessionCookie)
if teamsRR.Code != http.StatusOK {
t.Fatalf("teams status = %d body = %s", teamsRR.Code, teamsRR.Body.String())
}
var teams []Team
if err := json.Unmarshal(teamsRR.Body.Bytes(), &teams); err != nil {
t.Fatalf("teams decode error = %v", err)
}
if len(teams) != 1 {
t.Fatalf("teams len = %d, want 1", len(teams))
}
return sessionCookie, payload.User, teams[0]
}
func TestSignupCreatesCookieSessionAndDefaultTeam(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
cookie, user, team := signup(t, api, "alice@example.com")
if !cookie.HttpOnly {
t.Fatal("session cookie must be httpOnly")
}
if user.ID == "" || user.Email != "alice@example.com" {
t.Fatalf("unexpected user: %#v", user)
}
if team.OwnerUserID != user.ID || team.PlanType != "free" {
t.Fatalf("unexpected team: %#v", team)
}
meRR := doJSON(t, api, http.MethodGet, "/auth/me", nil, cookie)
if meRR.Code != http.StatusOK {
t.Fatalf("me status = %d body = %s", meRR.Code, meRR.Body.String())
}
logoutRR := doJSON(t, api, http.MethodPost, "/auth/logout", nil, cookie)
if logoutRR.Code != http.StatusOK {
t.Fatalf("logout status = %d body = %s", logoutRR.Code, logoutRR.Body.String())
}
if logoutRR.Body.String() == "" {
t.Fatal("logout must return JSON for frontend fetchApi compatibility")
}
afterLogoutRR := doJSON(t, api, http.MethodGet, "/auth/me", nil, cookie)
if afterLogoutRR.Code != http.StatusUnauthorized {
t.Fatalf("me after logout status = %d body = %s", afterLogoutRR.Code, afterLogoutRR.Body.String())
}
}
func TestDrawingAccessRequiresTeamMembership(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": aliceTeam.ID,
"title": "Architecture map",
}, aliceCookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
forbiddenRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
if forbiddenRR.Code != http.StatusForbidden {
t.Fatalf("bob get drawing status = %d body = %s", forbiddenRR.Code, forbiddenRR.Body.String())
}
listRR := doJSON(t, api, http.MethodGet, "/drawings?team_id="+aliceTeam.ID, nil, bobCookie)
if listRR.Code != http.StatusForbidden {
t.Fatalf("bob list team drawings status = %d body = %s", listRR.Code, listRR.Body.String())
}
}
func TestTeamMembersRequireMembership(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
okRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, aliceCookie)
if okRR.Code != http.StatusOK {
t.Fatalf("alice members status = %d body = %s", okRR.Code, okRR.Body.String())
}
var members []TeamMembership
if err := json.Unmarshal(okRR.Body.Bytes(), &members); err != nil {
t.Fatalf("members decode error = %v", err)
}
if len(members) != 1 || members[0].Role != "owner" || members[0].User == nil {
t.Fatalf("unexpected members: %#v", members)
}
forbiddenRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
if forbiddenRR.Code != http.StatusForbidden {
t.Fatalf("bob members status = %d body = %s", forbiddenRR.Code, forbiddenRR.Body.String())
}
}
func TestDrawingRevisionsTemplatesAndActivity(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
cookie, _, team := signup(t, api, "alice@example.com")
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": team.ID,
"title": "Launch plan",
"snapshot": map[string]any{"type": "excalidraw", "elements": []any{}},
}, cookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
if drawing.LatestRevisionID == nil {
t.Fatal("create drawing with snapshot must create latest_revision_id")
}
revisionRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/revisions", map[string]any{
"snapshot": map[string]any{"type": "excalidraw", "elements": []any{map[string]any{"id": "a"}}},
"change_summary": "Added first shape",
}, cookie)
if revisionRR.Code != http.StatusCreated {
t.Fatalf("create revision status = %d body = %s", revisionRR.Code, revisionRR.Body.String())
}
revisionsRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID+"/revisions", nil, cookie)
if revisionsRR.Code != http.StatusOK {
t.Fatalf("list revisions status = %d body = %s", revisionsRR.Code, revisionsRR.Body.String())
}
var revisions []DrawingRevision
if err := json.Unmarshal(revisionsRR.Body.Bytes(), &revisions); err != nil {
t.Fatalf("revisions decode error = %v", err)
}
if len(revisions) != 2 || revisions[0].RevisionNumber != 2 {
t.Fatalf("unexpected revisions: %#v", revisions)
}
templatesRR := doJSON(t, api, http.MethodGet, "/templates", nil, cookie)
if templatesRR.Code != http.StatusOK {
t.Fatalf("templates status = %d body = %s", templatesRR.Code, templatesRR.Body.String())
}
var templates []Template
if err := json.Unmarshal(templatesRR.Body.Bytes(), &templates); err != nil {
t.Fatalf("templates decode error = %v", err)
}
if len(templates) < 4 {
t.Fatalf("templates len = %d, want at least 4", len(templates))
}
activityRR := doJSON(t, api, http.MethodGet, "/activity?team_id="+team.ID, nil, cookie)
if activityRR.Code != http.StatusOK {
t.Fatalf("activity status = %d body = %s", activityRR.Code, activityRR.Body.String())
}
var activity []ActivityEvent
if err := json.Unmarshal(activityRR.Body.Bytes(), &activity); err != nil {
t.Fatalf("activity decode error = %v", err)
}
if len(activity) == 0 {
t.Fatal("expected activity events")
}
statsRR := doJSON(t, api, http.MethodGet, "/stats?team_id="+team.ID, nil, cookie)
if statsRR.Code != http.StatusOK {
t.Fatalf("stats status = %d body = %s", statsRR.Code, statsRR.Body.String())
}
var stats WorkspaceStats
if err := json.Unmarshal(statsRR.Body.Bytes(), &stats); err != nil {
t.Fatalf("stats decode error = %v", err)
}
if stats.Drawings != 1 || stats.Revisions != 2 || stats.Templates < 4 {
t.Fatalf("unexpected stats: %#v", stats)
}
}
func TestHealth(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
rr := doJSON(t, api, http.MethodGet, "/health", nil)
if rr.Code != http.StatusOK {
t.Fatalf("health status = %d body = %s", rr.Code, rr.Body.String())
}
}
+217
View File
@@ -0,0 +1,217 @@
package workspace
import (
"encoding/json"
"time"
)
type User struct {
ID string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
AvatarURL *string `json:"avatar_url"`
Locale string `json:"locale"`
Timezone string `json:"timezone"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
type Team struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
OwnerUserID string `json:"owner_user_id"`
PlanType string `json:"plan_type"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type TeamMembership struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
UserID string `json:"user_id"`
Role string `json:"role"`
JoinedAt time.Time `json:"joined_at"`
User *User `json:"user,omitempty"`
}
type TeamInvite struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Role string `json:"role"`
InvitedBy string `json:"invited_by"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
type Project struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
Slug string `json:"slug"`
Description *string `json:"description"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Folder struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ProjectID *string `json:"project_id"`
ParentFolderID *string `json:"parent_folder_id"`
Name string `json:"name"`
Slug string `json:"slug"`
PathCache string `json:"path_cache"`
Visibility string `json:"visibility"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Drawing struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
FolderID *string `json:"folder_id"`
ProjectID *string `json:"project_id"`
Slug *string `json:"slug"`
Title string `json:"title"`
Description *string `json:"description"`
OwnerUserID string `json:"owner_user_id"`
LatestRevisionID *string `json:"latest_revision_id"`
Visibility string `json:"visibility"`
IsArchived bool `json:"is_archived"`
ThumbnailAssetID *string `json:"thumbnail_asset_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at"`
Owner *User `json:"owner,omitempty"`
Folder *Folder `json:"folder,omitempty"`
Project *Project `json:"project,omitempty"`
ThumbnailURL *string `json:"thumbnail_url,omitempty"`
}
type DrawingRevision struct {
ID string `json:"id"`
DrawingID string `json:"drawing_id"`
RevisionNumber int `json:"revision_number"`
SnapshotPath string `json:"snapshot_path"`
SnapshotSize int64 `json:"snapshot_size"`
ContentHash string `json:"content_hash"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
ChangeSummary *string `json:"change_summary"`
Snapshot json.RawMessage `json:"snapshot,omitempty"`
CreatedByUser *User `json:"created_by_user,omitempty"`
}
type DrawingAsset struct {
ID string `json:"id"`
DrawingID string `json:"drawing_id"`
Kind string `json:"kind"`
Path string `json:"path"`
MimeType string `json:"mime_type"`
Size int64 `json:"size"`
Width *int `json:"width"`
Height *int `json:"height"`
UploadedBy string `json:"uploaded_by"`
CreatedAt time.Time `json:"created_at"`
URL *string `json:"url,omitempty"`
}
type Template struct {
ID string `json:"id"`
TeamID *string `json:"team_id"`
Scope string `json:"scope"`
Type string `json:"type"`
Name string `json:"name"`
Description *string `json:"description"`
SnapshotPath string `json:"snapshot_path"`
MetadataJSON map[string]any `json:"metadata_json"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
PreviewURL *string `json:"preview_url,omitempty"`
}
type ActivityEvent struct {
ID string `json:"id"`
ActorUserID *string `json:"actor_user_id"`
TeamID *string `json:"team_id"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id"`
EventType string `json:"event_type"`
MetadataJSON map[string]any `json:"metadata_json"`
CreatedAt time.Time `json:"created_at"`
Actor *User `json:"actor,omitempty"`
}
type ShareLink struct {
ID string `json:"id"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id"`
TokenHash string `json:"token_hash,omitempty"`
Permission string `json:"permission"`
ExpiresAt *time.Time `json:"expires_at"`
PasswordHash *string `json:"-"`
CreatedBy string `json:"created_by"`
RevokedAt *time.Time `json:"revoked_at"`
CreatedAt time.Time `json:"created_at"`
}
type PermissionGrant struct {
ID string `json:"id"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id"`
SubjectType string `json:"subject_type"`
SubjectID string `json:"subject_id"`
Permission string `json:"permission"`
InheritedFrom *string `json:"inherited_from"`
CreatedAt time.Time `json:"created_at"`
}
type Embed struct {
ID string `json:"id"`
DrawingID string `json:"drawing_id"`
SourceURL string `json:"source_url"`
CanonicalURL string `json:"canonical_url"`
Provider string `json:"provider"`
EmbedType string `json:"embed_type"`
Title *string `json:"title"`
PreviewAssetID *string `json:"preview_asset_id"`
SafeEmbedHTML *string `json:"safe_embed_html"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
}
type LinkReference struct {
ID string `json:"id"`
SourceResourceType string `json:"source_resource_type"`
SourceResourceID string `json:"source_resource_id"`
TargetResourceType string `json:"target_resource_type"`
TargetResourceID string `json:"target_resource_id"`
Label *string `json:"label"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
}
type WorkspaceStats struct {
Teams int `json:"teams"`
Members int `json:"members"`
Projects int `json:"projects"`
Folders int `json:"folders"`
Drawings int `json:"drawings"`
Templates int `json:"templates"`
Revisions int `json:"revisions"`
Assets int `json:"assets"`
StorageBytes int64 `json:"storage_bytes"`
}
+192
View File
@@ -0,0 +1,192 @@
package workspace
import (
"context"
"crypto/rand"
"database/sql"
"errors"
dbpostgres "excalidraw-complete/internal/postgres"
"fmt"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
type OAuthProfile struct {
Provider string
ProviderUserID string
Email string
Name string
Username string
AvatarURL string
EmailVerified bool
}
func (s *Store) UpsertOAuthSession(ctx context.Context, profile OAuthProfile) (*User, *Session, string, error) {
profile.Provider = strings.TrimSpace(strings.ToLower(profile.Provider))
profile.ProviderUserID = strings.TrimSpace(profile.ProviderUserID)
if profile.Provider == "" || profile.ProviderUserID == "" {
return nil, nil, "", fmt.Errorf("oauth provider and provider user id are required")
}
email := strings.TrimSpace(profile.Email)
if email == "" {
email = fmt.Sprintf("%s-%s@users.local", profile.Provider, slugify(profile.ProviderUserID))
}
normalizedEmail, err := normalizeEmail(email)
if err != nil {
return nil, nil, "", err
}
name := strings.TrimSpace(profile.Name)
if name == "" {
name = strings.TrimSpace(profile.Username)
}
if name == "" {
name = normalizedEmail
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, nil, "", err
}
defer tx.Rollback()
userID, err := userIDByIdentityTx(ctx, tx, profile.Provider, profile.ProviderUserID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, nil, "", err
}
now := time.Now().UTC()
var user *User
if userID != "" {
user, err = updateOAuthUserTx(ctx, tx, userID, name, profile.Username, normalizedEmail, profile.AvatarURL)
if err != nil {
return nil, nil, "", err
}
} else {
userID, err = userIDByEmailTx(ctx, tx, normalizedEmail)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, nil, "", err
}
if userID == "" {
user, err = createOAuthUserTx(ctx, tx, name, profile.Username, normalizedEmail, profile.AvatarURL)
if err != nil {
return nil, nil, "", err
}
team, err := createTeamTx(ctx, tx, user.ID, name+"'s Workspace", "")
if err != nil {
return nil, nil, "", err
}
if err := insertActivityTx(ctx, tx, &user.ID, &team.ID, "team", team.ID, "member_joined", map[string]any{"role": "owner"}); err != nil {
return nil, nil, "", err
}
} else {
user, err = updateOAuthUserTx(ctx, tx, userID, name, profile.Username, normalizedEmail, profile.AvatarURL)
if err != nil {
return nil, nil, "", err
}
}
var verifiedAt *time.Time
if profile.EmailVerified {
verifiedAt = &now
}
_, err = tx.ExecContext(ctx, `INSERT INTO workspace_auth_identities
(id, user_id, provider, provider_user_id, email_verified_at, created_at)
VALUES (?, ?, ?, ?, ?, ?)`,
newID(), user.ID, profile.Provider, profile.ProviderUserID, verifiedAt, now,
)
if err != nil {
return nil, nil, "", err
}
}
session, token, err := createSessionTx(ctx, tx, user.ID)
if err != nil {
return nil, nil, "", err
}
if err := insertActivityTx(ctx, tx, &user.ID, nil, "user", user.ID, "login_success", map[string]any{"provider": profile.Provider}); err != nil {
return nil, nil, "", err
}
if err := tx.Commit(); err != nil {
return nil, nil, "", err
}
return user, session, token, nil
}
func userIDByIdentityTx(ctx context.Context, tx *dbpostgres.Tx, provider, providerUserID string) (string, error) {
var userID string
err := tx.QueryRowContext(ctx, `SELECT user_id FROM workspace_auth_identities WHERE provider = ? AND provider_user_id = ?`, provider, providerUserID).Scan(&userID)
return userID, err
}
func userIDByEmailTx(ctx context.Context, tx *dbpostgres.Tx, email string) (string, error) {
var userID string
err := tx.QueryRowContext(ctx, `SELECT id FROM workspace_users WHERE email = ?`, email).Scan(&userID)
return userID, err
}
func createOAuthUserTx(ctx context.Context, tx *dbpostgres.Tx, name, username, email, avatarURL string) (*User, error) {
password := make([]byte, 32)
if _, err := rand.Read(password); err != nil {
return nil, err
}
hash, err := bcrypt.GenerateFromPassword(password, 12)
if err != nil {
return nil, err
}
if username == "" {
username = strings.TrimSuffix(email, email[strings.LastIndex(email, "@"):])
}
now := time.Now().UTC()
user := &User{
ID: newID(),
Name: name,
Username: uniqueUsername(ctx, tx, slugify(username)),
Email: email,
Locale: "en",
Timezone: "UTC",
CreatedAt: now,
UpdatedAt: now,
}
if avatarURL != "" {
user.AvatarURL = &avatarURL
}
_, err = tx.ExecContext(ctx, `INSERT INTO workspace_users
(id, name, username, email, password_hash, avatar_url, locale, timezone, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
user.ID, user.Name, user.Username, user.Email, string(hash), user.AvatarURL, user.Locale, user.Timezone, user.CreatedAt, user.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
func updateOAuthUserTx(ctx context.Context, tx *dbpostgres.Tx, userID, name, username, email, avatarURL string) (*User, error) {
current := &User{}
var currentAvatar *string
err := tx.QueryRowContext(ctx, `SELECT id, name, username, email, avatar_url, locale, timezone, created_at, updated_at FROM workspace_users WHERE id = ?`, userID).
Scan(&current.ID, &current.Name, &current.Username, &current.Email, &currentAvatar, &current.Locale, &current.Timezone, &current.CreatedAt, &current.UpdatedAt)
if err != nil {
return nil, err
}
if strings.TrimSpace(name) == "" {
name = current.Name
}
if strings.TrimSpace(username) == "" {
username = current.Username
}
avatar := currentAvatar
if avatarURL != "" {
avatar = &avatarURL
}
now := time.Now().UTC()
_, err = tx.ExecContext(ctx, `UPDATE workspace_users SET name = ?, avatar_url = ?, updated_at = ? WHERE id = ?`, name, avatar, now, userID)
if err != nil {
return nil, err
}
current.Name = name
current.AvatarURL = avatar
current.UpdatedAt = now
return current, nil
}
+46
View File
@@ -0,0 +1,46 @@
package workspace
import (
"context"
"testing"
)
func TestUpsertOAuthSessionCreatesAndReusesIdentity(t *testing.T) {
store, cleanup := newTestStore(t)
defer cleanup()
profile := OAuthProfile{
Provider: "github",
ProviderUserID: "123",
Email: "octo@example.com",
Name: "Octo User",
Username: "octo",
AvatarURL: "https://example.com/avatar.png",
EmailVerified: true,
}
user, session, token, err := store.UpsertOAuthSession(context.Background(), profile)
if err != nil {
t.Fatalf("UpsertOAuthSession() error = %v", err)
}
if user.ID == "" || session.ID == "" || token == "" {
t.Fatalf("missing oauth output user=%#v session=%#v token=%q", user, session, token)
}
teams, err := store.ListTeamsForUser(context.Background(), user.ID)
if err != nil {
t.Fatalf("ListTeamsForUser() error = %v", err)
}
if len(teams) != 1 {
t.Fatalf("teams len = %d, want 1", len(teams))
}
sameUser, secondSession, secondToken, err := store.UpsertOAuthSession(context.Background(), profile)
if err != nil {
t.Fatalf("second UpsertOAuthSession() error = %v", err)
}
if sameUser.ID != user.ID {
t.Fatalf("second oauth user id = %s, want %s", sameUser.ID, user.ID)
}
if secondSession.ID == session.ID || secondToken == token {
t.Fatal("oauth login should create a fresh session")
}
}
+370
View File
@@ -0,0 +1,370 @@
package workspace
import (
"encoding/json"
"net/http"
"testing"
)
// TestPermissionMatrix tests the full permission matrix for drawings.
// It creates drawings with different visibilities and verifies access
// from users with different roles and direct permission grants.
func TestPermissionMatrix(t *testing.T) {
cases := []struct {
name string
visibility string
grantPerm string // direct grant permission to bob
bobRole string // bob's role in alice's team
expectGet int // expected HTTP status for GET /drawings/:id
expectUpdate int // expected HTTP status for PATCH /drawings/:id
expectListTeam int // expected HTTP status for GET /drawings (team-scoped list)
}{
// public drawings - anyone in the team can view, edit requires role
{
name: "public_drawing_team_viewer_can_view_not_edit",
visibility: "public", grantPerm: "", bobRole: "viewer",
expectGet: http.StatusOK, expectUpdate: http.StatusForbidden,
expectListTeam: http.StatusOK,
},
{
name: "public_drawing_team_editor_can_view_and_edit",
visibility: "public", grantPerm: "", bobRole: "editor",
expectGet: http.StatusOK, expectUpdate: http.StatusOK,
expectListTeam: http.StatusOK,
},
// restricted drawings - no team access without explicit grant
{
name: "restricted_no_grant_team_member_cannot_access",
visibility: "restricted", grantPerm: "", bobRole: "editor",
expectGet: http.StatusForbidden, expectUpdate: http.StatusForbidden,
expectListTeam: http.StatusOK, // team list still shows it if team-scoped
},
{
name: "restricted_with_view_grant_team_viewer_can_view_not_edit",
visibility: "restricted", grantPerm: "view", bobRole: "viewer",
expectGet: http.StatusOK, expectUpdate: http.StatusForbidden,
expectListTeam: http.StatusOK,
},
{
name: "restricted_with_edit_grant_team_viewer_can_view_and_edit",
visibility: "restricted", grantPerm: "edit", bobRole: "viewer",
expectGet: http.StatusOK, expectUpdate: http.StatusOK,
expectListTeam: http.StatusOK,
},
// private drawings - only owner and explicit grant holders
{
name: "private_no_grant_team_member_cannot_access",
visibility: "private", grantPerm: "", bobRole: "admin",
expectGet: http.StatusForbidden, expectUpdate: http.StatusForbidden,
expectListTeam: http.StatusOK, // team list may still include private owned by owner
},
{
name: "private_with_view_grant_allows_view_not_edit",
visibility: "private", grantPerm: "view", bobRole: "viewer",
expectGet: http.StatusOK, expectUpdate: http.StatusForbidden,
expectListTeam: http.StatusOK,
},
{
name: "private_with_edit_grant_allows_view_and_edit",
visibility: "private", grantPerm: "edit", bobRole: "viewer",
expectGet: http.StatusOK, expectUpdate: http.StatusOK,
expectListTeam: http.StatusOK,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
// Invite bob to alice's team with specified role
if tc.bobRole != "" {
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
"email": "bob@example.com",
"role": tc.bobRole,
}, aliceCookie)
if inviteRR.Code != http.StatusCreated {
t.Fatalf("invite status = %d body = %s", inviteRR.Code, inviteRR.Body.String())
}
// Accept the invite
var invitePayload struct {
Token string `json:"token"`
}
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invitePayload); err != nil {
t.Fatalf("invite decode error = %v", err)
}
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{
"token": invitePayload.Token,
}, bobCookie)
if acceptRR.Code != http.StatusOK {
t.Fatalf("accept status = %d body = %s", acceptRR.Code, acceptRR.Body.String())
}
}
// Alice creates a drawing with specified visibility
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": aliceTeam.ID,
"title": tc.name + " drawing",
"visibility": tc.visibility,
}, aliceCookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
// Optionally grant bob a direct permission
if tc.grantPerm != "" {
grantRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/permissions", map[string]any{
"subject_type": "user",
"email": "bob@example.com",
"permission": tc.grantPerm,
}, aliceCookie)
if grantRR.Code != http.StatusCreated {
t.Fatalf("grant status = %d body = %s", grantRR.Code, grantRR.Body.String())
}
}
// Bob attempts to GET the drawing
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
if getRR.Code != tc.expectGet {
t.Errorf("GET status = %d, want %d; body = %s", getRR.Code, tc.expectGet, getRR.Body.String())
}
// Bob attempts to PATCH (update) the drawing
updateRR := doJSON(t, api, http.MethodPatch, "/drawings/"+drawing.ID, map[string]any{
"title": tc.name + " updated",
}, bobCookie)
if updateRR.Code != tc.expectUpdate {
t.Errorf("PATCH status = %d, want %d; body = %s", updateRR.Code, tc.expectUpdate, updateRR.Body.String())
}
// Bob attempts to list team drawings
listRR := doJSON(t, api, http.MethodGet, "/drawings?team_id="+aliceTeam.ID, nil, bobCookie)
if listRR.Code != tc.expectListTeam {
t.Errorf("LIST status = %d, want %d; body = %s", listRR.Code, tc.expectListTeam, listRR.Body.String())
}
// Verify alice (owner) can still access everything
ownerGet := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, aliceCookie)
if ownerGet.Code != http.StatusOK {
t.Errorf("owner GET status = %d, want %d", ownerGet.Code, http.StatusOK)
}
ownerUpdate := doJSON(t, api, http.MethodPatch, "/drawings/"+drawing.ID, map[string]any{
"title": tc.name + " owner updated",
}, aliceCookie)
if ownerUpdate.Code != http.StatusOK {
t.Errorf("owner PATCH status = %d, want %d", ownerUpdate.Code, http.StatusOK)
}
})
}
}
// TestAdminCanManageTeam verifies that team admins can manage team settings,
// members, and resources while non-admins cannot.
func TestAdminCanManageTeam(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
charlieCookie, _, _ := signup(t, api, "charlie@example.com")
// Invite bob as admin, charlie as viewer
for _, tc := range []struct{ email, role string }{
{"bob@example.com", "admin"},
{"charlie@example.com", "viewer"},
} {
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
"email": tc.email,
"role": tc.role,
}, aliceCookie)
if inviteRR.Code != http.StatusCreated {
t.Fatalf("invite %s status = %d body = %s", tc.email, inviteRR.Code, inviteRR.Body.String())
}
var invitePayload struct {
Token string `json:"token"`
}
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invitePayload); err != nil {
t.Fatalf("invite decode error = %v", err)
}
var cookie *http.Cookie
if tc.email == "bob@example.com" {
cookie = bobCookie
} else {
cookie = charlieCookie
}
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{
"token": invitePayload.Token,
}, cookie)
if acceptRR.Code != http.StatusOK {
t.Fatalf("accept %s status = %d body = %s", tc.email, acceptRR.Code, acceptRR.Body.String())
}
}
// Bob (admin) can update team name
updateTeam := doJSON(t, api, http.MethodPatch, "/teams/"+aliceTeam.ID, map[string]any{
"name": "Updated by admin",
}, bobCookie)
if updateTeam.Code != http.StatusOK {
t.Errorf("admin update team status = %d, want %d; body = %s", updateTeam.Code, http.StatusOK, updateTeam.Body.String())
}
// Charlie (viewer) cannot update team name
charlieUpdate := doJSON(t, api, http.MethodPatch, "/teams/"+aliceTeam.ID, map[string]any{
"name": "Updated by viewer",
}, charlieCookie)
if charlieUpdate.Code != http.StatusForbidden {
t.Errorf("viewer update team status = %d, want %d; body = %s", charlieUpdate.Code, http.StatusForbidden, charlieUpdate.Body.String())
}
// Bob (admin) can manage members
membersRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
if membersRR.Code != http.StatusOK {
t.Errorf("admin list members status = %d, want %d", membersRR.Code, http.StatusOK)
}
// Charlie (viewer) can view members
charlieMembers := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, charlieCookie)
if charlieMembers.Code != http.StatusOK {
t.Errorf("viewer list members status = %d, want %d", charlieMembers.Code, http.StatusOK)
}
}
// TestNonMemberCannotAccessPrivateTeam verifies that users not in a team
// cannot access any team resources regardless of drawing visibility.
func TestNonMemberCannotAccessPrivateTeam(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
// Alice creates a public drawing in her team
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": aliceTeam.ID,
"title": "Public team drawing",
"visibility": "public",
}, aliceCookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
// Bob (not in team) cannot access the drawing even though it's public
// because "public" in this context means public within the team
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
if getRR.Code != http.StatusForbidden {
t.Errorf("non-member GET status = %d, want %d; body = %s", getRR.Code, http.StatusForbidden, getRR.Body.String())
}
// Bob cannot list team drawings
listRR := doJSON(t, api, http.MethodGet, "/drawings?team_id="+aliceTeam.ID, nil, bobCookie)
if listRR.Code != http.StatusForbidden && listRR.Code != http.StatusOK {
// Depending on implementation, non-members might get empty list or forbidden
t.Logf("non-member LIST status = %d", listRR.Code)
}
}
// TestPermissionInheritance verifies that permissions flow correctly
// through the resource hierarchy (team -> project -> folder -> drawing).
func TestPermissionInheritance(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
// Invite bob as editor
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
"email": "bob@example.com",
"role": "editor",
}, aliceCookie)
if inviteRR.Code != http.StatusCreated {
t.Fatalf("invite status = %d body = %s", inviteRR.Code, inviteRR.Body.String())
}
var invitePayload struct {
Token string `json:"token"`
}
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invitePayload); err != nil {
t.Fatalf("invite decode error = %v", err)
}
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{
"token": invitePayload.Token,
}, bobCookie)
if acceptRR.Code != http.StatusOK {
t.Fatalf("accept status = %d body = %s", acceptRR.Code, acceptRR.Body.String())
}
// Alice creates a project
projectRR := doJSON(t, api, http.MethodPost, "/projects", map[string]any{
"team_id": aliceTeam.ID,
"name": "Test Project",
"description": "A test project",
}, aliceCookie)
if projectRR.Code != http.StatusCreated {
t.Fatalf("create project status = %d body = %s", projectRR.Code, projectRR.Body.String())
}
var project struct {
ID string `json:"id"`
}
if err := json.Unmarshal(projectRR.Body.Bytes(), &project); err != nil {
t.Fatalf("project decode error = %v", err)
}
// Alice creates a folder in the project
folderRR := doJSON(t, api, http.MethodPost, "/folders", map[string]any{
"team_id": aliceTeam.ID,
"project_id": project.ID,
"name": "Test Folder",
}, aliceCookie)
if folderRR.Code != http.StatusCreated {
t.Fatalf("create folder status = %d body = %s", folderRR.Code, folderRR.Body.String())
}
var folder struct {
ID string `json:"id"`
}
if err := json.Unmarshal(folderRR.Body.Bytes(), &folder); err != nil {
t.Fatalf("folder decode error = %v", err)
}
// Alice creates a drawing in the folder
drawingRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": aliceTeam.ID,
"project_id": project.ID,
"folder_id": folder.ID,
"title": "Nested Drawing",
"visibility": "public",
}, aliceCookie)
if drawingRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", drawingRR.Code, drawingRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(drawingRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
// Bob (team editor) should be able to access the nested drawing through team membership
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
if getRR.Code != http.StatusOK {
t.Errorf("team editor GET nested drawing status = %d, want %d; body = %s",
getRR.Code, http.StatusOK, getRR.Body.String())
}
// Bob should be able to update the drawing
updateRR := doJSON(t, api, http.MethodPatch, "/drawings/"+drawing.ID, map[string]any{
"title": "Updated by team editor",
}, bobCookie)
if updateRR.Code != http.StatusOK {
t.Errorf("team editor PATCH nested drawing status = %d, want %d; body = %s",
updateRR.Code, http.StatusOK, updateRR.Body.String())
}
}
+43
View File
@@ -0,0 +1,43 @@
package workspace
import (
"sync"
"time"
)
type rateLimiter struct {
mu sync.Mutex
limit int
window time.Duration
attempts map[string][]time.Time
}
func newRateLimiter(limit int, window time.Duration) *rateLimiter {
return &rateLimiter{
limit: limit,
window: window,
attempts: make(map[string][]time.Time),
}
}
func (l *rateLimiter) allow(key string) bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
cutoff := now.Add(-l.window)
values := l.attempts[key]
kept := values[:0]
for _, value := range values {
if value.After(cutoff) {
kept = append(kept, value)
}
}
if len(kept) >= l.limit {
l.attempts[key] = kept
return false
}
kept = append(kept, now)
l.attempts[key] = kept
return true
}
+220
View File
@@ -0,0 +1,220 @@
package workspace
import (
"encoding/json"
"net/http"
"testing"
)
func TestInviteAcceptAddsEditorMembership(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
inviteRR := doJSON(t, api, http.MethodPost, "/teams/"+aliceTeam.ID+"/invites", map[string]any{
"email": "bob@example.com",
"role": "editor",
}, aliceCookie)
if inviteRR.Code != http.StatusCreated {
t.Fatalf("invite status = %d body = %s", inviteRR.Code, inviteRR.Body.String())
}
var invite struct {
Invite TeamInvite `json:"invite"`
Token string `json:"token"`
}
if err := json.Unmarshal(inviteRR.Body.Bytes(), &invite); err != nil {
t.Fatalf("invite decode error = %v", err)
}
if invite.Token == "" || invite.Invite.Email != "bob@example.com" {
t.Fatalf("unexpected invite: %#v", invite)
}
beforeRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
if beforeRR.Code != http.StatusForbidden {
t.Fatalf("bob members before accept status = %d body = %s", beforeRR.Code, beforeRR.Body.String())
}
acceptRR := doJSON(t, api, http.MethodPost, "/invites/accept", map[string]string{"token": invite.Token}, bobCookie)
if acceptRR.Code != http.StatusOK {
t.Fatalf("accept status = %d body = %s", acceptRR.Code, acceptRR.Body.String())
}
afterRR := doJSON(t, api, http.MethodGet, "/teams/"+aliceTeam.ID+"/members", nil, bobCookie)
if afterRR.Code != http.StatusOK {
t.Fatalf("bob members after accept status = %d body = %s", afterRR.Code, afterRR.Body.String())
}
}
func TestRestrictedDrawingGrantAllowsViewNotEdit(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
bobCookie, _, _ := signup(t, api, "bob@example.com")
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": aliceTeam.ID,
"title": "Restricted map",
"visibility": "restricted",
}, aliceCookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
noAccessRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
if noAccessRR.Code != http.StatusForbidden {
t.Fatalf("bob get before grant status = %d body = %s", noAccessRR.Code, noAccessRR.Body.String())
}
grantRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/permissions", map[string]any{
"subject_type": "user",
"email": "bob@example.com",
"permission": "view",
}, aliceCookie)
if grantRR.Code != http.StatusCreated {
t.Fatalf("grant status = %d body = %s", grantRR.Code, grantRR.Body.String())
}
getRR := doJSON(t, api, http.MethodGet, "/drawings/"+drawing.ID, nil, bobCookie)
if getRR.Code != http.StatusOK {
t.Fatalf("bob get after grant status = %d body = %s", getRR.Code, getRR.Body.String())
}
editRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/revisions", map[string]any{
"snapshot": map[string]any{"type": "excalidraw", "elements": []any{}},
}, bobCookie)
if editRR.Code != http.StatusForbidden {
t.Fatalf("bob edit with view grant status = %d body = %s", editRR.Code, editRR.Body.String())
}
}
func TestShareLinkAllowsUnauthenticatedDrawingRead(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
aliceCookie, _, aliceTeam := signup(t, api, "alice@example.com")
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": aliceTeam.ID,
"title": "Shared map",
}, aliceCookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
shareRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/share-links", map[string]any{
"permission": "view",
}, aliceCookie)
if shareRR.Code != http.StatusCreated {
t.Fatalf("share status = %d body = %s", shareRR.Code, shareRR.Body.String())
}
var share struct {
ShareLink ShareLink `json:"share_link"`
Token string `json:"token"`
}
if err := json.Unmarshal(shareRR.Body.Bytes(), &share); err != nil {
t.Fatalf("share decode error = %v", err)
}
if share.Token == "" || share.ShareLink.TokenHash != "" {
t.Fatalf("unexpected share response: %#v", share)
}
publicRR := doJSON(t, api, http.MethodGet, "/shared/"+share.Token, nil)
if publicRR.Code != http.StatusOK {
t.Fatalf("public shared status = %d body = %s", publicRR.Code, publicRR.Body.String())
}
}
func TestEmbedsRejectUnsafeURLs(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
cookie, _, team := signup(t, api, "alice@example.com")
createRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": team.ID,
"title": "Embed map",
}, cookie)
if createRR.Code != http.StatusCreated {
t.Fatalf("create drawing status = %d body = %s", createRR.Code, createRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(createRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
for _, unsafeURL := range []string{"javascript:alert(1)", "http://127.0.0.1/admin", "http://localhost:3002"} {
rr := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/embeds", map[string]any{
"source_url": unsafeURL,
"embed_type": "link",
}, cookie)
if rr.Code != http.StatusBadRequest {
t.Fatalf("unsafe url %q status = %d body = %s", unsafeURL, rr.Code, rr.Body.String())
}
}
safeRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/embeds", map[string]any{
"source_url": "https://example.com/roadmap",
"embed_type": "link",
}, cookie)
if safeRR.Code != http.StatusCreated {
t.Fatalf("safe embed status = %d body = %s", safeRR.Code, safeRR.Body.String())
}
}
func TestAssetsAndLinkReferences(t *testing.T) {
api, cleanup := newTestAPI(t)
defer cleanup()
cookie, _, team := signup(t, api, "alice@example.com")
projectRR := doJSON(t, api, http.MethodPost, "/projects", map[string]any{
"team_id": team.ID,
"name": "Roadmap",
}, cookie)
if projectRR.Code != http.StatusCreated {
t.Fatalf("project status = %d body = %s", projectRR.Code, projectRR.Body.String())
}
var project Project
if err := json.Unmarshal(projectRR.Body.Bytes(), &project); err != nil {
t.Fatalf("project decode error = %v", err)
}
drawingRR := doJSON(t, api, http.MethodPost, "/drawings", map[string]any{
"team_id": team.ID,
"title": "Linked map",
}, cookie)
if drawingRR.Code != http.StatusCreated {
t.Fatalf("drawing status = %d body = %s", drawingRR.Code, drawingRR.Body.String())
}
var drawing Drawing
if err := json.Unmarshal(drawingRR.Body.Bytes(), &drawing); err != nil {
t.Fatalf("drawing decode error = %v", err)
}
assetRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/assets", map[string]any{
"kind": "attachment",
"mime_type": "image/png",
"size": 2048,
"width": 800,
"height": 600,
}, cookie)
if assetRR.Code != http.StatusCreated {
t.Fatalf("asset status = %d body = %s", assetRR.Code, assetRR.Body.String())
}
linkRR := doJSON(t, api, http.MethodPost, "/drawings/"+drawing.ID+"/links", map[string]any{
"target_resource_type": "project",
"target_resource_id": project.ID,
"label": "Roadmap project",
}, cookie)
if linkRR.Code != http.StatusCreated {
t.Fatalf("link status = %d body = %s", linkRR.Code, linkRR.Body.String())
}
}
+53
View File
@@ -0,0 +1,53 @@
package workspace
import (
"context"
)
func (s *Store) WorkspaceStats(ctx context.Context, userID, teamID string) (*WorkspaceStats, error) {
if teamID != "" {
if ok, err := s.UserCanAccessTeam(ctx, userID, teamID); err != nil || !ok {
return nil, ErrForbidden
}
return s.workspaceStatsForWhere(ctx, `m.user_id = ? AND t.id = ?`, userID, teamID)
}
return s.workspaceStatsForWhere(ctx, `m.user_id = ?`, userID)
}
func (s *Store) workspaceStatsForWhere(ctx context.Context, membershipWhere string, args ...any) (*WorkspaceStats, error) {
stats := &WorkspaceStats{}
teamWhere := `id IN (SELECT t.id FROM workspace_teams t JOIN workspace_team_memberships m ON m.team_id = t.id WHERE ` + membershipWhere + `)`
if err := s.count(ctx, &stats.Teams, `SELECT COUNT(*) FROM workspace_teams WHERE `+teamWhere, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Members, `SELECT COUNT(*) FROM workspace_team_memberships WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Projects, `SELECT COUNT(*) FROM workspace_projects WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Folders, `SELECT COUNT(*) FROM workspace_folders WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Drawings, `SELECT COUNT(*) FROM workspace_drawings WHERE deleted_at IS NULL AND team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Templates, `SELECT COUNT(*) FROM workspace_templates WHERE scope = 'system' OR team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`)`, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Revisions, `SELECT COUNT(*) FROM workspace_drawing_revisions WHERE drawing_id IN (SELECT id FROM workspace_drawings WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`))`, args...); err != nil {
return nil, err
}
if err := s.count(ctx, &stats.Assets, `SELECT COUNT(*) FROM workspace_drawing_assets WHERE drawing_id IN (SELECT id FROM workspace_drawings WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`))`, args...); err != nil {
return nil, err
}
err := s.db.QueryRowContext(ctx, `SELECT COALESCE(SUM(snapshot_size), 0) FROM workspace_drawing_revisions WHERE drawing_id IN (SELECT id FROM workspace_drawings WHERE team_id IN (SELECT id FROM workspace_teams WHERE `+teamWhere+`))`, args...).Scan(&stats.StorageBytes)
if err != nil {
return nil, err
}
return stats, nil
}
func (s *Store) count(ctx context.Context, dest *int, query string, args ...any) error {
return s.db.QueryRowContext(ctx, query, args...).Scan(dest)
}
+1205
View File
File diff suppressed because it is too large Load Diff
+725
View File
@@ -0,0 +1,725 @@
package workspace
import (
"context"
"database/sql"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"time"
)
type CreateInviteRequest struct {
Email string `json:"email"`
Role string `json:"role"`
}
type CreatePermissionGrantRequest struct {
SubjectType string `json:"subject_type"`
SubjectID string `json:"subject_id"`
Email string `json:"email"`
Permission string `json:"permission"`
}
type CreateShareLinkRequest struct {
Permission string `json:"permission"`
ExpiresAt *time.Time `json:"expires_at"`
}
type CreateAssetRequest struct {
Kind string `json:"kind"`
MimeType string `json:"mime_type"`
Size int64 `json:"size"`
Width *int `json:"width"`
Height *int `json:"height"`
}
type CreateEmbedRequest struct {
SourceURL string `json:"source_url"`
EmbedType string `json:"embed_type"`
Title *string `json:"title"`
}
type CreateLinkRequest struct {
TargetResourceType string `json:"target_resource_type"`
TargetResourceID string `json:"target_resource_id"`
Label *string `json:"label"`
}
func (s *Store) ListTeamInvites(ctx context.Context, userID, teamID string) ([]TeamInvite, error) {
if err := s.ensureTeamPermission(ctx, userID, teamID, "invite"); err != nil {
return nil, err
}
rows, err := s.db.QueryContext(ctx, `SELECT id, team_id, email, role, invited_by, expires_at, created_at
FROM workspace_team_invites
WHERE team_id = ? AND accepted_at IS NULL AND revoked_at IS NULL AND expires_at > ?
ORDER BY created_at DESC`, teamID, time.Now().UTC())
if err != nil {
return nil, err
}
defer rows.Close()
invites := []TeamInvite{}
for rows.Next() {
var invite TeamInvite
if err := rows.Scan(&invite.ID, &invite.TeamID, &invite.Email, &invite.Role, &invite.InvitedBy, &invite.ExpiresAt, &invite.CreatedAt); err != nil {
return nil, err
}
invites = append(invites, invite)
}
return invites, rows.Err()
}
func (s *Store) CreateTeamInvite(ctx context.Context, userID, teamID string, req CreateInviteRequest) (*TeamInvite, string, error) {
if err := s.ensureTeamPermission(ctx, userID, teamID, "invite"); err != nil {
return nil, "", err
}
email, err := normalizeEmail(req.Email)
if err != nil {
return nil, "", err
}
role := req.Role
if role == "" {
role = "viewer"
}
if !validTeamRole(role) || role == "owner" {
return nil, "", fmt.Errorf("invalid invite role")
}
token, err := randomToken()
if err != nil {
return nil, "", err
}
now := time.Now().UTC()
invite := &TeamInvite{
ID: newID(),
TeamID: teamID,
Email: email,
Role: role,
InvitedBy: userID,
ExpiresAt: now.Add(14 * 24 * time.Hour),
CreatedAt: now,
}
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_team_invites
(id, team_id, email, role, token_hash, invited_by, expires_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
invite.ID, invite.TeamID, invite.Email, invite.Role, hashToken(token), invite.InvitedBy, invite.ExpiresAt, invite.CreatedAt,
)
if err != nil {
return nil, "", err
}
_ = s.insertActivity(ctx, &userID, &teamID, "team", teamID, "member_invited", map[string]any{"email": email, "role": role})
return invite, token, nil
}
func (s *Store) AcceptInvite(ctx context.Context, userID, token string) (*TeamMembership, error) {
if strings.TrimSpace(token) == "" {
return nil, fmt.Errorf("invite token is required")
}
user, err := s.userByID(ctx, userID)
if err != nil {
return nil, err
}
var invite TeamInvite
err = s.db.QueryRowContext(ctx, `SELECT id, team_id, email, role, invited_by, expires_at, created_at
FROM workspace_team_invites
WHERE token_hash = ? AND accepted_at IS NULL AND revoked_at IS NULL AND expires_at > ?`,
hashToken(token), time.Now().UTC(),
).Scan(&invite.ID, &invite.TeamID, &invite.Email, &invite.Role, &invite.InvitedBy, &invite.ExpiresAt, &invite.CreatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if !strings.EqualFold(user.Email, invite.Email) {
return nil, ErrForbidden
}
now := time.Now().UTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx, `INSERT INTO workspace_team_memberships (id, team_id, user_id, role, joined_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(team_id, user_id) DO UPDATE SET role = excluded.role`, newID(), invite.TeamID, userID, invite.Role, now)
if err != nil {
return nil, err
}
_, err = tx.ExecContext(ctx, `UPDATE workspace_team_invites SET accepted_at = ? WHERE id = ?`, now, invite.ID)
if err != nil {
return nil, err
}
if err := insertActivityTx(ctx, tx, &userID, &invite.TeamID, "team", invite.TeamID, "member_joined", map[string]any{"role": invite.Role}); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return s.teamMembership(ctx, invite.TeamID, userID)
}
func (s *Store) ListPermissionGrants(ctx context.Context, userID, resourceType, resourceID string) ([]PermissionGrant, error) {
if resourceType == "drawing" {
if err := s.ensureDrawingAccess(ctx, userID, resourceID, "manage"); err != nil {
return nil, err
}
}
rows, err := s.db.QueryContext(ctx, `SELECT id, resource_type, resource_id, subject_type, subject_id, permission, inherited_from, created_at
FROM workspace_permission_grants
WHERE resource_type = ? AND resource_id = ?
ORDER BY created_at DESC`, resourceType, resourceID)
if err != nil {
return nil, err
}
defer rows.Close()
grants := []PermissionGrant{}
for rows.Next() {
var grant PermissionGrant
if err := rows.Scan(&grant.ID, &grant.ResourceType, &grant.ResourceID, &grant.SubjectType, &grant.SubjectID, &grant.Permission, &grant.InheritedFrom, &grant.CreatedAt); err != nil {
return nil, err
}
grants = append(grants, grant)
}
return grants, rows.Err()
}
func (s *Store) CreateDrawingPermissionGrant(ctx context.Context, userID, drawingID string, req CreatePermissionGrantRequest) (*PermissionGrant, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "share"); err != nil {
return nil, err
}
if !validPermission(req.Permission) {
return nil, fmt.Errorf("invalid permission")
}
subjectType := req.SubjectType
if subjectType == "" {
subjectType = "user"
}
subjectID := strings.TrimSpace(req.SubjectID)
if subjectType == "user" && subjectID == "" {
user, err := s.userByEmail(ctx, req.Email)
if err != nil {
return nil, err
}
subjectID = user.ID
}
if subjectType != "user" && subjectType != "team" {
return nil, fmt.Errorf("invalid subject type")
}
if subjectID == "" {
return nil, fmt.Errorf("subject id is required")
}
now := time.Now().UTC()
grant := &PermissionGrant{
ID: newID(),
ResourceType: "drawing",
ResourceID: drawingID,
SubjectType: subjectType,
SubjectID: subjectID,
Permission: req.Permission,
CreatedAt: now,
}
_, err := s.db.ExecContext(ctx, `INSERT INTO workspace_permission_grants
(id, resource_type, resource_id, subject_type, subject_id, permission, inherited_from, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(resource_type, resource_id, subject_type, subject_id, permission) DO UPDATE SET created_at = excluded.created_at`,
grant.ID, grant.ResourceType, grant.ResourceID, grant.SubjectType, grant.SubjectID, grant.Permission, grant.InheritedFrom, grant.CreatedAt,
)
if err != nil {
return nil, err
}
_ = s.insertActivity(ctx, &userID, nil, "drawing", drawingID, "permission_changed", map[string]any{"permission": grant.Permission, "subject_type": grant.SubjectType})
return grant, nil
}
func (s *Store) ListShareLinks(ctx context.Context, userID, resourceType, resourceID string) ([]ShareLink, error) {
if resourceType == "drawing" {
if err := s.ensureDrawingAccess(ctx, userID, resourceID, "share"); err != nil {
return nil, err
}
}
rows, err := s.db.QueryContext(ctx, `SELECT id, resource_type, resource_id, token_hash, permission, expires_at, password_hash, created_by, revoked_at, created_at
FROM workspace_share_links
WHERE resource_type = ? AND resource_id = ? AND revoked_at IS NULL
ORDER BY created_at DESC`, resourceType, resourceID)
if err != nil {
return nil, err
}
defer rows.Close()
links := []ShareLink{}
for rows.Next() {
link, err := scanShareLink(rows)
if err != nil {
return nil, err
}
link.TokenHash = ""
links = append(links, *link)
}
return links, rows.Err()
}
func (s *Store) CreateDrawingShareLink(ctx context.Context, userID, drawingID string, req CreateShareLinkRequest) (*ShareLink, string, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "share"); err != nil {
return nil, "", err
}
permission := req.Permission
if permission == "" {
permission = "view"
}
if permission != "view" && permission != "comment" && permission != "edit" {
return nil, "", fmt.Errorf("invalid share permission")
}
token, err := randomToken()
if err != nil {
return nil, "", err
}
now := time.Now().UTC()
link := &ShareLink{
ID: newID(),
ResourceType: "drawing",
ResourceID: drawingID,
TokenHash: hashToken(token),
Permission: permission,
ExpiresAt: req.ExpiresAt,
CreatedBy: userID,
CreatedAt: now,
}
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_share_links
(id, resource_type, resource_id, token_hash, permission, expires_at, password_hash, created_by, revoked_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
link.ID, link.ResourceType, link.ResourceID, link.TokenHash, link.Permission, link.ExpiresAt, link.PasswordHash, link.CreatedBy, link.RevokedAt, link.CreatedAt,
)
if err != nil {
return nil, "", err
}
link.TokenHash = ""
_ = s.insertActivity(ctx, &userID, nil, "drawing", drawingID, "drawing_shared", map[string]any{"permission": permission})
return link, token, nil
}
func (s *Store) SharedResourceByToken(ctx context.Context, token string) (map[string]any, error) {
if strings.TrimSpace(token) == "" {
return nil, ErrNotFound
}
row := s.db.QueryRowContext(ctx, `SELECT id, resource_type, resource_id, token_hash, permission, expires_at, password_hash, created_by, revoked_at, created_at
FROM workspace_share_links
WHERE token_hash = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > ?)`,
hashToken(token), time.Now().UTC(),
)
link, err := scanShareLink(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
link.TokenHash = ""
payload := map[string]any{"share_link": link}
if link.ResourceType == "drawing" {
drawing, err := s.drawingByIDNoAuth(ctx, link.ResourceID)
if err != nil {
return nil, err
}
payload["drawing"] = drawing
return payload, nil
}
return payload, nil
}
func (s *Store) ListDrawingAssets(ctx context.Context, userID, drawingID string) ([]DrawingAsset, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "view"); err != nil {
return nil, err
}
rows, err := s.db.QueryContext(ctx, `SELECT id, drawing_id, kind, path, mime_type, size, width, height, uploaded_by, created_at
FROM workspace_drawing_assets WHERE drawing_id = ? ORDER BY created_at DESC`, drawingID)
if err != nil {
return nil, err
}
defer rows.Close()
return scanAssets(rows)
}
func (s *Store) CreateDrawingAsset(ctx context.Context, userID, drawingID string, req CreateAssetRequest) (*DrawingAsset, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "edit"); err != nil {
return nil, err
}
if !validAssetKind(req.Kind) {
return nil, fmt.Errorf("invalid asset kind")
}
if !validAssetMIME(req.MimeType) {
return nil, fmt.Errorf("invalid asset mime type")
}
if req.Size <= 0 || req.Size > 25<<20 {
return nil, fmt.Errorf("asset size must be between 1 byte and 25 MiB")
}
drawing, err := s.drawingByIDNoAuth(ctx, drawingID)
if err != nil {
return nil, err
}
now := time.Now().UTC()
asset := &DrawingAsset{
ID: newID(),
DrawingID: drawingID,
Kind: req.Kind,
MimeType: req.MimeType,
Size: req.Size,
Width: req.Width,
Height: req.Height,
UploadedBy: userID,
CreatedAt: now,
}
asset.Path = fmt.Sprintf("/data/teams/%s/drawings/%s/assets/%s", drawing.TeamID, drawingID, asset.ID)
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_drawing_assets
(id, drawing_id, kind, path, mime_type, size, width, height, uploaded_by, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
asset.ID, asset.DrawingID, asset.Kind, asset.Path, asset.MimeType, asset.Size, asset.Width, asset.Height, asset.UploadedBy, asset.CreatedAt,
)
if err != nil {
return nil, err
}
return asset, nil
}
func (s *Store) ListEmbeds(ctx context.Context, userID, drawingID string) ([]Embed, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "view"); err != nil {
return nil, err
}
rows, err := s.db.QueryContext(ctx, `SELECT id, drawing_id, source_url, canonical_url, provider, embed_type, title, preview_asset_id, safe_embed_html, created_by, created_at
FROM workspace_embeds WHERE drawing_id = ? ORDER BY created_at DESC`, drawingID)
if err != nil {
return nil, err
}
defer rows.Close()
embeds := []Embed{}
for rows.Next() {
var embed Embed
if err := rows.Scan(&embed.ID, &embed.DrawingID, &embed.SourceURL, &embed.CanonicalURL, &embed.Provider, &embed.EmbedType, &embed.Title, &embed.PreviewAssetID, &embed.SafeEmbedHTML, &embed.CreatedBy, &embed.CreatedAt); err != nil {
return nil, err
}
embeds = append(embeds, embed)
}
return embeds, rows.Err()
}
func (s *Store) CreateEmbed(ctx context.Context, userID, drawingID string, req CreateEmbedRequest) (*Embed, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "edit"); err != nil {
return nil, err
}
canonical, provider, err := validateEmbedURL(req.SourceURL)
if err != nil {
return nil, err
}
embedType := req.EmbedType
if embedType == "" {
embedType = "link"
}
if embedType != "link" && embedType != "iframe" && embedType != "provider" {
return nil, fmt.Errorf("invalid embed type")
}
now := time.Now().UTC()
embed := &Embed{
ID: newID(),
DrawingID: drawingID,
SourceURL: canonical,
CanonicalURL: canonical,
Provider: provider,
EmbedType: embedType,
Title: req.Title,
CreatedBy: userID,
CreatedAt: now,
}
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_embeds
(id, drawing_id, source_url, canonical_url, provider, embed_type, title, preview_asset_id, safe_embed_html, created_by, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
embed.ID, embed.DrawingID, embed.SourceURL, embed.CanonicalURL, embed.Provider, embed.EmbedType, embed.Title, embed.PreviewAssetID, embed.SafeEmbedHTML, embed.CreatedBy, embed.CreatedAt,
)
if err != nil {
return nil, err
}
_ = s.insertActivity(ctx, &userID, nil, "drawing", drawingID, "embed_created", map[string]any{"provider": provider, "embed_type": embedType})
return embed, nil
}
func (s *Store) ListLinkReferences(ctx context.Context, userID, resourceType, resourceID string) ([]LinkReference, error) {
if resourceType == "drawing" {
if err := s.ensureDrawingAccess(ctx, userID, resourceID, "view"); err != nil {
return nil, err
}
}
rows, err := s.db.QueryContext(ctx, `SELECT id, source_resource_type, source_resource_id, target_resource_type, target_resource_id, label, created_by, created_at
FROM workspace_link_references
WHERE source_resource_type = ? AND source_resource_id = ?
ORDER BY created_at DESC`, resourceType, resourceID)
if err != nil {
return nil, err
}
defer rows.Close()
links := []LinkReference{}
for rows.Next() {
var link LinkReference
if err := rows.Scan(&link.ID, &link.SourceResourceType, &link.SourceResourceID, &link.TargetResourceType, &link.TargetResourceID, &link.Label, &link.CreatedBy, &link.CreatedAt); err != nil {
return nil, err
}
links = append(links, link)
}
return links, rows.Err()
}
func (s *Store) CreateDrawingLinkReference(ctx context.Context, userID, drawingID string, req CreateLinkRequest) (*LinkReference, error) {
if err := s.ensureDrawingAccess(ctx, userID, drawingID, "edit"); err != nil {
return nil, err
}
drawing, err := s.drawingByIDNoAuth(ctx, drawingID)
if err != nil {
return nil, err
}
if err := s.ensureTargetInTeam(ctx, drawing.TeamID, req.TargetResourceType, req.TargetResourceID); err != nil {
return nil, err
}
now := time.Now().UTC()
link := &LinkReference{
ID: newID(),
SourceResourceType: "drawing",
SourceResourceID: drawingID,
TargetResourceType: req.TargetResourceType,
TargetResourceID: req.TargetResourceID,
Label: req.Label,
CreatedBy: userID,
CreatedAt: now,
}
_, err = s.db.ExecContext(ctx, `INSERT INTO workspace_link_references
(id, source_resource_type, source_resource_id, target_resource_type, target_resource_id, label, created_by, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
link.ID, link.SourceResourceType, link.SourceResourceID, link.TargetResourceType, link.TargetResourceID, link.Label, link.CreatedBy, link.CreatedAt,
)
if err != nil {
return nil, err
}
return link, nil
}
func (s *Store) ensureTeamPermission(ctx context.Context, userID, teamID, permission string) error {
role, err := s.teamRole(ctx, userID, teamID)
if errors.Is(err, sql.ErrNoRows) {
return ErrForbidden
}
if err != nil {
return err
}
if roleAllows(role, permission) {
return nil
}
return ErrForbidden
}
func (s *Store) teamRole(ctx context.Context, userID, teamID string) (string, error) {
var role string
err := s.db.QueryRowContext(ctx, `SELECT role FROM workspace_team_memberships WHERE user_id = ? AND team_id = ?`, userID, teamID).Scan(&role)
return role, err
}
func (s *Store) userByID(ctx context.Context, userID string) (*User, error) {
var user User
err := s.db.QueryRowContext(ctx, `SELECT id, name, username, email, avatar_url, locale, timezone, created_at, updated_at FROM workspace_users WHERE id = ?`, userID).
Scan(&user.ID, &user.Name, &user.Username, &user.Email, &user.AvatarURL, &user.Locale, &user.Timezone, &user.CreatedAt, &user.UpdatedAt)
return &user, err
}
func (s *Store) userByEmail(ctx context.Context, email string) (*User, error) {
email, err := normalizeEmail(email)
if err != nil {
return nil, err
}
var user User
err = s.db.QueryRowContext(ctx, `SELECT id, name, username, email, avatar_url, locale, timezone, created_at, updated_at FROM workspace_users WHERE email = ?`, email).
Scan(&user.ID, &user.Name, &user.Username, &user.Email, &user.AvatarURL, &user.Locale, &user.Timezone, &user.CreatedAt, &user.UpdatedAt)
return &user, err
}
func (s *Store) teamMembership(ctx context.Context, teamID, userID string) (*TeamMembership, error) {
var member TeamMembership
err := s.db.QueryRowContext(ctx, `SELECT id, team_id, user_id, role, joined_at FROM workspace_team_memberships WHERE team_id = ? AND user_id = ?`, teamID, userID).
Scan(&member.ID, &member.TeamID, &member.UserID, &member.Role, &member.JoinedAt)
if err != nil {
return nil, err
}
user, err := s.userByID(ctx, userID)
if err != nil {
return nil, err
}
member.User = user
return &member, nil
}
func (s *Store) drawingByIDNoAuth(ctx context.Context, drawingID string) (*Drawing, error) {
row := s.db.QueryRowContext(ctx, `SELECT d.id, d.team_id, d.folder_id, d.project_id, d.slug, d.title, d.description,
d.owner_user_id, d.latest_revision_id, d.visibility, d.is_archived, d.thumbnail_asset_id, d.created_at, d.updated_at, d.deleted_at,
u.id, u.name, u.username, u.email, u.avatar_url, u.locale, u.timezone, u.created_at, u.updated_at
FROM workspace_drawings d
JOIN workspace_users u ON u.id = d.owner_user_id
WHERE d.id = ? AND d.deleted_at IS NULL`, drawingID)
return scanDrawing(row)
}
func (s *Store) ensureTargetInTeam(ctx context.Context, teamID, resourceType, resourceID string) error {
var found int
var query string
switch resourceType {
case "drawing":
query = `SELECT 1 FROM workspace_drawings WHERE id = ? AND team_id = ? AND deleted_at IS NULL`
case "folder":
query = `SELECT 1 FROM workspace_folders WHERE id = ? AND team_id = ?`
case "project":
query = `SELECT 1 FROM workspace_projects WHERE id = ? AND team_id = ?`
case "embed":
query = `SELECT 1 FROM workspace_embeds e JOIN workspace_drawings d ON d.id = e.drawing_id WHERE e.id = ? AND d.team_id = ?`
default:
return fmt.Errorf("invalid target resource type")
}
err := s.db.QueryRowContext(ctx, query, resourceID, teamID).Scan(&found)
if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound
}
return err
}
func (s *Store) grantAllows(ctx context.Context, resourceType, resourceID, userID, teamID, required string) (bool, error) {
rows, err := s.db.QueryContext(ctx, `SELECT permission FROM workspace_permission_grants
WHERE resource_type = ? AND resource_id = ? AND (
(subject_type = 'user' AND subject_id = ?) OR
(subject_type = 'team' AND subject_id = ?)
)`, resourceType, resourceID, userID, teamID)
if err != nil {
return false, err
}
defer rows.Close()
for rows.Next() {
var permission string
if err := rows.Scan(&permission); err != nil {
return false, err
}
if permissionAllows(permission, required) {
return true, nil
}
}
return false, rows.Err()
}
func scanShareLink(scanner interface{ Scan(dest ...any) error }) (*ShareLink, error) {
var link ShareLink
err := scanner.Scan(&link.ID, &link.ResourceType, &link.ResourceID, &link.TokenHash, &link.Permission, &link.ExpiresAt, &link.PasswordHash, &link.CreatedBy, &link.RevokedAt, &link.CreatedAt)
if err != nil {
return nil, err
}
return &link, nil
}
func scanAssets(rows *sql.Rows) ([]DrawingAsset, error) {
assets := []DrawingAsset{}
for rows.Next() {
var asset DrawingAsset
if err := rows.Scan(&asset.ID, &asset.DrawingID, &asset.Kind, &asset.Path, &asset.MimeType, &asset.Size, &asset.Width, &asset.Height, &asset.UploadedBy, &asset.CreatedAt); err != nil {
return nil, err
}
assets = append(assets, asset)
}
return assets, rows.Err()
}
func roleAllows(role, permission string) bool {
switch role {
case "owner", "admin":
return true
case "editor":
return permission == "view" || permission == "comment" || permission == "edit"
case "viewer":
return permission == "view"
default:
return false
}
}
func permissionAllows(grant, required string) bool {
if grant == required {
return true
}
switch grant {
case "manage":
return true
case "edit":
return required == "view" || required == "comment" || required == "edit"
case "comment":
return required == "view" || required == "comment"
case "share":
return required == "view" || required == "share"
case "invite":
return required == "view" || required == "invite"
case "view":
return required == "view"
default:
return false
}
}
func validPermission(permission string) bool {
switch permission {
case "view", "comment", "edit", "manage", "share", "invite":
return true
default:
return false
}
}
func validTeamRole(role string) bool {
switch role {
case "owner", "admin", "editor", "viewer":
return true
default:
return false
}
}
func validAssetKind(kind string) bool {
switch kind {
case "image", "export", "attachment", "thumbnail":
return true
default:
return false
}
}
func validAssetMIME(mimeType string) bool {
switch mimeType {
case "image/png", "image/jpeg", "image/webp", "image/gif", "application/pdf", "application/json", "text/plain":
return true
default:
return false
}
}
func validateEmbedURL(raw string) (string, string, error) {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return "", "", fmt.Errorf("invalid URL")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", "", fmt.Errorf("embed URL must use http or https")
}
if parsed.User != nil {
return "", "", fmt.Errorf("embed URL must not include credentials")
}
host := strings.ToLower(parsed.Hostname())
if host == "" || host == "localhost" || strings.HasSuffix(host, ".localhost") {
return "", "", fmt.Errorf("embed URL host is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
addr, ok := netip.AddrFromSlice(ip)
if !ok || !addr.IsGlobalUnicast() || addr.IsPrivate() || addr.IsLoopback() || addr.IsLinkLocalUnicast() {
return "", "", fmt.Errorf("embed URL host is not allowed")
}
}
parsed.Fragment = ""
return parsed.String(), host, nil
}