package main import ( "embed" _ "embed" "excalidraw-complete/handlers/api/firebase" "excalidraw-complete/handlers/api/kv" "excalidraw-complete/handlers/auth" authMiddleware "excalidraw-complete/middleware" "excalidraw-complete/stores" "excalidraw-complete/workspace" "flag" "fmt" "io" "io/fs" "net/http" "os" "os/signal" "strings" "syscall" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/joho/godotenv" "github.com/sirupsen/logrus" "github.com/zishang520/engine.io/v2/types" "github.com/zishang520/engine.io/v2/utils" socketio "github.com/zishang520/socket.io/v2/socket" ) type ( UserToFollow struct { SocketId string `json:"socketId"` Username string `json:"username"` } OnUserFollowedPayload struct { UserToFollow UserToFollow `json:"userToFollow"` Action string `json:"action"` // "FOLLOW" | "UNFOLLOW" } ) //go:embed all:frontend var assets embed.FS func handleUI() http.HandlerFunc { sub, err := fs.Sub(assets, "frontend") if err != nil { panic(err) } return func(w http.ResponseWriter, r *http.Request) { path := r.URL.Path // If the path is empty, it means it's the root, so serve index.html if path == "/" || path == "" { path = "/index.html" } // Check if the file exists in the embedded filesystem. f, err := sub.Open(strings.TrimPrefix(path, "/")) if err != nil { // If the file does not exist, and it's not a request for a static asset (like .js, .css), // then it's likely a client-side route. In that case, we should serve the index.html // and let the client-side router handle it. if os.IsNotExist(err) && !strings.Contains(path, ".") { path = "/index.html" f, err = sub.Open("index.html") } else { // It's a genuine 404 for a missing asset. http.NotFound(w, r) return } } if err != nil { // If we still have an error, something is wrong. http.Error(w, "File not found", http.StatusNotFound) return } defer f.Close() fileContent, err := io.ReadAll(f) if err != nil { http.Error(w, "Error reading file", http.StatusInternalServerError) return } // 替换为请求的url对应的domain,使其在反向代理或不同域名下也能正常工作。 backendHost := os.Getenv("EXCALIDRAW_BACKEND_HOST") if backendHost == "" { backendHost = r.Host } modifiedContent := strings.ReplaceAll(string(fileContent), "firestore.googleapis.com", backendHost) modifiedContent = strings.ReplaceAll(modifiedContent, "ssl=!0", "ssl=0") modifiedContent = strings.ReplaceAll(modifiedContent, "ssl:!0", "ssl:0") // Set the correct Content-Type based on the file extension contentType := http.DetectContentType([]byte(modifiedContent)) switch { case strings.HasSuffix(path, ".js"): contentType = "application/javascript" case strings.HasSuffix(path, ".html"): contentType = "text/html" case strings.HasSuffix(path, ".css"): contentType = "text/css" case strings.HasSuffix(path, ".wasm"): contentType = "application/wasm" case strings.HasSuffix(path, ".tsx"): contentType = "text/typescript" case strings.HasSuffix(path, ".png"): contentType = "image/png" case strings.HasSuffix(path, ".woff2"): contentType = "font/woff2" } // Serve the modified content w.Header().Set("Content-Type", contentType) _, err = w.Write([]byte(modifiedContent)) if err != nil { http.Error(w, "Error serving file", http.StatusInternalServerError) return } } } func setupRouter(store stores.Store, workspaceAPI *workspace.API) *chi.Mux { r := chi.NewRouter() r.Use(middleware.Logger) r.Use(securityHeaders) r.Use(cors.Handler(cors.Options{ AllowedOrigins: allowedOrigins(), AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "Content-Length", "X-CSRF-Token", "Token", "session", "Origin", "Host", "Connection", "Accept-Encoding", "Accept-Language", "X-Requested-With"}, AllowCredentials: true, MaxAge: 300, // Maximum value not ignored by any of major browsers })) r.Route("/v1/projects/{project_id}/databases/{database_id}", func(r chi.Router) { r.Post("/documents:commit", firebase.HandleBatchCommit()) r.Post("/documents:batchGet", firebase.HandleBatchGet()) }) if workspaceAPI != nil { r.Mount("/api", workspaceAPI.Routes()) } r.Route("/api/v2", func(r chi.Router) { // Route for canvases, protected by JWT auth r.Group(func(r chi.Router) { r.Use(authMiddleware.AuthJWT) r.Route("/kv", func(r chi.Router) { r.Get("/", kv.HandleListCanvases(store)) r.Route("/{key}", func(r chi.Router) { r.Get("/", kv.HandleGetCanvas(store)) r.Put("/", kv.HandleSaveCanvas(store)) r.Delete("/", kv.HandleDeleteCanvas(store)) }) }) }) // Legacy anonymous document routes removed per project.md Phase 1. // All persistence now goes through the workspace API with auth. }) r.Route("/auth", func(r chi.Router) { r.Get("/login", auth.HandleLogin) r.Get("/callback", auth.HandleCallback) }) return r } func allowedOrigins() []string { raw := strings.TrimSpace(os.Getenv("ALLOWED_ORIGINS")) if raw == "" { return []string{ "http://localhost:3000", "http://localhost:3002", "http://localhost:5173", "http://127.0.0.1:3000", "http://127.0.0.1:3002", "http://127.0.0.1:5173", } } parts := strings.Split(raw, ",") origins := make([]string, 0, len(parts)) for _, part := range parts { origin := strings.TrimSpace(part) if origin != "" { origins = append(origins, origin) } } return origins } func securityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") w.Header().Set("Content-Security-Policy", "default-src 'self'; connect-src 'self' https://libraries.excalidraw.com ws: wss:; img-src 'self' data: blob: https://libraries.excalidraw.com; font-src 'self' data: https://fonts.gstatic.com https://unpkg.com; script-src 'self' 'wasm-unsafe-eval'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; frame-ancestors 'self'; base-uri 'self'; form-action 'self'") if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } next.ServeHTTP(w, r) }) } func setupSocketIO() *socketio.Server { opts := socketio.DefaultServerOptions() opts.SetMaxHttpBufferSize(5000000) opts.SetPath("/socket.io") opts.SetAllowEIO3(true) // Mirror HTTP CORS origin policy — wildcard + credentials is rejected by browsers. socketOrigins := strings.Join(allowedOrigins(), ",") if socketOrigins == "" { socketOrigins = "http://localhost:3000,http://localhost:3002,http://localhost:5173" } opts.SetCors(&types.Cors{ Origin: socketOrigins, Credentials: true, }) ioo := socketio.NewServer(nil, opts) ioo.On("connection", func(clients ...any) { socket := clients[0].(*socketio.Socket) me := socket.Id() myRoom := socketio.Room(me) ioo.To(myRoom).Emit("init-room") utils.Log().Printf("init room %v\n", myRoom) socket.On("join-room", func(datas ...any) { room := socketio.Room(datas[0].(string)) utils.Log().Printf("Socket %v has joined %v\n", me, room) socket.Join(room) ioo.In(room).FetchSockets()(func(usersInRoom []*socketio.RemoteSocket, _ error) { if len(usersInRoom) <= 1 { ioo.To(myRoom).Emit("first-in-room") } else { utils.Log().Printf("emit new user %v in room %v\n", me, room) socket.Broadcast().To(room).Emit("new-user", me) } // Inform all clients by new users. newRoomUsers := []socketio.SocketId{} for _, user := range usersInRoom { newRoomUsers = append(newRoomUsers, user.Id()) } utils.Log().Printf("room %v has users %v\n", room, newRoomUsers) ioo.In(room).Emit( "room-user-change", newRoomUsers, ) }) }) socket.On("server-broadcast", func(datas ...any) { roomID := datas[0].(string) utils.Log().Printf(" user %v sends update to room %v\n", me, roomID) socket.Broadcast().To(socketio.Room(roomID)).Emit("client-broadcast", datas[1], datas[2]) }) socket.On("server-volatile-broadcast", func(datas ...any) { roomID := datas[0].(string) utils.Log().Printf(" user %v sends volatile update to room %v\n", me, roomID) socket.Volatile().Broadcast().To(socketio.Room(roomID)).Emit("client-broadcast", datas[1], datas[2]) }) socket.On("user-follow", func(datas ...any) { // TODO() }) socket.On("disconnecting", func(datas ...any) { for _, currentRoom := range socket.Rooms().Keys() { ioo.In(currentRoom).FetchSockets()(func(usersInRoom []*socketio.RemoteSocket, _ error) { otherClients := []socketio.SocketId{} utils.Log().Printf("disconnecting %v from room %v\n", me, currentRoom) for _, userInRoom := range usersInRoom { if userInRoom.Id() != me { otherClients = append(otherClients, userInRoom.Id()) } } if len(otherClients) > 0 { utils.Log().Printf("leaving user, room %v has users %v\n", currentRoom, otherClients) ioo.In(currentRoom).Emit( "room-user-change", otherClients, ) } }) } }) socket.On("disconnect", func(datas ...any) { socket.RemoveAllListeners("") socket.Disconnect(true) }) }) return ioo } func waitForShutdown(ioo *socketio.Server) { exit := make(chan struct{}) SignalC := make(chan os.Signal, 1) signal.Notify(SignalC, os.Interrupt, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) go func() { for s := range SignalC { switch s { case os.Interrupt, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT: close(exit) return } } }() <-exit ioo.Close(nil) os.Exit(0) fmt.Println("Shutting down...") // TODO(patwie): Close other resources os.Exit(0) } func main() { // Load .env file if err := godotenv.Load(); err != nil { logrus.Info("No .env file found") } listenAddress := flag.String("listen", ":3002", "The address to listen on.") logLevel := flag.String("loglevel", "info", "The log level (debug, info, warn, error).") flag.Parse() level, err := logrus.ParseLevel(*logLevel) if err != nil { logrus.Fatalf("Invalid log level: %v", err) } logrus.SetLevel(level) logrus.SetFormatter(&logrus.TextFormatter{ FullTimestamp: true, }) // Validate critical environment variables jwtSecret := os.Getenv("JWT_SECRET") if jwtSecret == "" || jwtSecret == "YOUR_SUPER_SECRET_RANDOM_STRING" || jwtSecret == "YOUR_SUPER_SECRET_RANDOM_STRING_MIN_32_CHARS" { logrus.Fatal("JWT_SECRET must be set to a secure random string (min 32 chars). Generate with: openssl rand -base64 32") } storageType := os.Getenv("STORAGE_TYPE") validStorageTypes := []string{"postgres", "memory", "filesystem", "kv", "s3", ""} valid := false for _, s := range validStorageTypes { if storageType == s { valid = true break } } if !valid { logrus.Fatalf("STORAGE_TYPE must be one of: postgres, memory, filesystem, kv, s3. Got: %s", storageType) } if storageType == "" || storageType == "postgres" { if os.Getenv("DATABASE_URL") == "" { logrus.Fatal("DATABASE_URL must be set for PostgreSQL storage") } } // Warn about incomplete OAuth/OIDC configuration hasGitHubClient := os.Getenv("GITHUB_CLIENT_ID") != "" && os.Getenv("GITHUB_CLIENT_SECRET") != "" hasOIDC := os.Getenv("OIDC_ISSUER_URL") != "" && os.Getenv("OIDC_CLIENT_ID") != "" && os.Getenv("OIDC_CLIENT_SECRET") != "" if os.Getenv("GITHUB_CLIENT_ID") != "" && !hasGitHubClient { logrus.Warn("GITHUB_CLIENT_ID is set but GITHUB_CLIENT_SECRET is missing — GitHub OAuth will not work") } if os.Getenv("OIDC_ISSUER_URL") != "" && !hasOIDC { logrus.Warn("OIDC configuration is incomplete — OIDC SSO will not work") } if !hasGitHubClient && !hasOIDC { logrus.Info("No external auth provider configured. Only password authentication is available.") } auth.InitAuth() store := stores.GetStore() workspaceStore, err := workspace.NewStore(os.Getenv("DATABASE_URL")) if err != nil { logrus.WithError(err).Fatal("failed to initialize workspace backend") } defer workspaceStore.Close() auth.SetWorkspaceStore(workspaceStore) workspaceAPI := workspace.NewAPI(workspaceStore) r := setupRouter(store, workspaceAPI) ioo := setupSocketIO() r.Mount("/socket.io/", ioo.ServeHandler(nil)) r.NotFound(handleUI()) logrus.WithField("addr", *listenAddress).Info("starting server") go func() { if err := http.ListenAndServe(*listenAddress, r); err != nil { logrus.WithField("event", "start server").Fatal(err) } }() logrus.Debug("Server is running in the background") waitForShutdown(ioo) }