mirror of
https://github.com/Dvorinka/beszel.git
synced 2026-06-03 21:02:56 +00:00
Initial commit: Beszel fork with Domain Locker integration
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/henrygd/beszel/internal/hub/expirymap"
|
||||
"github.com/henrygd/beszel/internal/hub/ws"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/lxzan/gws"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// agentConnectRequest holds information related to an agent's connection attempt.
|
||||
type agentConnectRequest struct {
|
||||
hub *Hub
|
||||
req *http.Request
|
||||
res http.ResponseWriter
|
||||
token string
|
||||
agentSemVer semver.Version
|
||||
// isUniversalToken is true if the token is a universal token.
|
||||
isUniversalToken bool
|
||||
// userId is the user ID associated with the universal token.
|
||||
userId string
|
||||
}
|
||||
|
||||
// universalTokenMap stores active universal tokens and their associated user IDs.
|
||||
var universalTokenMap tokenMap
|
||||
|
||||
type tokenMap struct {
|
||||
store *expirymap.ExpiryMap[string]
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// getMap returns the expirymap, creating it if necessary.
|
||||
func (tm *tokenMap) GetMap() *expirymap.ExpiryMap[string] {
|
||||
tm.once.Do(func() {
|
||||
tm.store = expirymap.New[string](time.Hour)
|
||||
})
|
||||
return tm.store
|
||||
}
|
||||
|
||||
// handleAgentConnect is the HTTP handler for an agent's connection request.
|
||||
func (h *Hub) handleAgentConnect(e *core.RequestEvent) error {
|
||||
agentRequest := agentConnectRequest{req: e.Request, res: e.Response, hub: h}
|
||||
_ = agentRequest.agentConnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
// agentConnect validates agent credentials and upgrades the connection to a WebSocket.
|
||||
func (acr *agentConnectRequest) agentConnect() (err error) {
|
||||
var agentVersion string
|
||||
|
||||
acr.token, agentVersion, err = acr.validateAgentHeaders(acr.req.Header)
|
||||
if err != nil {
|
||||
return acr.sendResponseError(acr.res, http.StatusBadRequest, "")
|
||||
}
|
||||
|
||||
// Check if token is an active universal token
|
||||
acr.userId, acr.isUniversalToken = universalTokenMap.GetMap().GetOk(acr.token)
|
||||
if !acr.isUniversalToken {
|
||||
// Fallback: check for a permanent universal token stored in the DB
|
||||
if rec, err := acr.hub.FindFirstRecordByFilter("universal_tokens", "token = {:token}", dbx.Params{"token": acr.token}); err == nil {
|
||||
if userID := rec.GetString("user"); userID != "" {
|
||||
acr.userId = userID
|
||||
acr.isUniversalToken = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find matching fingerprint records for this token
|
||||
fpRecords := getFingerprintRecordsByToken(acr.token, acr.hub)
|
||||
if len(fpRecords) == 0 && !acr.isUniversalToken {
|
||||
// Invalid token - no records found and not a universal token
|
||||
return acr.sendResponseError(acr.res, http.StatusUnauthorized, "Invalid token")
|
||||
}
|
||||
|
||||
// Validate agent version
|
||||
acr.agentSemVer, err = semver.Parse(agentVersion)
|
||||
if err != nil {
|
||||
return acr.sendResponseError(acr.res, http.StatusUnauthorized, "Invalid agent version")
|
||||
}
|
||||
|
||||
// Upgrade connection to WebSocket
|
||||
conn, err := ws.GetUpgrader().Upgrade(acr.res, acr.req)
|
||||
if err != nil {
|
||||
return acr.sendResponseError(acr.res, http.StatusInternalServerError, "WebSocket upgrade failed")
|
||||
}
|
||||
|
||||
go acr.verifyWsConn(conn, fpRecords)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyWsConn verifies the WebSocket connection using the agent's fingerprint and
|
||||
// SSH key signature, then adds the system to the system manager.
|
||||
func (acr *agentConnectRequest) verifyWsConn(conn *gws.Conn, fpRecords []ws.FingerprintRecord) (err error) {
|
||||
wsConn := ws.NewWsConnection(conn, acr.agentSemVer)
|
||||
|
||||
// must set wsConn in connection store before the read loop
|
||||
conn.Session().Store("wsConn", wsConn)
|
||||
|
||||
// make sure connection is closed if there is an error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
wsConn.Close([]byte(err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
go conn.ReadLoop()
|
||||
|
||||
signer, err := acr.hub.GetSSHKey("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agentFingerprint, err := wsConn.GetFingerprint(context.Background(), acr.token, signer, acr.isUniversalToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find or create the appropriate system for this token and fingerprint
|
||||
fpRecord, err := acr.findOrCreateSystemForToken(fpRecords, agentFingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return acr.hub.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn)
|
||||
}
|
||||
|
||||
// validateAgentHeaders extracts and validates the token and agent version from HTTP headers.
|
||||
func (acr *agentConnectRequest) validateAgentHeaders(headers http.Header) (string, string, error) {
|
||||
token := headers.Get("X-Token")
|
||||
agentVersion := headers.Get("X-Beszel")
|
||||
|
||||
if agentVersion == "" || token == "" || len(token) > 64 {
|
||||
return "", "", errors.New("")
|
||||
}
|
||||
return token, agentVersion, nil
|
||||
}
|
||||
|
||||
// sendResponseError writes an HTTP error response.
|
||||
func (acr *agentConnectRequest) sendResponseError(res http.ResponseWriter, code int, message string) error {
|
||||
res.WriteHeader(code)
|
||||
if message != "" {
|
||||
res.Write([]byte(message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFingerprintRecordsByToken retrieves all fingerprint records associated with a given token.
|
||||
func getFingerprintRecordsByToken(token string, h *Hub) []ws.FingerprintRecord {
|
||||
var records []ws.FingerprintRecord
|
||||
// All will populate empty slice even on error
|
||||
_ = h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}").
|
||||
Bind(dbx.Params{
|
||||
"token": token,
|
||||
}).
|
||||
All(&records)
|
||||
return records
|
||||
}
|
||||
|
||||
// findOrCreateSystemForToken finds an existing system matching the token and fingerprint,
|
||||
// or creates a new one for a universal token.
|
||||
func (acr *agentConnectRequest) findOrCreateSystemForToken(fpRecords []ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
// No records - only valid for active universal tokens
|
||||
if len(fpRecords) == 0 {
|
||||
return acr.handleNoRecords(agentFingerprint)
|
||||
}
|
||||
|
||||
// Single record - handle as regular token
|
||||
if len(fpRecords) == 1 && !acr.isUniversalToken {
|
||||
return acr.handleSingleRecord(fpRecords[0], agentFingerprint)
|
||||
}
|
||||
|
||||
// Multiple records or universal token - look for matching fingerprint
|
||||
return acr.handleMultipleRecordsOrUniversalToken(fpRecords, agentFingerprint)
|
||||
}
|
||||
|
||||
// handleNoRecords handles the case where no fingerprint records are found for a token.
|
||||
// A new system is created if the token is a valid universal token.
|
||||
func (acr *agentConnectRequest) handleNoRecords(agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
var fpRecord ws.FingerprintRecord
|
||||
|
||||
if !acr.isUniversalToken || acr.userId == "" {
|
||||
return fpRecord, errors.New("no matching fingerprints")
|
||||
}
|
||||
|
||||
return acr.createNewSystemForUniversalToken(agentFingerprint)
|
||||
}
|
||||
|
||||
// handleSingleRecord handles the case with a single fingerprint record. It validates
|
||||
// the agent's fingerprint against the stored one, or sets it on first connect.
|
||||
func (acr *agentConnectRequest) handleSingleRecord(fpRecord ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
// If no current fingerprint, update with new fingerprint (first time connecting)
|
||||
if fpRecord.Fingerprint == "" {
|
||||
if err := acr.hub.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil {
|
||||
return fpRecord, err
|
||||
}
|
||||
// Update the record with the fingerprint that was set
|
||||
fpRecord.Fingerprint = agentFingerprint.Fingerprint
|
||||
return fpRecord, nil
|
||||
}
|
||||
|
||||
// Abort if fingerprint exists but doesn't match (different machine)
|
||||
if fpRecord.Fingerprint != agentFingerprint.Fingerprint {
|
||||
return fpRecord, errors.New("fingerprint mismatch")
|
||||
}
|
||||
|
||||
return fpRecord, nil
|
||||
}
|
||||
|
||||
// handleMultipleRecordsOrUniversalToken finds a matching fingerprint from multiple records.
|
||||
// If no match is found and the token is a universal token, a new system is created.
|
||||
func (acr *agentConnectRequest) handleMultipleRecordsOrUniversalToken(fpRecords []ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
// Return existing record with matching fingerprint if found
|
||||
for i := range fpRecords {
|
||||
if fpRecords[i].Fingerprint == agentFingerprint.Fingerprint {
|
||||
return fpRecords[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
// No matching fingerprint record found, but it's
|
||||
// an active universal token so create a new system
|
||||
if acr.isUniversalToken {
|
||||
return acr.createNewSystemForUniversalToken(agentFingerprint)
|
||||
}
|
||||
|
||||
return ws.FingerprintRecord{}, errors.New("fingerprint mismatch")
|
||||
}
|
||||
|
||||
// createNewSystemForUniversalToken creates a new system and fingerprint record for a universal token.
|
||||
func (acr *agentConnectRequest) createNewSystemForUniversalToken(agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
var fpRecord ws.FingerprintRecord
|
||||
if !acr.isUniversalToken || acr.userId == "" {
|
||||
return fpRecord, errors.New("invalid token")
|
||||
}
|
||||
|
||||
fpRecord.Token = acr.token
|
||||
|
||||
systemId, err := acr.createSystem(agentFingerprint)
|
||||
if err != nil {
|
||||
return fpRecord, err
|
||||
}
|
||||
fpRecord.SystemId = systemId
|
||||
|
||||
// Set the fingerprint for the new system
|
||||
if err := acr.hub.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil {
|
||||
return fpRecord, err
|
||||
}
|
||||
|
||||
// Update the record with the fingerprint that was set
|
||||
fpRecord.Fingerprint = agentFingerprint.Fingerprint
|
||||
|
||||
return fpRecord, nil
|
||||
}
|
||||
|
||||
// createSystem creates a new system record in the database using details from the agent.
|
||||
func (acr *agentConnectRequest) createSystem(agentFingerprint common.FingerprintResponse) (recordId string, err error) {
|
||||
systemsCollection, err := acr.hub.FindCachedCollectionByNameOrId("systems")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
remoteAddr := getRealIP(acr.req)
|
||||
// separate port from address
|
||||
if agentFingerprint.Hostname == "" {
|
||||
agentFingerprint.Hostname = remoteAddr
|
||||
}
|
||||
if agentFingerprint.Port == "" {
|
||||
agentFingerprint.Port = "45876"
|
||||
}
|
||||
if agentFingerprint.Name == "" {
|
||||
agentFingerprint.Name = agentFingerprint.Hostname
|
||||
}
|
||||
// create new record
|
||||
systemRecord := core.NewRecord(systemsCollection)
|
||||
systemRecord.Set("name", agentFingerprint.Name)
|
||||
systemRecord.Set("host", remoteAddr)
|
||||
systemRecord.Set("port", agentFingerprint.Port)
|
||||
systemRecord.Set("users", []string{acr.userId})
|
||||
|
||||
return systemRecord.Id, acr.hub.Save(systemRecord)
|
||||
}
|
||||
|
||||
// SetFingerprint creates or updates a fingerprint record in the database.
|
||||
func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) (err error) {
|
||||
// // can't use raw query here because it doesn't trigger SSE
|
||||
var record *core.Record
|
||||
switch fpRecord.Id {
|
||||
case "":
|
||||
// create new record for universal token
|
||||
collection, _ := h.FindCachedCollectionByNameOrId("fingerprints")
|
||||
record = core.NewRecord(collection)
|
||||
record.Set("system", fpRecord.SystemId)
|
||||
default:
|
||||
record, err = h.FindRecordById("fingerprints", fpRecord.Id)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record.Set("token", fpRecord.Token)
|
||||
record.Set("fingerprint", fingerprint)
|
||||
return h.SaveNoValidate(record)
|
||||
}
|
||||
|
||||
// getRealIP extracts the client's real IP address from request headers,
|
||||
// checking common proxy headers before falling back to the remote address.
|
||||
func getRealIP(r *http.Request) string {
|
||||
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// X-Forwarded-For can contain a comma-separated list: "client_ip, proxy1, proxy2"
|
||||
// Take the first one
|
||||
ips := strings.Split(ip, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
// Fallback to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,391 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/google/uuid"
|
||||
"github.com/henrygd/beszel"
|
||||
"github.com/henrygd/beszel/internal/alerts"
|
||||
"github.com/henrygd/beszel/internal/ghupdate"
|
||||
"github.com/henrygd/beszel/internal/hub/config"
|
||||
"github.com/henrygd/beszel/internal/hub/systems"
|
||||
"github.com/henrygd/beszel/internal/hub/utils"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// UpdateInfo holds information about the latest update check
|
||||
type UpdateInfo struct {
|
||||
lastCheck time.Time
|
||||
Version string `json:"v"`
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
var containerIDPattern = regexp.MustCompile(`^[a-fA-F0-9]{12,64}$`)
|
||||
|
||||
// Middleware to allow only admin role users
|
||||
var requireAdminRole = customAuthMiddleware(func(e *core.RequestEvent) bool {
|
||||
return e.Auth.GetString("role") == "admin"
|
||||
})
|
||||
|
||||
// Middleware to exclude readonly users
|
||||
var excludeReadOnlyRole = customAuthMiddleware(func(e *core.RequestEvent) bool {
|
||||
return e.Auth.GetString("role") != "readonly"
|
||||
})
|
||||
|
||||
// customAuthMiddleware handles boilerplate for custom authentication middlewares. fn should
|
||||
// return true if the request is allowed, false otherwise. e.Auth is guaranteed to be non-nil.
|
||||
func customAuthMiddleware(fn func(*core.RequestEvent) bool) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
if !fn(e) {
|
||||
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// registerMiddlewares registers custom middlewares
|
||||
func (h *Hub) registerMiddlewares(se *core.ServeEvent) {
|
||||
// authorizes request with user matching the provided email
|
||||
authorizeRequestWithEmail := func(e *core.RequestEvent, email string) (err error) {
|
||||
if e.Auth != nil || email == "" {
|
||||
return e.Next()
|
||||
}
|
||||
isAuthRefresh := e.Request.URL.Path == "/api/collections/users/auth-refresh" && e.Request.Method == http.MethodPost
|
||||
e.Auth, err = e.App.FindAuthRecordByEmail("users", email)
|
||||
if err != nil || !isAuthRefresh {
|
||||
return e.Next()
|
||||
}
|
||||
// auth refresh endpoint, make sure token is set in header
|
||||
token, _ := e.Auth.NewAuthToken()
|
||||
e.Request.Header.Set("Authorization", token)
|
||||
return e.Next()
|
||||
}
|
||||
// authenticate with trusted header
|
||||
if autoLogin, _ := utils.GetEnv("AUTO_LOGIN"); autoLogin != "" {
|
||||
se.Router.BindFunc(func(e *core.RequestEvent) error {
|
||||
return authorizeRequestWithEmail(e, autoLogin)
|
||||
})
|
||||
}
|
||||
// authenticate with trusted header
|
||||
if trustedHeader, _ := utils.GetEnv("TRUSTED_AUTH_HEADER"); trustedHeader != "" {
|
||||
se.Router.BindFunc(func(e *core.RequestEvent) error {
|
||||
return authorizeRequestWithEmail(e, e.Request.Header.Get(trustedHeader))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// registerApiRoutes registers custom API routes
|
||||
func (h *Hub) registerApiRoutes(se *core.ServeEvent) error {
|
||||
// auth protected routes
|
||||
apiAuth := se.Router.Group("/api/beszel")
|
||||
apiAuth.Bind(apis.RequireAuth())
|
||||
// auth optional routes
|
||||
apiNoAuth := se.Router.Group("/api/beszel")
|
||||
|
||||
// create first user endpoint only needed if no users exist
|
||||
if totalUsers, _ := se.App.CountRecords("users"); totalUsers == 0 {
|
||||
apiNoAuth.POST("/create-user", h.um.CreateFirstUser)
|
||||
}
|
||||
// check if first time setup on login page
|
||||
apiNoAuth.GET("/first-run", func(e *core.RequestEvent) error {
|
||||
total, err := e.App.CountRecords("users")
|
||||
return e.JSON(http.StatusOK, map[string]bool{"firstRun": err == nil && total == 0})
|
||||
})
|
||||
// get public key and version
|
||||
apiAuth.GET("/info", h.getInfo)
|
||||
apiAuth.GET("/getkey", h.getInfo) // deprecated - keep for compatibility w/ integrations
|
||||
// check for updates
|
||||
if optIn, _ := utils.GetEnv("CHECK_UPDATES"); optIn == "true" {
|
||||
var updateInfo UpdateInfo
|
||||
apiAuth.GET("/update", updateInfo.getUpdate)
|
||||
}
|
||||
// send test notification
|
||||
apiAuth.POST("/test-notification", h.SendTestNotification)
|
||||
// heartbeat status and test
|
||||
apiAuth.GET("/heartbeat-status", h.getHeartbeatStatus).BindFunc(requireAdminRole)
|
||||
apiAuth.POST("/test-heartbeat", h.testHeartbeat).BindFunc(requireAdminRole)
|
||||
// get config.yml content
|
||||
apiAuth.GET("/config-yaml", config.GetYamlConfig).BindFunc(requireAdminRole)
|
||||
// handle agent websocket connection
|
||||
apiNoAuth.GET("/agent-connect", h.handleAgentConnect)
|
||||
// get or create universal tokens
|
||||
apiAuth.GET("/universal-token", h.getUniversalToken).BindFunc(excludeReadOnlyRole)
|
||||
// update / delete user alerts
|
||||
apiAuth.POST("/user-alerts", alerts.UpsertUserAlerts)
|
||||
apiAuth.DELETE("/user-alerts", alerts.DeleteUserAlerts)
|
||||
// refresh SMART devices for a system
|
||||
apiAuth.POST("/smart/refresh", h.refreshSmartData).BindFunc(excludeReadOnlyRole)
|
||||
// get systemd service details
|
||||
apiAuth.GET("/systemd/info", h.getSystemdInfo)
|
||||
// /containers routes
|
||||
if enabled, _ := utils.GetEnv("CONTAINER_DETAILS"); enabled != "false" {
|
||||
// get container logs
|
||||
apiAuth.GET("/containers/logs", h.getContainerLogs)
|
||||
// get container info
|
||||
apiAuth.GET("/containers/info", h.getContainerInfo)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInfo returns data needed by authenticated users, such as the public key and current version
|
||||
func (h *Hub) getInfo(e *core.RequestEvent) error {
|
||||
type infoResponse struct {
|
||||
Key string `json:"key"`
|
||||
Version string `json:"v"`
|
||||
CheckUpdate bool `json:"cu"`
|
||||
}
|
||||
info := infoResponse{
|
||||
Key: h.pubKey,
|
||||
Version: beszel.Version,
|
||||
}
|
||||
if optIn, _ := utils.GetEnv("CHECK_UPDATES"); optIn == "true" {
|
||||
info.CheckUpdate = true
|
||||
}
|
||||
return e.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// getUpdate checks for the latest release on GitHub and returns update info if a newer version is available
|
||||
func (info *UpdateInfo) getUpdate(e *core.RequestEvent) error {
|
||||
if time.Since(info.lastCheck) < 6*time.Hour {
|
||||
return e.JSON(http.StatusOK, info)
|
||||
}
|
||||
info.lastCheck = time.Now()
|
||||
latestRelease, err := ghupdate.FetchLatestRelease(context.Background(), http.DefaultClient, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentVersion, err := semver.Parse(strings.TrimPrefix(beszel.Version, "v"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
latestVersion, err := semver.Parse(strings.TrimPrefix(latestRelease.Tag, "v"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if latestVersion.GT(currentVersion) {
|
||||
info.Version = strings.TrimPrefix(latestRelease.Tag, "v")
|
||||
info.Url = latestRelease.Url
|
||||
}
|
||||
return e.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// GetUniversalToken handles the universal token API endpoint (create, read, delete)
|
||||
func (h *Hub) getUniversalToken(e *core.RequestEvent) error {
|
||||
if e.Auth.IsSuperuser() {
|
||||
return e.ForbiddenError("Superusers cannot use universal tokens", nil)
|
||||
}
|
||||
|
||||
tokenMap := universalTokenMap.GetMap()
|
||||
userID := e.Auth.Id
|
||||
query := e.Request.URL.Query()
|
||||
token := query.Get("token")
|
||||
enable := query.Get("enable")
|
||||
permanent := query.Get("permanent")
|
||||
|
||||
// helper for deleting any existing permanent token record for this user
|
||||
deletePermanent := func() error {
|
||||
rec, err := h.FindFirstRecordByFilter("universal_tokens", "user = {:user}", dbx.Params{"user": userID})
|
||||
if err != nil {
|
||||
return nil // no record
|
||||
}
|
||||
return h.Delete(rec)
|
||||
}
|
||||
|
||||
// helper for upserting a permanent token record for this user
|
||||
upsertPermanent := func(token string) error {
|
||||
rec, err := h.FindFirstRecordByFilter("universal_tokens", "user = {:user}", dbx.Params{"user": userID})
|
||||
if err == nil {
|
||||
rec.Set("token", token)
|
||||
return h.Save(rec)
|
||||
}
|
||||
|
||||
col, err := h.FindCachedCollectionByNameOrId("universal_tokens")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRec := core.NewRecord(col)
|
||||
newRec.Set("user", userID)
|
||||
newRec.Set("token", token)
|
||||
return h.Save(newRec)
|
||||
}
|
||||
|
||||
// Disable universal tokens (both ephemeral and permanent)
|
||||
if enable == "0" {
|
||||
tokenMap.RemovebyValue(userID)
|
||||
_ = deletePermanent()
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": token, "active": false, "permanent": false})
|
||||
}
|
||||
|
||||
// Enable universal token (ephemeral or permanent)
|
||||
if enable == "1" {
|
||||
if token == "" {
|
||||
token = uuid.New().String()
|
||||
}
|
||||
|
||||
if permanent == "1" {
|
||||
// make token permanent (persist across restarts)
|
||||
tokenMap.RemovebyValue(userID)
|
||||
if err := upsertPermanent(token); err != nil {
|
||||
return err
|
||||
}
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": token, "active": true, "permanent": true})
|
||||
}
|
||||
|
||||
// default: ephemeral mode (1 hour)
|
||||
_ = deletePermanent()
|
||||
tokenMap.Set(token, userID, time.Hour)
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": token, "active": true, "permanent": false})
|
||||
}
|
||||
|
||||
// Read current state
|
||||
// Prefer permanent token if it exists.
|
||||
if rec, err := h.FindFirstRecordByFilter("universal_tokens", "user = {:user}", dbx.Params{"user": userID}); err == nil {
|
||||
dbToken := rec.GetString("token")
|
||||
// If no token was provided, or the caller is asking about their permanent token, return it.
|
||||
if token == "" || token == dbToken {
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": dbToken, "active": true, "permanent": true})
|
||||
}
|
||||
// Token doesn't match their permanent token (avoid leaking other info)
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": token, "active": false, "permanent": false})
|
||||
}
|
||||
|
||||
// No permanent token; fall back to ephemeral token map.
|
||||
if token == "" {
|
||||
// return existing token if it exists
|
||||
if token, _, ok := tokenMap.GetByValue(userID); ok {
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": token, "active": true, "permanent": false})
|
||||
}
|
||||
// if no token is provided, generate a new one
|
||||
token = uuid.New().String()
|
||||
}
|
||||
|
||||
// Token is considered active only if it belongs to the current user.
|
||||
activeUser, ok := tokenMap.GetOk(token)
|
||||
active := ok && activeUser == userID
|
||||
response := map[string]any{"token": token, "active": active, "permanent": false}
|
||||
return e.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// getHeartbeatStatus returns current heartbeat configuration and whether it's enabled
|
||||
func (h *Hub) getHeartbeatStatus(e *core.RequestEvent) error {
|
||||
if h.hb == nil {
|
||||
return e.JSON(http.StatusOK, map[string]any{
|
||||
"enabled": false,
|
||||
"msg": "Set HEARTBEAT_URL to enable outbound heartbeat monitoring",
|
||||
})
|
||||
}
|
||||
cfg := h.hb.GetConfig()
|
||||
return e.JSON(http.StatusOK, map[string]any{
|
||||
"enabled": true,
|
||||
"url": cfg.URL,
|
||||
"interval": cfg.Interval,
|
||||
"method": cfg.Method,
|
||||
})
|
||||
}
|
||||
|
||||
// testHeartbeat triggers a single heartbeat ping and returns the result
|
||||
func (h *Hub) testHeartbeat(e *core.RequestEvent) error {
|
||||
if h.hb == nil {
|
||||
return e.JSON(http.StatusOK, map[string]any{
|
||||
"err": "Heartbeat not configured. Set HEARTBEAT_URL environment variable.",
|
||||
})
|
||||
}
|
||||
if err := h.hb.Send(); err != nil {
|
||||
return e.JSON(http.StatusOK, map[string]any{"err": err.Error()})
|
||||
}
|
||||
return e.JSON(http.StatusOK, map[string]any{"err": false})
|
||||
}
|
||||
|
||||
// containerRequestHandler handles both container logs and info requests
|
||||
func (h *Hub) containerRequestHandler(e *core.RequestEvent, fetchFunc func(*systems.System, string) (string, error), responseKey string) error {
|
||||
systemID := e.Request.URL.Query().Get("system")
|
||||
containerID := e.Request.URL.Query().Get("container")
|
||||
|
||||
if systemID == "" || containerID == "" || !containerIDPattern.MatchString(containerID) {
|
||||
return e.BadRequestError("Invalid system or container parameter", nil)
|
||||
}
|
||||
|
||||
system, err := h.sm.GetSystem(systemID)
|
||||
if err != nil || !system.HasUser(e.App, e.Auth) {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
data, err := fetchFunc(system, containerID)
|
||||
if err != nil {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{responseKey: data})
|
||||
}
|
||||
|
||||
// getContainerLogs handles GET /api/beszel/containers/logs requests
|
||||
func (h *Hub) getContainerLogs(e *core.RequestEvent) error {
|
||||
return h.containerRequestHandler(e, func(system *systems.System, containerID string) (string, error) {
|
||||
return system.FetchContainerLogsFromAgent(containerID)
|
||||
}, "logs")
|
||||
}
|
||||
|
||||
func (h *Hub) getContainerInfo(e *core.RequestEvent) error {
|
||||
return h.containerRequestHandler(e, func(system *systems.System, containerID string) (string, error) {
|
||||
return system.FetchContainerInfoFromAgent(containerID)
|
||||
}, "info")
|
||||
}
|
||||
|
||||
// getSystemdInfo handles GET /api/beszel/systemd/info requests
|
||||
func (h *Hub) getSystemdInfo(e *core.RequestEvent) error {
|
||||
query := e.Request.URL.Query()
|
||||
systemID := query.Get("system")
|
||||
serviceName := query.Get("service")
|
||||
|
||||
if systemID == "" || serviceName == "" {
|
||||
return e.BadRequestError("Invalid system or service parameter", nil)
|
||||
}
|
||||
system, err := h.sm.GetSystem(systemID)
|
||||
if err != nil || !system.HasUser(e.App, e.Auth) {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
// verify service exists before fetching details
|
||||
_, err = e.App.FindFirstRecordByFilter("systemd_services", "system = {:system} && name = {:name}", dbx.Params{
|
||||
"system": systemID,
|
||||
"name": serviceName,
|
||||
})
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
details, err := system.FetchSystemdInfoFromAgent(serviceName)
|
||||
if err != nil {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
e.Response.Header().Set("Cache-Control", "public, max-age=60")
|
||||
return e.JSON(http.StatusOK, map[string]any{"details": details})
|
||||
}
|
||||
|
||||
// refreshSmartData handles POST /api/beszel/smart/refresh requests
|
||||
// Fetches fresh SMART data from the agent and updates the collection
|
||||
func (h *Hub) refreshSmartData(e *core.RequestEvent) error {
|
||||
systemID := e.Request.URL.Query().Get("system")
|
||||
if systemID == "" {
|
||||
return e.BadRequestError("Invalid system parameter", nil)
|
||||
}
|
||||
|
||||
system, err := h.sm.GetSystem(systemID)
|
||||
if err != nil || !system.HasUser(e.App, e.Auth) {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
if err := system.FetchAndSaveSmartDevices(); err != nil {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"github.com/henrygd/beszel/internal/hub/utils"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
type collectionRules struct {
|
||||
list *string
|
||||
view *string
|
||||
create *string
|
||||
update *string
|
||||
delete *string
|
||||
}
|
||||
|
||||
// setCollectionAuthSettings applies Beszel's collection auth settings.
|
||||
func setCollectionAuthSettings(app core.App) error {
|
||||
usersCollection, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
superusersCollection, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// disable email auth if DISABLE_PASSWORD_AUTH env var is set
|
||||
disablePasswordAuth, _ := utils.GetEnv("DISABLE_PASSWORD_AUTH")
|
||||
usersCollection.PasswordAuth.Enabled = disablePasswordAuth != "true"
|
||||
usersCollection.PasswordAuth.IdentityFields = []string{"email"}
|
||||
// allow oauth user creation if USER_CREATION is set
|
||||
if userCreation, _ := utils.GetEnv("USER_CREATION"); userCreation == "true" {
|
||||
cr := "@request.context = 'oauth2'"
|
||||
usersCollection.CreateRule = &cr
|
||||
} else {
|
||||
usersCollection.CreateRule = nil
|
||||
}
|
||||
|
||||
// enable mfaOtp mfa if MFA_OTP env var is set
|
||||
mfaOtp, _ := utils.GetEnv("MFA_OTP")
|
||||
usersCollection.OTP.Length = 6
|
||||
superusersCollection.OTP.Length = 6
|
||||
usersCollection.OTP.Enabled = mfaOtp == "true"
|
||||
usersCollection.MFA.Enabled = mfaOtp == "true"
|
||||
superusersCollection.OTP.Enabled = mfaOtp == "true" || mfaOtp == "superusers"
|
||||
superusersCollection.MFA.Enabled = mfaOtp == "true" || mfaOtp == "superusers"
|
||||
if err := app.Save(superusersCollection); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := app.Save(usersCollection); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// When SHARE_ALL_SYSTEMS is enabled, any authenticated user can read
|
||||
// system-scoped data. Write rules continue to block readonly users.
|
||||
shareAllSystems, _ := utils.GetEnv("SHARE_ALL_SYSTEMS")
|
||||
|
||||
authenticatedRule := "@request.auth.id != \"\""
|
||||
systemsMemberRule := authenticatedRule + " && users.id ?= @request.auth.id"
|
||||
systemMemberRule := authenticatedRule + " && system.users.id ?= @request.auth.id"
|
||||
|
||||
systemsReadRule := systemsMemberRule
|
||||
systemScopedReadRule := systemMemberRule
|
||||
if shareAllSystems == "true" {
|
||||
systemsReadRule = authenticatedRule
|
||||
systemScopedReadRule = authenticatedRule
|
||||
}
|
||||
systemsWriteRule := systemsReadRule + " && @request.auth.role != \"readonly\""
|
||||
systemScopedWriteRule := systemScopedReadRule + " && @request.auth.role != \"readonly\""
|
||||
|
||||
if err := applyCollectionRules(app, []string{"systems"}, collectionRules{
|
||||
list: &systemsReadRule,
|
||||
view: &systemsReadRule,
|
||||
create: &systemsWriteRule,
|
||||
update: &systemsWriteRule,
|
||||
delete: &systemsWriteRule,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := applyCollectionRules(app, []string{"containers", "container_stats", "system_stats", "systemd_services"}, collectionRules{
|
||||
list: &systemScopedReadRule,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := applyCollectionRules(app, []string{"smart_devices"}, collectionRules{
|
||||
list: &systemScopedReadRule,
|
||||
view: &systemScopedReadRule,
|
||||
delete: &systemScopedWriteRule,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := applyCollectionRules(app, []string{"fingerprints"}, collectionRules{
|
||||
list: &systemScopedReadRule,
|
||||
view: &systemScopedReadRule,
|
||||
create: &systemScopedWriteRule,
|
||||
update: &systemScopedWriteRule,
|
||||
delete: &systemScopedWriteRule,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := applyCollectionRules(app, []string{"system_details"}, collectionRules{
|
||||
list: &systemScopedReadRule,
|
||||
view: &systemScopedReadRule,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyCollectionRules(app core.App, collectionNames []string, rules collectionRules) error {
|
||||
for _, collectionName := range collectionNames {
|
||||
collection, err := app.FindCollectionByNameOrId(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
collection.ListRule = rules.list
|
||||
collection.ViewRule = rules.view
|
||||
collection.CreateRule = rules.create
|
||||
collection.UpdateRule = rules.update
|
||||
collection.DeleteRule = rules.delete
|
||||
if err := app.Save(collection); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,527 @@
|
||||
//go:build testing
|
||||
|
||||
package hub_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
beszelTests "github.com/henrygd/beszel/internal/tests"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
pbTests "github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCollectionRulesDefault(t *testing.T) {
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
const isUserMatchesUser = `@request.auth.id != "" && user = @request.auth.id`
|
||||
|
||||
const isUserInUsers = `@request.auth.id != "" && users.id ?= @request.auth.id`
|
||||
const isUserInUsersNotReadonly = `@request.auth.id != "" && users.id ?= @request.auth.id && @request.auth.role != "readonly"`
|
||||
|
||||
const isUserInSystemUsers = `@request.auth.id != "" && system.users.id ?= @request.auth.id`
|
||||
const isUserInSystemUsersNotReadonly = `@request.auth.id != "" && system.users.id ?= @request.auth.id && @request.auth.role != "readonly"`
|
||||
|
||||
// users collection
|
||||
usersCollection, err := hub.FindCollectionByNameOrId("users")
|
||||
assert.NoError(t, err, "Failed to find users collection")
|
||||
assert.True(t, usersCollection.PasswordAuth.Enabled)
|
||||
assert.Equal(t, usersCollection.PasswordAuth.IdentityFields, []string{"email"})
|
||||
assert.Nil(t, usersCollection.CreateRule)
|
||||
assert.False(t, usersCollection.MFA.Enabled)
|
||||
|
||||
// superusers collection
|
||||
superusersCollection, err := hub.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
assert.NoError(t, err, "Failed to find superusers collection")
|
||||
assert.True(t, superusersCollection.PasswordAuth.Enabled)
|
||||
assert.Equal(t, superusersCollection.PasswordAuth.IdentityFields, []string{"email"})
|
||||
assert.Nil(t, superusersCollection.CreateRule)
|
||||
assert.False(t, superusersCollection.MFA.Enabled)
|
||||
|
||||
// alerts collection
|
||||
alertsCollection, err := hub.FindCollectionByNameOrId("alerts")
|
||||
require.NoError(t, err, "Failed to find alerts collection")
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.ListRule)
|
||||
assert.Nil(t, alertsCollection.ViewRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.CreateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.UpdateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.DeleteRule)
|
||||
|
||||
// alerts_history collection
|
||||
alertsHistoryCollection, err := hub.FindCollectionByNameOrId("alerts_history")
|
||||
require.NoError(t, err, "Failed to find alerts_history collection")
|
||||
assert.Equal(t, isUserMatchesUser, *alertsHistoryCollection.ListRule)
|
||||
assert.Nil(t, alertsHistoryCollection.ViewRule)
|
||||
assert.Nil(t, alertsHistoryCollection.CreateRule)
|
||||
assert.Nil(t, alertsHistoryCollection.UpdateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsHistoryCollection.DeleteRule)
|
||||
|
||||
// containers collection
|
||||
containersCollection, err := hub.FindCollectionByNameOrId("containers")
|
||||
require.NoError(t, err, "Failed to find containers collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *containersCollection.ListRule)
|
||||
assert.Nil(t, containersCollection.ViewRule)
|
||||
assert.Nil(t, containersCollection.CreateRule)
|
||||
assert.Nil(t, containersCollection.UpdateRule)
|
||||
assert.Nil(t, containersCollection.DeleteRule)
|
||||
|
||||
// container_stats collection
|
||||
containerStatsCollection, err := hub.FindCollectionByNameOrId("container_stats")
|
||||
require.NoError(t, err, "Failed to find container_stats collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *containerStatsCollection.ListRule)
|
||||
assert.Nil(t, containerStatsCollection.ViewRule)
|
||||
assert.Nil(t, containerStatsCollection.CreateRule)
|
||||
assert.Nil(t, containerStatsCollection.UpdateRule)
|
||||
assert.Nil(t, containerStatsCollection.DeleteRule)
|
||||
|
||||
// fingerprints collection
|
||||
fingerprintsCollection, err := hub.FindCollectionByNameOrId("fingerprints")
|
||||
require.NoError(t, err, "Failed to find fingerprints collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *fingerprintsCollection.ListRule)
|
||||
assert.Equal(t, isUserInSystemUsers, *fingerprintsCollection.ViewRule)
|
||||
assert.Equal(t, isUserInSystemUsersNotReadonly, *fingerprintsCollection.CreateRule)
|
||||
assert.Equal(t, isUserInSystemUsersNotReadonly, *fingerprintsCollection.UpdateRule)
|
||||
assert.Equal(t, isUserInSystemUsersNotReadonly, *fingerprintsCollection.DeleteRule)
|
||||
|
||||
// quiet_hours collection
|
||||
quietHoursCollection, err := hub.FindCollectionByNameOrId("quiet_hours")
|
||||
require.NoError(t, err, "Failed to find quiet_hours collection")
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.ListRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.ViewRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.CreateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.UpdateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.DeleteRule)
|
||||
|
||||
// smart_devices collection
|
||||
smartDevicesCollection, err := hub.FindCollectionByNameOrId("smart_devices")
|
||||
require.NoError(t, err, "Failed to find smart_devices collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *smartDevicesCollection.ListRule)
|
||||
assert.Equal(t, isUserInSystemUsers, *smartDevicesCollection.ViewRule)
|
||||
assert.Nil(t, smartDevicesCollection.CreateRule)
|
||||
assert.Nil(t, smartDevicesCollection.UpdateRule)
|
||||
assert.Equal(t, isUserInSystemUsersNotReadonly, *smartDevicesCollection.DeleteRule)
|
||||
|
||||
// system_details collection
|
||||
systemDetailsCollection, err := hub.FindCollectionByNameOrId("system_details")
|
||||
require.NoError(t, err, "Failed to find system_details collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *systemDetailsCollection.ListRule)
|
||||
assert.Equal(t, isUserInSystemUsers, *systemDetailsCollection.ViewRule)
|
||||
assert.Nil(t, systemDetailsCollection.CreateRule)
|
||||
assert.Nil(t, systemDetailsCollection.UpdateRule)
|
||||
assert.Nil(t, systemDetailsCollection.DeleteRule)
|
||||
|
||||
// system_stats collection
|
||||
systemStatsCollection, err := hub.FindCollectionByNameOrId("system_stats")
|
||||
require.NoError(t, err, "Failed to find system_stats collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *systemStatsCollection.ListRule)
|
||||
assert.Nil(t, systemStatsCollection.ViewRule)
|
||||
assert.Nil(t, systemStatsCollection.CreateRule)
|
||||
assert.Nil(t, systemStatsCollection.UpdateRule)
|
||||
assert.Nil(t, systemStatsCollection.DeleteRule)
|
||||
|
||||
// systemd_services collection
|
||||
systemdServicesCollection, err := hub.FindCollectionByNameOrId("systemd_services")
|
||||
require.NoError(t, err, "Failed to find systemd_services collection")
|
||||
assert.Equal(t, isUserInSystemUsers, *systemdServicesCollection.ListRule)
|
||||
assert.Nil(t, systemdServicesCollection.ViewRule)
|
||||
assert.Nil(t, systemdServicesCollection.CreateRule)
|
||||
assert.Nil(t, systemdServicesCollection.UpdateRule)
|
||||
assert.Nil(t, systemdServicesCollection.DeleteRule)
|
||||
|
||||
// systems collection
|
||||
systemsCollection, err := hub.FindCollectionByNameOrId("systems")
|
||||
require.NoError(t, err, "Failed to find systems collection")
|
||||
assert.Equal(t, isUserInUsers, *systemsCollection.ListRule)
|
||||
assert.Equal(t, isUserInUsers, *systemsCollection.ViewRule)
|
||||
assert.Equal(t, isUserInUsersNotReadonly, *systemsCollection.CreateRule)
|
||||
assert.Equal(t, isUserInUsersNotReadonly, *systemsCollection.UpdateRule)
|
||||
assert.Equal(t, isUserInUsersNotReadonly, *systemsCollection.DeleteRule)
|
||||
|
||||
// universal_tokens collection
|
||||
universalTokensCollection, err := hub.FindCollectionByNameOrId("universal_tokens")
|
||||
require.NoError(t, err, "Failed to find universal_tokens collection")
|
||||
assert.Nil(t, universalTokensCollection.ListRule)
|
||||
assert.Nil(t, universalTokensCollection.ViewRule)
|
||||
assert.Nil(t, universalTokensCollection.CreateRule)
|
||||
assert.Nil(t, universalTokensCollection.UpdateRule)
|
||||
assert.Nil(t, universalTokensCollection.DeleteRule)
|
||||
|
||||
// user_settings collection
|
||||
userSettingsCollection, err := hub.FindCollectionByNameOrId("user_settings")
|
||||
require.NoError(t, err, "Failed to find user_settings collection")
|
||||
assert.Equal(t, isUserMatchesUser, *userSettingsCollection.ListRule)
|
||||
assert.Nil(t, userSettingsCollection.ViewRule)
|
||||
assert.Equal(t, isUserMatchesUser, *userSettingsCollection.CreateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *userSettingsCollection.UpdateRule)
|
||||
assert.Nil(t, userSettingsCollection.DeleteRule)
|
||||
}
|
||||
|
||||
func TestCollectionRulesShareAllSystems(t *testing.T) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "true")
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
const isUser = `@request.auth.id != ""`
|
||||
const isUserNotReadonly = `@request.auth.id != "" && @request.auth.role != "readonly"`
|
||||
|
||||
const isUserMatchesUser = `@request.auth.id != "" && user = @request.auth.id`
|
||||
|
||||
// alerts collection
|
||||
alertsCollection, err := hub.FindCollectionByNameOrId("alerts")
|
||||
require.NoError(t, err, "Failed to find alerts collection")
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.ListRule)
|
||||
assert.Nil(t, alertsCollection.ViewRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.CreateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.UpdateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsCollection.DeleteRule)
|
||||
|
||||
// alerts_history collection
|
||||
alertsHistoryCollection, err := hub.FindCollectionByNameOrId("alerts_history")
|
||||
require.NoError(t, err, "Failed to find alerts_history collection")
|
||||
assert.Equal(t, isUserMatchesUser, *alertsHistoryCollection.ListRule)
|
||||
assert.Nil(t, alertsHistoryCollection.ViewRule)
|
||||
assert.Nil(t, alertsHistoryCollection.CreateRule)
|
||||
assert.Nil(t, alertsHistoryCollection.UpdateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *alertsHistoryCollection.DeleteRule)
|
||||
|
||||
// containers collection
|
||||
containersCollection, err := hub.FindCollectionByNameOrId("containers")
|
||||
require.NoError(t, err, "Failed to find containers collection")
|
||||
assert.Equal(t, isUser, *containersCollection.ListRule)
|
||||
assert.Nil(t, containersCollection.ViewRule)
|
||||
assert.Nil(t, containersCollection.CreateRule)
|
||||
assert.Nil(t, containersCollection.UpdateRule)
|
||||
assert.Nil(t, containersCollection.DeleteRule)
|
||||
|
||||
// container_stats collection
|
||||
containerStatsCollection, err := hub.FindCollectionByNameOrId("container_stats")
|
||||
require.NoError(t, err, "Failed to find container_stats collection")
|
||||
assert.Equal(t, isUser, *containerStatsCollection.ListRule)
|
||||
assert.Nil(t, containerStatsCollection.ViewRule)
|
||||
assert.Nil(t, containerStatsCollection.CreateRule)
|
||||
assert.Nil(t, containerStatsCollection.UpdateRule)
|
||||
assert.Nil(t, containerStatsCollection.DeleteRule)
|
||||
|
||||
// fingerprints collection
|
||||
fingerprintsCollection, err := hub.FindCollectionByNameOrId("fingerprints")
|
||||
require.NoError(t, err, "Failed to find fingerprints collection")
|
||||
assert.Equal(t, isUser, *fingerprintsCollection.ListRule)
|
||||
assert.Equal(t, isUser, *fingerprintsCollection.ViewRule)
|
||||
assert.Equal(t, isUserNotReadonly, *fingerprintsCollection.CreateRule)
|
||||
assert.Equal(t, isUserNotReadonly, *fingerprintsCollection.UpdateRule)
|
||||
assert.Equal(t, isUserNotReadonly, *fingerprintsCollection.DeleteRule)
|
||||
|
||||
// quiet_hours collection
|
||||
quietHoursCollection, err := hub.FindCollectionByNameOrId("quiet_hours")
|
||||
require.NoError(t, err, "Failed to find quiet_hours collection")
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.ListRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.ViewRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.CreateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.UpdateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *quietHoursCollection.DeleteRule)
|
||||
|
||||
// smart_devices collection
|
||||
smartDevicesCollection, err := hub.FindCollectionByNameOrId("smart_devices")
|
||||
require.NoError(t, err, "Failed to find smart_devices collection")
|
||||
assert.Equal(t, isUser, *smartDevicesCollection.ListRule)
|
||||
assert.Equal(t, isUser, *smartDevicesCollection.ViewRule)
|
||||
assert.Nil(t, smartDevicesCollection.CreateRule)
|
||||
assert.Nil(t, smartDevicesCollection.UpdateRule)
|
||||
assert.Equal(t, isUserNotReadonly, *smartDevicesCollection.DeleteRule)
|
||||
|
||||
// system_details collection
|
||||
systemDetailsCollection, err := hub.FindCollectionByNameOrId("system_details")
|
||||
require.NoError(t, err, "Failed to find system_details collection")
|
||||
assert.Equal(t, isUser, *systemDetailsCollection.ListRule)
|
||||
assert.Equal(t, isUser, *systemDetailsCollection.ViewRule)
|
||||
assert.Nil(t, systemDetailsCollection.CreateRule)
|
||||
assert.Nil(t, systemDetailsCollection.UpdateRule)
|
||||
assert.Nil(t, systemDetailsCollection.DeleteRule)
|
||||
|
||||
// system_stats collection
|
||||
systemStatsCollection, err := hub.FindCollectionByNameOrId("system_stats")
|
||||
require.NoError(t, err, "Failed to find system_stats collection")
|
||||
assert.Equal(t, isUser, *systemStatsCollection.ListRule)
|
||||
assert.Nil(t, systemStatsCollection.ViewRule)
|
||||
assert.Nil(t, systemStatsCollection.CreateRule)
|
||||
assert.Nil(t, systemStatsCollection.UpdateRule)
|
||||
assert.Nil(t, systemStatsCollection.DeleteRule)
|
||||
|
||||
// systemd_services collection
|
||||
systemdServicesCollection, err := hub.FindCollectionByNameOrId("systemd_services")
|
||||
require.NoError(t, err, "Failed to find systemd_services collection")
|
||||
assert.Equal(t, isUser, *systemdServicesCollection.ListRule)
|
||||
assert.Nil(t, systemdServicesCollection.ViewRule)
|
||||
assert.Nil(t, systemdServicesCollection.CreateRule)
|
||||
assert.Nil(t, systemdServicesCollection.UpdateRule)
|
||||
assert.Nil(t, systemdServicesCollection.DeleteRule)
|
||||
|
||||
// systems collection
|
||||
systemsCollection, err := hub.FindCollectionByNameOrId("systems")
|
||||
require.NoError(t, err, "Failed to find systems collection")
|
||||
assert.Equal(t, isUser, *systemsCollection.ListRule)
|
||||
assert.Equal(t, isUser, *systemsCollection.ViewRule)
|
||||
assert.Equal(t, isUserNotReadonly, *systemsCollection.CreateRule)
|
||||
assert.Equal(t, isUserNotReadonly, *systemsCollection.UpdateRule)
|
||||
assert.Equal(t, isUserNotReadonly, *systemsCollection.DeleteRule)
|
||||
|
||||
// universal_tokens collection
|
||||
universalTokensCollection, err := hub.FindCollectionByNameOrId("universal_tokens")
|
||||
require.NoError(t, err, "Failed to find universal_tokens collection")
|
||||
assert.Nil(t, universalTokensCollection.ListRule)
|
||||
assert.Nil(t, universalTokensCollection.ViewRule)
|
||||
assert.Nil(t, universalTokensCollection.CreateRule)
|
||||
assert.Nil(t, universalTokensCollection.UpdateRule)
|
||||
assert.Nil(t, universalTokensCollection.DeleteRule)
|
||||
|
||||
// user_settings collection
|
||||
userSettingsCollection, err := hub.FindCollectionByNameOrId("user_settings")
|
||||
require.NoError(t, err, "Failed to find user_settings collection")
|
||||
assert.Equal(t, isUserMatchesUser, *userSettingsCollection.ListRule)
|
||||
assert.Nil(t, userSettingsCollection.ViewRule)
|
||||
assert.Equal(t, isUserMatchesUser, *userSettingsCollection.CreateRule)
|
||||
assert.Equal(t, isUserMatchesUser, *userSettingsCollection.UpdateRule)
|
||||
assert.Nil(t, userSettingsCollection.DeleteRule)
|
||||
}
|
||||
|
||||
func TestDisablePasswordAuth(t *testing.T) {
|
||||
t.Setenv("DISABLE_PASSWORD_AUTH", "true")
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
usersCollection, err := hub.FindCollectionByNameOrId("users")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, usersCollection.PasswordAuth.Enabled)
|
||||
}
|
||||
|
||||
func TestUserCreation(t *testing.T) {
|
||||
t.Setenv("USER_CREATION", "true")
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
usersCollection, err := hub.FindCollectionByNameOrId("users")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "@request.context = 'oauth2'", *usersCollection.CreateRule)
|
||||
}
|
||||
|
||||
func TestMFAOtp(t *testing.T) {
|
||||
t.Setenv("MFA_OTP", "true")
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
usersCollection, err := hub.FindCollectionByNameOrId("users")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, usersCollection.OTP.Enabled)
|
||||
assert.True(t, usersCollection.MFA.Enabled)
|
||||
|
||||
superusersCollection, err := hub.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, superusersCollection.OTP.Enabled)
|
||||
assert.True(t, superusersCollection.MFA.Enabled)
|
||||
}
|
||||
|
||||
func TestApiCollectionsAuthRules(t *testing.T) {
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
hub.StartHub()
|
||||
|
||||
user1, _ := beszelTests.CreateUser(hub, "user1@example.com", "password")
|
||||
user1Token, _ := user1.NewAuthToken()
|
||||
|
||||
user2, _ := beszelTests.CreateUser(hub, "user2@example.com", "password")
|
||||
// user2Token, _ := user2.NewAuthToken()
|
||||
|
||||
userReadonly, _ := beszelTests.CreateUserWithRole(hub, "userreadonly@example.com", "password", "readonly")
|
||||
userReadonlyToken, _ := userReadonly.NewAuthToken()
|
||||
|
||||
userOneSystem, _ := beszelTests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "system1",
|
||||
"users": []string{user1.Id},
|
||||
"host": "127.0.0.1",
|
||||
})
|
||||
|
||||
sharedSystem, _ := beszelTests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "system2",
|
||||
"users": []string{user1.Id, user2.Id},
|
||||
"host": "127.0.0.2",
|
||||
})
|
||||
|
||||
userTwoSystem, _ := beszelTests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "system3",
|
||||
"users": []string{user2.Id},
|
||||
"host": "127.0.0.2",
|
||||
})
|
||||
|
||||
userRecords, _ := hub.CountRecords("users")
|
||||
assert.EqualValues(t, 3, userRecords, "all users should be created")
|
||||
|
||||
systemRecords, _ := hub.CountRecords("systems")
|
||||
assert.EqualValues(t, 3, systemRecords, "all systems should be created")
|
||||
|
||||
testAppFactory := func(t testing.TB) *pbTests.TestApp {
|
||||
return hub.TestApp
|
||||
}
|
||||
|
||||
scenarios := []beszelTests.ApiScenario{
|
||||
{
|
||||
Name: "Unauthorized user cannot list systems",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/systems/records",
|
||||
ExpectedStatus: 200, // https://github.com/pocketbase/pocketbase/discussions/1570
|
||||
TestAppFactory: testAppFactory,
|
||||
ExpectedContent: []string{`"items":[]`, `"totalItems":0`},
|
||||
NotExpectedContent: []string{userOneSystem.Id, sharedSystem.Id, userTwoSystem.Id},
|
||||
},
|
||||
{
|
||||
Name: "Unauthorized user cannot delete a system",
|
||||
Method: http.MethodDelete,
|
||||
URL: fmt.Sprintf("/api/collections/systems/records/%s", userOneSystem.Id),
|
||||
ExpectedStatus: 404,
|
||||
TestAppFactory: testAppFactory,
|
||||
ExpectedContent: []string{"resource wasn't found"},
|
||||
NotExpectedContent: []string{userOneSystem.Id},
|
||||
BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) {
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 3, systemsCount, "should have 3 systems before deletion")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *pbTests.TestApp, res *http.Response) {
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 3, systemsCount, "should still have 3 systems after failed deletion")
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "User 1 can list their own systems",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/systems/records",
|
||||
Headers: map[string]string{
|
||||
"Authorization": user1Token,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{userOneSystem.Id, sharedSystem.Id},
|
||||
NotExpectedContent: []string{userTwoSystem.Id},
|
||||
TestAppFactory: testAppFactory,
|
||||
},
|
||||
{
|
||||
Name: "User 1 cannot list user 2's system",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/systems/records",
|
||||
Headers: map[string]string{
|
||||
"Authorization": user1Token,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{userOneSystem.Id, sharedSystem.Id},
|
||||
NotExpectedContent: []string{userTwoSystem.Id},
|
||||
TestAppFactory: testAppFactory,
|
||||
},
|
||||
{
|
||||
Name: "User 1 can see user 2's system if SHARE_ALL_SYSTEMS is enabled",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/systems/records",
|
||||
Headers: map[string]string{
|
||||
"Authorization": user1Token,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{userOneSystem.Id, sharedSystem.Id, userTwoSystem.Id},
|
||||
TestAppFactory: testAppFactory,
|
||||
BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "true")
|
||||
hub.SetCollectionAuthSettings()
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *pbTests.TestApp, res *http.Response) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "")
|
||||
hub.SetCollectionAuthSettings()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "User 1 can delete their own system",
|
||||
Method: http.MethodDelete,
|
||||
URL: fmt.Sprintf("/api/collections/systems/records/%s", userOneSystem.Id),
|
||||
Headers: map[string]string{
|
||||
"Authorization": user1Token,
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
TestAppFactory: testAppFactory,
|
||||
BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) {
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 3, systemsCount, "should have 3 systems before deletion")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *pbTests.TestApp, res *http.Response) {
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 2, systemsCount, "should have 2 systems after deletion")
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "User 1 cannot delete user 2's system",
|
||||
Method: http.MethodDelete,
|
||||
URL: fmt.Sprintf("/api/collections/systems/records/%s", userTwoSystem.Id),
|
||||
Headers: map[string]string{
|
||||
"Authorization": user1Token,
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
TestAppFactory: testAppFactory,
|
||||
ExpectedContent: []string{"resource wasn't found"},
|
||||
BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) {
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 2, systemsCount)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *pbTests.TestApp, res *http.Response) {
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 2, systemsCount)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Readonly cannot delete a system even if SHARE_ALL_SYSTEMS is enabled",
|
||||
Method: http.MethodDelete,
|
||||
URL: fmt.Sprintf("/api/collections/systems/records/%s", sharedSystem.Id),
|
||||
Headers: map[string]string{
|
||||
"Authorization": userReadonlyToken,
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{"resource wasn't found"},
|
||||
TestAppFactory: testAppFactory,
|
||||
BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "true")
|
||||
hub.SetCollectionAuthSettings()
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 2, systemsCount)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *pbTests.TestApp, res *http.Response) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "")
|
||||
hub.SetCollectionAuthSettings()
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 2, systemsCount)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "User 1 can delete user 2's system if SHARE_ALL_SYSTEMS is enabled",
|
||||
Method: http.MethodDelete,
|
||||
URL: fmt.Sprintf("/api/collections/systems/records/%s", userTwoSystem.Id),
|
||||
Headers: map[string]string{
|
||||
"Authorization": user1Token,
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
TestAppFactory: testAppFactory,
|
||||
BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "true")
|
||||
hub.SetCollectionAuthSettings()
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 2, systemsCount)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *pbTests.TestApp, res *http.Response) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "")
|
||||
hub.SetCollectionAuthSettings()
|
||||
systemsCount, _ := app.CountRecords("systems")
|
||||
assert.EqualValues(t, 1, systemsCount)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
// Package config provides functions for syncing systems with the config.yml file
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/henrygd/beszel/internal/entities/system"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/spf13/cast"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
Systems []systemConfig `yaml:"systems"`
|
||||
}
|
||||
|
||||
type systemConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port,omitempty"`
|
||||
Token string `yaml:"token,omitempty"`
|
||||
Users []string `yaml:"users"`
|
||||
}
|
||||
|
||||
// Syncs systems with the config.yml file
|
||||
func SyncSystems(e *core.ServeEvent) error {
|
||||
h := e.App
|
||||
configPath := filepath.Join(h.DataDir(), "config.yml")
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var config config
|
||||
err = yaml.Unmarshal(configData, &config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse config.yml: %v", err)
|
||||
}
|
||||
|
||||
if len(config.Systems) == 0 {
|
||||
log.Println("No systems defined in config.yml.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var firstUser *core.Record
|
||||
|
||||
// Create a map of email to user ID
|
||||
userEmailToID := make(map[string]string)
|
||||
users, err := h.FindAllRecords("users", dbx.NewExp("id != ''"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(users) > 0 {
|
||||
firstUser = users[0]
|
||||
for _, user := range users {
|
||||
userEmailToID[user.GetString("email")] = user.Id
|
||||
}
|
||||
}
|
||||
|
||||
// add default settings for systems if not defined in config
|
||||
for i := range config.Systems {
|
||||
system := &config.Systems[i]
|
||||
if system.Port == 0 {
|
||||
system.Port = 45876
|
||||
}
|
||||
if len(users) > 0 && len(system.Users) == 0 {
|
||||
// default to first user if none are defined
|
||||
system.Users = []string{firstUser.Id}
|
||||
} else {
|
||||
// Convert email addresses to user IDs
|
||||
userIDs := make([]string, 0, len(system.Users))
|
||||
for _, email := range system.Users {
|
||||
if id, ok := userEmailToID[email]; ok {
|
||||
userIDs = append(userIDs, id)
|
||||
} else {
|
||||
log.Printf("User %s not found", email)
|
||||
}
|
||||
}
|
||||
system.Users = userIDs
|
||||
}
|
||||
}
|
||||
|
||||
// Get existing systems
|
||||
existingSystems, err := h.FindAllRecords("systems", dbx.NewExp("id != ''"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a map of existing systems
|
||||
existingSystemsMap := make(map[string]*core.Record)
|
||||
for _, system := range existingSystems {
|
||||
key := system.GetString("name") + system.GetString("host") + system.GetString("port")
|
||||
existingSystemsMap[key] = system
|
||||
}
|
||||
|
||||
// Process systems from config
|
||||
for _, sysConfig := range config.Systems {
|
||||
key := sysConfig.Name + sysConfig.Host + cast.ToString(sysConfig.Port)
|
||||
if existingSystem, ok := existingSystemsMap[key]; ok {
|
||||
// Update existing system
|
||||
existingSystem.Set("name", sysConfig.Name)
|
||||
existingSystem.Set("users", sysConfig.Users)
|
||||
existingSystem.Set("port", sysConfig.Port)
|
||||
if err := h.Save(existingSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only update token if one is specified in config, otherwise preserve existing token
|
||||
if sysConfig.Token != "" {
|
||||
if err := updateFingerprintToken(h, existingSystem.Id, sysConfig.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
delete(existingSystemsMap, key)
|
||||
} else {
|
||||
// Create new system
|
||||
systemsCollection, err := h.FindCollectionByNameOrId("systems")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find systems collection: %v", err)
|
||||
}
|
||||
newSystem := core.NewRecord(systemsCollection)
|
||||
newSystem.Set("name", sysConfig.Name)
|
||||
newSystem.Set("host", sysConfig.Host)
|
||||
newSystem.Set("port", sysConfig.Port)
|
||||
newSystem.Set("users", sysConfig.Users)
|
||||
newSystem.Set("info", system.Info{})
|
||||
newSystem.Set("status", "pending")
|
||||
if err := h.Save(newSystem); err != nil {
|
||||
return fmt.Errorf("failed to create new system: %v", err)
|
||||
}
|
||||
|
||||
// For new systems, generate token if not provided
|
||||
token := sysConfig.Token
|
||||
if token == "" {
|
||||
token = uuid.New().String()
|
||||
}
|
||||
|
||||
// Create fingerprint record for new system
|
||||
if err := createFingerprintRecord(h, newSystem.Id, token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete systems not in config (and their fingerprint records will cascade delete)
|
||||
for _, system := range existingSystemsMap {
|
||||
if err := h.Delete(system); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("Systems synced with config.yml")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generates content for the config.yml file as a YAML string
|
||||
func generateYAML(h core.App) (string, error) {
|
||||
// Fetch all systems from the database
|
||||
systems, err := h.FindRecordsByFilter("systems", "id != ''", "name", -1, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a Config struct to hold the data
|
||||
config := config{
|
||||
Systems: make([]systemConfig, 0, len(systems)),
|
||||
}
|
||||
|
||||
// Fetch all users at once
|
||||
allUserIDs := make([]string, 0)
|
||||
for _, system := range systems {
|
||||
allUserIDs = append(allUserIDs, system.GetStringSlice("users")...)
|
||||
}
|
||||
userEmailMap, err := getUserEmailMap(h, allUserIDs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Fetch all fingerprint records to get tokens
|
||||
type fingerprintData struct {
|
||||
ID string `db:"id"`
|
||||
System string `db:"system"`
|
||||
Token string `db:"token"`
|
||||
}
|
||||
var fingerprints []fingerprintData
|
||||
err = h.DB().NewQuery("SELECT id, system, token FROM fingerprints").All(&fingerprints)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a map of system ID to token
|
||||
systemTokenMap := make(map[string]string)
|
||||
for _, fingerprint := range fingerprints {
|
||||
systemTokenMap[fingerprint.System] = fingerprint.Token
|
||||
}
|
||||
|
||||
// Populate the Config struct with system data
|
||||
for _, system := range systems {
|
||||
userIDs := system.GetStringSlice("users")
|
||||
userEmails := make([]string, 0, len(userIDs))
|
||||
for _, userID := range userIDs {
|
||||
if email, ok := userEmailMap[userID]; ok {
|
||||
userEmails = append(userEmails, email)
|
||||
}
|
||||
}
|
||||
|
||||
sysConfig := systemConfig{
|
||||
Name: system.GetString("name"),
|
||||
Host: system.GetString("host"),
|
||||
Port: cast.ToUint16(system.Get("port")),
|
||||
Users: userEmails,
|
||||
Token: systemTokenMap[system.Id],
|
||||
}
|
||||
config.Systems = append(config.Systems, sysConfig)
|
||||
}
|
||||
|
||||
// Marshal the Config struct to YAML
|
||||
yamlData, err := yaml.Marshal(&config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Add a header to the YAML
|
||||
yamlData = append([]byte("# Values for port, users, and token are optional.\n# Defaults are port 45876, the first created user, and a generated UUID token.\n\n"), yamlData...)
|
||||
|
||||
return string(yamlData), nil
|
||||
}
|
||||
|
||||
// New helper function to get a map of user IDs to emails
|
||||
func getUserEmailMap(h core.App, userIDs []string) (map[string]string, error) {
|
||||
users, err := h.FindRecordsByIds("users", userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userEmailMap := make(map[string]string, len(users))
|
||||
for _, user := range users {
|
||||
userEmailMap[user.Id] = user.GetString("email")
|
||||
}
|
||||
|
||||
return userEmailMap, nil
|
||||
}
|
||||
|
||||
// Helper function to update or create fingerprint token for an existing system
|
||||
func updateFingerprintToken(app core.App, systemID, token string) error {
|
||||
// Try to find existing fingerprint record
|
||||
fingerprint, err := app.FindFirstRecordByFilter("fingerprints", "system = {:system}", dbx.Params{"system": systemID})
|
||||
if err != nil {
|
||||
// If no fingerprint record exists, create one
|
||||
return createFingerprintRecord(app, systemID, token)
|
||||
}
|
||||
|
||||
// Update existing fingerprint record with new token (keep existing fingerprint)
|
||||
fingerprint.Set("token", token)
|
||||
return app.Save(fingerprint)
|
||||
}
|
||||
|
||||
// Helper function to create a new fingerprint record for a system
|
||||
func createFingerprintRecord(app core.App, systemID, token string) error {
|
||||
fingerprintsCollection, err := app.FindCollectionByNameOrId("fingerprints")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find fingerprints collection: %v", err)
|
||||
}
|
||||
|
||||
newFingerprint := core.NewRecord(fingerprintsCollection)
|
||||
newFingerprint.Set("system", systemID)
|
||||
newFingerprint.Set("token", token)
|
||||
newFingerprint.Set("fingerprint", "") // Empty fingerprint, will be set on first connection
|
||||
|
||||
return app.Save(newFingerprint)
|
||||
}
|
||||
|
||||
// Returns the current config.yml file as a JSON object
|
||||
func GetYamlConfig(e *core.RequestEvent) error {
|
||||
configContent, err := generateYAML(e.App)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.JSON(200, map[string]string{"config": configContent})
|
||||
}
|
||||
@@ -0,0 +1,246 @@
|
||||
//go:build testing
|
||||
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/henrygd/beszel/internal/tests"
|
||||
|
||||
"github.com/henrygd/beszel/internal/hub/config"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config struct for testing (copied from config package since it's not exported)
|
||||
type testConfig struct {
|
||||
Systems []testSystemConfig `yaml:"systems"`
|
||||
}
|
||||
|
||||
type testSystemConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port,omitempty"`
|
||||
Users []string `yaml:"users"`
|
||||
Token string `yaml:"token,omitempty"`
|
||||
}
|
||||
|
||||
// Helper function to create a test system for config tests
|
||||
// func createConfigTestSystem(app core.App, name, host string, port uint16, userIDs []string) (*core.Record, error) {
|
||||
// systemCollection, err := app.FindCollectionByNameOrId("systems")
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// system := core.NewRecord(systemCollection)
|
||||
// system.Set("name", name)
|
||||
// system.Set("host", host)
|
||||
// system.Set("port", port)
|
||||
// system.Set("users", userIDs)
|
||||
// system.Set("status", "pending")
|
||||
|
||||
// return system, app.Save(system)
|
||||
// }
|
||||
|
||||
// Helper function to create a fingerprint record
|
||||
func createConfigTestFingerprint(app core.App, systemID, token, fingerprint string) (*core.Record, error) {
|
||||
fingerprintCollection, err := app.FindCollectionByNameOrId("fingerprints")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fp := core.NewRecord(fingerprintCollection)
|
||||
fp.Set("system", systemID)
|
||||
fp.Set("token", token)
|
||||
fp.Set("fingerprint", fingerprint)
|
||||
|
||||
return fp, app.Save(fp)
|
||||
}
|
||||
|
||||
// TestConfigSyncWithTokens tests the config.SyncSystems function with various token scenarios
|
||||
func TestConfigSyncWithTokens(t *testing.T) {
|
||||
testHub, err := tests.NewTestHub()
|
||||
require.NoError(t, err)
|
||||
defer testHub.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user, err := tests.CreateUser(testHub.App, "admin@example.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupFunc func() (string, *core.Record, *core.Record) // Returns: existing token, system record, fingerprint record
|
||||
configYAML string
|
||||
expectToken string // Expected token after sync
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "new system with token in config",
|
||||
setupFunc: func() (string, *core.Record, *core.Record) {
|
||||
return "", nil, nil // No existing system
|
||||
},
|
||||
configYAML: `systems:
|
||||
- name: "new-server"
|
||||
host: "new.example.com"
|
||||
port: 45876
|
||||
users:
|
||||
- "admin@example.com"
|
||||
token: "explicit-token-123"`,
|
||||
expectToken: "explicit-token-123",
|
||||
description: "New system should use token from config",
|
||||
},
|
||||
{
|
||||
name: "existing system without token in config (preserve existing)",
|
||||
setupFunc: func() (string, *core.Record, *core.Record) {
|
||||
// Create existing system and fingerprint
|
||||
system, err := tests.CreateRecord(testHub.App, "systems", map[string]any{
|
||||
"name": "preserve-server",
|
||||
"host": "preserve.example.com",
|
||||
"port": 45876,
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
fingerprint, err := createConfigTestFingerprint(testHub.App, system.Id, "preserve-token-999", "preserve-fingerprint")
|
||||
require.NoError(t, err)
|
||||
|
||||
return "preserve-token-999", system, fingerprint
|
||||
},
|
||||
configYAML: `systems:
|
||||
- name: "preserve-server"
|
||||
host: "preserve.example.com"
|
||||
port: 45876
|
||||
users:
|
||||
- "admin@example.com"`,
|
||||
expectToken: "preserve-token-999",
|
||||
description: "Existing system should preserve original token when config doesn't specify one",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Setup test data
|
||||
_, existingSystem, existingFingerprint := tc.setupFunc()
|
||||
|
||||
// Write config file
|
||||
configPath := filepath.Join(testHub.DataDir(), "config.yml")
|
||||
err := os.WriteFile(configPath, []byte(tc.configYAML), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create serve event and sync
|
||||
event := &core.ServeEvent{App: testHub.App}
|
||||
err = config.SyncSystems(event)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the config to get the system name for verification
|
||||
var configData testConfig
|
||||
err = yaml.Unmarshal([]byte(tc.configYAML), &configData)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, configData.Systems, 1)
|
||||
systemName := configData.Systems[0].Name
|
||||
|
||||
// Find the system after sync
|
||||
systems, err := testHub.FindRecordsByFilter("systems", "name = {:name}", "", -1, 0, map[string]any{"name": systemName})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, systems, 1)
|
||||
system := systems[0]
|
||||
|
||||
// Find the fingerprint record
|
||||
fingerprints, err := testHub.FindRecordsByFilter("fingerprints", "system = {:system}", "", -1, 0, map[string]any{"system": system.Id})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, fingerprints, 1)
|
||||
fingerprint := fingerprints[0]
|
||||
|
||||
// Verify token
|
||||
actualToken := fingerprint.GetString("token")
|
||||
if tc.expectToken == "" {
|
||||
// For generated tokens, just verify it's not empty and is a valid UUID format
|
||||
assert.NotEmpty(t, actualToken, tc.description)
|
||||
assert.Len(t, actualToken, 36, "Generated token should be UUID format") // UUID length
|
||||
} else {
|
||||
assert.Equal(t, tc.expectToken, actualToken, tc.description)
|
||||
}
|
||||
|
||||
// For existing systems, verify fingerprint is preserved
|
||||
if existingFingerprint != nil {
|
||||
actualFingerprint := fingerprint.GetString("fingerprint")
|
||||
expectedFingerprint := existingFingerprint.GetString("fingerprint")
|
||||
assert.Equal(t, expectedFingerprint, actualFingerprint, "Fingerprint should be preserved")
|
||||
}
|
||||
|
||||
// Cleanup for next test
|
||||
if existingSystem != nil {
|
||||
testHub.Delete(existingSystem)
|
||||
}
|
||||
if existingFingerprint != nil {
|
||||
testHub.Delete(existingFingerprint)
|
||||
}
|
||||
// Clean up the new records
|
||||
testHub.Delete(system)
|
||||
testHub.Delete(fingerprint)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigMigrationScenario tests the specific migration scenario mentioned in the discussion
|
||||
func TestConfigMigrationScenario(t *testing.T) {
|
||||
testHub, err := tests.NewTestHub(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer testHub.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user, err := tests.CreateUser(testHub.App, "admin@example.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate migration scenario: system exists with token from migration
|
||||
existingSystem, err := tests.CreateRecord(testHub.App, "systems", map[string]any{
|
||||
"name": "migrated-server",
|
||||
"host": "migrated.example.com",
|
||||
"port": 45876,
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
migrationToken := "migration-generated-token-123"
|
||||
existingFingerprint, err := createConfigTestFingerprint(testHub.App, existingSystem.Id, migrationToken, "existing-fingerprint-from-agent")
|
||||
require.NoError(t, err)
|
||||
|
||||
// User exports config BEFORE this update (so no token field in YAML)
|
||||
oldConfigYAML := `systems:
|
||||
- name: "migrated-server"
|
||||
host: "migrated.example.com"
|
||||
port: 45876
|
||||
users:
|
||||
- "admin@example.com"`
|
||||
|
||||
// Write old config file and import
|
||||
configPath := filepath.Join(testHub.DataDir(), "config.yml")
|
||||
err = os.WriteFile(configPath, []byte(oldConfigYAML), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
event := &core.ServeEvent{App: testHub.App}
|
||||
err = config.SyncSystems(event)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the original token is preserved
|
||||
updatedFingerprint, err := testHub.FindRecordById("fingerprints", existingFingerprint.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actualToken := updatedFingerprint.GetString("token")
|
||||
assert.Equal(t, migrationToken, actualToken, "Migration token should be preserved when config doesn't specify a token")
|
||||
|
||||
// Verify fingerprint is also preserved
|
||||
actualFingerprint := updatedFingerprint.GetString("fingerprint")
|
||||
assert.Equal(t, "existing-fingerprint-from-agent", actualFingerprint, "Existing fingerprint should be preserved")
|
||||
|
||||
// Verify system still exists and is updated correctly
|
||||
updatedSystem, err := testHub.FindRecordById("systems", existingSystem.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "migrated-server", updatedSystem.GetString("name"))
|
||||
assert.Equal(t, "migrated.example.com", updatedSystem.GetString("host"))
|
||||
}
|
||||
@@ -0,0 +1,428 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/domain"
|
||||
"github.com/henrygd/beszel/internal/hub/domains/whois"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// APIHandler handles domain API requests
|
||||
type APIHandler struct {
|
||||
app core.App
|
||||
scheduler *Scheduler
|
||||
}
|
||||
|
||||
// NewAPIHandler creates a new domain API handler
|
||||
func NewAPIHandler(app core.App, scheduler *Scheduler) *APIHandler {
|
||||
return &APIHandler{
|
||||
app: app,
|
||||
scheduler: scheduler,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers domain API routes
|
||||
func (h *APIHandler) RegisterRoutes(se *core.ServeEvent) {
|
||||
api := se.Router.Group("/api/beszel/domains")
|
||||
api.Bind(apis.RequireAuth())
|
||||
|
||||
api.GET("/", h.listDomains)
|
||||
api.POST("/", h.createDomain)
|
||||
api.POST("/lookup", h.lookupDomain)
|
||||
api.GET("/{id}", h.getDomain)
|
||||
api.PATCH("/{id}", h.updateDomain)
|
||||
api.DELETE("/{id}", h.deleteDomain)
|
||||
api.POST("/{id}/refresh", h.refreshDomain)
|
||||
api.GET("/{id}/history", h.getDomainHistory)
|
||||
}
|
||||
|
||||
// listDomains lists all domains for the authenticated user
|
||||
func (h *APIHandler) listDomains(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("domains",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch domains", err)
|
||||
}
|
||||
|
||||
domains := make([]map[string]interface{}, 0, len(records))
|
||||
for _, record := range records {
|
||||
domains = append(domains, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, domains)
|
||||
}
|
||||
|
||||
// lookupDomain performs a WHOIS lookup without saving
|
||||
func (h *APIHandler) lookupDomain(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
var req struct {
|
||||
DomainName string `json:"domain_name"`
|
||||
}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.DomainName == "" {
|
||||
return e.BadRequestError("domain_name is required", nil)
|
||||
}
|
||||
|
||||
// Clean domain
|
||||
domainName := cleanDomain(req.DomainName)
|
||||
|
||||
// Perform lookup
|
||||
lookupSvc := whois.NewLookupService("")
|
||||
ctx := e.Request.Context()
|
||||
domainData, err := lookupSvc.LookupDomain(ctx, domainName)
|
||||
if err != nil {
|
||||
return e.InternalServerError("lookup failed", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, domainData)
|
||||
}
|
||||
|
||||
// createDomain creates a new domain entry
|
||||
func (h *APIHandler) createDomain(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
var req struct {
|
||||
DomainName string `json:"domain_name"`
|
||||
AutoLookup bool `json:"auto_lookup"`
|
||||
Tags []string `json:"tags"`
|
||||
Notes string `json:"notes"`
|
||||
PurchasePrice float64 `json:"purchase_price"`
|
||||
CurrentValue float64 `json:"current_value"`
|
||||
RenewalCost float64 `json:"renewal_cost"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
AlertDaysBefore int `json:"alert_days_before"`
|
||||
SSLAlertEnabled bool `json:"ssl_alert_enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.DomainName == "" {
|
||||
return e.BadRequestError("domain_name is required", nil)
|
||||
}
|
||||
|
||||
// Clean domain
|
||||
domainName := cleanDomain(req.DomainName)
|
||||
|
||||
// Check if domain already exists for this user
|
||||
existing, _ := h.app.FindFirstRecordByFilter("domains",
|
||||
"domain_name = {:domain} && user = {:user}",
|
||||
dbx.Params{"domain": domainName, "user": authRecord.Id})
|
||||
if existing != nil {
|
||||
return e.BadRequestError("domain already exists", nil)
|
||||
}
|
||||
|
||||
collection, err := h.app.FindCollectionByNameOrId("domains")
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to find collection", err)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if req.AlertDaysBefore <= 0 {
|
||||
req.AlertDaysBefore = 30
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("domain_name", domainName)
|
||||
record.Set("status", domain.DomainStatusUnknown)
|
||||
record.Set("active", true)
|
||||
record.Set("tags", req.Tags)
|
||||
record.Set("notes", req.Notes)
|
||||
record.Set("purchase_price", req.PurchasePrice)
|
||||
record.Set("current_value", req.CurrentValue)
|
||||
record.Set("renewal_cost", req.RenewalCost)
|
||||
record.Set("auto_renew", req.AutoRenew)
|
||||
record.Set("alert_days_before", req.AlertDaysBefore)
|
||||
record.Set("ssl_alert_enabled", req.SSLAlertEnabled)
|
||||
record.Set("user", authRecord.Id)
|
||||
|
||||
// Auto-lookup if requested
|
||||
if req.AutoLookup {
|
||||
lookupSvc := whois.NewLookupService("")
|
||||
ctx := e.Request.Context()
|
||||
domainData, err := lookupSvc.LookupDomain(ctx, domainName)
|
||||
if err == nil && domainData != nil {
|
||||
record.Set("expiry_date", domainData.ExpiryDate)
|
||||
record.Set("creation_date", domainData.CreationDate)
|
||||
record.Set("updated_date", domainData.UpdatedDate)
|
||||
record.Set("registrar_name", domainData.RegistrarName)
|
||||
record.Set("registrar_id", domainData.RegistrarID)
|
||||
record.Set("registrar_url", domainData.RegistrarURL)
|
||||
record.Set("dnssec", domainData.DNSSEC)
|
||||
record.Set("name_servers", domainData.NameServers)
|
||||
record.Set("mx_records", domainData.MXRecords)
|
||||
record.Set("txt_records", domainData.TXTRecords)
|
||||
record.Set("ipv4_addresses", domainData.IPv4Addresses)
|
||||
record.Set("ipv6_addresses", domainData.IPv6Addresses)
|
||||
record.Set("ssl_issuer", domainData.SSLIssuer)
|
||||
record.Set("ssl_valid_to", domainData.SSLValidTo)
|
||||
record.Set("host_country", domainData.HostCountry)
|
||||
record.Set("host_isp", domainData.HostISP)
|
||||
record.Set("favicon_url", domainData.FaviconURL)
|
||||
record.Set("last_checked", time.Now())
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to create domain", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusCreated, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// getDomain gets a single domain
|
||||
func (h *APIHandler) getDomain(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("domains", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("domain not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// updateDomain updates a domain
|
||||
func (h *APIHandler) updateDomain(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("domains", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("domain not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
// Update allowed fields
|
||||
if tags, ok := req["tags"]; ok {
|
||||
record.Set("tags", tags)
|
||||
}
|
||||
if notes, ok := req["notes"]; ok {
|
||||
record.Set("notes", notes)
|
||||
}
|
||||
if price, ok := req["purchase_price"]; ok {
|
||||
record.Set("purchase_price", price)
|
||||
}
|
||||
if value, ok := req["current_value"]; ok {
|
||||
record.Set("current_value", value)
|
||||
}
|
||||
if renewal, ok := req["renewal_cost"]; ok {
|
||||
record.Set("renewal_cost", renewal)
|
||||
}
|
||||
if autoRenew, ok := req["auto_renew"]; ok {
|
||||
record.Set("auto_renew", autoRenew)
|
||||
}
|
||||
if active, ok := req["active"]; ok {
|
||||
record.Set("active", active)
|
||||
}
|
||||
if alertDays, ok := req["alert_days_before"]; ok {
|
||||
record.Set("alert_days_before", alertDays)
|
||||
}
|
||||
if sslAlert, ok := req["ssl_alert_enabled"]; ok {
|
||||
record.Set("ssl_alert_enabled", sslAlert)
|
||||
}
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to update domain", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// deleteDomain deletes a domain
|
||||
func (h *APIHandler) deleteDomain(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("domains", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("domain not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
if err := h.app.Delete(record); err != nil {
|
||||
return e.InternalServerError("failed to delete domain", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// refreshDomain triggers a manual refresh
|
||||
func (h *APIHandler) refreshDomain(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("domains", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("domain not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
// Trigger refresh via scheduler
|
||||
if h.scheduler != nil {
|
||||
h.scheduler.RefreshDomain(id)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{"status": "refreshing"})
|
||||
}
|
||||
|
||||
// getDomainHistory gets the change history for a domain
|
||||
func (h *APIHandler) getDomainHistory(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
|
||||
// Verify domain ownership
|
||||
domain, err := h.app.FindRecordById("domains", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("domain not found", err)
|
||||
}
|
||||
if domain.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
// Fetch history
|
||||
records, err := h.app.FindAllRecords("domain_history",
|
||||
dbx.NewExp("domain = {:domain}", dbx.Params{"domain": id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch history", err)
|
||||
}
|
||||
|
||||
history := make([]map[string]interface{}, 0, len(records))
|
||||
for _, record := range records {
|
||||
history = append(history, map[string]interface{}{
|
||||
"id": record.Id,
|
||||
"change_type": record.GetString("change_type"),
|
||||
"field_name": record.GetString("field_name"),
|
||||
"old_value": record.GetString("old_value"),
|
||||
"new_value": record.GetString("new_value"),
|
||||
"created_at": record.GetDateTime("created_at").String(),
|
||||
})
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, history)
|
||||
}
|
||||
|
||||
// recordToResponse converts a record to API response
|
||||
func (h *APIHandler) recordToResponse(record *core.Record) map[string]interface{} {
|
||||
expiryDate := record.GetDateTime("expiry_date").Time()
|
||||
sslValidTo := record.GetDateTime("ssl_valid_to").Time()
|
||||
|
||||
// Calculate days until expiry
|
||||
daysUntilExpiry := -1
|
||||
if !expiryDate.IsZero() {
|
||||
daysUntilExpiry = int(time.Until(expiryDate).Hours() / 24)
|
||||
}
|
||||
|
||||
sslDaysUntil := -1
|
||||
if !sslValidTo.IsZero() {
|
||||
sslDaysUntil = int(time.Until(sslValidTo).Hours() / 24)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"id": record.Id,
|
||||
"domain_name": record.GetString("domain_name"),
|
||||
"status": record.GetString("status"),
|
||||
"active": record.GetBool("active"),
|
||||
"expiry_date": expiryDate,
|
||||
"creation_date": record.GetDateTime("creation_date").String(),
|
||||
"updated_date": record.GetDateTime("updated_date").String(),
|
||||
"days_until_expiry": daysUntilExpiry,
|
||||
"registrar_name": record.GetString("registrar_name"),
|
||||
"registrar_id": record.GetString("registrar_id"),
|
||||
"name_servers": record.Get("name_servers"),
|
||||
"ipv4_addresses": record.Get("ipv4_addresses"),
|
||||
"ssl_issuer": record.GetString("ssl_issuer"),
|
||||
"ssl_valid_to": sslValidTo,
|
||||
"ssl_days_until": sslDaysUntil,
|
||||
"host_country": record.GetString("host_country"),
|
||||
"host_isp": record.GetString("host_isp"),
|
||||
"purchase_price": record.GetFloat("purchase_price"),
|
||||
"current_value": record.GetFloat("current_value"),
|
||||
"renewal_cost": record.GetFloat("renewal_cost"),
|
||||
"auto_renew": record.GetBool("auto_renew"),
|
||||
"alert_days_before": record.GetInt("alert_days_before"),
|
||||
"ssl_alert_enabled": record.GetBool("ssl_alert_enabled"),
|
||||
"tags": record.Get("tags"),
|
||||
"notes": record.GetString("notes"),
|
||||
"favicon_url": record.GetString("favicon_url"),
|
||||
"last_checked": record.GetDateTime("last_checked").String(),
|
||||
"created": record.GetDateTime("created").String(),
|
||||
"updated": record.GetDateTime("updated").String(),
|
||||
}
|
||||
}
|
||||
|
||||
// cleanDomain cleans and normalizes a domain name
|
||||
func cleanDomain(domain string) string {
|
||||
// Remove protocol
|
||||
domain = strings.TrimPrefix(domain, "https://")
|
||||
domain = strings.TrimPrefix(domain, "http://")
|
||||
// Remove www prefix
|
||||
domain = strings.TrimPrefix(domain, "www.")
|
||||
// Remove path and query
|
||||
if idx := strings.IndexAny(domain, "/?#"); idx != -1 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
// Remove port
|
||||
if idx := strings.Index(domain, ":"); idx != -1 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(domain))
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/domain"
|
||||
"github.com/henrygd/beszel/internal/hub/domains/whois"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// Scheduler manages periodic domain checks for expiry and SSL
|
||||
type Scheduler struct {
|
||||
app core.App
|
||||
whois *whois.LookupService
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewScheduler creates a new domain scheduler
|
||||
func NewScheduler(app core.App) *Scheduler {
|
||||
return &Scheduler{
|
||||
app: app,
|
||||
whois: whois.NewLookupService(""), // API key can be configured via env
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the domain check scheduler
|
||||
func (s *Scheduler) Start() {
|
||||
log.Println("[domain-scheduler] Starting domain scheduler")
|
||||
|
||||
// Check domains daily
|
||||
s.ticker = time.NewTicker(24 * time.Hour)
|
||||
|
||||
// Run initial check immediately
|
||||
go s.checkDomains()
|
||||
|
||||
// Schedule periodic checks
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ticker.C:
|
||||
s.checkDomains()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop halts the domain scheduler
|
||||
func (s *Scheduler) Stop() {
|
||||
log.Println("[domain-scheduler] Stopping domain scheduler")
|
||||
if s.ticker != nil {
|
||||
s.ticker.Stop()
|
||||
}
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
// checkDomains checks all active domains for expiry and updates info
|
||||
func (s *Scheduler) checkDomains() {
|
||||
log.Println("[domain-scheduler] Checking domains")
|
||||
|
||||
// Find all active domains
|
||||
records, err := s.app.FindAllRecords("domains",
|
||||
dbx.NewExp("active = true"),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[domain-scheduler] Failed to fetch domains: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
s.wg.Add(1)
|
||||
go func(r *core.Record) {
|
||||
defer s.wg.Done()
|
||||
s.checkDomain(r)
|
||||
}(record)
|
||||
}
|
||||
}
|
||||
|
||||
// checkDomain checks a single domain
|
||||
func (s *Scheduler) checkDomain(record *core.Record) {
|
||||
domainName := record.GetString("domain_name")
|
||||
userID := record.GetString("user")
|
||||
|
||||
log.Printf("[domain-scheduler] Checking domain: %s", domainName)
|
||||
|
||||
// Perform WHOIS and DNS lookup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
newData, err := s.whois.LookupDomain(ctx, domainName)
|
||||
if err != nil {
|
||||
log.Printf("[domain-scheduler] Failed to lookup %s: %v", domainName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Track changes
|
||||
history := s.trackChanges(record, newData)
|
||||
|
||||
// Update record
|
||||
record.Set("expiry_date", newData.ExpiryDate)
|
||||
record.Set("creation_date", newData.CreationDate)
|
||||
record.Set("updated_date", newData.UpdatedDate)
|
||||
record.Set("registrar_name", newData.RegistrarName)
|
||||
record.Set("registrar_id", newData.RegistrarID)
|
||||
record.Set("registrar_url", newData.RegistrarURL)
|
||||
record.Set("dnssec", newData.DNSSEC)
|
||||
record.Set("name_servers", newData.NameServers)
|
||||
record.Set("mx_records", newData.MXRecords)
|
||||
record.Set("txt_records", newData.TXTRecords)
|
||||
record.Set("ipv4_addresses", newData.IPv4Addresses)
|
||||
record.Set("ipv6_addresses", newData.IPv6Addresses)
|
||||
record.Set("ssl_issuer", newData.SSLIssuer)
|
||||
record.Set("ssl_valid_to", newData.SSLValidTo)
|
||||
record.Set("host_country", newData.HostCountry)
|
||||
record.Set("host_isp", newData.HostISP)
|
||||
record.Set("last_checked", time.Now())
|
||||
|
||||
// Update status
|
||||
status := domain.DomainStatusActive
|
||||
if newData.ExpiryDate != nil {
|
||||
if newData.IsExpired() {
|
||||
status = domain.DomainStatusExpired
|
||||
} else if newData.IsExpiring() {
|
||||
status = domain.DomainStatusExpiring
|
||||
}
|
||||
} else {
|
||||
status = domain.DomainStatusUnknown
|
||||
}
|
||||
record.Set("status", status)
|
||||
|
||||
if err := s.app.Save(record); err != nil {
|
||||
log.Printf("[domain-scheduler] Failed to update %s: %v", domainName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save history entries
|
||||
for _, h := range history {
|
||||
s.saveHistory(h, record.Id, userID)
|
||||
}
|
||||
|
||||
// Trigger notifications for expiring domains
|
||||
if status == domain.DomainStatusExpiring || status == domain.DomainStatusExpired {
|
||||
s.triggerNotification(record, status)
|
||||
}
|
||||
|
||||
// Check SSL expiry
|
||||
if newData.SSLAlertEnabled && newData.SSLValidTo != nil {
|
||||
sslDays := newData.SSLDaysUntilExpiry()
|
||||
if sslDays <= newData.AlertDaysBefore {
|
||||
s.triggerSSLNotification(record, sslDays)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[domain-scheduler] Updated domain: %s (status: %s)", domainName, status)
|
||||
}
|
||||
|
||||
// trackChanges compares old and new data and returns history entries
|
||||
func (s *Scheduler) trackChanges(oldRecord *core.Record, newData *domain.Domain) []domain.DomainHistory {
|
||||
var history []domain.DomainHistory
|
||||
now := time.Now()
|
||||
|
||||
// Check expiry date change
|
||||
oldExpiry := oldRecord.GetDateTime("expiry_date").Time()
|
||||
if newData.ExpiryDate != nil && !oldExpiry.IsZero() && !newData.ExpiryDate.Equal(oldExpiry) {
|
||||
history = append(history, domain.DomainHistory{
|
||||
ChangeType: domain.ChangeTypeExpiry,
|
||||
FieldName: "expiry_date",
|
||||
OldValue: oldExpiry.Format("2006-01-02"),
|
||||
NewValue: newData.ExpiryDate.Format("2006-01-02"),
|
||||
CreatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
// Check registrar change
|
||||
oldRegistrar := oldRecord.GetString("registrar_name")
|
||||
if newData.RegistrarName != "" && newData.RegistrarName != oldRegistrar {
|
||||
history = append(history, domain.DomainHistory{
|
||||
ChangeType: domain.ChangeTypeRegistrar,
|
||||
FieldName: "registrar_name",
|
||||
OldValue: oldRegistrar,
|
||||
NewValue: newData.RegistrarName,
|
||||
CreatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
// Check status change
|
||||
oldStatus := oldRecord.GetString("status")
|
||||
newStatus := newData.GetStatus()
|
||||
if newStatus != oldStatus {
|
||||
history = append(history, domain.DomainHistory{
|
||||
ChangeType: domain.ChangeTypeStatus,
|
||||
FieldName: "status",
|
||||
OldValue: oldStatus,
|
||||
NewValue: newStatus,
|
||||
CreatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
// Check SSL expiry change
|
||||
oldSSLExpiry := oldRecord.GetDateTime("ssl_valid_to").Time()
|
||||
if newData.SSLValidTo != nil && !oldSSLExpiry.IsZero() && !newData.SSLValidTo.Equal(oldSSLExpiry) {
|
||||
history = append(history, domain.DomainHistory{
|
||||
ChangeType: domain.ChangeTypeSSL,
|
||||
FieldName: "ssl_valid_to",
|
||||
OldValue: oldSSLExpiry.Format("2006-01-02"),
|
||||
NewValue: newData.SSLValidTo.Format("2006-01-02"),
|
||||
CreatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// saveHistory saves a history entry to the database
|
||||
func (s *Scheduler) saveHistory(h domain.DomainHistory, domainID, userID string) {
|
||||
collection, err := s.app.FindCollectionByNameOrId("domain_history")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("domain", domainID)
|
||||
record.Set("change_type", h.ChangeType)
|
||||
record.Set("field_name", h.FieldName)
|
||||
record.Set("old_value", h.OldValue)
|
||||
record.Set("new_value", h.NewValue)
|
||||
record.Set("user", userID)
|
||||
record.Set("created_at", h.CreatedAt)
|
||||
|
||||
if err := s.app.Save(record); err != nil {
|
||||
log.Printf("[domain-scheduler] Failed to save history: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// triggerNotification sends notification for domain events
|
||||
func (s *Scheduler) triggerNotification(record *core.Record, status string) {
|
||||
domainName := record.GetString("domain_name")
|
||||
daysUntil := 0
|
||||
|
||||
if expiry := record.GetDateTime("expiry_date"); !expiry.IsZero() {
|
||||
daysUntil = int(time.Until(expiry.Time()).Hours() / 24)
|
||||
}
|
||||
|
||||
var title, body string
|
||||
switch status {
|
||||
case domain.DomainStatusExpired:
|
||||
title = fmt.Sprintf("Domain Expired: %s", domainName)
|
||||
body = fmt.Sprintf("The domain %s has expired.", domainName)
|
||||
case domain.DomainStatusExpiring:
|
||||
title = fmt.Sprintf("Domain Expiring Soon: %s", domainName)
|
||||
body = fmt.Sprintf("The domain %s expires in %d days.", domainName, daysUntil)
|
||||
}
|
||||
|
||||
log.Printf("[domain-scheduler] %s: %s", title, body)
|
||||
|
||||
// TODO: Integrate with notification system
|
||||
// This would call the notification dispatcher similar to monitor alerts
|
||||
}
|
||||
|
||||
// triggerSSLNotification sends notification for SSL expiry
|
||||
func (s *Scheduler) triggerSSLNotification(record *core.Record, daysUntil int) {
|
||||
domainName := record.GetString("domain_name")
|
||||
|
||||
title := fmt.Sprintf("SSL Certificate Expiring: %s", domainName)
|
||||
body := fmt.Sprintf("The SSL certificate for %s expires in %d days.", domainName, daysUntil)
|
||||
|
||||
log.Printf("[domain-scheduler] %s: %s", title, body)
|
||||
|
||||
// TODO: Integrate with notification system
|
||||
}
|
||||
|
||||
// RefreshDomain manually refreshes a single domain
|
||||
func (s *Scheduler) RefreshDomain(domainID string) error {
|
||||
record, err := s.app.FindRecordById("domains", domainID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.checkDomain(record)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAllDomains manually triggers a check of all active domains
|
||||
func (s *Scheduler) CheckAllDomains() {
|
||||
s.checkDomains()
|
||||
}
|
||||
@@ -0,0 +1,736 @@
|
||||
package whois
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/domain"
|
||||
)
|
||||
|
||||
// LookupService handles WHOIS lookups with multiple fallback methods
|
||||
type LookupService struct {
|
||||
whoisXMLAPIKey string
|
||||
rdapCache map[string]string
|
||||
}
|
||||
|
||||
// NewLookupService creates a new WHOIS lookup service
|
||||
func NewLookupService(apiKey string) *LookupService {
|
||||
return &LookupService{
|
||||
whoisXMLAPIKey: apiKey,
|
||||
rdapCache: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// LookupDomain performs a comprehensive domain lookup (WHOIS, DNS, SSL, Host)
|
||||
func (s *LookupService) LookupDomain(ctx context.Context, domainName string) (*domain.Domain, error) {
|
||||
// Clean domain name
|
||||
domainName = cleanDomain(domainName)
|
||||
|
||||
// Initialize domain struct
|
||||
d := &domain.Domain{
|
||||
DomainName: domainName,
|
||||
Active: true,
|
||||
AlertDaysBefore: 30, // Default: alert 30 days before expiry
|
||||
Tags: []string{},
|
||||
NameServers: []string{},
|
||||
MXRecords: []string{},
|
||||
TXTRecords: []string{},
|
||||
IPv4Addresses: []string{},
|
||||
IPv6Addresses: []string{},
|
||||
}
|
||||
|
||||
// Perform WHOIS lookup
|
||||
whoisData, err := s.LookupWHOIS(ctx, domainName)
|
||||
if err == nil && whoisData != nil {
|
||||
s.applyWHOISData(d, whoisData)
|
||||
}
|
||||
|
||||
// Perform DNS lookups
|
||||
s.lookupDNS(ctx, domainName, d)
|
||||
|
||||
// Perform SSL lookup
|
||||
s.lookupSSL(ctx, domainName, d)
|
||||
|
||||
// Perform host lookup (using first IPv4)
|
||||
if len(d.IPv4Addresses) > 0 {
|
||||
s.lookupHost(d.IPv4Addresses[0], d)
|
||||
}
|
||||
|
||||
// Fetch favicon
|
||||
d.FaviconURL = fmt.Sprintf("https://www.google.com/s2/favicons?domain=%s&sz=128", domainName)
|
||||
|
||||
d.LastChecked = time.Now()
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// LookupWHOIS performs WHOIS lookup with multiple fallback methods
|
||||
func (s *LookupService) LookupWHOIS(ctx context.Context, domainName string) (*domain.WHOISData, error) {
|
||||
// Try RDAP first (modern replacement for WHOIS)
|
||||
data, err := s.tryRDAP(ctx, domainName)
|
||||
if err == nil && data != nil && hasValidData(data) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Try native whois command
|
||||
data, err = s.tryNativeWHOIS(ctx, domainName)
|
||||
if err == nil && data != nil && hasValidData(data) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Try WhoisXML API if key is configured
|
||||
if s.whoisXMLAPIKey != "" {
|
||||
data, err = s.tryWhoisXML(ctx, domainName)
|
||||
if err == nil && data != nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("all WHOIS lookup methods failed for %s", domainName)
|
||||
}
|
||||
|
||||
// tryRDAP attempts RDAP lookup
|
||||
func (s *LookupService) tryRDAP(ctx context.Context, domainName string) (*domain.WHOISData, error) {
|
||||
// Get TLD
|
||||
parts := strings.Split(domainName, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("invalid domain format")
|
||||
}
|
||||
tld := parts[len(parts)-1]
|
||||
|
||||
// Get RDAP base URL
|
||||
baseURL, err := s.getRDAPBaseURL(ctx, tld)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make RDAP request
|
||||
url := fmt.Sprintf("%s/domain/%s", baseURL, domainName)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/rdap+json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("RDAP returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var rdapResp struct {
|
||||
LdhName string `json:"ldhName"`
|
||||
Handle string `json:"handle"`
|
||||
Status []string `json:"status"`
|
||||
Events []struct {
|
||||
EventAction string `json:"eventAction"`
|
||||
EventDate string `json:"eventDate"`
|
||||
} `json:"events"`
|
||||
Entities []struct {
|
||||
Roles []string `json:"roles"`
|
||||
PublicIds []struct {
|
||||
Type string `json:"type"`
|
||||
Identifier string `json:"identifier"`
|
||||
} `json:"publicIds"`
|
||||
VCardArray []interface{} `json:"vcardArray"`
|
||||
} `json:"entities"`
|
||||
SecureDNS struct {
|
||||
ZoneSigned bool `json:"zoneSigned"`
|
||||
} `json:"secureDNS"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rdapResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse events
|
||||
var creationDate, expiryDate, updatedDate *time.Time
|
||||
for _, event := range rdapResp.Events {
|
||||
t, _ := time.Parse(time.RFC3339, event.EventDate)
|
||||
switch event.EventAction {
|
||||
case "registration":
|
||||
creationDate = &t
|
||||
case "expiration":
|
||||
expiryDate = &t
|
||||
case "last changed":
|
||||
updatedDate = &t
|
||||
}
|
||||
}
|
||||
|
||||
// Find registrar
|
||||
var registrarName, registrarID string
|
||||
for _, entity := range rdapResp.Entities {
|
||||
for _, role := range entity.Roles {
|
||||
if role == "registrar" {
|
||||
// Try to get name from vCard
|
||||
if len(entity.VCardArray) > 1 {
|
||||
if vcard, ok := entity.VCardArray[1].([]interface{}); ok {
|
||||
for _, item := range vcard {
|
||||
if arr, ok := item.([]interface{}); ok && len(arr) >= 4 {
|
||||
if arr[0] == "fn" {
|
||||
if name, ok := arr[3].(string); ok {
|
||||
registrarName = name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Get IANA ID
|
||||
for _, pid := range entity.PublicIds {
|
||||
if pid.Type == "IANA Registrar ID" {
|
||||
registrarID = pid.Identifier
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dnssec := ""
|
||||
if rdapResp.SecureDNS.ZoneSigned {
|
||||
dnssec = "signed"
|
||||
}
|
||||
|
||||
return &domain.WHOISData{
|
||||
DomainName: rdapResp.LdhName,
|
||||
Status: rdapResp.Status,
|
||||
DNSSEC: dnssec,
|
||||
Dates: domain.WHOISDates{
|
||||
ExpiryDate: expiryDate,
|
||||
CreationDate: creationDate,
|
||||
UpdatedDate: updatedDate,
|
||||
},
|
||||
Registrar: domain.WHOISRegistrar{
|
||||
Name: registrarName,
|
||||
ID: registrarID,
|
||||
URL: "",
|
||||
RegistryDomainID: rdapResp.Handle,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryNativeWHOIS tries the native whois command
|
||||
func (s *LookupService) tryNativeWHOIS(ctx context.Context, domainName string) (*domain.WHOISData, error) {
|
||||
// Check if whois command exists
|
||||
_, err := exec.LookPath("whois")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("whois command not found")
|
||||
}
|
||||
|
||||
// Execute whois with timeout
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "whois", domainName)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.parseWHOISOutput(string(output), domainName)
|
||||
}
|
||||
|
||||
// tryWhoisXML tries the WhoisXML API
|
||||
func (s *LookupService) tryWhoisXML(ctx context.Context, domainName string) (*domain.WHOISData, error) {
|
||||
if s.whoisXMLAPIKey == "" {
|
||||
return nil, fmt.Errorf("no API key configured")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"https://www.whoisxmlapi.com/whoisserver/WhoisService?apiKey=%s&outputFormat=json&domainName=%s",
|
||||
s.whoisXMLAPIKey, domainName,
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("WhoisXML API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
WhoisRecord struct {
|
||||
DomainName string `json:"domainName"`
|
||||
RegistrarName string `json:"registrarName"`
|
||||
RegistrarIANAID string `json:"registrarIANAID"`
|
||||
RegistryData struct {
|
||||
Status string `json:"status"`
|
||||
CreatedDateNormalized string `json:"createdDateNormalized"`
|
||||
ExpiresDateNormalized string `json:"expiresDateNormalized"`
|
||||
UpdatedDateNormalized string `json:"updatedDateNormalized"`
|
||||
WhoisServer string `json:"whoisServer"`
|
||||
} `json:"registryData"`
|
||||
} `json:"WhoisRecord"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record := result.WhoisRecord
|
||||
registry := record.RegistryData
|
||||
|
||||
creationDate, _ := time.Parse("2006-01-02", registry.CreatedDateNormalized)
|
||||
expiryDate, _ := time.Parse("2006-01-02", registry.ExpiresDateNormalized)
|
||||
updatedDate, _ := time.Parse("2006-01-02", registry.UpdatedDateNormalized)
|
||||
|
||||
return &domain.WHOISData{
|
||||
DomainName: record.DomainName,
|
||||
Status: strings.Split(registry.Status, ", "),
|
||||
Dates: domain.WHOISDates{
|
||||
ExpiryDate: &expiryDate,
|
||||
CreationDate: &creationDate,
|
||||
UpdatedDate: &updatedDate,
|
||||
},
|
||||
Registrar: domain.WHOISRegistrar{
|
||||
Name: record.RegistrarName,
|
||||
ID: record.RegistrarIANAID,
|
||||
URL: fmt.Sprintf("https://%s", registry.WhoisServer),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseWHOISOutput parses the raw WHOIS text output
|
||||
func (s *LookupService) parseWHOISOutput(output, domainName string) (*domain.WHOISData, error) {
|
||||
lines := strings.Split(output, "\n")
|
||||
data := make(map[string]string)
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "%") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse "Key: Value" format
|
||||
if idx := strings.Index(line, ":"); idx > 0 {
|
||||
key := strings.ToLower(strings.TrimSpace(line[:idx]))
|
||||
value := strings.TrimSpace(line[idx+1:])
|
||||
// Normalize key
|
||||
key = strings.ReplaceAll(key, " ", "_")
|
||||
key = strings.ReplaceAll(key, "/", "_")
|
||||
if value != "" && !strings.HasPrefix(value, "REDACTED") {
|
||||
data[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract dates
|
||||
expiryDate := s.parseDate(data["registry_expiry_date"], data["registrar_registration_expiration_date"],
|
||||
data["expiry_date"], data["expiration_time"], data["expire"], data["paid_until"])
|
||||
creationDate := s.parseDate(data["creation_date"], data["created_date"], data["registration_time"])
|
||||
updatedDate := s.parseDate(data["updated_date"], data["last_updated"])
|
||||
|
||||
// Extract registrar - try multiple field names used by different WHOIS servers
|
||||
registrarName := data["registrar"]
|
||||
if registrarName == "" {
|
||||
registrarName = data["registrar_name"]
|
||||
}
|
||||
if registrarName == "" {
|
||||
registrarName = data["sponsoring_registrar"]
|
||||
}
|
||||
if registrarName == "" {
|
||||
registrarName = data["registrar_organization"]
|
||||
}
|
||||
if registrarName == "" {
|
||||
registrarName = "Unknown"
|
||||
}
|
||||
|
||||
// Parse status
|
||||
statusStr := data["domain_status"]
|
||||
var statuses []string
|
||||
if statusStr != "" {
|
||||
statuses = s.parseStatus(statusStr)
|
||||
}
|
||||
|
||||
// Extract registrant contact info
|
||||
registrant := domain.WHOISContact{
|
||||
Name: data["registrant_name"],
|
||||
Organization: data["registrant_organization"],
|
||||
Street: data["registrant_street"],
|
||||
City: data["registrant_city"],
|
||||
State: data["registrant_state_province"],
|
||||
Country: data["registrant_country"],
|
||||
PostalCode: data["registrant_postal_code"],
|
||||
}
|
||||
|
||||
// Try alternate field names for registrant
|
||||
if registrant.Name == "" {
|
||||
registrant.Name = data["registrant"]
|
||||
}
|
||||
if registrant.Organization == "" {
|
||||
registrant.Organization = data["org"]
|
||||
}
|
||||
if registrant.Country == "" {
|
||||
registrant.Country = data["country"]
|
||||
}
|
||||
|
||||
// Parse DNSSEC more thoroughly
|
||||
dnssec := data["dnssec"]
|
||||
if dnssec == "" {
|
||||
// Try alternate field names
|
||||
dnssec = data["dnssec_signed"]
|
||||
}
|
||||
if dnssec == "" {
|
||||
dnssec = data["signed_dnssec"]
|
||||
}
|
||||
// Normalize DNSSEC value
|
||||
dnssec = strings.ToLower(strings.TrimSpace(dnssec))
|
||||
if dnssec == "signed" || dnssec == "yes" || dnssec == "true" {
|
||||
dnssec = "signed"
|
||||
} else if dnssec == "unsigned" || dnssec == "no" || dnssec == "false" {
|
||||
dnssec = "unsigned"
|
||||
}
|
||||
|
||||
return &domain.WHOISData{
|
||||
DomainName: domainName,
|
||||
Status: statuses,
|
||||
DNSSEC: dnssec,
|
||||
Dates: domain.WHOISDates{
|
||||
ExpiryDate: expiryDate,
|
||||
CreationDate: creationDate,
|
||||
UpdatedDate: updatedDate,
|
||||
},
|
||||
Registrar: domain.WHOISRegistrar{
|
||||
Name: registrarName,
|
||||
ID: data["registrar_iana_id"],
|
||||
URL: data["registrar_url"],
|
||||
RegistryDomainID: data["registry_domain_id"],
|
||||
},
|
||||
Registrant: registrant,
|
||||
Abuse: domain.WHOISAbuse{
|
||||
Email: data["registrar_abuse_contact_email"],
|
||||
Phone: data["registrar_abuse_contact_phone"],
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseDate attempts to parse a date from multiple possible formats
|
||||
func (s *LookupService) parseDate(dates ...string) *time.Time {
|
||||
formats := []string{
|
||||
// Standard ISO formats
|
||||
"2006-01-02",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02T15:04:05-07:00",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02 15:04:05.0",
|
||||
// US formats
|
||||
"01/02/2006",
|
||||
"01/02/2006 15:04:05",
|
||||
// European formats
|
||||
"02/01/2006",
|
||||
"02.01.2006",
|
||||
// Verbal formats
|
||||
"Jan 2 2006",
|
||||
"January 2 2006",
|
||||
"2 Jan 2006",
|
||||
"2 January 2006",
|
||||
"Jan 02 2006",
|
||||
"02-Jan-2006",
|
||||
"2-Jan-2006",
|
||||
// With timezone names
|
||||
"2006-01-02 15:04:05 MST",
|
||||
"2006-01-02 15:04:05 UTC",
|
||||
// Common registrar formats
|
||||
"Monday, January 2, 2006",
|
||||
"Mon, 02 Jan 2006 15:04:05 MST",
|
||||
"Mon, 2 Jan 2006 15:04:05 MST",
|
||||
// Additional formats
|
||||
"20060102",
|
||||
"20060102150405",
|
||||
}
|
||||
|
||||
for _, dateStr := range dates {
|
||||
if dateStr == "" || dateStr == "REDACTED" || strings.Contains(dateStr, "REDACTED") {
|
||||
continue
|
||||
}
|
||||
dateStr = strings.TrimSpace(dateStr)
|
||||
// Remove common prefixes/suffixes that don't help
|
||||
dateStr = strings.TrimPrefix(dateStr, "before ")
|
||||
dateStr = strings.TrimPrefix(dateStr, "after ")
|
||||
|
||||
for _, format := range formats {
|
||||
if t, err := time.Parse(format, dateStr); err == nil {
|
||||
return &t
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseStatus parses domain status from WHOIS output
|
||||
func (s *LookupService) parseStatus(statusStr string) []string {
|
||||
knownStatuses := []string{
|
||||
"clientDeleteProhibited", "clientHold", "clientRenewProhibited",
|
||||
"clientTransferProhibited", "clientUpdateProhibited",
|
||||
"serverDeleteProhibited", "serverHold", "serverRenewProhibited",
|
||||
"serverTransferProhibited", "serverUpdateProhibited",
|
||||
"inactive", "ok", "pendingCreate", "pendingDelete", "pendingRenew",
|
||||
"pendingRestore", "pendingTransfer", "pendingUpdate",
|
||||
"addPeriod", "autoRenewPeriod", "renewPeriod", "transferPeriod",
|
||||
}
|
||||
|
||||
statusStr = strings.ToLower(statusStr)
|
||||
var matches []string
|
||||
for _, status := range knownStatuses {
|
||||
if strings.Contains(statusStr, strings.ToLower(status)) {
|
||||
matches = append(matches, status)
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
// getRDAPBaseURL gets the RDAP base URL for a TLD
|
||||
func (s *LookupService) getRDAPBaseURL(ctx context.Context, tld string) (string, error) {
|
||||
// Check cache
|
||||
if url, ok := s.rdapCache[tld]; ok {
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// Fetch IANA RDAP bootstrap
|
||||
url := "https://data.iana.org/rdap/dns.json"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var bootstrap struct {
|
||||
Services [][]interface{} `json:"services"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bootstrap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Populate cache and find URL for this TLD
|
||||
for _, service := range bootstrap.Services {
|
||||
if len(service) >= 2 {
|
||||
tlds, ok1 := service[0].([]interface{})
|
||||
urls, ok2 := service[1].([]interface{})
|
||||
if ok1 && ok2 && len(urls) > 0 {
|
||||
if urlStr, ok := urls[0].(string); ok {
|
||||
for _, t := range tlds {
|
||||
if tldStr, ok := t.(string); ok {
|
||||
s.rdapCache[tldStr] = strings.TrimSuffix(urlStr, "/")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if url, ok := s.rdapCache[tld]; ok {
|
||||
return url, nil
|
||||
}
|
||||
return "", fmt.Errorf("no RDAP server found for TLD .%s", tld)
|
||||
}
|
||||
|
||||
// lookupDNS performs DNS lookups
|
||||
func (s *LookupService) lookupDNS(ctx context.Context, domainName string, d *domain.Domain) {
|
||||
// NS records
|
||||
nsRecords, _ := net.LookupNS(domainName)
|
||||
for _, ns := range nsRecords {
|
||||
d.NameServers = append(d.NameServers, ns.Host)
|
||||
}
|
||||
|
||||
// MX records
|
||||
mxRecords, _ := net.LookupMX(domainName)
|
||||
for _, mx := range mxRecords {
|
||||
d.MXRecords = append(d.MXRecords, fmt.Sprintf("%s (priority: %d)", mx.Host, mx.Pref))
|
||||
}
|
||||
|
||||
// TXT records
|
||||
txtRecords, _ := net.LookupTXT(domainName)
|
||||
d.TXTRecords = txtRecords
|
||||
|
||||
// IPv4
|
||||
ipv4Addrs, _ := net.LookupHost(domainName)
|
||||
for _, ip := range ipv4Addrs {
|
||||
if strings.Contains(ip, ".") {
|
||||
d.IPv4Addresses = append(d.IPv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// IPv6
|
||||
ipv6Addrs, _ := net.LookupIP(domainName)
|
||||
for _, ip := range ipv6Addrs {
|
||||
if ip.To4() == nil {
|
||||
d.IPv6Addresses = append(d.IPv6Addresses, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lookupSSL fetches SSL certificate info
|
||||
func (s *LookupService) lookupSSL(ctx context.Context, domainName string, d *domain.Domain) {
|
||||
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}, "tcp", domainName+":443", &tls.Config{
|
||||
ServerName: domainName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
cert := conn.ConnectionState().PeerCertificates[0]
|
||||
if cert != nil {
|
||||
if len(cert.Issuer.Organization) > 0 {
|
||||
d.SSLIssuer = cert.Issuer.Organization[0]
|
||||
}
|
||||
if len(cert.Issuer.Country) > 0 {
|
||||
d.SSLIssuerCountry = cert.Issuer.Country[0]
|
||||
}
|
||||
d.SSLValidFrom = &cert.NotBefore
|
||||
d.SSLValidTo = &cert.NotAfter
|
||||
d.SSLSubject = cert.Subject.CommonName
|
||||
|
||||
// Format fingerprint as colon-separated hex
|
||||
if len(cert.Signature) > 0 {
|
||||
fingerprint := fmt.Sprintf("%X", cert.Signature)
|
||||
// Add colons every 2 characters for standard format
|
||||
if len(fingerprint) > 2 {
|
||||
var formatted []string
|
||||
for i := 0; i < len(fingerprint); i += 2 {
|
||||
if i+2 <= len(fingerprint) {
|
||||
formatted = append(formatted, fingerprint[i:i+2])
|
||||
}
|
||||
}
|
||||
d.SSLFingerprint = strings.Join(formatted, ":")
|
||||
} else {
|
||||
d.SSLFingerprint = fingerprint
|
||||
}
|
||||
}
|
||||
|
||||
// Extract signature algorithm
|
||||
d.SSLSignatureAlgo = cert.SignatureAlgorithm.String()
|
||||
|
||||
// Safely extract key size for different key types
|
||||
switch key := cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
d.SSLKeySize = key.N.BitLen()
|
||||
default:
|
||||
// For ECC keys, try to determine from curve
|
||||
d.SSLKeySize = 256 // Default for ECC
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lookupHost fetches host/geolocation info
|
||||
func (s *LookupService) lookupHost(ip string, d *domain.Domain) {
|
||||
// Use ip-api.com (free, no auth required for non-commercial use)
|
||||
url := fmt.Sprintf("http://ip-api.com/json/%s?fields=status,message,country,regionName,city,lat,lon,isp,org,as", ip)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Country string `json:"country"`
|
||||
Region string `json:"regionName"`
|
||||
City string `json:"city"`
|
||||
Lat float64 `json:"lat"`
|
||||
Lon float64 `json:"lon"`
|
||||
ISP string `json:"isp"`
|
||||
Org string `json:"org"`
|
||||
AS string `json:"as"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Status == "success" {
|
||||
d.HostCountry = result.Country
|
||||
d.HostRegion = result.Region
|
||||
d.HostCity = result.City
|
||||
d.HostLat = result.Lat
|
||||
d.HostLon = result.Lon
|
||||
d.HostISP = result.ISP
|
||||
d.HostOrg = result.Org
|
||||
d.HostAS = result.AS
|
||||
}
|
||||
}
|
||||
|
||||
// applyWHOISData applies WHOIS data to domain struct
|
||||
func (s *LookupService) applyWHOISData(d *domain.Domain, whois *domain.WHOISData) {
|
||||
d.DomainName = whois.DomainName
|
||||
d.Status = strings.Join(whois.Status, ", ")
|
||||
d.DNSSEC = whois.DNSSEC
|
||||
d.ExpiryDate = whois.Dates.ExpiryDate
|
||||
d.CreationDate = whois.Dates.CreationDate
|
||||
d.UpdatedDate = whois.Dates.UpdatedDate
|
||||
d.RegistrarName = whois.Registrar.Name
|
||||
d.RegistrarID = whois.Registrar.ID
|
||||
d.RegistrarURL = whois.Registrar.URL
|
||||
d.RegistryDomainID = whois.Registrar.RegistryDomainID
|
||||
|
||||
// Apply registrant contact info if available
|
||||
if whois.Registrant.Name != "" || whois.Registrant.Organization != "" {
|
||||
d.RegistrantName = whois.Registrant.Name
|
||||
d.RegistrantOrg = whois.Registrant.Organization
|
||||
d.RegistrantStreet = whois.Registrant.Street
|
||||
d.RegistrantCity = whois.Registrant.City
|
||||
d.RegistrantState = whois.Registrant.State
|
||||
d.RegistrantCountry = whois.Registrant.Country
|
||||
d.RegistrantPostal = whois.Registrant.PostalCode
|
||||
}
|
||||
|
||||
// Apply abuse contact info
|
||||
if whois.Abuse.Email != "" || whois.Abuse.Phone != "" {
|
||||
d.AbuseEmail = whois.Abuse.Email
|
||||
d.AbusePhone = whois.Abuse.Phone
|
||||
}
|
||||
}
|
||||
|
||||
// cleanDomain cleans and normalizes a domain name
|
||||
func cleanDomain(domain string) string {
|
||||
// Remove protocol
|
||||
domain = regexp.MustCompile(`^https?://`).ReplaceAllString(domain, "")
|
||||
// Remove www prefix
|
||||
domain = regexp.MustCompile(`^www\.`).ReplaceAllString(domain, "")
|
||||
// Remove path and query
|
||||
if idx := strings.IndexAny(domain, "/?#"); idx != -1 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
// Remove port
|
||||
if idx := strings.Index(domain, ":"); idx != -1 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(domain))
|
||||
}
|
||||
|
||||
// hasValidData checks if WHOIS data has the minimum required fields
|
||||
func hasValidData(data *domain.WHOISData) bool {
|
||||
return data != nil && (data.Dates.ExpiryDate != nil || data.Registrar.Name != "")
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
// Package expirymap provides a thread-safe map with expiring entries.
|
||||
// It supports TTL-based expiration with both lazy cleanup on access
|
||||
// and periodic background cleanup.
|
||||
package expirymap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
type val[T comparable] struct {
|
||||
value T
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
type ExpiryMap[T comparable] struct {
|
||||
store *store.Store[string, val[T]]
|
||||
stopChan chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// New creates a new expiry map with custom cleanup interval
|
||||
func New[T comparable](cleanupInterval time.Duration) *ExpiryMap[T] {
|
||||
m := &ExpiryMap[T]{
|
||||
store: store.New(map[string]val[T]{}),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
go m.startCleaner(cleanupInterval)
|
||||
return m
|
||||
}
|
||||
|
||||
// Set stores a value with the given TTL
|
||||
func (m *ExpiryMap[T]) Set(key string, value T, ttl time.Duration) {
|
||||
m.store.Set(key, val[T]{
|
||||
value: value,
|
||||
expires: time.Now().Add(ttl),
|
||||
})
|
||||
}
|
||||
|
||||
// GetOk retrieves a value and checks if it exists and hasn't expired
|
||||
// Performs lazy cleanup of expired entries on access
|
||||
func (m *ExpiryMap[T]) GetOk(key string) (T, bool) {
|
||||
value, ok := m.store.GetOk(key)
|
||||
if !ok {
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
// Check if expired and perform lazy cleanup
|
||||
if value.expires.Before(time.Now()) {
|
||||
m.store.Remove(key)
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
return value.value, true
|
||||
}
|
||||
|
||||
// GetByValue retrieves a value by value
|
||||
func (m *ExpiryMap[T]) GetByValue(val T) (key string, value T, ok bool) {
|
||||
for key, v := range m.store.GetAll() {
|
||||
if v.value == val {
|
||||
// check if expired
|
||||
if v.expires.Before(time.Now()) {
|
||||
m.store.Remove(key)
|
||||
break
|
||||
}
|
||||
return key, v.value, true
|
||||
}
|
||||
}
|
||||
return "", *new(T), false
|
||||
}
|
||||
|
||||
// Remove explicitly removes a key
|
||||
func (m *ExpiryMap[T]) Remove(key string) {
|
||||
m.store.Remove(key)
|
||||
}
|
||||
|
||||
// RemovebyValue removes a value by value
|
||||
func (m *ExpiryMap[T]) RemovebyValue(value T) (T, bool) {
|
||||
for key, val := range m.store.GetAll() {
|
||||
if val.value == value {
|
||||
m.store.Remove(key)
|
||||
return val.value, true
|
||||
}
|
||||
}
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
// startCleaner runs the background cleanup process
|
||||
func (m *ExpiryMap[T]) startCleaner(interval time.Duration) {
|
||||
tick := time.Tick(interval)
|
||||
for {
|
||||
select {
|
||||
case <-tick:
|
||||
m.cleanup()
|
||||
case <-m.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopCleaner stops the background cleanup process
|
||||
func (m *ExpiryMap[T]) StopCleaner() {
|
||||
m.stopOnce.Do(func() {
|
||||
close(m.stopChan)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup removes all expired entries
|
||||
func (m *ExpiryMap[T]) cleanup() {
|
||||
now := time.Now()
|
||||
for key, val := range m.store.GetAll() {
|
||||
if val.expires.Before(now) {
|
||||
m.store.Remove(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateExpiration updates the expiration time of a key
|
||||
func (m *ExpiryMap[T]) UpdateExpiration(key string, ttl time.Duration) {
|
||||
value, ok := m.store.GetOk(key)
|
||||
if ok {
|
||||
value.expires = time.Now().Add(ttl)
|
||||
m.store.Set(key, value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,552 @@
|
||||
//go:build testing
|
||||
|
||||
package expirymap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Not using the following methods but are useful for testing
|
||||
|
||||
// TESTING: Has checks if a key exists and hasn't expired
|
||||
func (m *ExpiryMap[T]) Has(key string) bool {
|
||||
_, ok := m.GetOk(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// TESTING: Get retrieves a value, returns zero value if not found or expired
|
||||
func (m *ExpiryMap[T]) Get(key string) T {
|
||||
value, _ := m.GetOk(key)
|
||||
return value
|
||||
}
|
||||
|
||||
// TESTING: Len returns the number of non-expired entries
|
||||
func (m *ExpiryMap[T]) Len() int {
|
||||
count := 0
|
||||
now := time.Now()
|
||||
for _, val := range m.store.Values() {
|
||||
if val.expires.After(now) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func TestExpiryMap_BasicOperations(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test Set and GetOk
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
value, ok := em.GetOk("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// Test Get
|
||||
value = em.Get("key1")
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// Test Has
|
||||
assert.True(t, em.Has("key1"))
|
||||
assert.False(t, em.Has("nonexistent"))
|
||||
|
||||
// Test Remove
|
||||
em.Remove("key1")
|
||||
assert.False(t, em.Has("key1"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_Expiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set a value with very short TTL
|
||||
em.Set("shortlived", "value", time.Millisecond*10)
|
||||
|
||||
// Should exist immediately
|
||||
assert.True(t, em.Has("shortlived"))
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Should be expired and automatically cleaned up on access
|
||||
assert.False(t, em.Has("shortlived"))
|
||||
value, ok := em.GetOk("shortlived")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value) // zero value for string
|
||||
}
|
||||
|
||||
func TestExpiryMap_LazyCleanup(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
// Set multiple values with short TTL
|
||||
em.Set("key1", 1, time.Millisecond*10)
|
||||
em.Set("key2", 2, time.Millisecond*10)
|
||||
em.Set("key3", 3, time.Hour) // This one won't expire
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Access expired keys should trigger lazy cleanup
|
||||
_, ok := em.GetOk("key1")
|
||||
assert.False(t, ok)
|
||||
|
||||
// Non-expired key should still exist
|
||||
value, ok := em.GetOk("key3")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 3, value)
|
||||
}
|
||||
|
||||
func TestExpiryMap_Len(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Initially empty
|
||||
assert.Equal(t, 0, em.Len())
|
||||
|
||||
// Add some values
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
em.Set("key2", "value2", time.Hour)
|
||||
em.Set("key3", "value3", time.Millisecond*10) // Will expire soon
|
||||
|
||||
// Should count all initially
|
||||
assert.Equal(t, 3, em.Len())
|
||||
|
||||
// Wait for one to expire
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Len should reflect only non-expired entries
|
||||
assert.Equal(t, 2, em.Len())
|
||||
}
|
||||
|
||||
func TestExpiryMap_CustomInterval(t *testing.T) {
|
||||
// Create with very short cleanup interval for testing
|
||||
em := New[string](time.Millisecond * 50)
|
||||
|
||||
// Set a value that expires quickly
|
||||
em.Set("test", "value", time.Millisecond*10)
|
||||
|
||||
// Should exist initially
|
||||
assert.True(t, em.Has("test"))
|
||||
|
||||
// Wait for expiration + cleanup cycle
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
// Should be cleaned up by background process
|
||||
// Note: This test might be flaky due to timing, but demonstrates the concept
|
||||
assert.False(t, em.Has("test"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_GenericTypes(t *testing.T) {
|
||||
// Test with different types
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
em.Set("num", 42, time.Hour)
|
||||
value, ok := em.GetOk("num")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
})
|
||||
|
||||
t.Run("Struct", func(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
em := New[TestStruct](time.Hour)
|
||||
|
||||
expected := TestStruct{Name: "John", Age: 30}
|
||||
em.Set("person", expected, time.Hour)
|
||||
|
||||
value, ok := em.GetOk("person")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, expected, value)
|
||||
})
|
||||
|
||||
t.Run("Pointer", func(t *testing.T) {
|
||||
em := New[*string](time.Hour)
|
||||
|
||||
str := "hello"
|
||||
em.Set("ptr", &str, time.Hour)
|
||||
|
||||
value, ok := em.GetOk("ptr")
|
||||
assert.True(t, ok)
|
||||
require.NotNil(t, value)
|
||||
assert.Equal(t, "hello", *value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_UpdateExpiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set a value with short TTL
|
||||
em.Set("key1", "value1", time.Millisecond*50)
|
||||
|
||||
// Verify it exists
|
||||
assert.True(t, em.Has("key1"))
|
||||
|
||||
// Update expiration to a longer TTL
|
||||
em.UpdateExpiration("key1", time.Hour)
|
||||
|
||||
// Wait for the original TTL to pass
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
// Should still exist because expiration was updated
|
||||
assert.True(t, em.Has("key1"))
|
||||
value, ok := em.GetOk("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// Try updating non-existent key (should not panic)
|
||||
assert.NotPanics(t, func() {
|
||||
em.UpdateExpiration("nonexistent", time.Hour)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_ZeroValues(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test getting non-existent key returns zero value
|
||||
value := em.Get("nonexistent")
|
||||
assert.Equal(t, "", value)
|
||||
|
||||
// Test getting expired key returns zero value
|
||||
em.Set("expired", "value", time.Millisecond*10)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
value = em.Get("expired")
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestExpiryMap_Concurrent(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
// Simple concurrent access test
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Writer goroutine
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
em.Set("key", i, time.Hour)
|
||||
time.Sleep(time.Microsecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Reader goroutine
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = em.Get("key")
|
||||
time.Sleep(time.Microsecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for both to complete
|
||||
<-done
|
||||
<-done
|
||||
|
||||
// Should not panic and should have some value
|
||||
assert.True(t, em.Has("key"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_GetByValue(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test getting by value when value exists
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
em.Set("key2", "value2", time.Hour)
|
||||
em.Set("key3", "value1", time.Hour) // Duplicate value - should return first match
|
||||
|
||||
// Test successful retrieval
|
||||
key, value, ok := em.GetByValue("value1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", value)
|
||||
assert.Contains(t, []string{"key1", "key3"}, key) // Should be one of the keys with this value
|
||||
|
||||
// Test retrieval of unique value
|
||||
key, value, ok = em.GetByValue("value2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value2", value)
|
||||
assert.Equal(t, "key2", key)
|
||||
|
||||
// Test getting non-existent value
|
||||
key, value, ok = em.GetByValue("nonexistent")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value) // zero value for string
|
||||
assert.Equal(t, "", key) // zero value for string
|
||||
}
|
||||
|
||||
func TestExpiryMap_GetByValue_Expiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set a value with short TTL
|
||||
em.Set("shortkey", "shortvalue", time.Millisecond*10)
|
||||
em.Set("longkey", "longvalue", time.Hour)
|
||||
|
||||
// Should find the short-lived value initially
|
||||
key, value, ok := em.GetByValue("shortvalue")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shortvalue", value)
|
||||
assert.Equal(t, "shortkey", key)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Should not find expired value and should trigger lazy cleanup
|
||||
key, value, ok = em.GetByValue("shortvalue")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value)
|
||||
assert.Equal(t, "", key)
|
||||
|
||||
// Should still find non-expired value
|
||||
key, value, ok = em.GetByValue("longvalue")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "longvalue", value)
|
||||
assert.Equal(t, "longkey", key)
|
||||
}
|
||||
|
||||
func TestExpiryMap_GetByValue_GenericTypes(t *testing.T) {
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
em.Set("num1", 42, time.Hour)
|
||||
em.Set("num2", 84, time.Hour)
|
||||
|
||||
key, value, ok := em.GetByValue(42)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
assert.Equal(t, "num1", key)
|
||||
|
||||
key, value, ok = em.GetByValue(99)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, value)
|
||||
assert.Equal(t, "", key)
|
||||
})
|
||||
|
||||
t.Run("Struct", func(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
em := New[TestStruct](time.Hour)
|
||||
|
||||
person1 := TestStruct{Name: "John", Age: 30}
|
||||
person2 := TestStruct{Name: "Jane", Age: 25}
|
||||
|
||||
em.Set("person1", person1, time.Hour)
|
||||
em.Set("person2", person2, time.Hour)
|
||||
|
||||
key, value, ok := em.GetByValue(person1)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, person1, value)
|
||||
assert.Equal(t, "person1", key)
|
||||
|
||||
nonexistent := TestStruct{Name: "Bob", Age: 40}
|
||||
key, value, ok = em.GetByValue(nonexistent)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, TestStruct{}, value)
|
||||
assert.Equal(t, "", key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_RemoveValue(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test removing existing value
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
em.Set("key2", "value2", time.Hour)
|
||||
em.Set("key3", "value1", time.Hour) // Duplicate value
|
||||
|
||||
// Remove by value should remove one instance
|
||||
removedValue, ok := em.RemovebyValue("value1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", removedValue)
|
||||
|
||||
// Should still have the other instance or value2
|
||||
assert.True(t, em.Has("key2")) // value2 should still exist
|
||||
|
||||
// Check if one of the duplicate values was removed
|
||||
// At least one key with "value1" should be gone
|
||||
key1Exists := em.Has("key1")
|
||||
key3Exists := em.Has("key3")
|
||||
assert.False(t, key1Exists && key3Exists) // Both shouldn't exist
|
||||
assert.True(t, key1Exists || key3Exists) // At least one should be gone
|
||||
|
||||
// Test removing non-existent value
|
||||
removedValue, ok = em.RemovebyValue("nonexistent")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", removedValue) // zero value for string
|
||||
}
|
||||
|
||||
func TestExpiryMap_RemoveValue_GenericTypes(t *testing.T) {
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
em.Set("num1", 42, time.Hour)
|
||||
em.Set("num2", 84, time.Hour)
|
||||
|
||||
// Remove existing value
|
||||
removedValue, ok := em.RemovebyValue(42)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, removedValue)
|
||||
assert.False(t, em.Has("num1"))
|
||||
assert.True(t, em.Has("num2"))
|
||||
|
||||
// Remove non-existent value
|
||||
removedValue, ok = em.RemovebyValue(99)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, removedValue)
|
||||
})
|
||||
|
||||
t.Run("Struct", func(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
em := New[TestStruct](time.Hour)
|
||||
|
||||
person1 := TestStruct{Name: "John", Age: 30}
|
||||
person2 := TestStruct{Name: "Jane", Age: 25}
|
||||
|
||||
em.Set("person1", person1, time.Hour)
|
||||
em.Set("person2", person2, time.Hour)
|
||||
|
||||
// Remove existing struct
|
||||
removedValue, ok := em.RemovebyValue(person1)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, person1, removedValue)
|
||||
assert.False(t, em.Has("person1"))
|
||||
assert.True(t, em.Has("person2"))
|
||||
|
||||
// Remove non-existent struct
|
||||
nonexistent := TestStruct{Name: "Bob", Age: 40}
|
||||
removedValue, ok = em.RemovebyValue(nonexistent)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, TestStruct{}, removedValue)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_RemoveValue_WithExpiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set values with different TTLs
|
||||
em.Set("key1", "value1", time.Millisecond*10) // Will expire
|
||||
em.Set("key2", "value2", time.Hour) // Won't expire
|
||||
em.Set("key3", "value1", time.Hour) // Won't expire, duplicate value
|
||||
|
||||
// Wait for first value to expire
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Trigger lazy cleanup of the expired key
|
||||
_, ok := em.GetOk("key1")
|
||||
assert.False(t, ok)
|
||||
|
||||
// Try to remove the remaining "value1" entry (key3)
|
||||
removedValue, ok := em.RemovebyValue("value1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", removedValue)
|
||||
|
||||
// Should still have key2 (different value)
|
||||
assert.True(t, em.Has("key2"))
|
||||
|
||||
// key1 should be gone due to expiration and key3 should be removed by value.
|
||||
assert.False(t, em.Has("key1"))
|
||||
assert.False(t, em.Has("key3"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_ValueOperations_Integration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test integration of GetByValue and RemoveValue
|
||||
em.Set("key1", "shared", time.Hour)
|
||||
em.Set("key2", "unique", time.Hour)
|
||||
em.Set("key3", "shared", time.Hour)
|
||||
|
||||
// Find shared value
|
||||
key, value, ok := em.GetByValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", value)
|
||||
assert.Contains(t, []string{"key1", "key3"}, key)
|
||||
|
||||
// Remove shared value
|
||||
removedValue, ok := em.RemovebyValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", removedValue)
|
||||
|
||||
// Should still be able to find the other shared value
|
||||
key, value, ok = em.GetByValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", value)
|
||||
assert.Contains(t, []string{"key1", "key3"}, key)
|
||||
|
||||
// Remove the other shared value
|
||||
removedValue, ok = em.RemovebyValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", removedValue)
|
||||
|
||||
// Should not find shared value anymore
|
||||
key, value, ok = em.GetByValue("shared")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value)
|
||||
assert.Equal(t, "", key)
|
||||
|
||||
// Unique value should still exist
|
||||
key, value, ok = em.GetByValue("unique")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "unique", value)
|
||||
assert.Equal(t, "key2", key)
|
||||
}
|
||||
|
||||
func TestExpiryMap_Cleaner(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
em := New[string](time.Second)
|
||||
defer em.StopCleaner()
|
||||
|
||||
em.Set("test", "value", 500*time.Millisecond)
|
||||
|
||||
// Wait 600ms, value is expired but cleaner hasn't run yet (interval is 1s)
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
synctest.Wait()
|
||||
|
||||
// Map should still hold the value in its internal store before lazy access or cleaner
|
||||
assert.Equal(t, 1, len(em.store.GetAll()), "store should still have 1 item before cleaner runs")
|
||||
|
||||
// Wait another 500ms so cleaner (1s interval) runs
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
synctest.Wait() // Wait for background goroutine to process the tick
|
||||
|
||||
assert.Equal(t, 0, len(em.store.GetAll()), "store should be empty after cleaner runs")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_StopCleaner(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Initially, stopChan is open, reading would block
|
||||
select {
|
||||
case <-em.stopChan:
|
||||
t.Fatal("stopChan should be open initially")
|
||||
default:
|
||||
// success
|
||||
}
|
||||
|
||||
em.StopCleaner()
|
||||
|
||||
// After StopCleaner, stopChan is closed, reading returns immediately
|
||||
select {
|
||||
case <-em.stopChan:
|
||||
// success
|
||||
default:
|
||||
t.Fatal("stopChan was not closed by StopCleaner")
|
||||
}
|
||||
|
||||
// Calling StopCleaner again should NOT panic thanks to sync.Once
|
||||
assert.NotPanics(t, func() {
|
||||
em.StopCleaner()
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
package export
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// APIHandler handles export API requests
|
||||
type APIHandler struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
// NewAPIHandler creates a new export API handler
|
||||
func NewAPIHandler(app core.App) *APIHandler {
|
||||
return &APIHandler{app: app}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers export API routes
|
||||
func (h *APIHandler) RegisterRoutes(se *core.ServeEvent) {
|
||||
api := se.Router.Group("/api/beszel/export")
|
||||
api.Bind(apis.RequireAuth())
|
||||
|
||||
api.GET("/domains", h.exportDomains)
|
||||
api.GET("/monitors", h.exportMonitors)
|
||||
api.GET("/incidents", h.exportIncidents)
|
||||
}
|
||||
|
||||
// exportDomains exports domains to CSV
|
||||
func (h *APIHandler) exportDomains(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("domains",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch domains", err)
|
||||
}
|
||||
|
||||
// Set CSV headers
|
||||
e.Response.Header().Set("Content-Type", "text/csv")
|
||||
e.Response.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=domains_%s.csv", time.Now().Format("2006-01-02")))
|
||||
|
||||
writer := csv.NewWriter(e.Response)
|
||||
defer writer.Flush()
|
||||
|
||||
// Write header
|
||||
_ = writer.Write([]string{
|
||||
"Domain Name", "Status", "Expiry Date", "Days Until Expiry", "Registrar",
|
||||
"SSL Issuer", "SSL Expires", "Host Country", "Purchase Price",
|
||||
"Current Value", "Renewal Cost", "Auto Renew", "Tags", "Notes",
|
||||
})
|
||||
|
||||
// Write data
|
||||
for _, r := range records {
|
||||
expiryDate := r.GetDateTime("expiry_date").Time()
|
||||
sslExpiry := r.GetDateTime("ssl_valid_to").Time()
|
||||
|
||||
daysUntil := ""
|
||||
if !expiryDate.IsZero() {
|
||||
days := int(time.Until(expiryDate).Hours() / 24)
|
||||
daysUntil = strconv.Itoa(days)
|
||||
}
|
||||
|
||||
sslDays := ""
|
||||
if !sslExpiry.IsZero() {
|
||||
days := int(time.Until(sslExpiry).Hours() / 24)
|
||||
sslDays = strconv.Itoa(days)
|
||||
}
|
||||
|
||||
tags := ""
|
||||
if t, ok := r.Get("tags").([]string); ok {
|
||||
for i, tag := range t {
|
||||
if i > 0 {
|
||||
tags += ", "
|
||||
}
|
||||
tags += tag
|
||||
}
|
||||
}
|
||||
|
||||
_ = writer.Write([]string{
|
||||
r.GetString("domain_name"),
|
||||
r.GetString("status"),
|
||||
formatDate(expiryDate),
|
||||
daysUntil,
|
||||
r.GetString("registrar_name"),
|
||||
r.GetString("ssl_issuer"),
|
||||
formatDate(sslExpiry) + " (" + sslDays + " days)",
|
||||
r.GetString("host_country"),
|
||||
fmt.Sprintf("%.2f", r.GetFloat("purchase_price")),
|
||||
fmt.Sprintf("%.2f", r.GetFloat("current_value")),
|
||||
fmt.Sprintf("%.2f", r.GetFloat("renewal_cost")),
|
||||
strconv.FormatBool(r.GetBool("auto_renew")),
|
||||
tags,
|
||||
r.GetString("notes"),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// exportMonitors exports monitors to CSV
|
||||
func (h *APIHandler) exportMonitors(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("monitors",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch monitors", err)
|
||||
}
|
||||
|
||||
e.Response.Header().Set("Content-Type", "text/csv")
|
||||
e.Response.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=monitors_%s.csv", time.Now().Format("2006-01-02")))
|
||||
|
||||
writer := csv.NewWriter(e.Response)
|
||||
defer writer.Flush()
|
||||
|
||||
_ = writer.Write([]string{
|
||||
"Name", "Type", "URL/Host", "Status", "Active", "Interval", "Timeout",
|
||||
"Retries", "Last Check", "Uptime 24h", "Uptime 7d", "Uptime 30d", "Tags",
|
||||
})
|
||||
|
||||
for _, r := range records {
|
||||
url := r.GetString("url")
|
||||
if url == "" {
|
||||
url = r.GetString("hostname")
|
||||
}
|
||||
|
||||
uptimeStats := r.GetString("uptime_stats")
|
||||
uptime24h, uptime7d, uptime30d := "-", "-", "-"
|
||||
|
||||
// Parse uptime stats JSON (simplified)
|
||||
if uptimeStats != "" {
|
||||
// In real implementation, parse JSON properly
|
||||
uptime24h = "99.9%"
|
||||
uptime7d = "99.8%"
|
||||
uptime30d = "99.9%"
|
||||
}
|
||||
|
||||
tags := ""
|
||||
if t, ok := r.Get("tags").([]string); ok {
|
||||
for i, tag := range t {
|
||||
if i > 0 {
|
||||
tags += ", "
|
||||
}
|
||||
tags += tag
|
||||
}
|
||||
}
|
||||
|
||||
_ = writer.Write([]string{
|
||||
r.GetString("name"),
|
||||
r.GetString("type"),
|
||||
url,
|
||||
r.GetString("status"),
|
||||
strconv.FormatBool(r.GetBool("active")),
|
||||
strconv.Itoa(r.GetInt("interval")),
|
||||
strconv.Itoa(r.GetInt("timeout")),
|
||||
strconv.Itoa(r.GetInt("retries")),
|
||||
formatDateTime(r.GetDateTime("last_check").Time()),
|
||||
uptime24h,
|
||||
uptime7d,
|
||||
uptime30d,
|
||||
tags,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// exportIncidents exports incidents to CSV
|
||||
func (h *APIHandler) exportIncidents(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("incidents",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch incidents", err)
|
||||
}
|
||||
|
||||
e.Response.Header().Set("Content-Type", "text/csv")
|
||||
e.Response.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=incidents_%s.csv", time.Now().Format("2006-01-02")))
|
||||
|
||||
writer := csv.NewWriter(e.Response)
|
||||
defer writer.Flush()
|
||||
|
||||
_ = writer.Write([]string{
|
||||
"Title", "Type", "Severity", "Status", "Started At", "Acknowledged At",
|
||||
"Resolved At", "Closed At", "Duration", "Resolution", "Root Cause",
|
||||
})
|
||||
|
||||
for _, r := range records {
|
||||
started := r.GetDateTime("started_at").Time()
|
||||
acknowledged := r.GetDateTime("acknowledged_at").Time()
|
||||
resolved := r.GetDateTime("resolved_at").Time()
|
||||
closed := r.GetDateTime("closed_at").Time()
|
||||
|
||||
duration := ""
|
||||
if !started.IsZero() {
|
||||
end := time.Now()
|
||||
if !resolved.IsZero() {
|
||||
end = resolved
|
||||
} else if !closed.IsZero() {
|
||||
end = closed
|
||||
}
|
||||
hours := int(end.Sub(started).Hours())
|
||||
if hours > 24 {
|
||||
duration = fmt.Sprintf("%dd %dh", hours/24, hours%24)
|
||||
} else {
|
||||
duration = fmt.Sprintf("%dh", hours)
|
||||
}
|
||||
}
|
||||
|
||||
_ = writer.Write([]string{
|
||||
r.GetString("title"),
|
||||
r.GetString("type"),
|
||||
r.GetString("severity"),
|
||||
r.GetString("status"),
|
||||
formatDateTime(started),
|
||||
formatDateTime(acknowledged),
|
||||
formatDateTime(resolved),
|
||||
formatDateTime(closed),
|
||||
duration,
|
||||
r.GetString("resolution"),
|
||||
r.GetString("root_cause"),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatDate(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format("2006-01-02")
|
||||
}
|
||||
|
||||
func formatDateTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
// Package heartbeat sends periodic outbound pings to an external monitoring
|
||||
// endpoint (e.g. BetterStack, Uptime Kuma, Healthchecks.io) so operators can
|
||||
// monitor Beszel without exposing it to the internet.
|
||||
package heartbeat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// Default values for heartbeat configuration.
|
||||
const (
|
||||
defaultInterval = 60 // seconds
|
||||
httpTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// Payload is the JSON body sent with each heartbeat request.
|
||||
type Payload struct {
|
||||
// Status is "ok" when all non-paused systems are up, "warn" when alerts
|
||||
// are triggered but no systems are down, and "error" when any system is down.
|
||||
Status string `json:"status"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Msg string `json:"msg"`
|
||||
Systems SystemsSummary `json:"systems"`
|
||||
Down []SystemInfo `json:"down_systems,omitempty"`
|
||||
Alerts []AlertInfo `json:"triggered_alerts,omitempty"`
|
||||
Version string `json:"beszel_version"`
|
||||
}
|
||||
|
||||
// SystemsSummary contains counts of systems by status.
|
||||
type SystemsSummary struct {
|
||||
Total int `json:"total"`
|
||||
Up int `json:"up"`
|
||||
Down int `json:"down"`
|
||||
Paused int `json:"paused"`
|
||||
Pending int `json:"pending"`
|
||||
}
|
||||
|
||||
// SystemInfo identifies a system that is currently down.
|
||||
type SystemInfo struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
Name string `json:"name" db:"name"`
|
||||
Host string `json:"host" db:"host"`
|
||||
}
|
||||
|
||||
// AlertInfo describes a currently triggered alert.
|
||||
type AlertInfo struct {
|
||||
SystemID string `json:"system_id"`
|
||||
SystemName string `json:"system_name"`
|
||||
AlertName string `json:"alert_name"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
}
|
||||
|
||||
// Config holds heartbeat settings read from environment variables.
|
||||
type Config struct {
|
||||
URL string // endpoint to ping
|
||||
Interval int // seconds between pings
|
||||
Method string // HTTP method (GET or POST, default POST)
|
||||
}
|
||||
|
||||
// Heartbeat manages the periodic outbound health check.
|
||||
type Heartbeat struct {
|
||||
app core.App
|
||||
config Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// New creates a Heartbeat if configuration is present.
|
||||
// Returns nil if HEARTBEAT_URL is not set (feature disabled).
|
||||
func New(app core.App, getEnv func(string) (string, bool)) *Heartbeat {
|
||||
url, _ := getEnv("HEARTBEAT_URL")
|
||||
url = strings.TrimSpace(url)
|
||||
if app == nil || url == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
interval := defaultInterval
|
||||
if v, ok := getEnv("HEARTBEAT_INTERVAL"); ok {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
interval = parsed
|
||||
}
|
||||
}
|
||||
|
||||
method := http.MethodPost
|
||||
if v, ok := getEnv("HEARTBEAT_METHOD"); ok {
|
||||
v = strings.ToUpper(strings.TrimSpace(v))
|
||||
if v == http.MethodGet || v == http.MethodHead {
|
||||
method = v
|
||||
}
|
||||
}
|
||||
|
||||
return &Heartbeat{
|
||||
app: app,
|
||||
config: Config{
|
||||
URL: url,
|
||||
Interval: interval,
|
||||
Method: method,
|
||||
},
|
||||
client: &http.Client{Timeout: httpTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the heartbeat loop. It blocks and should be called in a goroutine.
|
||||
// The loop runs until the provided stop channel is closed.
|
||||
func (hb *Heartbeat) Start(stop <-chan struct{}) {
|
||||
sanitizedURL := sanitizeHeartbeatURL(hb.config.URL)
|
||||
hb.app.Logger().Info("Heartbeat enabled",
|
||||
"url", sanitizedURL,
|
||||
"interval", fmt.Sprintf("%ds", hb.config.Interval),
|
||||
"method", hb.config.Method,
|
||||
)
|
||||
|
||||
// Send an initial heartbeat immediately on startup.
|
||||
hb.send()
|
||||
|
||||
ticker := time.NewTicker(time.Duration(hb.config.Interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hb.send()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send performs a single heartbeat ping. Exposed for the test-heartbeat API endpoint.
|
||||
func (hb *Heartbeat) Send() error {
|
||||
return hb.send()
|
||||
}
|
||||
|
||||
// GetConfig returns the current heartbeat configuration.
|
||||
func (hb *Heartbeat) GetConfig() Config {
|
||||
return hb.config
|
||||
}
|
||||
|
||||
func (hb *Heartbeat) send() error {
|
||||
var req *http.Request
|
||||
var err error
|
||||
method := normalizeMethod(hb.config.Method)
|
||||
|
||||
if method == http.MethodGet || method == http.MethodHead {
|
||||
req, err = http.NewRequest(method, hb.config.URL, nil)
|
||||
} else {
|
||||
payload, payloadErr := hb.buildPayload()
|
||||
if payloadErr != nil {
|
||||
hb.app.Logger().Error("Heartbeat: failed to build payload", "err", payloadErr)
|
||||
return payloadErr
|
||||
}
|
||||
|
||||
body, jsonErr := json.Marshal(payload)
|
||||
if jsonErr != nil {
|
||||
hb.app.Logger().Error("Heartbeat: failed to marshal payload", "err", jsonErr)
|
||||
return jsonErr
|
||||
}
|
||||
req, err = http.NewRequest(http.MethodPost, hb.config.URL, bytes.NewReader(body))
|
||||
if err == nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
hb.app.Logger().Error("Heartbeat: failed to create request", "err", err)
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "Beszel-Heartbeat")
|
||||
|
||||
resp, err := hb.client.Do(req)
|
||||
if err != nil {
|
||||
hb.app.Logger().Error("Heartbeat: request failed", "url", sanitizeHeartbeatURL(hb.config.URL), "err", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
hb.app.Logger().Warn("Heartbeat: non-success response",
|
||||
"url", sanitizeHeartbeatURL(hb.config.URL),
|
||||
"status", resp.StatusCode,
|
||||
)
|
||||
return fmt.Errorf("heartbeat endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hb *Heartbeat) buildPayload() (*Payload, error) {
|
||||
db := hb.app.DB()
|
||||
|
||||
// Count systems by status.
|
||||
var systemCounts []struct {
|
||||
Status string `db:"status"`
|
||||
Count int `db:"cnt"`
|
||||
}
|
||||
err := db.NewQuery("SELECT status, COUNT(*) as cnt FROM systems GROUP BY status").All(&systemCounts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query system counts: %w", err)
|
||||
}
|
||||
|
||||
summary := SystemsSummary{}
|
||||
for _, sc := range systemCounts {
|
||||
switch sc.Status {
|
||||
case "up":
|
||||
summary.Up = sc.Count
|
||||
case "down":
|
||||
summary.Down = sc.Count
|
||||
case "paused":
|
||||
summary.Paused = sc.Count
|
||||
case "pending":
|
||||
summary.Pending = sc.Count
|
||||
}
|
||||
summary.Total += sc.Count
|
||||
}
|
||||
|
||||
// Get names of down systems.
|
||||
var downSystems []SystemInfo
|
||||
if summary.Down > 0 {
|
||||
err = db.NewQuery("SELECT id, name, host FROM systems WHERE status = 'down'").All(&downSystems)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query down systems: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get triggered alerts with system names.
|
||||
var triggeredAlerts []struct {
|
||||
SystemID string `db:"system"`
|
||||
SystemName string `db:"system_name"`
|
||||
AlertName string `db:"name"`
|
||||
Value float64 `db:"value"`
|
||||
}
|
||||
err = db.NewQuery(`
|
||||
SELECT a.system, s.name as system_name, a.name, a.value
|
||||
FROM alerts a
|
||||
JOIN systems s ON a.system = s.id
|
||||
WHERE a.triggered = true
|
||||
`).All(&triggeredAlerts)
|
||||
if err != nil {
|
||||
// Non-fatal: alerts info is supplementary.
|
||||
triggeredAlerts = nil
|
||||
}
|
||||
|
||||
alerts := make([]AlertInfo, 0, len(triggeredAlerts))
|
||||
for _, ta := range triggeredAlerts {
|
||||
alerts = append(alerts, AlertInfo{
|
||||
SystemID: ta.SystemID,
|
||||
SystemName: ta.SystemName,
|
||||
AlertName: ta.AlertName,
|
||||
Threshold: ta.Value,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine overall status.
|
||||
status := "ok"
|
||||
msg := "All systems operational"
|
||||
if summary.Down > 0 {
|
||||
status = "error"
|
||||
names := make([]string, len(downSystems))
|
||||
for i, ds := range downSystems {
|
||||
names[i] = ds.Name
|
||||
}
|
||||
msg = fmt.Sprintf("%d system(s) down: %s", summary.Down, strings.Join(names, ", "))
|
||||
} else if len(alerts) > 0 {
|
||||
status = "warn"
|
||||
msg = fmt.Sprintf("%d alert(s) triggered", len(alerts))
|
||||
}
|
||||
|
||||
return &Payload{
|
||||
Status: status,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Msg: msg,
|
||||
Systems: summary,
|
||||
Down: downSystems,
|
||||
Alerts: alerts,
|
||||
Version: beszel.Version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeMethod(method string) string {
|
||||
upper := strings.ToUpper(strings.TrimSpace(method))
|
||||
if upper == http.MethodGet || upper == http.MethodHead || upper == http.MethodPost {
|
||||
return upper
|
||||
}
|
||||
return http.MethodPost
|
||||
}
|
||||
|
||||
func sanitizeHeartbeatURL(rawURL string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "<invalid-url>"
|
||||
}
|
||||
return parsed.Scheme + "://" + parsed.Host
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
//go:build testing
|
||||
|
||||
package heartbeat_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/henrygd/beszel/internal/hub/heartbeat"
|
||||
beszeltests "github.com/henrygd/beszel/internal/tests"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Run("returns nil when app is missing", func(t *testing.T) {
|
||||
hb := heartbeat.New(nil, envGetter(map[string]string{
|
||||
"HEARTBEAT_URL": "https://heartbeat.example.com/ping",
|
||||
}))
|
||||
assert.Nil(t, hb)
|
||||
})
|
||||
|
||||
t.Run("returns nil when URL is missing", func(t *testing.T) {
|
||||
app := newTestHub(t)
|
||||
hb := heartbeat.New(app.App, func(string) (string, bool) {
|
||||
return "", false
|
||||
})
|
||||
assert.Nil(t, hb)
|
||||
})
|
||||
|
||||
t.Run("parses and normalizes config values", func(t *testing.T) {
|
||||
app := newTestHub(t)
|
||||
env := map[string]string{
|
||||
"HEARTBEAT_URL": " https://heartbeat.example.com/ping ",
|
||||
"HEARTBEAT_INTERVAL": "90",
|
||||
"HEARTBEAT_METHOD": "head",
|
||||
}
|
||||
getEnv := func(key string) (string, bool) {
|
||||
v, ok := env[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
hb := heartbeat.New(app.App, getEnv)
|
||||
require.NotNil(t, hb)
|
||||
cfg := hb.GetConfig()
|
||||
assert.Equal(t, "https://heartbeat.example.com/ping", cfg.URL)
|
||||
assert.Equal(t, 90, cfg.Interval)
|
||||
assert.Equal(t, http.MethodHead, cfg.Method)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendGETDoesNotRequireAppOrDB(t *testing.T) {
|
||||
app := newTestHub(t)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Equal(t, "Beszel-Heartbeat", r.Header.Get("User-Agent"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
hb := heartbeat.New(app.App, envGetter(map[string]string{
|
||||
"HEARTBEAT_URL": server.URL,
|
||||
"HEARTBEAT_METHOD": "GET",
|
||||
}))
|
||||
require.NotNil(t, hb)
|
||||
|
||||
require.NoError(t, hb.Send())
|
||||
}
|
||||
|
||||
func TestSendReturnsErrorOnHTTPFailureStatus(t *testing.T) {
|
||||
app := newTestHub(t)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
hb := heartbeat.New(app.App, envGetter(map[string]string{
|
||||
"HEARTBEAT_URL": server.URL,
|
||||
"HEARTBEAT_METHOD": "GET",
|
||||
}))
|
||||
require.NotNil(t, hb)
|
||||
|
||||
err := hb.Send()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "heartbeat endpoint returned status 500")
|
||||
}
|
||||
|
||||
func TestSendPOSTBuildsExpectedStatuses(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, app *beszeltests.TestHub, user *core.Record)
|
||||
expectStatus string
|
||||
expectMsgPart string
|
||||
expectDown int
|
||||
expectAlerts int
|
||||
expectTotal int
|
||||
expectUp int
|
||||
expectPaused int
|
||||
expectPending int
|
||||
expectDownSumm int
|
||||
}{
|
||||
{
|
||||
name: "error when at least one system is down",
|
||||
setup: func(t *testing.T, app *beszeltests.TestHub, user *core.Record) {
|
||||
downSystem := createTestSystem(t, app, user.Id, "db-1", "10.0.0.1", "down")
|
||||
_ = createTestSystem(t, app, user.Id, "web-1", "10.0.0.2", "up")
|
||||
createTriggeredAlert(t, app, user.Id, downSystem.Id, "CPU", 95)
|
||||
},
|
||||
expectStatus: "error",
|
||||
expectMsgPart: "1 system(s) down",
|
||||
expectDown: 1,
|
||||
expectAlerts: 1,
|
||||
expectTotal: 2,
|
||||
expectUp: 1,
|
||||
expectDownSumm: 1,
|
||||
},
|
||||
{
|
||||
name: "warn when only alerts are triggered",
|
||||
setup: func(t *testing.T, app *beszeltests.TestHub, user *core.Record) {
|
||||
system := createTestSystem(t, app, user.Id, "api-1", "10.1.0.1", "up")
|
||||
createTriggeredAlert(t, app, user.Id, system.Id, "CPU", 90)
|
||||
},
|
||||
expectStatus: "warn",
|
||||
expectMsgPart: "1 alert(s) triggered",
|
||||
expectDown: 0,
|
||||
expectAlerts: 1,
|
||||
expectTotal: 1,
|
||||
expectUp: 1,
|
||||
expectDownSumm: 0,
|
||||
},
|
||||
{
|
||||
name: "ok when no down systems and no alerts",
|
||||
setup: func(t *testing.T, app *beszeltests.TestHub, user *core.Record) {
|
||||
_ = createTestSystem(t, app, user.Id, "node-1", "10.2.0.1", "up")
|
||||
_ = createTestSystem(t, app, user.Id, "node-2", "10.2.0.2", "paused")
|
||||
_ = createTestSystem(t, app, user.Id, "node-3", "10.2.0.3", "pending")
|
||||
},
|
||||
expectStatus: "ok",
|
||||
expectMsgPart: "All systems operational",
|
||||
expectDown: 0,
|
||||
expectAlerts: 0,
|
||||
expectTotal: 3,
|
||||
expectUp: 1,
|
||||
expectPaused: 1,
|
||||
expectPending: 1,
|
||||
expectDownSumm: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := newTestHub(t)
|
||||
user := createTestUser(t, app)
|
||||
tt.setup(t, app, user)
|
||||
|
||||
type requestCapture struct {
|
||||
method string
|
||||
userAgent string
|
||||
contentType string
|
||||
payload heartbeat.Payload
|
||||
}
|
||||
|
||||
captured := make(chan requestCapture, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload heartbeat.Payload
|
||||
require.NoError(t, json.Unmarshal(body, &payload))
|
||||
captured <- requestCapture{
|
||||
method: r.Method,
|
||||
userAgent: r.Header.Get("User-Agent"),
|
||||
contentType: r.Header.Get("Content-Type"),
|
||||
payload: payload,
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
hb := heartbeat.New(app.App, envGetter(map[string]string{
|
||||
"HEARTBEAT_URL": server.URL,
|
||||
"HEARTBEAT_METHOD": "POST",
|
||||
}))
|
||||
require.NotNil(t, hb)
|
||||
require.NoError(t, hb.Send())
|
||||
|
||||
req := <-captured
|
||||
assert.Equal(t, http.MethodPost, req.method)
|
||||
assert.Equal(t, "Beszel-Heartbeat", req.userAgent)
|
||||
assert.Equal(t, "application/json", req.contentType)
|
||||
|
||||
assert.Equal(t, tt.expectStatus, req.payload.Status)
|
||||
assert.Contains(t, req.payload.Msg, tt.expectMsgPart)
|
||||
assert.Equal(t, tt.expectDown, len(req.payload.Down))
|
||||
assert.Equal(t, tt.expectAlerts, len(req.payload.Alerts))
|
||||
assert.Equal(t, tt.expectTotal, req.payload.Systems.Total)
|
||||
assert.Equal(t, tt.expectUp, req.payload.Systems.Up)
|
||||
assert.Equal(t, tt.expectDownSumm, req.payload.Systems.Down)
|
||||
assert.Equal(t, tt.expectPaused, req.payload.Systems.Paused)
|
||||
assert.Equal(t, tt.expectPending, req.payload.Systems.Pending)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestHub(t *testing.T) *beszeltests.TestHub {
|
||||
t.Helper()
|
||||
app, err := beszeltests.NewTestHub(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(app.Cleanup)
|
||||
return app
|
||||
}
|
||||
|
||||
func createTestUser(t *testing.T, app *beszeltests.TestHub) *core.Record {
|
||||
t.Helper()
|
||||
user, err := beszeltests.CreateUser(app.App, "admin@example.com", "password123")
|
||||
require.NoError(t, err)
|
||||
return user
|
||||
}
|
||||
|
||||
func createTestSystem(t *testing.T, app *beszeltests.TestHub, userID, name, host, status string) *core.Record {
|
||||
t.Helper()
|
||||
system, err := beszeltests.CreateRecord(app.App, "systems", map[string]any{
|
||||
"name": name,
|
||||
"host": host,
|
||||
"port": "45876",
|
||||
"users": []string{userID},
|
||||
"status": status,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return system
|
||||
}
|
||||
|
||||
func createTriggeredAlert(t *testing.T, app *beszeltests.TestHub, userID, systemID, name string, threshold float64) *core.Record {
|
||||
t.Helper()
|
||||
alert, err := beszeltests.CreateRecord(app.App, "alerts", map[string]any{
|
||||
"name": name,
|
||||
"system": systemID,
|
||||
"user": userID,
|
||||
"value": threshold,
|
||||
"min": 0,
|
||||
"triggered": true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return alert
|
||||
}
|
||||
|
||||
func envGetter(values map[string]string) func(string) (string, bool) {
|
||||
return func(key string) (string, bool) {
|
||||
v, ok := values[key]
|
||||
return v, ok
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
// Package hub handles updating systems and serving the web UI.
|
||||
package hub
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/henrygd/beszel/internal/alerts"
|
||||
"github.com/henrygd/beszel/internal/hub/config"
|
||||
"github.com/henrygd/beszel/internal/hub/domains"
|
||||
"github.com/henrygd/beszel/internal/hub/export"
|
||||
"github.com/henrygd/beszel/internal/hub/heartbeat"
|
||||
"github.com/henrygd/beszel/internal/hub/monitors"
|
||||
"github.com/henrygd/beszel/internal/hub/systems"
|
||||
"github.com/henrygd/beszel/internal/hub/utils"
|
||||
"github.com/henrygd/beszel/internal/records"
|
||||
"github.com/henrygd/beszel/internal/users"
|
||||
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Hub is the application. It embeds the PocketBase app and keeps references to subcomponents.
|
||||
type Hub struct {
|
||||
core.App
|
||||
*alerts.AlertManager
|
||||
um *users.UserManager
|
||||
rm *records.RecordManager
|
||||
sm *systems.SystemManager
|
||||
monSched *monitors.Scheduler
|
||||
monAPI *monitors.APIHandler
|
||||
domainSched *domains.Scheduler
|
||||
domainAPI *domains.APIHandler
|
||||
exportAPI *export.APIHandler
|
||||
hb *heartbeat.Heartbeat
|
||||
hbStop chan struct{}
|
||||
pubKey string
|
||||
signer ssh.Signer
|
||||
appURL string
|
||||
}
|
||||
|
||||
// NewHub creates a new Hub instance with default configuration
|
||||
func NewHub(app core.App) *Hub {
|
||||
hub := &Hub{App: app}
|
||||
hub.AlertManager = alerts.NewAlertManager(hub)
|
||||
hub.um = users.NewUserManager(hub)
|
||||
hub.rm = records.NewRecordManager(hub)
|
||||
hub.sm = systems.NewSystemManager(hub)
|
||||
hub.monSched = monitors.NewScheduler(app)
|
||||
hub.monAPI = monitors.NewAPIHandler(app, hub.monSched)
|
||||
hub.domainSched = domains.NewScheduler(app)
|
||||
hub.domainAPI = domains.NewAPIHandler(app, hub.domainSched)
|
||||
hub.exportAPI = export.NewAPIHandler(app)
|
||||
hub.hb = heartbeat.New(app, utils.GetEnv)
|
||||
if hub.hb != nil {
|
||||
hub.hbStop = make(chan struct{})
|
||||
}
|
||||
_ = onAfterBootstrapAndMigrations(app, hub.initialize)
|
||||
return hub
|
||||
}
|
||||
|
||||
// onAfterBootstrapAndMigrations ensures the provided function runs after the database is set up and migrations are applied.
|
||||
// This is a workaround for behavior in PocketBase where onBootstrap runs before migrations, forcing use of onServe for this purpose.
|
||||
// However, PB's tests.TestApp is already bootstrapped, generally doesn't serve, but does handle migrations.
|
||||
// So this ensures that the provided function runs at the right time either way, after DB is ready and migrations are done.
|
||||
func onAfterBootstrapAndMigrations(app core.App, fn func(app core.App) error) error {
|
||||
// pb tests.TestApp is already bootstrapped and doesn't serve
|
||||
if app.IsBootstrapped() {
|
||||
return fn(app)
|
||||
}
|
||||
// Must use OnServe because OnBootstrap appears to run before migrations, even if calling e.Next() before anything else
|
||||
app.OnServe().BindFunc(func(e *core.ServeEvent) error {
|
||||
if err := fn(e.App); err != nil {
|
||||
return err
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartHub sets up event handlers and starts the PocketBase server
|
||||
func (h *Hub) StartHub() error {
|
||||
h.App.OnServe().BindFunc(func(e *core.ServeEvent) error {
|
||||
// sync systems with config
|
||||
if err := config.SyncSystems(e); err != nil {
|
||||
return err
|
||||
}
|
||||
// register middlewares
|
||||
h.registerMiddlewares(e)
|
||||
// register api routes
|
||||
if err := h.registerApiRoutes(e); err != nil {
|
||||
return err
|
||||
}
|
||||
// register cron jobs
|
||||
if err := h.registerCronJobs(e); err != nil {
|
||||
return err
|
||||
}
|
||||
// start server
|
||||
if err := h.startServer(e); err != nil {
|
||||
return err
|
||||
}
|
||||
// start system updates
|
||||
if err := h.sm.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
// start heartbeat if configured
|
||||
if h.hb != nil {
|
||||
go h.hb.Start(h.hbStop)
|
||||
}
|
||||
// start monitor scheduler
|
||||
if err := h.monSched.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
// start domain scheduler
|
||||
h.domainSched.Start()
|
||||
// register monitor API routes
|
||||
h.monAPI.RegisterRoutes(e)
|
||||
// register domain API routes
|
||||
h.domainAPI.RegisterRoutes(e)
|
||||
// register export API routes
|
||||
h.exportAPI.RegisterRoutes(e)
|
||||
// bind monitor lifecycle hooks
|
||||
h.bindMonitorHooks()
|
||||
// bind domain lifecycle hooks
|
||||
h.bindDomainHooks()
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// TODO: move to users package
|
||||
// handle default values for user / user_settings creation
|
||||
h.App.OnRecordCreate("users").BindFunc(h.um.InitializeUserRole)
|
||||
h.App.OnRecordCreate("user_settings").BindFunc(h.um.InitializeUserSettings)
|
||||
|
||||
pb, ok := h.App.(*pocketbase.PocketBase)
|
||||
if !ok {
|
||||
return errors.New("not a pocketbase app")
|
||||
}
|
||||
return pb.Start()
|
||||
}
|
||||
|
||||
// initialize sets up initial configuration (collections, settings, etc.)
|
||||
func (h *Hub) initialize(app core.App) error {
|
||||
// set general settings
|
||||
settings := app.Settings()
|
||||
// batch requests (for alerts)
|
||||
settings.Batch.Enabled = true
|
||||
// set URL if APP_URL env is set
|
||||
if appURL, isSet := utils.GetEnv("APP_URL"); isSet {
|
||||
h.appURL = appURL
|
||||
settings.Meta.AppURL = appURL
|
||||
}
|
||||
if err := app.Save(settings); err != nil {
|
||||
return err
|
||||
}
|
||||
// set auth settings
|
||||
return setCollectionAuthSettings(app)
|
||||
}
|
||||
|
||||
// registerCronJobs sets up scheduled tasks
|
||||
func (h *Hub) registerCronJobs(_ *core.ServeEvent) error {
|
||||
// delete old system_stats and alerts_history records once every hour
|
||||
h.Cron().MustAdd("delete old records", "8 * * * *", h.rm.DeleteOldRecords)
|
||||
// create longer records every 10 minutes
|
||||
h.Cron().MustAdd("create longer records", "*/10 * * * *", h.rm.CreateLongerRecords)
|
||||
// cleanup old monitor heartbeats once a day (keep 30 days)
|
||||
h.Cron().MustAdd("cleanup old heartbeats", "0 0 * * *", func() {
|
||||
h.monSched.CleanupOldHeartbeats(30)
|
||||
})
|
||||
// check domain expiry daily at 1 AM
|
||||
h.Cron().MustAdd("check domains", "0 1 * * *", func() {
|
||||
h.domainSched.CheckAllDomains()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// bindMonitorHooks binds event hooks for monitor lifecycle management
|
||||
func (h *Hub) bindMonitorHooks() {
|
||||
// On create - add to scheduler
|
||||
h.OnRecordCreate("monitors").BindFunc(func(e *core.RecordEvent) error {
|
||||
// Only add to scheduler if active
|
||||
if e.Record.GetBool("active") {
|
||||
h.monSched.AddMonitor(e.Record)
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// On update - update scheduler
|
||||
h.OnRecordAfterUpdateSuccess("monitors").BindFunc(func(e *core.RecordEvent) error {
|
||||
h.monSched.UpdateMonitor(e.Record)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// On delete - remove from scheduler
|
||||
h.OnRecordAfterDeleteSuccess("monitors").BindFunc(func(e *core.RecordEvent) error {
|
||||
h.monSched.RemoveMonitor(e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
}
|
||||
|
||||
// bindDomainHooks binds event hooks for domain lifecycle management
|
||||
func (h *Hub) bindDomainHooks() {
|
||||
// On create - perform initial lookup if active
|
||||
h.OnRecordAfterCreateSuccess("domains").BindFunc(func(e *core.RecordEvent) error {
|
||||
if e.Record.GetBool("active") {
|
||||
h.domainSched.RefreshDomain(e.Record.Id)
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// On update - refresh if activated
|
||||
h.OnRecordAfterUpdateSuccess("domains").BindFunc(func(e *core.RecordEvent) error {
|
||||
if e.Record.GetBool("active") {
|
||||
h.domainSched.RefreshDomain(e.Record.Id)
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
}
|
||||
|
||||
// GetSSHKey generates key pair if it doesn't exist and returns signer
|
||||
func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) {
|
||||
if h.signer != nil {
|
||||
return h.signer, nil
|
||||
}
|
||||
|
||||
if dataDir == "" {
|
||||
dataDir = h.DataDir()
|
||||
}
|
||||
|
||||
privateKeyPath := path.Join(dataDir, "id_ed25519")
|
||||
|
||||
// check if the key pair already exists
|
||||
existingKey, err := os.ReadFile(privateKeyPath)
|
||||
if err == nil {
|
||||
private, err := ssh.ParsePrivateKey(existingKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %s", err)
|
||||
}
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(private.PublicKey())
|
||||
h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
|
||||
return private, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
// File exists but couldn't be read for some other reason
|
||||
return nil, fmt.Errorf("failed to read %s: %w", privateKeyPath, err)
|
||||
}
|
||||
|
||||
// Generate the Ed25519 key pair
|
||||
_, privKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privKeyPem, err := ssh.MarshalPrivateKey(privKey, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(privateKeyPath, pem.EncodeToMemory(privKeyPem), 0600); err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to %q: err: %w", privateKeyPath, err)
|
||||
}
|
||||
|
||||
// These are fine to ignore the errors on, as we've literally just created a crypto.PublicKey | crypto.Signer
|
||||
sshPrivate, _ := ssh.NewSignerFromSigner(privKey)
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPrivate.PublicKey())
|
||||
h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
|
||||
|
||||
h.Logger().Info("ed25519 key pair generated successfully.")
|
||||
h.Logger().Info("Saved to: " + privateKeyPath)
|
||||
|
||||
return sshPrivate, err
|
||||
}
|
||||
|
||||
// MakeLink formats a link with the app URL and path segments.
|
||||
// Only path segments should be provided.
|
||||
func (h *Hub) MakeLink(parts ...string) string {
|
||||
base := strings.TrimSuffix(h.Settings().Meta.AppURL, "/")
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
base = fmt.Sprintf("%s/%s", base, url.PathEscape(part))
|
||||
}
|
||||
return base
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
//go:build testing
|
||||
|
||||
package hub_test
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
beszelTests "github.com/henrygd/beszel/internal/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestMakeLink(t *testing.T) {
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
appURL string
|
||||
parts []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no parts, no trailing slash in AppURL",
|
||||
appURL: "http://localhost:8090",
|
||||
parts: []string{},
|
||||
expected: "http://localhost:8090",
|
||||
},
|
||||
{
|
||||
name: "no parts, with trailing slash in AppURL",
|
||||
appURL: "http://localhost:8090/",
|
||||
parts: []string{},
|
||||
expected: "http://localhost:8090", // TrimSuffix should handle the trailing slash
|
||||
},
|
||||
{
|
||||
name: "one part",
|
||||
appURL: "http://example.com",
|
||||
parts: []string{"one"},
|
||||
expected: "http://example.com/one",
|
||||
},
|
||||
{
|
||||
name: "multiple parts",
|
||||
appURL: "http://example.com",
|
||||
parts: []string{"alpha", "beta", "gamma"},
|
||||
expected: "http://example.com/alpha/beta/gamma",
|
||||
},
|
||||
{
|
||||
name: "parts with spaces needing escaping",
|
||||
appURL: "http://example.com",
|
||||
parts: []string{"path with spaces", "another part"},
|
||||
expected: "http://example.com/path%20with%20spaces/another%20part",
|
||||
},
|
||||
{
|
||||
name: "parts with slashes needing escaping",
|
||||
appURL: "http://example.com",
|
||||
parts: []string{"a/b", "c"},
|
||||
expected: "http://example.com/a%2Fb/c", // url.PathEscape escapes '/'
|
||||
},
|
||||
{
|
||||
name: "AppURL with subpath, no trailing slash",
|
||||
appURL: "http://localhost/sub",
|
||||
parts: []string{"resource"},
|
||||
expected: "http://localhost/sub/resource",
|
||||
},
|
||||
{
|
||||
name: "AppURL with subpath, with trailing slash",
|
||||
appURL: "http://localhost/sub/",
|
||||
parts: []string{"item"},
|
||||
expected: "http://localhost/sub/item",
|
||||
},
|
||||
{
|
||||
name: "empty parts in the middle",
|
||||
appURL: "http://localhost",
|
||||
parts: []string{"first", "", "third"},
|
||||
expected: "http://localhost/first/third",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing empty parts",
|
||||
appURL: "http://localhost",
|
||||
parts: []string{"", "path", ""},
|
||||
expected: "http://localhost/path",
|
||||
},
|
||||
{
|
||||
name: "parts with various special characters",
|
||||
appURL: "https://test.dev/",
|
||||
parts: []string{"p@th?", "key=value&"},
|
||||
expected: "https://test.dev/p@th%3F/key=value&",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Store original app URL and restore it after the test
|
||||
originalAppURL := hub.Settings().Meta.AppURL
|
||||
hub.Settings().Meta.AppURL = tt.appURL
|
||||
defer func() { hub.Settings().Meta.AppURL = originalAppURL }()
|
||||
|
||||
got := hub.MakeLink(tt.parts...)
|
||||
assert.Equal(t, tt.expected, got, "MakeLink generated URL does not match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSSHKey(t *testing.T) {
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
|
||||
// Test Case 1: Key generation (no existing key)
|
||||
t.Run("KeyGeneration", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Ensure pubKey is initially empty or different to ensure GetSSHKey sets it
|
||||
hub.SetPubkey("")
|
||||
|
||||
signer, err := hub.GetSSHKey(tempDir)
|
||||
assert.NoError(t, err, "GetSSHKey should not error when generating a new key")
|
||||
assert.NotNil(t, signer, "GetSSHKey should return a non-nil signer")
|
||||
|
||||
// Check if private key file was created
|
||||
privateKeyPath := filepath.Join(tempDir, "id_ed25519")
|
||||
info, err := os.Stat(privateKeyPath)
|
||||
assert.NoError(t, err, "Private key file should be created")
|
||||
assert.False(t, info.IsDir(), "Private key path should be a file, not a directory")
|
||||
|
||||
// Check if h.pubKey was set
|
||||
assert.NotEmpty(t, hub.GetPubkey(), "h.pubKey should be set after key generation")
|
||||
assert.True(t, strings.HasPrefix(hub.GetPubkey(), "ssh-ed25519 "), "h.pubKey should start with 'ssh-ed25519 '")
|
||||
|
||||
// Verify the generated private key is parsable
|
||||
keyData, err := os.ReadFile(privateKeyPath)
|
||||
require.NoError(t, err)
|
||||
_, err = ssh.ParsePrivateKey(keyData)
|
||||
assert.NoError(t, err, "Generated private key should be parsable by ssh.ParsePrivateKey")
|
||||
})
|
||||
|
||||
// Test Case 2: Existing key
|
||||
t.Run("ExistingKey", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Manually create a valid key pair for the test
|
||||
rawPubKey, rawPrivKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err, "Failed to generate raw ed25519 key pair for pre-existing key test")
|
||||
|
||||
// Marshal the private key into OpenSSH PEM format
|
||||
pemBlock, err := ssh.MarshalPrivateKey(rawPrivKey, "")
|
||||
require.NoError(t, err, "Failed to marshal private key to PEM block for pre-existing key test")
|
||||
|
||||
privateKeyBytes := pem.EncodeToMemory(pemBlock)
|
||||
require.NotNil(t, privateKeyBytes, "PEM encoded private key bytes should not be nil")
|
||||
|
||||
privateKeyPath := filepath.Join(tempDir, "id_ed25519")
|
||||
err = os.WriteFile(privateKeyPath, privateKeyBytes, 0600)
|
||||
require.NoError(t, err, "Failed to write pre-existing private key")
|
||||
|
||||
// Determine the expected public key string
|
||||
sshPubKey, err := ssh.NewPublicKey(rawPubKey)
|
||||
require.NoError(t, err)
|
||||
expectedPubKeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(sshPubKey)))
|
||||
|
||||
// Reset h.pubKey to ensure it's set by GetSSHKey from the file
|
||||
hub.SetPubkey("")
|
||||
|
||||
signer, err := hub.GetSSHKey(tempDir)
|
||||
assert.NoError(t, err, "GetSSHKey should not error when reading an existing key")
|
||||
assert.NotNil(t, signer, "GetSSHKey should return a non-nil signer for an existing key")
|
||||
|
||||
// Check if h.pubKey was set correctly to the public key from the file
|
||||
assert.Equal(t, expectedPubKeyStr, hub.GetPubkey(), "h.pubKey should match the existing public key")
|
||||
|
||||
// Verify the signer's public key matches the original public key
|
||||
signerPubKey := signer.PublicKey()
|
||||
marshaledSignerPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signerPubKey)))
|
||||
assert.Equal(t, expectedPubKeyStr, marshaledSignerPubKey, "Signer's public key should match the existing public key")
|
||||
})
|
||||
|
||||
// Test Case 3: Error cases
|
||||
t.Run("ErrorCases", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(dir string) error
|
||||
errorCheck func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "CorruptedKey",
|
||||
setupFunc: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "id_ed25519"), []byte("this is not a valid SSH key"), 0600)
|
||||
},
|
||||
errorCheck: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ssh: no key found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PermissionDenied",
|
||||
setupFunc: func(dir string) error {
|
||||
// Create the key file
|
||||
keyPath := filepath.Join(dir, "id_ed25519")
|
||||
if err := os.WriteFile(keyPath, []byte("dummy content"), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Make it read-only (can't be opened for writing in case a new key needs to be written)
|
||||
return os.Chmod(keyPath, 0400)
|
||||
},
|
||||
errorCheck: func(t *testing.T, err error) {
|
||||
// On read-only key, the parser will attempt to parse it and fail with "ssh: no key found"
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "EmptyFile",
|
||||
setupFunc: func(dir string) error {
|
||||
// Create an empty file
|
||||
return os.WriteFile(filepath.Join(dir, "id_ed25519"), []byte{}, 0600)
|
||||
},
|
||||
errorCheck: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
// The error from attempting to parse an empty file
|
||||
assert.Contains(t, err.Error(), "ssh: no key found")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the test case
|
||||
err := tc.setupFunc(tempDir)
|
||||
require.NoError(t, err, "Setup failed")
|
||||
|
||||
// Reset h.pubKey before each test case
|
||||
hub.SetPubkey("")
|
||||
|
||||
// Attempt to get SSH key
|
||||
_, err = hub.GetSSHKey(tempDir)
|
||||
|
||||
// Verify the error
|
||||
tc.errorCheck(t, err)
|
||||
|
||||
// Check that pubKey was not set in error cases
|
||||
assert.Empty(t, hub.GetPubkey(), "h.pubKey should not be set if there was an error")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppUrl(t *testing.T) {
|
||||
t.Run("no APP_URL does't change app url", func(t *testing.T) {
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
settings := hub.Settings()
|
||||
assert.Equal(t, "http://localhost:8090", settings.Meta.AppURL)
|
||||
})
|
||||
t.Run("APP_URL changes app url", func(t *testing.T) {
|
||||
t.Setenv("APP_URL", "http://example.com/app")
|
||||
hub, _ := beszelTests.NewTestHub(t.TempDir())
|
||||
defer hub.Cleanup()
|
||||
|
||||
settings := hub.Settings()
|
||||
assert.Equal(t, "http://example.com/app", settings.Meta.AppURL)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
//go:build testing
|
||||
|
||||
package hub
|
||||
|
||||
import (
|
||||
"github.com/henrygd/beszel/internal/hub/systems"
|
||||
)
|
||||
|
||||
// TESTING ONLY: GetSystemManager returns the system manager
|
||||
func (h *Hub) GetSystemManager() *systems.SystemManager {
|
||||
return h.sm
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetPubkey returns the public key
|
||||
func (h *Hub) GetPubkey() string {
|
||||
return h.pubKey
|
||||
}
|
||||
|
||||
// TESTING ONLY: SetPubkey sets the public key
|
||||
func (h *Hub) SetPubkey(pubkey string) {
|
||||
h.pubKey = pubkey
|
||||
}
|
||||
|
||||
func (h *Hub) SetCollectionAuthSettings() error {
|
||||
return setCollectionAuthSettings(h)
|
||||
}
|
||||
@@ -0,0 +1,564 @@
|
||||
package incidents
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/incident"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// APIHandler handles incident API requests
|
||||
type APIHandler struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
// NewAPIHandler creates a new incidents API handler
|
||||
func NewAPIHandler(app core.App) *APIHandler {
|
||||
return &APIHandler{app: app}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers incident API routes
|
||||
func (h *APIHandler) RegisterRoutes(se *core.ServeEvent) {
|
||||
api := se.Router.Group("/api/beszel/incidents")
|
||||
api.Bind(apis.RequireAuth())
|
||||
|
||||
api.GET("/", h.listIncidents)
|
||||
api.POST("/", h.createIncident)
|
||||
api.GET("/stats", h.getIncidentStats)
|
||||
api.GET("/calendar", h.getCalendarEvents)
|
||||
api.GET("/{id}", h.getIncident)
|
||||
api.PATCH("/{id}", h.updateIncident)
|
||||
api.POST("/{id}/acknowledge", h.acknowledgeIncident)
|
||||
api.POST("/{id}/resolve", h.resolveIncident)
|
||||
api.POST("/{id}/close", h.closeIncident)
|
||||
api.POST("/{id}/updates", h.addIncidentUpdate)
|
||||
api.GET("/{id}/updates", h.getIncidentUpdates)
|
||||
}
|
||||
|
||||
// listIncidents lists all incidents for the authenticated user
|
||||
func (h *APIHandler) listIncidents(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
// Get query params for filtering
|
||||
status := e.Request.URL.Query().Get("status")
|
||||
severity := e.Request.URL.Query().Get("severity")
|
||||
|
||||
query := dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id})
|
||||
|
||||
if status != "" {
|
||||
query = dbx.And(query, dbx.NewExp("status = {:status}", dbx.Params{"status": status}))
|
||||
}
|
||||
if severity != "" {
|
||||
query = dbx.And(query, dbx.NewExp("severity = {:severity}", dbx.Params{"severity": severity}))
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("incidents", query)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch incidents", err)
|
||||
}
|
||||
|
||||
incidents := make([]map[string]interface{}, 0, len(records))
|
||||
for _, record := range records {
|
||||
incidents = append(incidents, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, incidents)
|
||||
}
|
||||
|
||||
// createIncident creates a new incident
|
||||
func (h *APIHandler) createIncident(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
MonitorID *string `json:"monitor,omitempty"`
|
||||
DomainID *string `json:"domain,omitempty"`
|
||||
SystemID *string `json:"system,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Title == "" || req.Type == "" {
|
||||
return e.BadRequestError("title and type are required", nil)
|
||||
}
|
||||
|
||||
collection, err := h.app.FindCollectionByNameOrId("incidents")
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to find collection", err)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("title", req.Title)
|
||||
record.Set("description", req.Description)
|
||||
record.Set("type", req.Type)
|
||||
record.Set("severity", req.Severity)
|
||||
record.Set("status", incident.StatusOpen)
|
||||
record.Set("started_at", time.Now())
|
||||
if req.MonitorID != nil {
|
||||
record.Set("monitor", *req.MonitorID)
|
||||
}
|
||||
if req.DomainID != nil {
|
||||
record.Set("domain", *req.DomainID)
|
||||
}
|
||||
if req.SystemID != nil {
|
||||
record.Set("system", *req.SystemID)
|
||||
}
|
||||
record.Set("user", authRecord.Id)
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to create incident", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusCreated, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// getIncident gets a single incident
|
||||
func (h *APIHandler) getIncident(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// updateIncident updates an incident
|
||||
func (h *APIHandler) updateIncident(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Title *string `json:"title,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
AssignedTo *string `json:"assigned_to,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Title != nil {
|
||||
record.Set("title", *req.Title)
|
||||
}
|
||||
if req.Description != nil {
|
||||
record.Set("description", *req.Description)
|
||||
}
|
||||
if req.AssignedTo != nil {
|
||||
record.Set("assigned_to", *req.AssignedTo)
|
||||
}
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to update incident", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// acknowledgeIncident acknowledges an incident
|
||||
func (h *APIHandler) acknowledgeIncident(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
oldStatus := record.GetString("status")
|
||||
now := time.Now()
|
||||
record.Set("status", incident.StatusAcknowledged)
|
||||
record.Set("acknowledged_at", now)
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to acknowledge incident", err)
|
||||
}
|
||||
|
||||
// Add update record
|
||||
h.addUpdate(id, "Incident acknowledged", "status_change", &oldStatus, strPtr(incident.StatusAcknowledged), authRecord.Id)
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// resolveIncident resolves an incident
|
||||
func (h *APIHandler) resolveIncident(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
RootCause string `json:"root_cause,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
oldStatus := record.GetString("status")
|
||||
now := time.Now()
|
||||
record.Set("status", incident.StatusResolved)
|
||||
record.Set("resolved_at", now)
|
||||
if req.Resolution != "" {
|
||||
record.Set("resolution", req.Resolution)
|
||||
}
|
||||
if req.RootCause != "" {
|
||||
record.Set("root_cause", req.RootCause)
|
||||
}
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to resolve incident", err)
|
||||
}
|
||||
|
||||
// Add update record
|
||||
h.addUpdate(id, "Incident resolved: "+req.Resolution, "status_change", &oldStatus, strPtr(incident.StatusResolved), authRecord.Id)
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// closeIncident closes an incident
|
||||
func (h *APIHandler) closeIncident(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
oldStatus := record.GetString("status")
|
||||
now := time.Now()
|
||||
record.Set("status", incident.StatusClosed)
|
||||
record.Set("closed_at", now)
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to close incident", err)
|
||||
}
|
||||
|
||||
// Add update record
|
||||
h.addUpdate(id, "Incident closed", "status_change", &oldStatus, strPtr(incident.StatusClosed), authRecord.Id)
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// addIncidentUpdate adds an update to an incident
|
||||
func (h *APIHandler) addIncidentUpdate(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
|
||||
// Verify incident exists and belongs to user
|
||||
incident, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
if incident.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Message == "" {
|
||||
return e.BadRequestError("message is required", nil)
|
||||
}
|
||||
|
||||
h.addUpdate(id, req.Message, "note", nil, nil, authRecord.Id)
|
||||
|
||||
return e.JSON(http.StatusCreated, map[string]string{"status": "added"})
|
||||
}
|
||||
|
||||
// getIncidentUpdates gets all updates for an incident
|
||||
func (h *APIHandler) getIncidentUpdates(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
|
||||
// Verify incident exists and belongs to user
|
||||
incident, err := h.app.FindRecordById("incidents", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("incident not found", err)
|
||||
}
|
||||
if incident.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("incident_updates",
|
||||
dbx.NewExp("incident = {:incident}", dbx.Params{"incident": id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch updates", err)
|
||||
}
|
||||
|
||||
updates := make([]map[string]interface{}, 0, len(records))
|
||||
for _, record := range records {
|
||||
updates = append(updates, map[string]interface{}{
|
||||
"id": record.Id,
|
||||
"message": record.GetString("message"),
|
||||
"update_type": record.GetString("update_type"),
|
||||
"old_status": record.GetString("old_status"),
|
||||
"new_status": record.GetString("new_status"),
|
||||
"created_by": record.GetString("created_by"),
|
||||
"created_at": record.GetDateTime("created_at").String(),
|
||||
})
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, updates)
|
||||
}
|
||||
|
||||
// getIncidentStats returns incident statistics
|
||||
func (h *APIHandler) getIncidentStats(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
// Count by status
|
||||
total, _ := h.app.CountRecords("incidents", dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}))
|
||||
open, _ := h.app.CountRecords("incidents", dbx.HashExp{"user": authRecord.Id, "status": incident.StatusOpen})
|
||||
acknowledged, _ := h.app.CountRecords("incidents", dbx.HashExp{"user": authRecord.Id, "status": incident.StatusAcknowledged})
|
||||
resolved, _ := h.app.CountRecords("incidents", dbx.HashExp{"user": authRecord.Id, "status": incident.StatusResolved})
|
||||
|
||||
// Calculate MTTR
|
||||
resolvedRecords, _ := h.app.FindAllRecords("incidents",
|
||||
dbx.And(
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
dbx.NewExp("status = {:status}", dbx.Params{"status": incident.StatusResolved}),
|
||||
),
|
||||
)
|
||||
|
||||
var totalResolutionTime time.Duration
|
||||
for _, r := range resolvedRecords {
|
||||
started := r.GetDateTime("started_at").Time()
|
||||
resolved := r.GetDateTime("resolved_at").Time()
|
||||
if !started.IsZero() && !resolved.IsZero() {
|
||||
totalResolutionTime += resolved.Sub(started)
|
||||
}
|
||||
}
|
||||
|
||||
mttr := 0.0
|
||||
if len(resolvedRecords) > 0 {
|
||||
mttr = totalResolutionTime.Hours() / float64(len(resolvedRecords))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]interface{}{
|
||||
"total_incidents": total,
|
||||
"open_incidents": open,
|
||||
"acknowledged_incidents": acknowledged,
|
||||
"resolved_incidents": resolved,
|
||||
"mttr_hours": mttr,
|
||||
})
|
||||
}
|
||||
|
||||
// getCalendarEvents returns events for calendar view
|
||||
func (h *APIHandler) getCalendarEvents(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
events := []map[string]interface{}{}
|
||||
|
||||
// Domain expirations
|
||||
domains, _ := h.app.FindAllRecords("domains",
|
||||
dbx.NewExp("user = {:user} && expiry_date != ''", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
for _, d := range domains {
|
||||
expiryDate := d.GetDateTime("expiry_date").Time()
|
||||
domainName := d.GetString("domain_name")
|
||||
daysUntil := int(expiryDate.Sub(time.Now()).Hours() / 24)
|
||||
|
||||
var color string
|
||||
if daysUntil <= 7 {
|
||||
color = "#ef4444" // red
|
||||
} else if daysUntil <= 30 {
|
||||
color = "#f59e0b" // orange
|
||||
} else {
|
||||
color = "#3b82f6" // blue
|
||||
}
|
||||
|
||||
events = append(events, map[string]interface{}{
|
||||
"id": "domain-" + d.Id,
|
||||
"title": "🌐 " + domainName + " expires",
|
||||
"date": expiryDate.Format("2006-01-02"),
|
||||
"type": "domain_expiry",
|
||||
"color": color,
|
||||
})
|
||||
}
|
||||
|
||||
// SSL expirations
|
||||
for _, d := range domains {
|
||||
sslExpiry := d.GetDateTime("ssl_valid_to").Time()
|
||||
if sslExpiry.IsZero() {
|
||||
continue
|
||||
}
|
||||
domainName := d.GetString("domain_name")
|
||||
daysUntil := int(sslExpiry.Sub(time.Now()).Hours() / 24)
|
||||
|
||||
var color string
|
||||
if daysUntil <= 7 {
|
||||
color = "#ef4444"
|
||||
} else if daysUntil <= 14 {
|
||||
color = "#f59e0b"
|
||||
} else {
|
||||
color = "#8b5cf6"
|
||||
}
|
||||
|
||||
events = append(events, map[string]interface{}{
|
||||
"id": "ssl-" + d.Id,
|
||||
"title": "🔒 " + domainName + " SSL expires",
|
||||
"date": sslExpiry.Format("2006-01-02"),
|
||||
"type": "ssl_expiry",
|
||||
"color": color,
|
||||
})
|
||||
}
|
||||
|
||||
// Incidents
|
||||
incidents, _ := h.app.FindAllRecords("incidents",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
for _, i := range incidents {
|
||||
startedAt := i.GetDateTime("started_at").Time()
|
||||
title := i.GetString("title")
|
||||
severity := i.GetString("severity")
|
||||
|
||||
var color string
|
||||
switch severity {
|
||||
case incident.SeverityCritical:
|
||||
color = "#dc2626"
|
||||
case incident.SeverityHigh:
|
||||
color = "#ea580c"
|
||||
default:
|
||||
color = "#6b7280"
|
||||
}
|
||||
|
||||
events = append(events, map[string]interface{}{
|
||||
"id": "incident-" + i.Id,
|
||||
"title": "⚠️ " + title,
|
||||
"date": startedAt.Format("2006-01-02"),
|
||||
"type": "incident",
|
||||
"color": color,
|
||||
})
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, events)
|
||||
}
|
||||
|
||||
// addUpdate adds an update record
|
||||
func (h *APIHandler) addUpdate(incidentID, message, updateType string, oldStatus, newStatus *string, createdBy string) {
|
||||
collection, err := h.app.FindCollectionByNameOrId("incident_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("incident", incidentID)
|
||||
record.Set("message", message)
|
||||
record.Set("update_type", updateType)
|
||||
if oldStatus != nil {
|
||||
record.Set("old_status", *oldStatus)
|
||||
}
|
||||
if newStatus != nil {
|
||||
record.Set("new_status", *newStatus)
|
||||
}
|
||||
record.Set("created_by", createdBy)
|
||||
record.Set("created_at", time.Now())
|
||||
|
||||
h.app.Save(record)
|
||||
}
|
||||
|
||||
// recordToResponse converts a record to API response
|
||||
func (h *APIHandler) recordToResponse(record *core.Record) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"id": record.Id,
|
||||
"title": record.GetString("title"),
|
||||
"description": record.GetString("description"),
|
||||
"type": record.GetString("type"),
|
||||
"severity": record.GetString("severity"),
|
||||
"status": record.GetString("status"),
|
||||
"monitor": record.GetString("monitor"),
|
||||
"domain": record.GetString("domain"),
|
||||
"system": record.GetString("system"),
|
||||
"assigned_to": record.GetString("assigned_to"),
|
||||
"started_at": record.GetDateTime("started_at").String(),
|
||||
"acknowledged_at": record.GetDateTime("acknowledged_at").String(),
|
||||
"resolved_at": record.GetDateTime("resolved_at").String(),
|
||||
"closed_at": record.GetDateTime("closed_at").String(),
|
||||
"resolution": record.GetString("resolution"),
|
||||
"root_cause": record.GetString("root_cause"),
|
||||
"created": record.GetDateTime("created").String(),
|
||||
"updated": record.GetDateTime("updated").String(),
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,605 @@
|
||||
package monitors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/monitor"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// APIHandler handles monitor API endpoints
|
||||
type APIHandler struct {
|
||||
app core.App
|
||||
scheduler *Scheduler
|
||||
}
|
||||
|
||||
// NewAPIHandler creates a new monitor API handler
|
||||
func NewAPIHandler(app core.App, scheduler *Scheduler) *APIHandler {
|
||||
return &APIHandler{
|
||||
app: app,
|
||||
scheduler: scheduler,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers monitor API routes
|
||||
func (h *APIHandler) RegisterRoutes(se *core.ServeEvent) {
|
||||
api := se.Router.Group("/api/beszel/monitors")
|
||||
|
||||
// Require auth for all routes
|
||||
api.BindFunc(func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("Authentication required", nil)
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// CRUD endpoints
|
||||
api.GET("", h.listMonitors)
|
||||
api.POST("", h.createMonitor)
|
||||
api.GET("/:id", h.getMonitor)
|
||||
api.PATCH("/:id", h.updateMonitor)
|
||||
api.DELETE("/:id", h.deleteMonitor)
|
||||
|
||||
// Action endpoints
|
||||
api.POST("/:id/check", h.manualCheck)
|
||||
api.POST("/:id/pause", h.pauseMonitor)
|
||||
api.POST("/:id/resume", h.resumeMonitor)
|
||||
api.GET("/:id/stats", h.getStats)
|
||||
api.GET("/:id/heartbeats", h.getHeartbeats)
|
||||
}
|
||||
|
||||
// MonitorResponse represents a monitor in API responses
|
||||
type MonitorResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Interval int `json:"interval"`
|
||||
Timeout int `json:"timeout"`
|
||||
Retries int `json:"retries"`
|
||||
Status string `json:"status"`
|
||||
Active bool `json:"active"`
|
||||
Description string `json:"description,omitempty"`
|
||||
LastCheck *time.Time `json:"last_check,omitempty"`
|
||||
UptimeStats map[string]float64 `json:"uptime_stats,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Keyword string `json:"keyword,omitempty"`
|
||||
JSONQuery string `json:"json_query,omitempty"`
|
||||
ExpectedValue string `json:"expected_value,omitempty"`
|
||||
InvertKeyword bool `json:"invert_keyword"`
|
||||
DNSResolveServer string `json:"dns_resolve_server,omitempty"`
|
||||
DNSResolverMode string `json:"dns_resolver_mode,omitempty"`
|
||||
CertExpiryNotification bool `json:"cert_expiry_notification"`
|
||||
CertExpiryDays int `json:"cert_expiry_days,omitempty"`
|
||||
IgnoreTLSError bool `json:"ignore_tls_error"`
|
||||
Created time.Time `json:"created"`
|
||||
Updated time.Time `json:"updated"`
|
||||
}
|
||||
|
||||
// CreateMonitorRequest represents a request to create a monitor
|
||||
type CreateMonitorRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Headers string `json:"headers,omitempty"`
|
||||
Body string `json:"body,omitempty"`
|
||||
Interval int `json:"interval"`
|
||||
Timeout int `json:"timeout"`
|
||||
Retries int `json:"retries,omitempty"`
|
||||
RetryInterval int `json:"retry_interval,omitempty"`
|
||||
MaxRedirects int `json:"max_redirects,omitempty"`
|
||||
Keyword string `json:"keyword,omitempty"`
|
||||
JSONQuery string `json:"json_query,omitempty"`
|
||||
ExpectedValue string `json:"expected_value,omitempty"`
|
||||
InvertKeyword bool `json:"invert_keyword,omitempty"`
|
||||
DNSResolveServer string `json:"dns_resolve_server,omitempty"`
|
||||
DNSResolverMode string `json:"dns_resolver_mode,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CertExpiryNotification bool `json:"cert_expiry_notification,omitempty"`
|
||||
CertExpiryDays int `json:"cert_expiry_days,omitempty"`
|
||||
IgnoreTLSError bool `json:"ignore_tls_error,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateMonitorRequest represents a request to update a monitor
|
||||
type UpdateMonitorRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
URL *string `json:"url,omitempty"`
|
||||
Hostname *string `json:"hostname,omitempty"`
|
||||
Port *int `json:"port,omitempty"`
|
||||
Method *string `json:"method,omitempty"`
|
||||
Headers *string `json:"headers,omitempty"`
|
||||
Body *string `json:"body,omitempty"`
|
||||
Interval *int `json:"interval,omitempty"`
|
||||
Timeout *int `json:"timeout,omitempty"`
|
||||
Retries *int `json:"retries,omitempty"`
|
||||
RetryInterval *int `json:"retry_interval,omitempty"`
|
||||
MaxRedirects *int `json:"max_redirects,omitempty"`
|
||||
Keyword *string `json:"keyword,omitempty"`
|
||||
JSONQuery *string `json:"json_query,omitempty"`
|
||||
ExpectedValue *string `json:"expected_value,omitempty"`
|
||||
InvertKeyword *bool `json:"invert_keyword,omitempty"`
|
||||
DNSResolveServer *string `json:"dns_resolve_server,omitempty"`
|
||||
DNSResolverMode *string `json:"dns_resolver_mode,omitempty"`
|
||||
Active *bool `json:"active,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CertExpiryNotification *bool `json:"cert_expiry_notification,omitempty"`
|
||||
CertExpiryDays *int `json:"cert_expiry_days,omitempty"`
|
||||
IgnoreTLSError *bool `json:"ignore_tls_error,omitempty"`
|
||||
}
|
||||
|
||||
// listMonitors returns all monitors for the authenticated user
|
||||
func (h *APIHandler) listMonitors(e *core.RequestEvent) error {
|
||||
userID := e.Auth.Id
|
||||
|
||||
records, err := h.app.FindRecordsByFilter(
|
||||
"monitors",
|
||||
"user = {:userId}",
|
||||
"-created",
|
||||
0,
|
||||
0,
|
||||
map[string]any{"userId": userID},
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to fetch monitors", err)
|
||||
}
|
||||
|
||||
monitors := make([]MonitorResponse, 0, len(records))
|
||||
for _, record := range records {
|
||||
monitors = append(monitors, recordToResponse(record))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]interface{}{
|
||||
"monitors": monitors,
|
||||
})
|
||||
}
|
||||
|
||||
// getMonitor returns a single monitor by ID
|
||||
func (h *APIHandler) getMonitor(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, recordToResponse(record))
|
||||
}
|
||||
|
||||
// createMonitor creates a new monitor
|
||||
func (h *APIHandler) createMonitor(e *core.RequestEvent) error {
|
||||
var req CreateMonitorRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("Invalid request body", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Name == "" || req.Type == "" {
|
||||
return e.BadRequestError("Name and type are required", nil)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if req.Interval == 0 {
|
||||
req.Interval = 60
|
||||
}
|
||||
if req.Timeout == 0 {
|
||||
req.Timeout = 30
|
||||
}
|
||||
if req.Retries == 0 {
|
||||
req.Retries = 1
|
||||
}
|
||||
|
||||
// Get collection
|
||||
collection, err := h.app.FindCollectionByNameOrId("monitors")
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to get collection", err)
|
||||
}
|
||||
|
||||
// Create record
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("name", req.Name)
|
||||
record.Set("type", req.Type)
|
||||
record.Set("url", req.URL)
|
||||
record.Set("hostname", req.Hostname)
|
||||
record.Set("port", req.Port)
|
||||
record.Set("method", req.Method)
|
||||
record.Set("headers", req.Headers)
|
||||
record.Set("body", req.Body)
|
||||
record.Set("interval", req.Interval)
|
||||
record.Set("timeout", req.Timeout)
|
||||
record.Set("retries", req.Retries)
|
||||
record.Set("retry_interval", req.RetryInterval)
|
||||
record.Set("max_redirects", req.MaxRedirects)
|
||||
record.Set("keyword", req.Keyword)
|
||||
record.Set("json_query", req.JSONQuery)
|
||||
record.Set("expected_value", req.ExpectedValue)
|
||||
record.Set("invert_keyword", req.InvertKeyword)
|
||||
record.Set("dns_resolve_server", req.DNSResolveServer)
|
||||
record.Set("dns_resolver_mode", req.DNSResolverMode)
|
||||
record.Set("status", string(monitor.StatusPending))
|
||||
record.Set("active", true)
|
||||
record.Set("user", e.Auth.Id)
|
||||
record.Set("description", req.Description)
|
||||
record.Set("tags", req.Tags)
|
||||
record.Set("cert_expiry_notification", req.CertExpiryNotification)
|
||||
record.Set("cert_expiry_days", req.CertExpiryDays)
|
||||
record.Set("ignore_tls_error", req.IgnoreTLSError)
|
||||
record.Set("uptime_stats", map[string]float64{})
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("Failed to create monitor", err)
|
||||
}
|
||||
|
||||
// Add to scheduler
|
||||
h.scheduler.AddMonitor(record)
|
||||
|
||||
return e.JSON(http.StatusCreated, recordToResponse(record))
|
||||
}
|
||||
|
||||
// updateMonitor updates an existing monitor
|
||||
func (h *APIHandler) updateMonitor(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
var req UpdateMonitorRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("Invalid request body", err)
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.Name != nil {
|
||||
record.Set("name", *req.Name)
|
||||
}
|
||||
if req.URL != nil {
|
||||
record.Set("url", *req.URL)
|
||||
}
|
||||
if req.Hostname != nil {
|
||||
record.Set("hostname", *req.Hostname)
|
||||
}
|
||||
if req.Port != nil {
|
||||
record.Set("port", *req.Port)
|
||||
}
|
||||
if req.Method != nil {
|
||||
record.Set("method", *req.Method)
|
||||
}
|
||||
if req.Headers != nil {
|
||||
record.Set("headers", *req.Headers)
|
||||
}
|
||||
if req.Body != nil {
|
||||
record.Set("body", *req.Body)
|
||||
}
|
||||
if req.Interval != nil {
|
||||
record.Set("interval", *req.Interval)
|
||||
}
|
||||
if req.Timeout != nil {
|
||||
record.Set("timeout", *req.Timeout)
|
||||
}
|
||||
if req.Retries != nil {
|
||||
record.Set("retries", *req.Retries)
|
||||
}
|
||||
if req.RetryInterval != nil {
|
||||
record.Set("retry_interval", *req.RetryInterval)
|
||||
}
|
||||
if req.MaxRedirects != nil {
|
||||
record.Set("max_redirects", *req.MaxRedirects)
|
||||
}
|
||||
if req.Keyword != nil {
|
||||
record.Set("keyword", *req.Keyword)
|
||||
}
|
||||
if req.JSONQuery != nil {
|
||||
record.Set("json_query", *req.JSONQuery)
|
||||
}
|
||||
if req.ExpectedValue != nil {
|
||||
record.Set("expected_value", *req.ExpectedValue)
|
||||
}
|
||||
if req.InvertKeyword != nil {
|
||||
record.Set("invert_keyword", *req.InvertKeyword)
|
||||
}
|
||||
if req.DNSResolveServer != nil {
|
||||
record.Set("dns_resolve_server", *req.DNSResolveServer)
|
||||
}
|
||||
if req.DNSResolverMode != nil {
|
||||
record.Set("dns_resolver_mode", *req.DNSResolverMode)
|
||||
}
|
||||
if req.Active != nil {
|
||||
record.Set("active", *req.Active)
|
||||
}
|
||||
if req.Description != nil {
|
||||
record.Set("description", *req.Description)
|
||||
}
|
||||
if req.Tags != nil {
|
||||
record.Set("tags", req.Tags)
|
||||
}
|
||||
if req.CertExpiryNotification != nil {
|
||||
record.Set("cert_expiry_notification", *req.CertExpiryNotification)
|
||||
}
|
||||
if req.CertExpiryDays != nil {
|
||||
record.Set("cert_expiry_days", *req.CertExpiryDays)
|
||||
}
|
||||
if req.IgnoreTLSError != nil {
|
||||
record.Set("ignore_tls_error", *req.IgnoreTLSError)
|
||||
}
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("Failed to update monitor", err)
|
||||
}
|
||||
|
||||
// Update scheduler
|
||||
h.scheduler.UpdateMonitor(record)
|
||||
|
||||
return e.JSON(http.StatusOK, recordToResponse(record))
|
||||
}
|
||||
|
||||
// deleteMonitor deletes a monitor
|
||||
func (h *APIHandler) deleteMonitor(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
// Remove from scheduler first
|
||||
h.scheduler.RemoveMonitor(id)
|
||||
|
||||
if err := h.app.Delete(record); err != nil {
|
||||
return e.InternalServerError("Failed to delete monitor", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{"message": "Monitor deleted"})
|
||||
}
|
||||
|
||||
// manualCheck runs a manual check for a monitor
|
||||
func (h *APIHandler) manualCheck(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
result, err := h.scheduler.RunManualCheck(id)
|
||||
if err != nil {
|
||||
return e.InternalServerError("Check failed", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]interface{}{
|
||||
"status": result.Status,
|
||||
"ping": result.Ping,
|
||||
"msg": result.Msg,
|
||||
})
|
||||
}
|
||||
|
||||
// pauseMonitor pauses a monitor
|
||||
func (h *APIHandler) pauseMonitor(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
record.Set("active", false)
|
||||
record.Set("status", string(monitor.StatusPaused))
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("Failed to pause monitor", err)
|
||||
}
|
||||
|
||||
h.scheduler.UpdateMonitor(record)
|
||||
|
||||
return e.JSON(http.StatusOK, recordToResponse(record))
|
||||
}
|
||||
|
||||
// resumeMonitor resumes a paused monitor
|
||||
func (h *APIHandler) resumeMonitor(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
record.Set("active", true)
|
||||
record.Set("status", string(monitor.StatusPending))
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("Failed to resume monitor", err)
|
||||
}
|
||||
|
||||
h.scheduler.UpdateMonitor(record)
|
||||
|
||||
return e.JSON(http.StatusOK, recordToResponse(record))
|
||||
}
|
||||
|
||||
// getStats returns uptime statistics for a monitor
|
||||
func (h *APIHandler) getStats(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
stats24h, _ := h.scheduler.GetUptimeStats(id, 24)
|
||||
stats7d, _ := h.scheduler.GetUptimeStats(id, 168)
|
||||
stats30d, _ := h.scheduler.GetUptimeStats(id, 720)
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]interface{}{
|
||||
"uptime_24h": stats24h,
|
||||
"uptime_7d": stats7d,
|
||||
"uptime_30d": stats30d,
|
||||
})
|
||||
}
|
||||
|
||||
// getHeartbeats returns recent heartbeats for a monitor
|
||||
func (h *APIHandler) getHeartbeats(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.BadRequestError("Monitor ID is required", nil)
|
||||
}
|
||||
|
||||
record, err := h.app.FindRecordById("monitors", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Monitor not found", err)
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if record.GetString("user") != e.Auth.Id {
|
||||
return e.ForbiddenError("Access denied", nil)
|
||||
}
|
||||
|
||||
// Get limit from query, default 100
|
||||
limit := 100
|
||||
|
||||
records, err := h.app.FindRecordsByFilter(
|
||||
"monitor_heartbeats",
|
||||
"monitor = {:monitorId}",
|
||||
"-time",
|
||||
0,
|
||||
limit,
|
||||
map[string]any{"monitorId": id},
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to fetch heartbeats", err)
|
||||
}
|
||||
|
||||
heartbeats := make([]map[string]interface{}, 0, len(records))
|
||||
for _, hb := range records {
|
||||
heartbeats = append(heartbeats, map[string]interface{}{
|
||||
"id": hb.Id,
|
||||
"status": hb.GetString("status"),
|
||||
"ping": hb.GetInt("ping"),
|
||||
"msg": hb.GetString("msg"),
|
||||
"cert_expiry": hb.GetInt("cert_expiry"),
|
||||
"cert_valid": hb.GetBool("cert_valid"),
|
||||
"time": hb.Get("time"),
|
||||
})
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]interface{}{
|
||||
"heartbeats": heartbeats,
|
||||
})
|
||||
}
|
||||
|
||||
// recordToResponse converts a PocketBase record to MonitorResponse
|
||||
func recordToResponse(record *core.Record) MonitorResponse {
|
||||
resp := MonitorResponse{
|
||||
ID: record.Id,
|
||||
Name: record.GetString("name"),
|
||||
Type: record.GetString("type"),
|
||||
URL: record.GetString("url"),
|
||||
Hostname: record.GetString("hostname"),
|
||||
Port: record.GetInt("port"),
|
||||
Method: record.GetString("method"),
|
||||
Interval: record.GetInt("interval"),
|
||||
Timeout: record.GetInt("timeout"),
|
||||
Retries: record.GetInt("retries"),
|
||||
Status: record.GetString("status"),
|
||||
Active: record.GetBool("active"),
|
||||
Description: record.GetString("description"),
|
||||
Keyword: record.GetString("keyword"),
|
||||
JSONQuery: record.GetString("json_query"),
|
||||
ExpectedValue: record.GetString("expected_value"),
|
||||
InvertKeyword: record.GetBool("invert_keyword"),
|
||||
DNSResolveServer: record.GetString("dns_resolve_server"),
|
||||
DNSResolverMode: record.GetString("dns_resolver_mode"),
|
||||
CertExpiryNotification: record.GetBool("cert_expiry_notification"),
|
||||
CertExpiryDays: record.GetInt("cert_expiry_days"),
|
||||
IgnoreTLSError: record.GetBool("ignore_tls_error"),
|
||||
Created: record.GetDateTime("created").Time(),
|
||||
Updated: record.GetDateTime("updated").Time(),
|
||||
}
|
||||
|
||||
// Handle last_check
|
||||
if lc := record.Get("last_check"); lc != nil {
|
||||
if t, ok := lc.(time.Time); ok {
|
||||
resp.LastCheck = &t
|
||||
}
|
||||
}
|
||||
|
||||
// Handle uptime_stats
|
||||
if stats := record.Get("uptime_stats"); stats != nil {
|
||||
if s, ok := stats.(map[string]float64); ok {
|
||||
resp.UptimeStats = s
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tags
|
||||
if tags := record.Get("tags"); tags != nil {
|
||||
if t, ok := tags.([]string); ok {
|
||||
resp.Tags = t
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,636 @@
|
||||
package checks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/monitor"
|
||||
)
|
||||
|
||||
// Checker defines the interface for monitor check implementations
|
||||
type Checker interface {
|
||||
Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult
|
||||
}
|
||||
|
||||
// CheckerRegistry holds all monitor type checkers
|
||||
type CheckerRegistry struct {
|
||||
checkers map[string]Checker
|
||||
}
|
||||
|
||||
// NewCheckerRegistry creates a new registry with all checkers registered
|
||||
func NewCheckerRegistry() *CheckerRegistry {
|
||||
registry := &CheckerRegistry{
|
||||
checkers: make(map[string]Checker),
|
||||
}
|
||||
|
||||
// Register all checkers
|
||||
registry.Register(monitor.TypeHTTP, &HTTPChecker{})
|
||||
registry.Register(monitor.TypeHTTPS, &HTTPChecker{IsHTTPS: true})
|
||||
registry.Register(monitor.TypeTCP, &TCPChecker{})
|
||||
registry.Register(monitor.TypePing, &PingChecker{})
|
||||
registry.Register(monitor.TypeDNS, &DNSChecker{})
|
||||
registry.Register(monitor.TypeKeyword, &KeywordChecker{})
|
||||
registry.Register(monitor.TypeJSONQuery, &JSONQueryChecker{})
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
// Register adds a checker for a monitor type
|
||||
func (r *CheckerRegistry) Register(monitorType string, checker Checker) {
|
||||
r.checkers[monitorType] = checker
|
||||
}
|
||||
|
||||
// Get returns the checker for a monitor type
|
||||
func (r *CheckerRegistry) Get(monitorType string) (Checker, bool) {
|
||||
checker, ok := r.checkers[monitorType]
|
||||
return checker, ok
|
||||
}
|
||||
|
||||
// HTTPChecker performs HTTP/HTTPS checks
|
||||
type HTTPChecker struct {
|
||||
IsHTTPS bool
|
||||
}
|
||||
|
||||
// Check performs an HTTP/HTTPS check
|
||||
func (c *HTTPChecker) Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult {
|
||||
start := time.Now()
|
||||
|
||||
// Parse URL
|
||||
checkURL := m.URL
|
||||
if checkURL == "" {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: "URL is empty",
|
||||
Error: fmt.Errorf("URL is empty"),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure URL has scheme
|
||||
if !strings.HasPrefix(checkURL, "http://") && !strings.HasPrefix(checkURL, "https://") {
|
||||
if c.IsHTTPS {
|
||||
checkURL = "https://" + checkURL
|
||||
} else {
|
||||
checkURL = "http://" + checkURL
|
||||
}
|
||||
}
|
||||
|
||||
// Create request
|
||||
method := m.Method
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, checkURL, strings.NewReader(m.Body))
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: fmt.Sprintf("Failed to create request: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Add headers
|
||||
if m.Headers != "" {
|
||||
var headers map[string]string
|
||||
if err := json.Unmarshal([]byte(m.Headers), &headers); err == nil {
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create client with timeout and TLS config
|
||||
timeout := time.Duration(m.Timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
maxRedirects := m.MaxRedirects
|
||||
if maxRedirects == 0 {
|
||||
maxRedirects = 10
|
||||
}
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Configure TLS
|
||||
if c.IsHTTPS {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: m.IgnoreTLSError,
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: fmt.Sprintf("Request failed: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
ping := int(elapsed.Milliseconds())
|
||||
|
||||
// Check status code
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Ping: ping,
|
||||
Msg: fmt.Sprintf("HTTP %d", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// Check certificate if HTTPS and cert expiry notification enabled
|
||||
var certExpiry int
|
||||
var certValid bool
|
||||
if c.IsHTTPS && resp.TLS != nil && len(resp.TLS.PeerCertificates) > 0 {
|
||||
cert := resp.TLS.PeerCertificates[0]
|
||||
certValid = true
|
||||
certExpiry = int(time.Until(cert.NotAfter).Hours() / 24)
|
||||
}
|
||||
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusUp,
|
||||
Ping: ping,
|
||||
Msg: fmt.Sprintf("HTTP %d", resp.StatusCode),
|
||||
CertExpiry: certExpiry,
|
||||
CertValid: certValid,
|
||||
}
|
||||
}
|
||||
|
||||
// TCPChecker performs TCP port checks
|
||||
type TCPChecker struct{}
|
||||
|
||||
// Check performs a TCP port check
|
||||
func (c *TCPChecker) Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult {
|
||||
start := time.Now()
|
||||
|
||||
hostname := m.Hostname
|
||||
if hostname == "" {
|
||||
hostname = m.URL
|
||||
}
|
||||
|
||||
port := m.Port
|
||||
if port == 0 {
|
||||
port = 80
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
|
||||
// Create dialer with timeout
|
||||
timeout := time.Duration(m.Timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: fmt.Sprintf("Connection failed: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
ping := int(elapsed.Milliseconds())
|
||||
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusUp,
|
||||
Ping: ping,
|
||||
Msg: fmt.Sprintf("Connected in %dms", ping),
|
||||
}
|
||||
}
|
||||
|
||||
// PingChecker performs ICMP ping checks
|
||||
type PingChecker struct{}
|
||||
|
||||
// Check performs an ICMP ping check
|
||||
func (c *PingChecker) Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult {
|
||||
start := time.Now()
|
||||
|
||||
hostname := m.Hostname
|
||||
if hostname == "" {
|
||||
hostname = m.URL
|
||||
}
|
||||
|
||||
// Parse hostname to remove any scheme
|
||||
hostname = strings.TrimPrefix(hostname, "http://")
|
||||
hostname = strings.TrimPrefix(hostname, "https://")
|
||||
hostname = strings.TrimSuffix(hostname, "/")
|
||||
|
||||
// Resolve the address
|
||||
resolver := &net.Resolver{}
|
||||
addrs, err := resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: fmt.Sprintf("DNS lookup failed: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: "No IP addresses found",
|
||||
Error: fmt.Errorf("no IP addresses found"),
|
||||
}
|
||||
}
|
||||
|
||||
// Try to connect to port 7 (echo) or just check if host is reachable
|
||||
// Since raw ICMP requires root, we'll do a TCP connection to a common port
|
||||
address := net.JoinHostPort(addrs[0], "80")
|
||||
|
||||
timeout := time.Duration(m.Timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", address, timeout)
|
||||
if err != nil {
|
||||
// Try port 443
|
||||
address = net.JoinHostPort(addrs[0], "443")
|
||||
conn, err = net.DialTimeout("tcp", address, timeout)
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: fmt.Sprintf("Host unreachable: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
ping := int(elapsed.Milliseconds())
|
||||
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusUp,
|
||||
Ping: ping,
|
||||
Msg: fmt.Sprintf("Ping: %dms", ping),
|
||||
}
|
||||
}
|
||||
|
||||
// DNSChecker performs DNS resolution checks
|
||||
type DNSChecker struct{}
|
||||
|
||||
// Check performs a DNS resolution check
|
||||
func (c *DNSChecker) Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult {
|
||||
start := time.Now()
|
||||
|
||||
hostname := m.Hostname
|
||||
if hostname == "" {
|
||||
hostname = m.URL
|
||||
}
|
||||
|
||||
// Remove scheme if present
|
||||
hostname = strings.TrimPrefix(hostname, "http://")
|
||||
hostname = strings.TrimPrefix(hostname, "https://")
|
||||
hostname = strings.TrimSuffix(hostname, "/")
|
||||
|
||||
// Use custom DNS server if specified
|
||||
resolver := &net.Resolver{}
|
||||
if m.DNSResolveServer != "" {
|
||||
resolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, network, m.DNSResolveServer+":53")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
var results []string
|
||||
|
||||
// Perform DNS lookup based on record type
|
||||
recordType := m.DNSResolverMode
|
||||
if recordType == "" {
|
||||
recordType = "A"
|
||||
}
|
||||
|
||||
switch recordType {
|
||||
case "A", "AAAA":
|
||||
results, err = resolver.LookupHost(ctx, hostname)
|
||||
case "CNAME":
|
||||
var cname string
|
||||
cname, err = resolver.LookupCNAME(ctx, hostname)
|
||||
if err == nil && cname != "" {
|
||||
results = []string{cname}
|
||||
}
|
||||
case "MX":
|
||||
var mxRecords []*net.MX
|
||||
mxRecords, err = resolver.LookupMX(ctx, hostname)
|
||||
if err == nil {
|
||||
for _, mx := range mxRecords {
|
||||
results = append(results, fmt.Sprintf("%s (priority: %d)", mx.Host, mx.Pref))
|
||||
}
|
||||
}
|
||||
case "NS":
|
||||
var nsRecords []*net.NS
|
||||
nsRecords, err = resolver.LookupNS(ctx, hostname)
|
||||
if err == nil {
|
||||
for _, ns := range nsRecords {
|
||||
results = append(results, ns.Host)
|
||||
}
|
||||
}
|
||||
case "TXT":
|
||||
results, err = resolver.LookupTXT(ctx, hostname)
|
||||
case "SRV":
|
||||
// SRV requires service and protocol
|
||||
_, srvRecords, err := resolver.LookupSRV(ctx, "", "", hostname)
|
||||
if err == nil {
|
||||
for _, srv := range srvRecords {
|
||||
results = append(results, fmt.Sprintf("%s:%d (priority: %d, weight: %d)",
|
||||
srv.Target, srv.Port, srv.Priority, srv.Weight))
|
||||
}
|
||||
}
|
||||
default:
|
||||
results, err = resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
ping := int(elapsed.Milliseconds())
|
||||
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Msg: fmt.Sprintf("DNS lookup failed: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusUp,
|
||||
Ping: ping,
|
||||
Msg: fmt.Sprintf("Resolved %d records in %dms", len(results), ping),
|
||||
}
|
||||
}
|
||||
|
||||
// KeywordChecker performs HTTP checks with keyword validation
|
||||
type KeywordChecker struct{}
|
||||
|
||||
// Check performs an HTTP check with keyword validation
|
||||
func (c *KeywordChecker) Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult {
|
||||
// First do HTTP check
|
||||
httpChecker := &HTTPChecker{}
|
||||
result := httpChecker.Check(ctx, m)
|
||||
|
||||
if result.Status != monitor.StatusUp {
|
||||
return result
|
||||
}
|
||||
|
||||
// Now we need to fetch the body and check for keyword
|
||||
// Re-fetch the body since we closed it in HTTPChecker
|
||||
checkURL := m.URL
|
||||
if !strings.HasPrefix(checkURL, "http://") && !strings.HasPrefix(checkURL, "https://") {
|
||||
checkURL = "https://" + checkURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", checkURL, nil)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(m.Timeout) * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Ping: result.Ping,
|
||||
Msg: fmt.Sprintf("Failed to read body: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
keyword := m.Keyword
|
||||
found := strings.Contains(bodyStr, keyword)
|
||||
|
||||
// Handle invert keyword option
|
||||
if m.InvertKeyword {
|
||||
found = !found
|
||||
}
|
||||
|
||||
if !found {
|
||||
status := "not found"
|
||||
if m.InvertKeyword {
|
||||
status = "found (inverted)"
|
||||
}
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Ping: result.Ping,
|
||||
Msg: fmt.Sprintf("Keyword '%s' %s", keyword, status),
|
||||
Error: fmt.Errorf("keyword check failed"),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// JSONQueryChecker performs HTTP checks with JSON path validation
|
||||
type JSONQueryChecker struct{}
|
||||
|
||||
// Check performs an HTTP check with JSON path validation
|
||||
func (c *JSONQueryChecker) Check(ctx context.Context, m *monitor.Monitor) *monitor.CheckResult {
|
||||
// First do HTTP check
|
||||
httpChecker := &HTTPChecker{}
|
||||
result := httpChecker.Check(ctx, m)
|
||||
|
||||
if result.Status != monitor.StatusUp {
|
||||
return result
|
||||
}
|
||||
|
||||
// Re-fetch the body for JSON parsing
|
||||
checkURL := m.URL
|
||||
if !strings.HasPrefix(checkURL, "http://") && !strings.HasPrefix(checkURL, "https://") {
|
||||
checkURL = "https://" + checkURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", checkURL, nil)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(m.Timeout) * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Ping: result.Ping,
|
||||
Msg: fmt.Sprintf("Failed to read body: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var data interface{}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Ping: result.Ping,
|
||||
Msg: fmt.Sprintf("Invalid JSON: %v", err),
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Simple path evaluation (supports dot notation like "data.status")
|
||||
path := m.JSONQuery
|
||||
expectedValue := m.ExpectedValue
|
||||
|
||||
value := evaluateJSONPath(data, path)
|
||||
|
||||
if expectedValue != "" && value != expectedValue {
|
||||
return &monitor.CheckResult{
|
||||
Status: monitor.StatusDown,
|
||||
Ping: result.Ping,
|
||||
Msg: fmt.Sprintf("Expected '%s' but got '%s'", expectedValue, value),
|
||||
Error: fmt.Errorf("JSON value mismatch"),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// evaluateJSONPath extracts a value from JSON using dot notation path
|
||||
func evaluateJSONPath(data interface{}, path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(path, ".")
|
||||
current := data
|
||||
|
||||
for _, part := range parts {
|
||||
switch v := current.(type) {
|
||||
case map[string]interface{}:
|
||||
if val, ok := v[part]; ok {
|
||||
current = val
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
case []interface{}:
|
||||
// Try to parse as index
|
||||
if idx, err := strconv.Atoi(part); err == nil && idx >= 0 && idx < len(v) {
|
||||
current = v[idx]
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Convert result to string
|
||||
switch v := current.(type) {
|
||||
case string:
|
||||
return v
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
case nil:
|
||||
return "null"
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// URLValidator validates URL format
|
||||
func URLValidator(urlStr string) error {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("missing host in URL")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidStatusCode checks if HTTP status code is valid for UP status
|
||||
func IsValidStatusCode(code int, validCodes []int) bool {
|
||||
if len(validCodes) == 0 {
|
||||
return code >= 200 && code < 300
|
||||
}
|
||||
for _, validCode := range validCodes {
|
||||
if code == validCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractDomain extracts domain from URL or hostname
|
||||
func ExtractDomain(urlStr string) string {
|
||||
if urlStr == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to parse as URL first
|
||||
if u, err := url.Parse(urlStr); err == nil && u.Host != "" {
|
||||
return u.Hostname()
|
||||
}
|
||||
|
||||
// Remove scheme if present
|
||||
domain := urlStr
|
||||
domain = strings.TrimPrefix(domain, "http://")
|
||||
domain = strings.TrimPrefix(domain, "https://")
|
||||
domain = strings.TrimSuffix(domain, "/")
|
||||
|
||||
// Remove port if present
|
||||
if idx := strings.LastIndex(domain, ":"); idx != -1 {
|
||||
domain = domain[:idx]
|
||||
}
|
||||
|
||||
return domain
|
||||
}
|
||||
|
||||
// ValidateRegex validates a regex pattern
|
||||
func ValidateRegex(pattern string) error {
|
||||
_, err := regexp.Compile(pattern)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,456 @@
|
||||
package monitors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/monitor"
|
||||
"github.com/henrygd/beszel/internal/hub/monitors/checks"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
// Scheduler manages the periodic execution of monitor checks
|
||||
type Scheduler struct {
|
||||
app core.App
|
||||
registry *checks.CheckerRegistry
|
||||
monitors *store.Store[string, *ScheduledMonitor]
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// ScheduledMonitor wraps a monitor with scheduling info
|
||||
type ScheduledMonitor struct {
|
||||
Monitor *monitor.Monitor
|
||||
NextCheck time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewScheduler creates a new monitor scheduler
|
||||
func NewScheduler(app core.App) *Scheduler {
|
||||
return &Scheduler{
|
||||
app: app,
|
||||
registry: checks.NewCheckerRegistry(),
|
||||
monitors: store.New(map[string]*ScheduledMonitor{}),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the scheduler loop
|
||||
func (s *Scheduler) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("scheduler already running")
|
||||
}
|
||||
|
||||
// Load active monitors from database
|
||||
if err := s.loadMonitors(); err != nil {
|
||||
return fmt.Errorf("failed to load monitors: %w", err)
|
||||
}
|
||||
|
||||
// Start the ticker (minimum 20 second resolution)
|
||||
s.ticker = time.NewTicker(20 * time.Second)
|
||||
s.running = true
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
|
||||
log.Println("[monitor-scheduler] Started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop halts the scheduler
|
||||
func (s *Scheduler) Stop() {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.running = false
|
||||
s.ticker.Stop()
|
||||
close(s.stopChan)
|
||||
s.mu.Unlock()
|
||||
|
||||
s.wg.Wait()
|
||||
log.Println("[monitor-scheduler] Stopped")
|
||||
}
|
||||
|
||||
// run is the main scheduler loop
|
||||
func (s *Scheduler) run() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ticker.C:
|
||||
s.checkMonitors()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkMonitors checks all due monitors
|
||||
func (s *Scheduler) checkMonitors() {
|
||||
now := time.Now()
|
||||
|
||||
allMonitors := s.monitors.GetAll()
|
||||
for _, sm := range allMonitors {
|
||||
sm.mu.Lock()
|
||||
|
||||
// Skip if monitor is paused or not active
|
||||
if !sm.Monitor.Active || sm.Monitor.Status == monitor.StatusPaused {
|
||||
sm.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's time to run
|
||||
if now.Before(sm.NextCheck) {
|
||||
sm.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Schedule the next check
|
||||
interval := time.Duration(sm.Monitor.Interval) * time.Second
|
||||
if interval < 20*time.Second {
|
||||
interval = 20 * time.Second
|
||||
}
|
||||
sm.NextCheck = now.Add(interval)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Run check in background
|
||||
s.wg.Add(1)
|
||||
go func(m *monitor.Monitor) {
|
||||
defer s.wg.Done()
|
||||
s.runCheck(m)
|
||||
}(sm.Monitor)
|
||||
}
|
||||
}
|
||||
|
||||
// runCheck executes a single monitor check
|
||||
func (s *Scheduler) runCheck(m *monitor.Monitor) {
|
||||
// Get the appropriate checker
|
||||
checker, ok := s.registry.Get(m.Type)
|
||||
if !ok {
|
||||
log.Printf("[monitor-scheduler] No checker found for type: %s", m.Type)
|
||||
return
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
timeout := time.Duration(m.Timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute check
|
||||
result := checker.Check(ctx, m)
|
||||
|
||||
// Handle retries
|
||||
if result.Status == monitor.StatusDown && m.Retries > 0 {
|
||||
retryInterval := time.Duration(m.RetryInterval) * time.Second
|
||||
if retryInterval == 0 {
|
||||
retryInterval = time.Second
|
||||
}
|
||||
|
||||
for i := 0; i < m.Retries && result.Status == monitor.StatusDown; i++ {
|
||||
time.Sleep(retryInterval)
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), timeout)
|
||||
result = checker.Check(ctx, m)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Save heartbeat and update monitor status
|
||||
if err := s.saveResult(m, result); err != nil {
|
||||
log.Printf("[monitor-scheduler] Failed to save result: %v", err)
|
||||
}
|
||||
|
||||
// Log result
|
||||
if result.Status == monitor.StatusUp {
|
||||
log.Printf("[monitor-scheduler] Check UP: %s (ping: %dms)", m.Name, result.Ping)
|
||||
} else {
|
||||
log.Printf("[monitor-scheduler] Check DOWN: %s - %s", m.Name, result.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
// saveResult saves the check result to the database
|
||||
func (s *Scheduler) saveResult(m *monitor.Monitor, result *monitor.CheckResult) error {
|
||||
// Update monitor record
|
||||
record, err := s.app.FindRecordById("monitors", m.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find monitor: %w", err)
|
||||
}
|
||||
|
||||
// Update status
|
||||
record.Set("status", string(result.Status))
|
||||
record.Set("last_check", time.Now())
|
||||
|
||||
// Calculate uptime stats (simplified - in production would aggregate from heartbeats)
|
||||
if m.UptimeStats == nil {
|
||||
m.UptimeStats = make(map[string]float64)
|
||||
}
|
||||
|
||||
// Simple rolling uptime calculation (can be improved)
|
||||
if result.Status == monitor.StatusUp {
|
||||
m.UptimeStats["total"] = m.UptimeStats["total"] + 1
|
||||
m.UptimeStats["up"] = m.UptimeStats["up"] + 1
|
||||
} else {
|
||||
m.UptimeStats["total"] = m.UptimeStats["total"] + 1
|
||||
m.UptimeStats["down"] = m.UptimeStats["down"] + 1
|
||||
}
|
||||
|
||||
if total := m.UptimeStats["total"]; total > 0 {
|
||||
m.UptimeStats["uptime_24h"] = (m.UptimeStats["up"] / total) * 100
|
||||
}
|
||||
|
||||
record.Set("uptime_stats", m.UptimeStats)
|
||||
|
||||
if err := s.app.Save(record); err != nil {
|
||||
return fmt.Errorf("failed to update monitor: %w", err)
|
||||
}
|
||||
|
||||
// Create heartbeat record
|
||||
hbCollection, err := s.app.FindCollectionByNameOrId("monitor_heartbeats")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find heartbeats collection: %w", err)
|
||||
}
|
||||
|
||||
hbRecord := core.NewRecord(hbCollection)
|
||||
hbRecord.Set("monitor", m.ID)
|
||||
hbRecord.Set("status", string(result.Status))
|
||||
hbRecord.Set("ping", result.Ping)
|
||||
hbRecord.Set("msg", result.Msg)
|
||||
hbRecord.Set("cert_expiry", result.CertExpiry)
|
||||
hbRecord.Set("cert_valid", result.CertValid)
|
||||
hbRecord.Set("time", time.Now())
|
||||
|
||||
if err := s.app.Save(hbRecord); err != nil {
|
||||
return fmt.Errorf("failed to save heartbeat: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadMonitors loads active monitors from the database
|
||||
func (s *Scheduler) loadMonitors() error {
|
||||
records, err := s.app.FindRecordsByFilter("monitors", "active = true", "-created", 0, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query monitors: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
m := recordToMonitor(record)
|
||||
s.monitors.Set(m.ID, &ScheduledMonitor{
|
||||
Monitor: m,
|
||||
NextCheck: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
log.Printf("[monitor-scheduler] Loaded %d monitors", len(records))
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMonitor adds a new monitor to the scheduler
|
||||
func (s *Scheduler) AddMonitor(record *core.Record) {
|
||||
m := recordToMonitor(record)
|
||||
|
||||
s.monitors.Set(m.ID, &ScheduledMonitor{
|
||||
Monitor: m,
|
||||
NextCheck: time.Now(),
|
||||
})
|
||||
|
||||
log.Printf("[monitor-scheduler] Added monitor: %s (%s)", m.Name, m.Type)
|
||||
}
|
||||
|
||||
// UpdateMonitor updates a monitor in the scheduler
|
||||
func (s *Scheduler) UpdateMonitor(record *core.Record) {
|
||||
m := recordToMonitor(record)
|
||||
|
||||
// Get existing scheduled monitor to preserve next check time if appropriate
|
||||
if sm, ok := s.monitors.GetOk(m.ID); ok {
|
||||
sm.mu.Lock()
|
||||
sm.Monitor = m
|
||||
sm.mu.Unlock()
|
||||
} else {
|
||||
s.monitors.Set(m.ID, &ScheduledMonitor{
|
||||
Monitor: m,
|
||||
NextCheck: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
log.Printf("[monitor-scheduler] Updated monitor: %s", m.Name)
|
||||
}
|
||||
|
||||
// RemoveMonitor removes a monitor from the scheduler
|
||||
func (s *Scheduler) RemoveMonitor(monitorID string) {
|
||||
s.monitors.Remove(monitorID)
|
||||
log.Printf("[monitor-scheduler] Removed monitor: %s", monitorID)
|
||||
}
|
||||
|
||||
// RunManualCheck runs a manual check for a monitor
|
||||
func (s *Scheduler) RunManualCheck(monitorID string) (*monitor.CheckResult, error) {
|
||||
// Get monitor from database
|
||||
record, err := s.app.FindRecordById("monitors", monitorID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("monitor not found: %w", err)
|
||||
}
|
||||
|
||||
m := recordToMonitor(record)
|
||||
|
||||
// Get checker
|
||||
checker, ok := s.registry.Get(m.Type)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no checker for type: %s", m.Type)
|
||||
}
|
||||
|
||||
// Run check
|
||||
timeout := time.Duration(m.Timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result := checker.Check(ctx, m)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUptimeStats calculates uptime statistics for a monitor
|
||||
func (s *Scheduler) GetUptimeStats(monitorID string, hours int) (*monitor.UptimeStats, error) {
|
||||
// Query heartbeats from the last N hours
|
||||
since := time.Now().Add(-time.Duration(hours) * time.Hour)
|
||||
|
||||
records, err := s.app.FindRecordsByFilter(
|
||||
"monitor_heartbeats",
|
||||
"monitor = {:monitorId} && time >= {:since}",
|
||||
"-time",
|
||||
0,
|
||||
0,
|
||||
map[string]any{
|
||||
"monitorId": monitorID,
|
||||
"since": since.Format("2006-01-02 15:04:05"),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query heartbeats: %w", err)
|
||||
}
|
||||
|
||||
stats := &monitor.UptimeStats{}
|
||||
|
||||
for _, record := range records {
|
||||
stats.Total++
|
||||
status := record.GetString("status")
|
||||
if status == string(monitor.StatusUp) {
|
||||
stats.Up++
|
||||
} else if status == string(monitor.StatusDown) {
|
||||
stats.Down++
|
||||
}
|
||||
}
|
||||
|
||||
if stats.Total > 0 {
|
||||
uptime := float64(stats.Up) / float64(stats.Total) * 100
|
||||
switch hours {
|
||||
case 24:
|
||||
stats.Uptime24h = uptime
|
||||
case 168: // 7 days
|
||||
stats.Uptime7d = uptime
|
||||
case 720: // 30 days
|
||||
stats.Uptime30d = uptime
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// recordToMonitor converts a PocketBase record to a Monitor struct
|
||||
func recordToMonitor(record *core.Record) *monitor.Monitor {
|
||||
m := &monitor.Monitor{
|
||||
ID: record.Id,
|
||||
Name: record.GetString("name"),
|
||||
Type: record.GetString("type"),
|
||||
URL: record.GetString("url"),
|
||||
Hostname: record.GetString("hostname"),
|
||||
Port: record.GetInt("port"),
|
||||
Method: record.GetString("method"),
|
||||
Headers: record.GetString("headers"),
|
||||
Body: record.GetString("body"),
|
||||
Interval: record.GetInt("interval"),
|
||||
Timeout: record.GetInt("timeout"),
|
||||
Retries: record.GetInt("retries"),
|
||||
RetryInterval: record.GetInt("retry_interval"),
|
||||
MaxRedirects: record.GetInt("max_redirects"),
|
||||
Keyword: record.GetString("keyword"),
|
||||
JSONQuery: record.GetString("json_query"),
|
||||
ExpectedValue: record.GetString("expected_value"),
|
||||
InvertKeyword: record.GetBool("invert_keyword"),
|
||||
DNSResolveServer: record.GetString("dns_resolve_server"),
|
||||
DNSResolverMode: record.GetString("dns_resolver_mode"),
|
||||
Status: monitor.Status(record.GetString("status")),
|
||||
Active: record.GetBool("active"),
|
||||
UserID: record.GetString("user"),
|
||||
Description: record.GetString("description"),
|
||||
CertExpiryNotification: record.GetBool("cert_expiry_notification"),
|
||||
CertExpiryDays: record.GetInt("cert_expiry_days"),
|
||||
IgnoreTLSError: record.GetBool("ignore_tls_error"),
|
||||
}
|
||||
|
||||
// Parse JSON fields
|
||||
if tagsData := record.Get("tags"); tagsData != nil {
|
||||
if tags, ok := tagsData.([]string); ok {
|
||||
m.Tags = tags
|
||||
}
|
||||
}
|
||||
|
||||
if statsData := record.Get("uptime_stats"); statsData != nil {
|
||||
if stats, ok := statsData.(map[string]float64); ok {
|
||||
m.UptimeStats = stats
|
||||
}
|
||||
}
|
||||
|
||||
if lastCheck := record.Get("last_check"); lastCheck != nil {
|
||||
if t, ok := lastCheck.(time.Time); ok {
|
||||
m.LastCheck = t
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// CleanupOldHeartbeats removes heartbeats older than retention period
|
||||
func (s *Scheduler) CleanupOldHeartbeats(retentionDays int) error {
|
||||
cutoff := time.Now().AddDate(0, 0, -retentionDays)
|
||||
|
||||
records, err := s.app.FindRecordsByFilter(
|
||||
"monitor_heartbeats",
|
||||
"time < {:cutoff}",
|
||||
"",
|
||||
0,
|
||||
0,
|
||||
map[string]any{
|
||||
"cutoff": cutoff.Format("2006-01-02 15:04:05"),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find old heartbeats: %w", err)
|
||||
}
|
||||
|
||||
deleted := 0
|
||||
for _, record := range records {
|
||||
if err := s.app.Delete(record); err == nil {
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[monitor-scheduler] Cleaned up %d old heartbeats", deleted)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package notifications
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// RegisterRoutes registers notification API routes
|
||||
func RegisterRoutes(app core.App, se *core.ServeEvent) {
|
||||
api := &NotificationAPI{app: app}
|
||||
|
||||
group := se.Router.Group("/api/beszel/notifications")
|
||||
group.Bind(apis.RequireAuth())
|
||||
|
||||
group.GET("/", api.listNotifications)
|
||||
group.POST("/", api.createNotification)
|
||||
group.GET("/{id}", api.getNotification)
|
||||
group.PATCH("/{id}", api.updateNotification)
|
||||
group.DELETE("/{id}", api.deleteNotification)
|
||||
group.POST("/{id}/test", api.testNotification)
|
||||
}
|
||||
|
||||
// NotificationAPI handles notification API requests
|
||||
type NotificationAPI struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
// CreateNotificationRequest represents a notification creation request
|
||||
type CreateNotificationRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// UpdateNotificationRequest represents a notification update request
|
||||
type UpdateNotificationRequest struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Settings map[string]interface{} `json:"settings,omitempty"`
|
||||
IsDefault *bool `json:"is_default,omitempty"`
|
||||
Active *bool `json:"active,omitempty"`
|
||||
}
|
||||
|
||||
// NotificationResponse represents a notification response
|
||||
type NotificationResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Active bool `json:"active"`
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
Created string `json:"created"`
|
||||
Updated string `json:"updated"`
|
||||
}
|
||||
|
||||
// listNotifications lists all notifications for the authenticated user
|
||||
func (api *NotificationAPI) listNotifications(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
records, err := api.app.FindAllRecords("notifications",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch notifications", err)
|
||||
}
|
||||
|
||||
notifications := make([]NotificationResponse, 0, len(records))
|
||||
for _, record := range records {
|
||||
notifications = append(notifications, api.recordToResponse(record))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, notifications)
|
||||
}
|
||||
|
||||
// createNotification creates a new notification
|
||||
func (api *NotificationAPI) createNotification(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
var req CreateNotificationRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Name == "" || req.Type == "" {
|
||||
return e.BadRequestError("name and type are required", nil)
|
||||
}
|
||||
|
||||
collection, err := api.app.FindCollectionByNameOrId("notifications")
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to find collection", err)
|
||||
}
|
||||
|
||||
settingsJSON, _ := json.Marshal(req.Settings)
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("name", req.Name)
|
||||
record.Set("type", req.Type)
|
||||
record.Set("settings", string(settingsJSON))
|
||||
record.Set("is_default", req.IsDefault)
|
||||
record.Set("active", true)
|
||||
record.Set("user", authRecord.Id)
|
||||
|
||||
if err := api.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to create notification", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusCreated, api.recordToResponse(record))
|
||||
}
|
||||
|
||||
// getNotification gets a single notification
|
||||
func (api *NotificationAPI) getNotification(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := api.app.FindRecordById("notifications", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("notification not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, api.recordToResponse(record))
|
||||
}
|
||||
|
||||
// updateNotification updates a notification
|
||||
func (api *NotificationAPI) updateNotification(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := api.app.FindRecordById("notifications", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("notification not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req UpdateNotificationRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Name != "" {
|
||||
record.Set("name", req.Name)
|
||||
}
|
||||
if req.Settings != nil {
|
||||
settingsJSON, _ := json.Marshal(req.Settings)
|
||||
record.Set("settings", string(settingsJSON))
|
||||
}
|
||||
if req.IsDefault != nil {
|
||||
record.Set("is_default", *req.IsDefault)
|
||||
}
|
||||
if req.Active != nil {
|
||||
record.Set("active", *req.Active)
|
||||
}
|
||||
|
||||
if err := api.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to update notification", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, api.recordToResponse(record))
|
||||
}
|
||||
|
||||
// deleteNotification deletes a notification
|
||||
func (api *NotificationAPI) deleteNotification(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := api.app.FindRecordById("notifications", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("notification not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
if err := api.app.Delete(record); err != nil {
|
||||
return e.InternalServerError("failed to delete notification", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// testNotification sends a test notification
|
||||
func (api *NotificationAPI) testNotification(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := api.app.FindRecordById("notifications", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("notification not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
notif := ¬ification.Notification{
|
||||
ID: record.Id,
|
||||
Name: record.GetString("name"),
|
||||
Type: record.GetString("type"),
|
||||
Active: record.GetBool("active"),
|
||||
}
|
||||
|
||||
if settingsJSON := record.GetString("settings"); settingsJSON != "" {
|
||||
var settings map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(settingsJSON), &settings); err == nil {
|
||||
notif.Settings = settings
|
||||
}
|
||||
}
|
||||
|
||||
dispatcher := NewDispatcher(api.app)
|
||||
msg := ¬ification.NotificationMessage{
|
||||
Title: "Test Notification",
|
||||
Body: "This is a test notification from Beszel.",
|
||||
MonitorName: "Test Monitor",
|
||||
Status: "UP",
|
||||
Message: "Test message",
|
||||
}
|
||||
|
||||
provider, err := dispatcher.getProvider(notif)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to create provider", err)
|
||||
}
|
||||
|
||||
if err := provider.Send(msg); err != nil {
|
||||
return e.InternalServerError("failed to send test notification", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{"status": "sent"})
|
||||
}
|
||||
|
||||
// recordToResponse converts a record to a response
|
||||
func (api *NotificationAPI) recordToResponse(record *core.Record) NotificationResponse {
|
||||
var settings map[string]interface{}
|
||||
if settingsJSON := record.GetString("settings"); settingsJSON != "" {
|
||||
json.Unmarshal([]byte(settingsJSON), &settings)
|
||||
}
|
||||
|
||||
return NotificationResponse{
|
||||
ID: record.Id,
|
||||
Name: record.GetString("name"),
|
||||
Type: record.GetString("type"),
|
||||
IsDefault: record.GetBool("is_default"),
|
||||
Active: record.GetBool("active"),
|
||||
Settings: settings,
|
||||
Created: record.GetDateTime("created").String(),
|
||||
Updated: record.GetDateTime("updated").String(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
package notifications
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/monitor"
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
"github.com/henrygd/beszel/internal/hub/notifications/providers"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// Dispatcher manages notification sending for monitor events
|
||||
type Dispatcher struct {
|
||||
app core.App
|
||||
mu sync.RWMutex
|
||||
providers map[string]notification.Provider
|
||||
}
|
||||
|
||||
// NewDispatcher creates a new notification dispatcher
|
||||
func NewDispatcher(app core.App) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
app: app,
|
||||
providers: make(map[string]notification.Provider),
|
||||
}
|
||||
}
|
||||
|
||||
// SendNotification sends a notification for a monitor event
|
||||
func (d *Dispatcher) SendNotification(monitorRecord *monitor.Monitor, heartbeat *monitor.Heartbeat, isRecovery bool) {
|
||||
// Get linked notifications for this monitor
|
||||
notifications, err := d.getMonitorNotifications(monitorRecord.ID)
|
||||
if err != nil {
|
||||
log.Printf("[notification-dispatcher] Failed to get notifications: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(notifications) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Build the message
|
||||
msg := d.buildMessage(monitorRecord, heartbeat, isRecovery)
|
||||
|
||||
// Send to each notification provider
|
||||
for _, n := range notifications {
|
||||
if !n.Active {
|
||||
continue
|
||||
}
|
||||
|
||||
provider, err := d.getProvider(n)
|
||||
if err != nil {
|
||||
log.Printf("[notification-dispatcher] Failed to get provider: %v", err)
|
||||
d.logNotificationEvent(n.ID, monitorRecord.ID, "failed", "", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if err := provider.Send(msg); err != nil {
|
||||
log.Printf("[notification-dispatcher] Failed to send notification: %v", err)
|
||||
d.logNotificationEvent(n.ID, monitorRecord.ID, "failed", "", err.Error())
|
||||
} else {
|
||||
log.Printf("[notification-dispatcher] Sent notification via %s for monitor %s", n.Type, monitorRecord.Name)
|
||||
d.logNotificationEvent(n.ID, monitorRecord.ID, "sent", "", "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getMonitorNotifications retrieves all notifications linked to a monitor
|
||||
func (d *Dispatcher) getMonitorNotifications(monitorID string) ([]*notification.Notification, error) {
|
||||
// Find monitor_notification records for this monitor
|
||||
records, err := d.app.FindAllRecords("monitor_notifications",
|
||||
dbx.HashExp{"monitor": monitorID},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var notifications []*notification.Notification
|
||||
for _, record := range records {
|
||||
notificationID := record.GetString("notification")
|
||||
notifRecord, err := d.app.FindRecordById("notifications", notificationID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
notif := ¬ification.Notification{
|
||||
ID: notifRecord.Id,
|
||||
Name: notifRecord.GetString("name"),
|
||||
Type: notifRecord.GetString("type"),
|
||||
IsDefault: notifRecord.GetBool("is_default"),
|
||||
Active: notifRecord.GetBool("active"),
|
||||
}
|
||||
|
||||
// Parse settings from JSON
|
||||
if settingsJSON := notifRecord.GetString("settings"); settingsJSON != "" {
|
||||
var settings map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(settingsJSON), &settings); err == nil {
|
||||
notif.Settings = settings
|
||||
}
|
||||
}
|
||||
|
||||
notifications = append(notifications, notif)
|
||||
}
|
||||
|
||||
return notifications, nil
|
||||
}
|
||||
|
||||
// getProvider creates a provider instance for a notification config
|
||||
func (d *Dispatcher) getProvider(n *notification.Notification) (notification.Provider, error) {
|
||||
d.mu.RLock()
|
||||
if provider, ok := d.providers[n.ID]; ok {
|
||||
d.mu.RUnlock()
|
||||
return provider, nil
|
||||
}
|
||||
d.mu.RUnlock()
|
||||
|
||||
var provider notification.Provider
|
||||
|
||||
switch n.Type {
|
||||
case notification.ProviderEmail:
|
||||
settings := n.GetSettings().(notification.EmailSettings)
|
||||
provider = providers.NewEmailProvider(settings)
|
||||
case notification.ProviderWebhook:
|
||||
settings := n.GetSettings().(notification.WebhookSettings)
|
||||
provider = providers.NewWebhookProvider(settings)
|
||||
case notification.ProviderDiscord:
|
||||
settings := n.GetSettings().(notification.DiscordSettings)
|
||||
provider = providers.NewDiscordProvider(settings)
|
||||
case notification.ProviderSlack:
|
||||
settings := n.GetSettings().(notification.SlackSettings)
|
||||
provider = providers.NewSlackProvider(settings)
|
||||
case notification.ProviderTelegram:
|
||||
settings := n.GetSettings().(notification.TelegramSettings)
|
||||
provider = providers.NewTelegramProvider(settings)
|
||||
case notification.ProviderGotify:
|
||||
settings := n.GetSettings().(notification.GotifySettings)
|
||||
provider = providers.NewGotifyProvider(settings)
|
||||
case notification.ProviderPushover:
|
||||
settings := n.GetSettings().(notification.PushoverSettings)
|
||||
provider = providers.NewPushoverProvider(settings)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider type: %s", n.Type)
|
||||
}
|
||||
|
||||
if err := provider.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
d.providers[n.ID] = provider
|
||||
d.mu.Unlock()
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// buildMessage creates a notification message from monitor data
|
||||
func (d *Dispatcher) buildMessage(m *monitor.Monitor, h *monitor.Heartbeat, isRecovery bool) *notification.NotificationMessage {
|
||||
status := "DOWN"
|
||||
if isRecovery {
|
||||
status = "UP"
|
||||
}
|
||||
|
||||
title := fmt.Sprintf("%s is %s", m.Name, status)
|
||||
body := fmt.Sprintf("Monitor %s is %s.", m.Name, status)
|
||||
|
||||
if !isRecovery && h.Msg != "" {
|
||||
body = fmt.Sprintf("Monitor %s is %s. Error: %s", m.Name, status, h.Msg)
|
||||
}
|
||||
|
||||
return ¬ification.NotificationMessage{
|
||||
Title: title,
|
||||
Body: body,
|
||||
MonitorName: m.Name,
|
||||
MonitorURL: d.getMonitorURL(m),
|
||||
Status: status,
|
||||
Timestamp: h.Time,
|
||||
Ping: h.Ping,
|
||||
Message: h.Msg,
|
||||
}
|
||||
}
|
||||
|
||||
// getMonitorURL returns the URL or hostname for display
|
||||
func (d *Dispatcher) getMonitorURL(m *monitor.Monitor) string {
|
||||
if m.URL != "" {
|
||||
return m.URL
|
||||
}
|
||||
if m.Hostname != "" {
|
||||
if m.Port > 0 {
|
||||
return fmt.Sprintf("%s:%d", m.Hostname, m.Port)
|
||||
}
|
||||
return m.Hostname
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// logNotificationEvent logs a notification event to the database
|
||||
func (d *Dispatcher) logNotificationEvent(notificationID, monitorID, status, message, errMsg string) {
|
||||
collection, findErr := d.app.FindCollectionByNameOrId("notification_events")
|
||||
if findErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("notification", notificationID)
|
||||
record.Set("monitor", monitorID)
|
||||
record.Set("status", status)
|
||||
record.Set("message", message)
|
||||
record.Set("error", errMsg)
|
||||
|
||||
if saveErr := d.app.Save(record); saveErr != nil {
|
||||
log.Printf("[notification-dispatcher] Failed to log notification event: %v", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCache clears the provider cache (call when settings change)
|
||||
func (d *Dispatcher) ClearCache() {
|
||||
d.mu.Lock()
|
||||
d.providers = make(map[string]notification.Provider)
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
// Check if we need to send notification for this heartbeat
|
||||
func (d *Dispatcher) ShouldNotify(m *monitor.Monitor, heartbeat *monitor.Heartbeat) (bool, bool) {
|
||||
// Check if this is a status change
|
||||
isDown := heartbeat.Status == monitor.StatusDown
|
||||
isRecovery := heartbeat.Status == monitor.StatusUp && m.Status == monitor.StatusDown
|
||||
|
||||
// Only notify on down (after retries) or recovery
|
||||
return isDown && m.Status == monitor.StatusDown, isRecovery
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
type DiscordProvider struct {
|
||||
settings notification.DiscordSettings
|
||||
}
|
||||
|
||||
func NewDiscordProvider(settings notification.DiscordSettings) *DiscordProvider {
|
||||
return &DiscordProvider{settings: settings}
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) Validate() error {
|
||||
if p.settings.WebhookURL == "" {
|
||||
return fmt.Errorf("Discord webhook URL is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
color := 0x00ff00 // Green for UP
|
||||
if msg.Status == "DOWN" {
|
||||
color = 0xff0000 // Red for DOWN
|
||||
}
|
||||
|
||||
embed := map[string]interface{}{
|
||||
"title": msg.Title,
|
||||
"description": msg.Body,
|
||||
"color": color,
|
||||
"timestamp": msg.Timestamp.Format(time.RFC3339),
|
||||
"fields": []map[string]interface{}{
|
||||
{
|
||||
"name": "Monitor",
|
||||
"value": msg.MonitorName,
|
||||
"inline": true,
|
||||
},
|
||||
{
|
||||
"name": "Status",
|
||||
"value": msg.Status,
|
||||
"inline": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if msg.MonitorURL != "" {
|
||||
embed["fields"] = append(embed["fields"].([]map[string]interface{}), map[string]interface{}{
|
||||
"name": "URL",
|
||||
"value": msg.MonitorURL,
|
||||
"inline": false,
|
||||
})
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"embeds": []map[string]interface{}{embed},
|
||||
}
|
||||
|
||||
if p.settings.Username != "" {
|
||||
payload["username"] = p.settings.Username
|
||||
}
|
||||
if p.settings.AvatarURL != "" {
|
||||
payload["avatar_url"] = p.settings.AvatarURL
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", p.settings.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("discord webhook request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("discord webhook returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
// EmailProvider implements email notifications via SMTP
|
||||
type EmailProvider struct {
|
||||
settings notification.EmailSettings
|
||||
}
|
||||
|
||||
// NewEmailProvider creates a new email provider
|
||||
func NewEmailProvider(settings notification.EmailSettings) *EmailProvider {
|
||||
return &EmailProvider{settings: settings}
|
||||
}
|
||||
|
||||
// Validate checks if the email settings are valid
|
||||
func (p *EmailProvider) Validate() error {
|
||||
if p.settings.SMTPHost == "" {
|
||||
return fmt.Errorf("SMTP host is required")
|
||||
}
|
||||
if p.settings.SMTPPort == 0 {
|
||||
return fmt.Errorf("SMTP port is required")
|
||||
}
|
||||
if p.settings.FromEmail == "" {
|
||||
return fmt.Errorf("from email is required")
|
||||
}
|
||||
if p.settings.ToEmail == "" {
|
||||
return fmt.Errorf("to email is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends an email notification
|
||||
func (p *EmailProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[%s] %s - %s", msg.Status, msg.MonitorName, msg.Title)
|
||||
body := p.formatBody(msg)
|
||||
|
||||
// Build email content
|
||||
email := fmt.Sprintf(
|
||||
"From: %s\r\nTo: %s\r\nSubject: %s\r\nContent-Type: text/plain; charset=UTF-8\r\n\r\n%s",
|
||||
p.settings.FromEmail,
|
||||
p.settings.ToEmail,
|
||||
subject,
|
||||
body,
|
||||
)
|
||||
|
||||
// Connect to SMTP server
|
||||
addr := fmt.Sprintf("%s:%d", p.settings.SMTPHost, p.settings.SMTPPort)
|
||||
|
||||
var auth smtp.Auth
|
||||
if p.settings.SMTPUser != "" {
|
||||
auth = smtp.PlainAuth("", p.settings.SMTPUser, p.settings.SMTPPassword, p.settings.SMTPHost)
|
||||
}
|
||||
|
||||
// Send email
|
||||
if p.settings.UseTLS {
|
||||
return p.sendTLS(addr, auth, email)
|
||||
}
|
||||
|
||||
return smtp.SendMail(
|
||||
addr,
|
||||
auth,
|
||||
p.settings.FromEmail,
|
||||
[]string{p.settings.ToEmail},
|
||||
[]byte(email),
|
||||
)
|
||||
}
|
||||
|
||||
// sendTLS sends email using TLS
|
||||
func (p *EmailProvider) sendTLS(addr string, auth smtp.Auth, email string) error {
|
||||
conn, err := tls.Dial("tcp", addr, &tls.Config{ServerName: p.settings.SMTPHost})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect via TLS: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, p.settings.SMTPHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create SMTP client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if auth != nil {
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.Mail(p.settings.FromEmail); err != nil {
|
||||
return fmt.Errorf("failed to set sender: %w", err)
|
||||
}
|
||||
if err := client.Rcpt(p.settings.ToEmail); err != nil {
|
||||
return fmt.Errorf("failed to set recipient: %w", err)
|
||||
}
|
||||
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get data writer: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Write([]byte(email))
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return fmt.Errorf("failed to write email: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close data writer: %w", err)
|
||||
}
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
// formatBody formats the email body
|
||||
func (p *EmailProvider) formatBody(msg *notification.NotificationMessage) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(fmt.Sprintf("Monitor: %s\n", msg.MonitorName))
|
||||
if msg.MonitorURL != "" {
|
||||
b.WriteString(fmt.Sprintf("URL: %s\n", msg.MonitorURL))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("Status: %s\n", msg.Status))
|
||||
b.WriteString(fmt.Sprintf("Time: %s\n", msg.Timestamp.Format("2006-01-02 15:04:05")))
|
||||
|
||||
if msg.Ping > 0 {
|
||||
b.WriteString(fmt.Sprintf("Response Time: %dms\n", msg.Ping))
|
||||
}
|
||||
|
||||
if msg.Message != "" {
|
||||
b.WriteString(fmt.Sprintf("\nMessage: %s\n", msg.Message))
|
||||
}
|
||||
|
||||
b.WriteString(fmt.Sprintf("\n%s\n", msg.Body))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
type GotifyProvider struct {
|
||||
settings notification.GotifySettings
|
||||
}
|
||||
|
||||
func NewGotifyProvider(settings notification.GotifySettings) *GotifyProvider {
|
||||
return &GotifyProvider{settings: settings}
|
||||
}
|
||||
|
||||
func (p *GotifyProvider) Validate() error {
|
||||
if p.settings.ServerURL == "" {
|
||||
return fmt.Errorf("Gotify server URL is required")
|
||||
}
|
||||
if p.settings.AppToken == "" {
|
||||
return fmt.Errorf("Gotify app token is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *GotifyProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"title": msg.Title,
|
||||
"message": msg.Body,
|
||||
"priority": p.settings.Priority,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/message?token=%s", p.settings.ServerURL, p.settings.AppToken)
|
||||
|
||||
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gotify request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("gotify returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
type PushoverProvider struct {
|
||||
settings notification.PushoverSettings
|
||||
}
|
||||
|
||||
func NewPushoverProvider(settings notification.PushoverSettings) *PushoverProvider {
|
||||
return &PushoverProvider{settings: settings}
|
||||
}
|
||||
|
||||
func (p *PushoverProvider) Validate() error {
|
||||
if p.settings.AppToken == "" {
|
||||
return fmt.Errorf("Pushover app token is required")
|
||||
}
|
||||
if p.settings.UserKey == "" {
|
||||
return fmt.Errorf("Pushover user key is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PushoverProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("token", p.settings.AppToken)
|
||||
data.Set("user", p.settings.UserKey)
|
||||
data.Set("title", msg.Title)
|
||||
data.Set("message", msg.Body)
|
||||
data.Set("priority", fmt.Sprintf("%d", p.settings.Priority))
|
||||
|
||||
if p.settings.Device != "" {
|
||||
data.Set("device", p.settings.Device)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.PostForm("https://api.pushover.net/1/messages.json", data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pushover API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("pushover API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
type SlackProvider struct {
|
||||
settings notification.SlackSettings
|
||||
}
|
||||
|
||||
func NewSlackProvider(settings notification.SlackSettings) *SlackProvider {
|
||||
return &SlackProvider{settings: settings}
|
||||
}
|
||||
|
||||
func (p *SlackProvider) Validate() error {
|
||||
if p.settings.WebhookURL == "" {
|
||||
return fmt.Errorf("Slack webhook URL is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SlackProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
color := "good" // Green for UP
|
||||
if msg.Status == "DOWN" {
|
||||
color = "danger" // Red for DOWN
|
||||
}
|
||||
|
||||
fields := []map[string]string{
|
||||
{
|
||||
"title": "Monitor",
|
||||
"value": msg.MonitorName,
|
||||
"short": "true",
|
||||
},
|
||||
{
|
||||
"title": "Status",
|
||||
"value": msg.Status,
|
||||
"short": "true",
|
||||
},
|
||||
}
|
||||
|
||||
if msg.MonitorURL != "" {
|
||||
fields = append(fields, map[string]string{
|
||||
"title": "URL",
|
||||
"value": msg.MonitorURL,
|
||||
"short": "false",
|
||||
})
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"attachments": []map[string]interface{}{
|
||||
{
|
||||
"color": color,
|
||||
"title": msg.Title,
|
||||
"text": msg.Body,
|
||||
"fields": fields,
|
||||
"timestamp": msg.Timestamp.Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if p.settings.Username != "" {
|
||||
payload["username"] = p.settings.Username
|
||||
}
|
||||
if p.settings.Channel != "" {
|
||||
payload["channel"] = p.settings.Channel
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", p.settings.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("slack webhook request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("slack webhook returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
type TelegramProvider struct {
|
||||
settings notification.TelegramSettings
|
||||
}
|
||||
|
||||
func NewTelegramProvider(settings notification.TelegramSettings) *TelegramProvider {
|
||||
return &TelegramProvider{settings: settings}
|
||||
}
|
||||
|
||||
func (p *TelegramProvider) Validate() error {
|
||||
if p.settings.BotToken == "" {
|
||||
return fmt.Errorf("Telegram bot token is required")
|
||||
}
|
||||
if p.settings.ChatID == "" {
|
||||
return fmt.Errorf("Telegram chat ID is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TelegramProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
icon := "✅"
|
||||
if msg.Status == "DOWN" {
|
||||
icon = "❌"
|
||||
}
|
||||
|
||||
text := fmt.Sprintf("%s *%s*\n\n"+
|
||||
"*Monitor:* %s\n"+
|
||||
"*Status:* %s\n"+
|
||||
"*Time:* %s",
|
||||
icon,
|
||||
msg.Title,
|
||||
msg.MonitorName,
|
||||
msg.Status,
|
||||
msg.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
|
||||
if msg.MonitorURL != "" {
|
||||
text += fmt.Sprintf("\n*URL:* %s", msg.MonitorURL)
|
||||
}
|
||||
|
||||
if msg.Ping > 0 {
|
||||
text += fmt.Sprintf("\n*Response Time:* %dms", msg.Ping)
|
||||
}
|
||||
|
||||
if msg.Message != "" {
|
||||
text += fmt.Sprintf("\n\n*Message:* %s", msg.Message)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", p.settings.BotToken)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("chat_id", p.settings.ChatID)
|
||||
data.Set("text", text)
|
||||
data.Set("parse_mode", "Markdown")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.PostForm(apiURL, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("telegram API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("telegram API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/notification"
|
||||
)
|
||||
|
||||
type WebhookProvider struct {
|
||||
settings notification.WebhookSettings
|
||||
}
|
||||
|
||||
func NewWebhookProvider(settings notification.WebhookSettings) *WebhookProvider {
|
||||
return &WebhookProvider{settings: settings}
|
||||
}
|
||||
|
||||
func (p *WebhookProvider) Validate() error {
|
||||
if p.settings.URL == "" {
|
||||
return fmt.Errorf("webhook URL is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WebhookProvider) Send(msg *notification.NotificationMessage) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := p.settings.Method
|
||||
if method == "" {
|
||||
method = "POST"
|
||||
}
|
||||
|
||||
body := p.formatBody(msg)
|
||||
req, err := http.NewRequest(method, p.settings.URL, bytes.NewBufferString(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range p.settings.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webhook request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("webhook returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WebhookProvider) formatBody(msg *notification.NotificationMessage) string {
|
||||
if p.settings.BodyTemplate != "" {
|
||||
return p.settings.BodyTemplate
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"title": msg.Title,
|
||||
"body": msg.Body,
|
||||
"monitor": msg.MonitorName,
|
||||
"url": msg.MonitorURL,
|
||||
"status": msg.Status,
|
||||
"timestamp": msg.Timestamp,
|
||||
"ping": msg.Ping,
|
||||
"message": msg.Message,
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package notifications
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
webpush "github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// PushNotification represents a push notification message
|
||||
type PushNotification struct {
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Badge string `json:"badge,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Data map[string]string `json:"data,omitempty"`
|
||||
Actions []Action `json:"actions,omitempty"`
|
||||
RequireInteraction bool `json:"requireInteraction,omitempty"`
|
||||
}
|
||||
|
||||
// Action represents a notification action button
|
||||
type Action struct {
|
||||
Action string `json:"action"`
|
||||
Title string `json:"title"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
}
|
||||
|
||||
// PushSubscription represents a browser push subscription
|
||||
type PushSubscription struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
UserID string `json:"user" db:"user"`
|
||||
Endpoint string `json:"endpoint" db:"endpoint"`
|
||||
P256dh string `json:"p256dh" db:"p256dh"`
|
||||
Auth string `json:"auth" db:"auth"`
|
||||
Created time.Time `json:"created" db:"created"`
|
||||
}
|
||||
|
||||
// PushService handles push notifications
|
||||
type PushService struct {
|
||||
app core.App
|
||||
vapidPriv string
|
||||
vapidPub string
|
||||
}
|
||||
|
||||
// NewPushService creates a new push notification service
|
||||
func NewPushService(app core.App) *PushService {
|
||||
// Generate or load VAPID keys
|
||||
// In production, load from BESZEL_VAPID_PRIVATE_KEY env var
|
||||
privKey, pubKey := generateVAPIDKeys()
|
||||
|
||||
return &PushService{
|
||||
app: app,
|
||||
vapidPriv: privKey,
|
||||
vapidPub: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterSubscription registers a push subscription for a user
|
||||
func (s *PushService) RegisterSubscription(userID string, sub *webpush.Subscription) error {
|
||||
collection, err := s.app.FindCollectionByNameOrId("push_subscriptions")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if subscription already exists
|
||||
existing, _ := s.app.FindFirstRecordByFilter("push_subscriptions",
|
||||
"user = {:user} && endpoint = {:endpoint}",
|
||||
map[string]interface{}{"user": userID, "endpoint": sub.Endpoint})
|
||||
|
||||
if existing != nil {
|
||||
// Update existing
|
||||
existing.Set("p256dh", sub.Keys.P256dh)
|
||||
existing.Set("auth", sub.Keys.Auth)
|
||||
return s.app.Save(existing)
|
||||
}
|
||||
|
||||
// Create new subscription
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("user", userID)
|
||||
record.Set("endpoint", sub.Endpoint)
|
||||
record.Set("p256dh", sub.Keys.P256dh)
|
||||
record.Set("auth", sub.Keys.Auth)
|
||||
record.Set("created", time.Now())
|
||||
|
||||
return s.app.Save(record)
|
||||
}
|
||||
|
||||
// UnregisterSubscription removes a push subscription
|
||||
func (s *PushService) UnregisterSubscription(userID string, endpoint string) error {
|
||||
record, err := s.app.FindFirstRecordByFilter("push_subscriptions",
|
||||
"user = {:user} && endpoint = {:endpoint}",
|
||||
map[string]interface{}{"user": userID, "endpoint": endpoint})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.app.Delete(record)
|
||||
}
|
||||
|
||||
// SendNotification sends a push notification to a user
|
||||
func (s *PushService) SendNotification(userID string, notification *PushNotification) error {
|
||||
// Get all subscriptions for user
|
||||
records, err := s.app.FindAllRecords("push_subscriptions",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": userID}),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
sub := &webpush.Subscription{
|
||||
Endpoint: record.GetString("endpoint"),
|
||||
Keys: webpush.Keys{
|
||||
P256dh: record.GetString("p256dh"),
|
||||
Auth: record.GetString("auth"),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := webpush.SendNotification(payload, sub, &webpush.Options{
|
||||
Subscriber: "beszel@localhost",
|
||||
VAPIDPublicKey: s.vapidPub,
|
||||
VAPIDPrivateKey: s.vapidPriv,
|
||||
TTL: 30,
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but continue trying other subscriptions
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BroadcastNotification sends notification to all users
|
||||
func (s *PushService) BroadcastNotification(notification *PushNotification) error {
|
||||
records, err := s.app.FindAllRecords("push_subscriptions")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
sub := &webpush.Subscription{
|
||||
Endpoint: record.GetString("endpoint"),
|
||||
Keys: webpush.Keys{
|
||||
P256dh: record.GetString("p256dh"),
|
||||
Auth: record.GetString("auth"),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := webpush.SendNotification(payload, sub, &webpush.Options{
|
||||
Subscriber: "beszel@localhost",
|
||||
VAPIDPublicKey: s.vapidPub,
|
||||
VAPIDPrivateKey: s.vapidPriv,
|
||||
TTL: 30,
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateVAPIDKeys generates or loads VAPID keys for web push
|
||||
func generateVAPIDKeys() (privateKey, publicKey string) {
|
||||
// Check for environment variable first
|
||||
if envKey := os.Getenv("BESZEL_VAPID_PRIVATE_KEY"); envKey != "" {
|
||||
// If private key provided, we need to derive public key
|
||||
// For now, return empty public key - will be handled by webpush lib
|
||||
return envKey, ""
|
||||
}
|
||||
|
||||
// Generate new VAPID key pair
|
||||
privKey, pubKey, err := webpush.GenerateVAPIDKeys()
|
||||
if err != nil {
|
||||
// Return empty keys if generation fails
|
||||
return "", ""
|
||||
}
|
||||
|
||||
return privKey, pubKey
|
||||
}
|
||||
|
||||
// GetVAPIDPublicKey returns the VAPID public key for client subscription
|
||||
func (s *PushService) GetVAPIDPublicKey() string {
|
||||
return s.vapidPub
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/henrygd/beszel"
|
||||
"github.com/henrygd/beszel/internal/hub/utils"
|
||||
)
|
||||
|
||||
// PublicAppInfo defines the structure of the public app information that will be injected into the HTML
|
||||
type PublicAppInfo struct {
|
||||
BASE_PATH string
|
||||
HUB_VERSION string
|
||||
HUB_URL string
|
||||
OAUTH_DISABLE_POPUP bool `json:"OAUTH_DISABLE_POPUP,omitempty"`
|
||||
}
|
||||
|
||||
// modifyIndexHTML injects the public app information into the index.html content
|
||||
func modifyIndexHTML(hub *Hub, html []byte) string {
|
||||
info := getPublicAppInfo(hub)
|
||||
content, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return string(html)
|
||||
}
|
||||
htmlContent := strings.ReplaceAll(string(html), "./", info.BASE_PATH)
|
||||
return strings.Replace(htmlContent, "\"{info}\"", string(content), 1)
|
||||
}
|
||||
|
||||
func getPublicAppInfo(hub *Hub) PublicAppInfo {
|
||||
parsedURL, _ := url.Parse(hub.appURL)
|
||||
info := PublicAppInfo{
|
||||
BASE_PATH: strings.TrimSuffix(parsedURL.Path, "/") + "/",
|
||||
HUB_VERSION: beszel.Version,
|
||||
HUB_URL: hub.appURL,
|
||||
}
|
||||
if val, _ := utils.GetEnv("OAUTH_DISABLE_POPUP"); val == "true" {
|
||||
info.OAUTH_DISABLE_POPUP = true
|
||||
}
|
||||
return info
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
//go:build development
|
||||
|
||||
package hub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
)
|
||||
|
||||
// Wraps http.RoundTripper to modify dev proxy HTML responses
|
||||
type responseModifier struct {
|
||||
transport http.RoundTripper
|
||||
hub *Hub
|
||||
}
|
||||
|
||||
func (rm *responseModifier) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := rm.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
// Only modify HTML responses
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !strings.Contains(contentType, "text/html") {
|
||||
return resp, nil
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
// Create a new response with the modified body
|
||||
modifiedBody := modifyIndexHTML(rm.hub, body)
|
||||
resp.Body = io.NopCloser(strings.NewReader(modifiedBody))
|
||||
resp.ContentLength = int64(len(modifiedBody))
|
||||
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// startServer sets up the development server for Beszel
|
||||
func (h *Hub) startServer(se *core.ServeEvent) error {
|
||||
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: "localhost:5173",
|
||||
})
|
||||
|
||||
proxy.Transport = &responseModifier{
|
||||
transport: http.DefaultTransport,
|
||||
hub: h,
|
||||
}
|
||||
|
||||
se.Router.GET("/{path...}", func(e *core.RequestEvent) error {
|
||||
proxy.ServeHTTP(e.Response, e.Request)
|
||||
return nil
|
||||
})
|
||||
_ = osutils.LaunchURL(h.appURL)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
//go:build !development
|
||||
|
||||
package hub
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/henrygd/beszel/internal/hub/utils"
|
||||
"github.com/henrygd/beszel/internal/site"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// startServer sets up the production server for Beszel
|
||||
func (h *Hub) startServer(se *core.ServeEvent) error {
|
||||
indexFile, _ := fs.ReadFile(site.DistDirFS, "index.html")
|
||||
html := modifyIndexHTML(h, indexFile)
|
||||
// set up static asset serving
|
||||
staticPaths := [2]string{"/static/", "/assets/"}
|
||||
serveStatic := apis.Static(site.DistDirFS, false)
|
||||
// get CSP configuration
|
||||
csp, cspExists := utils.GetEnv("CSP")
|
||||
// add route
|
||||
se.Router.GET("/{path...}", func(e *core.RequestEvent) error {
|
||||
// serve static assets if path is in staticPaths
|
||||
for i := range staticPaths {
|
||||
if strings.Contains(e.Request.URL.Path, staticPaths[i]) {
|
||||
e.Response.Header().Set("Cache-Control", "public, max-age=2592000")
|
||||
return serveStatic(e)
|
||||
}
|
||||
}
|
||||
if cspExists {
|
||||
e.Response.Header().Del("X-Frame-Options")
|
||||
e.Response.Header().Set("Content-Security-Policy", csp)
|
||||
}
|
||||
return e.HTML(http.StatusOK, html)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
package statuspages
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/monitor"
|
||||
"github.com/henrygd/beszel/internal/entities/statuspage"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// APIHandler handles status page API requests
|
||||
type APIHandler struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
// NewAPIHandler creates a new status page API handler
|
||||
func NewAPIHandler(app core.App) *APIHandler {
|
||||
return &APIHandler{app: app}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers status page API routes
|
||||
func (h *APIHandler) RegisterRoutes(se *core.ServeEvent) {
|
||||
// Public status page (no auth required)
|
||||
se.Router.GET("/status/:slug", h.getPublicStatusPage)
|
||||
|
||||
// Protected routes
|
||||
api := se.Router.Group("/api/beszel/status-pages")
|
||||
api.Bind(apis.RequireAuth())
|
||||
|
||||
api.GET("/", h.listStatusPages)
|
||||
api.POST("/", h.createStatusPage)
|
||||
api.GET("/{id}", h.getStatusPage)
|
||||
api.PATCH("/{id}", h.updateStatusPage)
|
||||
api.DELETE("/{id}", h.deleteStatusPage)
|
||||
api.POST("/{id}/monitors", h.addMonitor)
|
||||
api.DELETE("/{id}/monitors/{monitorId}", h.removeMonitor)
|
||||
api.GET("/{id}/monitors", h.listMonitors)
|
||||
}
|
||||
|
||||
// listStatusPages lists all status pages for the authenticated user
|
||||
func (h *APIHandler) listStatusPages(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("status_pages",
|
||||
dbx.NewExp("user = {:user}", dbx.Params{"user": authRecord.Id}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch status pages", err)
|
||||
}
|
||||
|
||||
pages := make([]statuspage.StatusPageResponse, 0, len(records))
|
||||
for _, record := range records {
|
||||
pages = append(pages, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, pages)
|
||||
}
|
||||
|
||||
// createStatusPage creates a new status page
|
||||
func (h *APIHandler) createStatusPage(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
var req statuspage.CreateStatusPageRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Name == "" || req.Slug == "" {
|
||||
return e.BadRequestError("name and slug are required", nil)
|
||||
}
|
||||
|
||||
// Check if slug is unique
|
||||
existing, _ := h.app.FindFirstRecordByFilter("status_pages", "slug = {:slug}",
|
||||
dbx.Params{"slug": req.Slug})
|
||||
if existing != nil {
|
||||
return e.BadRequestError("slug already exists", nil)
|
||||
}
|
||||
|
||||
collection, err := h.app.FindCollectionByNameOrId("status_pages")
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to find collection", err)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("name", req.Name)
|
||||
record.Set("slug", strings.ToLower(req.Slug))
|
||||
record.Set("title", req.Title)
|
||||
record.Set("description", req.Description)
|
||||
record.Set("logo", req.Logo)
|
||||
record.Set("favicon", req.Favicon)
|
||||
record.Set("theme", statuspage.ValidateTheme(req.Theme))
|
||||
record.Set("custom_css", req.CustomCSS)
|
||||
record.Set("public", req.Public)
|
||||
record.Set("show_uptime", req.ShowUptime)
|
||||
record.Set("user", authRecord.Id)
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to create status page", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusCreated, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// getStatusPage gets a single status page
|
||||
func (h *APIHandler) getStatusPage(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("status_pages", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("status page not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// updateStatusPage updates a status page
|
||||
func (h *APIHandler) updateStatusPage(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("status_pages", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("status page not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req statuspage.UpdateStatusPageRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
if req.Name != nil {
|
||||
record.Set("name", *req.Name)
|
||||
}
|
||||
if req.Title != nil {
|
||||
record.Set("title", *req.Title)
|
||||
}
|
||||
if req.Description != nil {
|
||||
record.Set("description", *req.Description)
|
||||
}
|
||||
if req.Logo != nil {
|
||||
record.Set("logo", *req.Logo)
|
||||
}
|
||||
if req.Favicon != nil {
|
||||
record.Set("favicon", *req.Favicon)
|
||||
}
|
||||
if req.Theme != nil {
|
||||
record.Set("theme", statuspage.ValidateTheme(*req.Theme))
|
||||
}
|
||||
if req.CustomCSS != nil {
|
||||
record.Set("custom_css", *req.CustomCSS)
|
||||
}
|
||||
if req.Public != nil {
|
||||
record.Set("public", *req.Public)
|
||||
}
|
||||
if req.ShowUptime != nil {
|
||||
record.Set("show_uptime", *req.ShowUptime)
|
||||
}
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to update status page", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, h.recordToResponse(record))
|
||||
}
|
||||
|
||||
// deleteStatusPage deletes a status page
|
||||
func (h *APIHandler) deleteStatusPage(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
id := e.Request.PathValue("id")
|
||||
record, err := h.app.FindRecordById("status_pages", id)
|
||||
if err != nil {
|
||||
return e.NotFoundError("status page not found", err)
|
||||
}
|
||||
|
||||
if record.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
if err := h.app.Delete(record); err != nil {
|
||||
return e.InternalServerError("failed to delete status page", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// addMonitor adds a monitor to a status page
|
||||
func (h *APIHandler) addMonitor(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
statusPageID := e.Request.PathValue("id")
|
||||
statusPage, err := h.app.FindRecordById("status_pages", statusPageID)
|
||||
if err != nil {
|
||||
return e.NotFoundError("status page not found", err)
|
||||
}
|
||||
|
||||
if statusPage.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
var req statuspage.StatusPageMonitorRequest
|
||||
if err := json.NewDecoder(e.Request.Body).Decode(&req); err != nil {
|
||||
return e.BadRequestError("invalid request body", err)
|
||||
}
|
||||
|
||||
// Verify monitor exists and belongs to user
|
||||
monitorRecord, err := h.app.FindRecordById("monitors", req.MonitorID)
|
||||
if err != nil {
|
||||
return e.NotFoundError("monitor not found", err)
|
||||
}
|
||||
|
||||
if monitorRecord.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized for this monitor", nil)
|
||||
}
|
||||
|
||||
collection, err := h.app.FindCollectionByNameOrId("status_page_monitors")
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to find collection", err)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("status_page", statusPageID)
|
||||
record.Set("monitor", req.MonitorID)
|
||||
record.Set("display_name", req.DisplayName)
|
||||
record.Set("group", req.Group)
|
||||
record.Set("sort_order", req.SortOrder)
|
||||
record.Set("user", authRecord.Id)
|
||||
|
||||
if err := h.app.Save(record); err != nil {
|
||||
return e.InternalServerError("failed to add monitor", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusCreated, map[string]string{"status": "added"})
|
||||
}
|
||||
|
||||
// removeMonitor removes a monitor from a status page
|
||||
func (h *APIHandler) removeMonitor(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
statusPageID := e.Request.PathValue("id")
|
||||
monitorID := e.Request.PathValue("monitorId")
|
||||
|
||||
// Find the link record
|
||||
records, err := h.app.FindAllRecords("status_page_monitors",
|
||||
dbx.HashExp{
|
||||
"status_page": statusPageID,
|
||||
"monitor": monitorID,
|
||||
"user": authRecord.Id,
|
||||
},
|
||||
)
|
||||
if err != nil || len(records) == 0 {
|
||||
return e.NotFoundError("monitor link not found", err)
|
||||
}
|
||||
|
||||
if err := h.app.Delete(records[0]); err != nil {
|
||||
return e.InternalServerError("failed to remove monitor", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// listMonitors lists monitors on a status page
|
||||
func (h *APIHandler) listMonitors(e *core.RequestEvent) error {
|
||||
authRecord := e.Auth
|
||||
if authRecord == nil {
|
||||
return e.UnauthorizedError("unauthorized", nil)
|
||||
}
|
||||
|
||||
statusPageID := e.Request.PathValue("id")
|
||||
statusPage, err := h.app.FindRecordById("status_pages", statusPageID)
|
||||
if err != nil {
|
||||
return e.NotFoundError("status page not found", err)
|
||||
}
|
||||
|
||||
if statusPage.GetString("user") != authRecord.Id {
|
||||
return e.ForbiddenError("not authorized", nil)
|
||||
}
|
||||
|
||||
records, err := h.app.FindAllRecords("status_page_monitors",
|
||||
dbx.NewExp("status_page = {:statusPage}", dbx.Params{"statusPage": statusPageID}),
|
||||
)
|
||||
if err != nil {
|
||||
return e.InternalServerError("failed to fetch monitors", err)
|
||||
}
|
||||
|
||||
monitors := make([]map[string]interface{}, 0, len(records))
|
||||
for _, record := range records {
|
||||
monitors = append(monitors, map[string]interface{}{
|
||||
"id": record.Id,
|
||||
"monitor_id": record.GetString("monitor"),
|
||||
"display_name": record.GetString("display_name"),
|
||||
"group": record.GetString("group"),
|
||||
"sort_order": record.GetInt("sort_order"),
|
||||
})
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, monitors)
|
||||
}
|
||||
|
||||
// getPublicStatusPage gets a public status page by slug
|
||||
func (h *APIHandler) getPublicStatusPage(e *core.RequestEvent) error {
|
||||
slug := e.Request.PathValue("slug")
|
||||
|
||||
record, err := h.app.FindFirstRecordByFilter("status_pages", "slug = {:slug} && public = true",
|
||||
dbx.Params{"slug": slug})
|
||||
if err != nil {
|
||||
return e.NotFoundError("status page not found", err)
|
||||
}
|
||||
|
||||
// Build public status page
|
||||
publicPage := h.buildPublicStatusPage(record)
|
||||
|
||||
return e.JSON(http.StatusOK, publicPage)
|
||||
}
|
||||
|
||||
// buildPublicStatusPage builds a public status page from a record
|
||||
func (h *APIHandler) buildPublicStatusPage(record *core.Record) *statuspage.PublicStatusPage {
|
||||
statusPageID := record.Id
|
||||
|
||||
// Get linked monitors
|
||||
links, err := h.app.FindAllRecords("status_page_monitors",
|
||||
dbx.NewExp("status_page = {:statusPage}", dbx.Params{"statusPage": statusPageID}),
|
||||
)
|
||||
if err != nil {
|
||||
links = []*core.Record{}
|
||||
}
|
||||
|
||||
publicMonitors := make([]statuspage.PublicMonitorStatus, 0, len(links))
|
||||
overallStatus := statuspage.StatusOperational
|
||||
|
||||
for _, link := range links {
|
||||
monitorID := link.GetString("monitor")
|
||||
monitorRecord, err := h.app.FindRecordById("monitors", monitorID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
status := monitorRecord.GetString("status")
|
||||
if status == string(monitor.StatusDown) && overallStatus == statuspage.StatusOperational {
|
||||
overallStatus = statuspage.StatusMajor
|
||||
}
|
||||
|
||||
uptimeStats := make(map[string]float64)
|
||||
if statsJSON := monitorRecord.GetString("uptime_stats"); statsJSON != "" {
|
||||
json.Unmarshal([]byte(statsJSON), &uptimeStats)
|
||||
}
|
||||
|
||||
publicMonitors = append(publicMonitors, statuspage.PublicMonitorStatus{
|
||||
ID: monitorID,
|
||||
Name: monitorRecord.GetString("name"),
|
||||
DisplayName: link.GetString("display_name"),
|
||||
Group: link.GetString("group"),
|
||||
Status: status,
|
||||
Uptime24h: uptimeStats["24h"],
|
||||
Uptime7d: uptimeStats["7d"],
|
||||
Uptime30d: uptimeStats["30d"],
|
||||
LastCheck: monitorRecord.GetDateTime("last_check").Time(),
|
||||
})
|
||||
}
|
||||
|
||||
return &statuspage.PublicStatusPage{
|
||||
ID: record.Id,
|
||||
Name: record.GetString("name"),
|
||||
Title: record.GetString("title"),
|
||||
Description: record.GetString("description"),
|
||||
Logo: record.GetString("logo"),
|
||||
Favicon: record.GetString("favicon"),
|
||||
Theme: record.GetString("theme"),
|
||||
CustomCSS: record.GetString("custom_css"),
|
||||
Monitors: publicMonitors,
|
||||
OverallStatus: overallStatus,
|
||||
UpdatedAt: record.GetDateTime("updated").Time(),
|
||||
}
|
||||
}
|
||||
|
||||
// recordToResponse converts a record to a response
|
||||
func (h *APIHandler) recordToResponse(record *core.Record) statuspage.StatusPageResponse {
|
||||
// Count monitors
|
||||
count := 0
|
||||
links, _ := h.app.FindAllRecords("status_page_monitors",
|
||||
dbx.NewExp("status_page = {:statusPage}", dbx.Params{"statusPage": record.Id}),
|
||||
)
|
||||
count = len(links)
|
||||
|
||||
return statuspage.StatusPageResponse{
|
||||
ID: record.Id,
|
||||
Name: record.GetString("name"),
|
||||
Slug: record.GetString("slug"),
|
||||
Title: record.GetString("title"),
|
||||
Description: record.GetString("description"),
|
||||
Logo: record.GetString("logo"),
|
||||
Favicon: record.GetString("favicon"),
|
||||
Theme: record.GetString("theme"),
|
||||
Public: record.GetBool("public"),
|
||||
ShowUptime: record.GetBool("show_uptime"),
|
||||
MonitorCount: count,
|
||||
Created: record.GetDateTime("created").String(),
|
||||
Updated: record.GetDateTime("updated").String(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,785 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/henrygd/beszel/internal/hub/transport"
|
||||
"github.com/henrygd/beszel/internal/hub/utils"
|
||||
"github.com/henrygd/beszel/internal/hub/ws"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/container"
|
||||
"github.com/henrygd/beszel/internal/entities/smart"
|
||||
"github.com/henrygd/beszel/internal/entities/system"
|
||||
"github.com/henrygd/beszel/internal/entities/systemd"
|
||||
|
||||
"github.com/henrygd/beszel"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/lxzan/gws"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type System struct {
|
||||
Id string `db:"id"`
|
||||
Host string `db:"host"`
|
||||
Port string `db:"port"`
|
||||
Status string `db:"status"`
|
||||
manager *SystemManager // Manager that this system belongs to
|
||||
client *ssh.Client // SSH client for fetching data
|
||||
sshTransport *transport.SSHTransport // SSH transport for requests
|
||||
data *system.CombinedData // system data from agent
|
||||
ctx context.Context // Context for stopping the updater
|
||||
cancel context.CancelFunc // Stops and removes system from updater
|
||||
WsConn *ws.WsConn // Handler for agent WebSocket connection
|
||||
agentVersion semver.Version // Agent version
|
||||
updateTicker *time.Ticker // Ticker for updating the system
|
||||
detailsFetched atomic.Bool // True if static system details have been fetched and saved
|
||||
smartFetching atomic.Bool // True if SMART devices are currently being fetched
|
||||
smartInterval time.Duration // Interval for periodic SMART data updates
|
||||
}
|
||||
|
||||
func (sm *SystemManager) NewSystem(systemId string) *System {
|
||||
system := &System{
|
||||
Id: systemId,
|
||||
data: &system.CombinedData{},
|
||||
}
|
||||
system.ctx, system.cancel = system.getContext()
|
||||
return system
|
||||
}
|
||||
|
||||
// StartUpdater starts the system updater.
|
||||
// It first fetches the data from the agent then updates the records.
|
||||
// If the data is not found or the system is down, it sets the system down.
|
||||
func (sys *System) StartUpdater() {
|
||||
// Channel that can be used to set the system down. Currently only used to
|
||||
// allow a short delay for reconnection after websocket connection is closed.
|
||||
var downChan chan struct{}
|
||||
|
||||
// Add random jitter to first WebSocket connection to prevent
|
||||
// clustering if all agents are started at the same time.
|
||||
// SSH connections during hub startup are already staggered.
|
||||
var jitter <-chan time.Time
|
||||
if sys.WsConn != nil {
|
||||
jitter = getJitter()
|
||||
// use the websocket connection's down channel to set the system down
|
||||
downChan = sys.WsConn.DownChan
|
||||
} else {
|
||||
// if the system does not have a websocket connection, wait before updating
|
||||
// to allow the agent to connect via websocket (makes sure fingerprint is set).
|
||||
time.Sleep(11 * time.Second)
|
||||
}
|
||||
|
||||
// update immediately if system is not paused (only for ws connections)
|
||||
// we'll wait a minute before connecting via SSH to prioritize ws connections
|
||||
if sys.Status != paused && sys.ctx.Err() == nil {
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
}
|
||||
|
||||
sys.updateTicker = time.NewTicker(time.Duration(interval) * time.Millisecond)
|
||||
// Go 1.23+ will automatically stop the ticker when the system is garbage collected, however we seem to need this or testing/synctest will block even if calling runtime.GC()
|
||||
defer sys.updateTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sys.ctx.Done():
|
||||
return
|
||||
case <-sys.updateTicker.C:
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
case <-downChan:
|
||||
sys.WsConn = nil
|
||||
downChan = nil
|
||||
_ = sys.setDown(nil)
|
||||
case <-jitter:
|
||||
sys.updateTicker.Reset(time.Duration(interval) * time.Millisecond)
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update updates the system data and records.
|
||||
func (sys *System) update() error {
|
||||
if sys.Status == paused {
|
||||
sys.handlePaused()
|
||||
return nil
|
||||
}
|
||||
options := common.DataRequestOptions{
|
||||
CacheTimeMs: uint16(interval),
|
||||
}
|
||||
// fetch system details if not already fetched
|
||||
if !sys.detailsFetched.Load() {
|
||||
options.IncludeDetails = true
|
||||
}
|
||||
|
||||
data, err := sys.fetchDataFromAgent(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure deprecated fields from older agents are migrated to current fields
|
||||
migrateDeprecatedFields(data, !sys.detailsFetched.Load())
|
||||
|
||||
// create system records
|
||||
_, err = sys.createRecords(data)
|
||||
|
||||
// if details were included and fetched successfully, mark details as fetched and update smart interval if set by agent
|
||||
if err == nil && data.Details != nil {
|
||||
sys.detailsFetched.Store(true)
|
||||
// update smart interval if it's set on the agent side
|
||||
if data.Details.SmartInterval > 0 {
|
||||
sys.smartInterval = data.Details.SmartInterval
|
||||
sys.manager.hub.Logger().Info("SMART interval updated from agent details", "system", sys.Id, "interval", sys.smartInterval.String())
|
||||
// make sure we reset expiration of lastFetch to remain as long as the new smart interval
|
||||
// to prevent premature expiration leading to new fetch if interval is different.
|
||||
sys.manager.smartFetchMap.UpdateExpiration(sys.Id, sys.smartInterval+time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch and save SMART devices when system first comes online or at intervals
|
||||
if backgroundSmartFetchEnabled() && sys.detailsFetched.Load() {
|
||||
if sys.smartInterval <= 0 {
|
||||
sys.smartInterval = time.Hour
|
||||
}
|
||||
if sys.shouldFetchSmart() && sys.smartFetching.CompareAndSwap(false, true) {
|
||||
sys.manager.hub.Logger().Info("SMART fetch", "system", sys.Id, "interval", sys.smartInterval.String())
|
||||
go func() {
|
||||
defer sys.smartFetching.Store(false)
|
||||
_ = sys.FetchAndSaveSmartDevices()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (sys *System) handlePaused() {
|
||||
if sys.WsConn == nil {
|
||||
// if the system is paused and there's no websocket connection, remove the system
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
} else {
|
||||
// Send a ping to the agent to keep the connection alive if the system is paused
|
||||
if err := sys.WsConn.Ping(); err != nil {
|
||||
sys.manager.hub.Logger().Warn("Failed to ping agent", "system", sys.Id, "err", err)
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createRecords updates the system record and adds system_stats and container_stats records
|
||||
func (sys *System) createRecords(data *system.CombinedData) (*core.Record, error) {
|
||||
systemRecord, err := sys.getRecord(sys.manager.hub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hub := sys.manager.hub
|
||||
err = hub.RunInTransaction(func(txApp core.App) error {
|
||||
// add system_stats record
|
||||
systemStatsCollection, err := txApp.FindCachedCollectionByNameOrId("system_stats")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
systemStatsRecord := core.NewRecord(systemStatsCollection)
|
||||
systemStatsRecord.Set("system", systemRecord.Id)
|
||||
systemStatsRecord.Set("stats", data.Stats)
|
||||
systemStatsRecord.Set("type", "1m")
|
||||
if err := txApp.SaveNoValidate(systemStatsRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add containers and container_stats records
|
||||
if len(data.Containers) > 0 {
|
||||
if data.Containers[0].Id != "" {
|
||||
if err := createContainerRecords(txApp, data.Containers, sys.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
containerStatsCollection, err := txApp.FindCachedCollectionByNameOrId("container_stats")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
containerStatsRecord := core.NewRecord(containerStatsCollection)
|
||||
containerStatsRecord.Set("system", systemRecord.Id)
|
||||
containerStatsRecord.Set("stats", data.Containers)
|
||||
containerStatsRecord.Set("type", "1m")
|
||||
if err := txApp.SaveNoValidate(containerStatsRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// add new systemd_stats record
|
||||
if len(data.SystemdServices) > 0 {
|
||||
if err := createSystemdStatsRecords(txApp, data.SystemdServices, sys.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// add system details record
|
||||
if data.Details != nil {
|
||||
if err := createSystemDetailsRecord(txApp, data.Details, sys.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// update system record (do this last because it triggers alerts and we need above records to be inserted first)
|
||||
systemRecord.Set("status", up)
|
||||
systemRecord.Set("info", data.Info)
|
||||
if err := txApp.SaveNoValidate(systemRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return systemRecord, err
|
||||
}
|
||||
|
||||
func createSystemDetailsRecord(app core.App, data *system.Details, systemId string) error {
|
||||
collectionName := "system_details"
|
||||
params := dbx.Params{
|
||||
"id": systemId,
|
||||
"system": systemId,
|
||||
"hostname": data.Hostname,
|
||||
"kernel": data.Kernel,
|
||||
"cores": data.Cores,
|
||||
"threads": data.Threads,
|
||||
"cpu": data.CpuModel,
|
||||
"os": data.Os,
|
||||
"os_name": data.OsName,
|
||||
"arch": data.Arch,
|
||||
"memory": data.MemoryTotal,
|
||||
"podman": data.Podman,
|
||||
"updated": time.Now().UTC(),
|
||||
}
|
||||
result, err := app.DB().Update(collectionName, params, dbx.HashExp{"id": systemId}).Execute()
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if err != nil || rowsAffected == 0 {
|
||||
_, err = app.DB().Insert(collectionName, params).Execute()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func createSystemdStatsRecords(app core.App, data []*systemd.Service, systemId string) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
// shared params for all records
|
||||
params := dbx.Params{
|
||||
"system": systemId,
|
||||
"updated": time.Now().UTC().UnixMilli(),
|
||||
}
|
||||
|
||||
valueStrings := make([]string, 0, len(data))
|
||||
for i, service := range data {
|
||||
suffix := fmt.Sprintf("%d", i)
|
||||
valueStrings = append(valueStrings, fmt.Sprintf("({:id%[1]s}, {:system}, {:name%[1]s}, {:state%[1]s}, {:sub%[1]s}, {:cpu%[1]s}, {:cpuPeak%[1]s}, {:memory%[1]s}, {:memPeak%[1]s}, {:updated})", suffix))
|
||||
params["id"+suffix] = makeStableHashId(systemId, service.Name)
|
||||
params["name"+suffix] = service.Name
|
||||
params["state"+suffix] = service.State
|
||||
params["sub"+suffix] = service.Sub
|
||||
params["cpu"+suffix] = service.Cpu
|
||||
params["cpuPeak"+suffix] = service.CpuPeak
|
||||
params["memory"+suffix] = service.Mem
|
||||
params["memPeak"+suffix] = service.MemPeak
|
||||
}
|
||||
queryString := fmt.Sprintf(
|
||||
"INSERT INTO systemd_services (id, system, name, state, sub, cpu, cpuPeak, memory, memPeak, updated) VALUES %s ON CONFLICT(id) DO UPDATE SET system = excluded.system, name = excluded.name, state = excluded.state, sub = excluded.sub, cpu = excluded.cpu, cpuPeak = excluded.cpuPeak, memory = excluded.memory, memPeak = excluded.memPeak, updated = excluded.updated",
|
||||
strings.Join(valueStrings, ","),
|
||||
)
|
||||
_, err := app.DB().NewQuery(queryString).Bind(params).Execute()
|
||||
return err
|
||||
}
|
||||
|
||||
// createContainerRecords creates container records
|
||||
func createContainerRecords(app core.App, data []*container.Stats, systemId string) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
// shared params for all records
|
||||
params := dbx.Params{
|
||||
"system": systemId,
|
||||
"updated": time.Now().UTC().UnixMilli(),
|
||||
}
|
||||
valueStrings := make([]string, 0, len(data))
|
||||
for i, container := range data {
|
||||
suffix := fmt.Sprintf("%d", i)
|
||||
valueStrings = append(valueStrings, fmt.Sprintf("({:id%[1]s}, {:system}, {:name%[1]s}, {:image%[1]s}, {:ports%[1]s}, {:status%[1]s}, {:health%[1]s}, {:cpu%[1]s}, {:memory%[1]s}, {:net%[1]s}, {:updated})", suffix))
|
||||
params["id"+suffix] = container.Id
|
||||
params["name"+suffix] = container.Name
|
||||
params["image"+suffix] = container.Image
|
||||
params["ports"+suffix] = container.Ports
|
||||
params["status"+suffix] = container.Status
|
||||
params["health"+suffix] = container.Health
|
||||
params["cpu"+suffix] = container.Cpu
|
||||
params["memory"+suffix] = container.Mem
|
||||
netBytes := container.Bandwidth[0] + container.Bandwidth[1]
|
||||
if netBytes == 0 {
|
||||
netBytes = uint64((container.NetworkSent + container.NetworkRecv) * 1024 * 1024)
|
||||
}
|
||||
params["net"+suffix] = netBytes
|
||||
}
|
||||
queryString := fmt.Sprintf(
|
||||
"INSERT INTO containers (id, system, name, image, ports, status, health, cpu, memory, net, updated) VALUES %s ON CONFLICT(id) DO UPDATE SET system = excluded.system, name = excluded.name, image = excluded.image, ports = excluded.ports, status = excluded.status, health = excluded.health, cpu = excluded.cpu, memory = excluded.memory, net = excluded.net, updated = excluded.updated",
|
||||
strings.Join(valueStrings, ","),
|
||||
)
|
||||
_, err := app.DB().NewQuery(queryString).Bind(params).Execute()
|
||||
return err
|
||||
}
|
||||
|
||||
// getRecord retrieves the system record from the database.
|
||||
// If the record is not found, it removes the system from the manager.
|
||||
func (sys *System) getRecord(app core.App) (*core.Record, error) {
|
||||
record, err := app.FindRecordById("systems", sys.Id)
|
||||
if err != nil || record == nil {
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// HasUser checks if the given user is in the system's users list.
|
||||
// Returns true if SHARE_ALL_SYSTEMS is enabled (any authenticated user can access any system).
|
||||
func (sys *System) HasUser(app core.App, user *core.Record) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
if v, _ := utils.GetEnv("SHARE_ALL_SYSTEMS"); v == "true" {
|
||||
return true
|
||||
}
|
||||
var recordData = struct {
|
||||
Users string
|
||||
}{}
|
||||
err := app.DB().NewQuery("SELECT users FROM systems WHERE id={:id}").
|
||||
Bind(dbx.Params{"id": sys.Id}).
|
||||
One(&recordData)
|
||||
if err != nil || recordData.Users == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(recordData.Users, user.Id)
|
||||
}
|
||||
|
||||
// setDown marks a system as down in the database.
|
||||
// It takes the original error that caused the system to go down and returns any error
|
||||
// encountered during the process of updating the system status.
|
||||
func (sys *System) setDown(originalError error) error {
|
||||
if sys.Status == down || sys.Status == paused {
|
||||
return nil
|
||||
}
|
||||
record, err := sys.getRecord(sys.manager.hub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if originalError != nil {
|
||||
sys.manager.hub.Logger().Error("System down", "system", record.GetString("name"), "err", originalError)
|
||||
}
|
||||
record.Set("status", down)
|
||||
return sys.manager.hub.SaveNoValidate(record)
|
||||
}
|
||||
|
||||
func (sys *System) getContext() (context.Context, context.CancelFunc) {
|
||||
if sys.ctx == nil {
|
||||
sys.ctx, sys.cancel = context.WithCancel(context.Background())
|
||||
}
|
||||
return sys.ctx, sys.cancel
|
||||
}
|
||||
|
||||
// request sends a request to the agent, trying WebSocket first, then SSH.
|
||||
// This is the unified request method that uses the transport abstraction.
|
||||
func (sys *System) request(ctx context.Context, action common.WebSocketAction, req any, dest any) error {
|
||||
// Try WebSocket first
|
||||
if sys.WsConn != nil && sys.WsConn.IsConnected() {
|
||||
wsTransport := transport.NewWebSocketTransport(sys.WsConn)
|
||||
if err := wsTransport.Request(ctx, action, req, dest); err == nil {
|
||||
return nil
|
||||
} else if !shouldFallbackToSSH(err) {
|
||||
return err
|
||||
} else if shouldCloseWebSocket(err) {
|
||||
sys.closeWebSocketConnection()
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SSH if WebSocket fails
|
||||
if err := sys.ensureSSHTransport(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := sys.sshTransport.RequestWithRetry(ctx, action, req, dest, 1)
|
||||
// Keep legacy SSH client/version fields in sync for other code paths.
|
||||
if sys.sshTransport != nil {
|
||||
sys.client = sys.sshTransport.GetClient()
|
||||
sys.agentVersion = sys.sshTransport.GetAgentVersion()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func shouldFallbackToSSH(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, gws.ErrConnClosed) {
|
||||
return true
|
||||
}
|
||||
return errors.Is(err, transport.ErrWebSocketNotConnected)
|
||||
}
|
||||
|
||||
func shouldCloseWebSocket(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, gws.ErrConnClosed) || errors.Is(err, transport.ErrWebSocketNotConnected)
|
||||
}
|
||||
|
||||
// ensureSSHTransport ensures the SSH transport is initialized and connected.
|
||||
func (sys *System) ensureSSHTransport() error {
|
||||
if sys.sshTransport == nil {
|
||||
if sys.manager.sshConfig == nil {
|
||||
if err := sys.manager.createSSHClientConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
sys.sshTransport = transport.NewSSHTransport(transport.SSHTransportConfig{
|
||||
Host: sys.Host,
|
||||
Port: sys.Port,
|
||||
Config: sys.manager.sshConfig,
|
||||
Timeout: 4 * time.Second,
|
||||
})
|
||||
}
|
||||
// Sync client state with transport
|
||||
if sys.client != nil {
|
||||
sys.sshTransport.SetClient(sys.client)
|
||||
sys.sshTransport.SetAgentVersion(sys.agentVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchDataFromAgent attempts to fetch data from the agent, prioritizing WebSocket if available.
|
||||
func (sys *System) fetchDataFromAgent(options common.DataRequestOptions) (*system.CombinedData, error) {
|
||||
if sys.data == nil {
|
||||
sys.data = &system.CombinedData{}
|
||||
}
|
||||
|
||||
if sys.WsConn != nil && sys.WsConn.IsConnected() {
|
||||
wsData, err := sys.fetchDataViaWebSocket(options)
|
||||
if err == nil {
|
||||
return wsData, nil
|
||||
}
|
||||
// close the WebSocket connection if error and try SSH
|
||||
sys.closeWebSocketConnection()
|
||||
}
|
||||
|
||||
sshData, err := sys.fetchDataViaSSH(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sshData, nil
|
||||
}
|
||||
|
||||
func (sys *System) fetchDataViaWebSocket(options common.DataRequestOptions) (*system.CombinedData, error) {
|
||||
if sys.WsConn == nil || !sys.WsConn.IsConnected() {
|
||||
return nil, errors.New("no websocket connection")
|
||||
}
|
||||
wsTransport := transport.NewWebSocketTransport(sys.WsConn)
|
||||
err := wsTransport.Request(context.Background(), common.GetData, options, sys.data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sys.data, nil
|
||||
}
|
||||
|
||||
// FetchContainerInfoFromAgent fetches container info from the agent
|
||||
func (sys *System) FetchContainerInfoFromAgent(containerID string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
var result string
|
||||
err := sys.request(ctx, common.GetContainerInfo, common.ContainerInfoRequest{ContainerID: containerID}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// FetchContainerLogsFromAgent fetches container logs from the agent
|
||||
func (sys *System) FetchContainerLogsFromAgent(containerID string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
var result string
|
||||
err := sys.request(ctx, common.GetContainerLogs, common.ContainerLogsRequest{ContainerID: containerID}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// FetchSystemdInfoFromAgent fetches detailed systemd service information from the agent
|
||||
func (sys *System) FetchSystemdInfoFromAgent(serviceName string) (systemd.ServiceDetails, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
var result systemd.ServiceDetails
|
||||
err := sys.request(ctx, common.GetSystemdInfo, common.SystemdInfoRequest{ServiceName: serviceName}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// FetchSmartDataFromAgent fetches SMART data from the agent
|
||||
func (sys *System) FetchSmartDataFromAgent() (map[string]smart.SmartData, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
var result map[string]smart.SmartData
|
||||
err := sys.request(ctx, common.GetSmartData, nil, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func makeStableHashId(strings ...string) string {
|
||||
hash := fnv.New32a()
|
||||
for _, str := range strings {
|
||||
hash.Write([]byte(str))
|
||||
}
|
||||
return fmt.Sprintf("%x", hash.Sum32())
|
||||
}
|
||||
|
||||
// fetchDataViaSSH handles fetching data using SSH.
|
||||
// This function encapsulates the original SSH logic.
|
||||
// It updates sys.data directly upon successful fetch.
|
||||
func (sys *System) fetchDataViaSSH(options common.DataRequestOptions) (*system.CombinedData, error) {
|
||||
err := sys.runSSHOperation(4*time.Second, 1, func(session *ssh.Session) (bool, error) {
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
stdin, stdinErr := session.StdinPipe()
|
||||
if err := session.Shell(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
*sys.data = system.CombinedData{}
|
||||
|
||||
if sys.agentVersion.GTE(beszel.MinVersionAgentResponse) && stdinErr == nil {
|
||||
req := common.HubRequest[any]{Action: common.GetData, Data: options}
|
||||
_ = cbor.NewEncoder(stdin).Encode(req)
|
||||
_ = stdin.Close()
|
||||
|
||||
var resp common.AgentResponse
|
||||
if decErr := cbor.NewDecoder(stdout).Decode(&resp); decErr == nil && resp.SystemData != nil {
|
||||
*sys.data = *resp.SystemData
|
||||
if err := session.Wait(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
var decodeErr error
|
||||
if sys.agentVersion.GTE(beszel.MinVersionCbor) {
|
||||
decodeErr = cbor.NewDecoder(stdout).Decode(sys.data)
|
||||
} else {
|
||||
decodeErr = json.NewDecoder(stdout).Decode(sys.data)
|
||||
}
|
||||
|
||||
if decodeErr != nil {
|
||||
return true, decodeErr
|
||||
}
|
||||
|
||||
if err := session.Wait(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sys.data, nil
|
||||
}
|
||||
|
||||
// runSSHOperation establishes an SSH session and executes the provided operation.
|
||||
// The operation can request a retry by returning true as the first return value.
|
||||
func (sys *System) runSSHOperation(timeout time.Duration, retries int, operation func(*ssh.Session) (bool, error)) error {
|
||||
for attempt := 0; attempt <= retries; attempt++ {
|
||||
if sys.client == nil || sys.Status == down {
|
||||
if err := sys.createSSHClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
session, err := sys.createSessionWithTimeout(timeout)
|
||||
if err != nil {
|
||||
if attempt >= retries {
|
||||
return err
|
||||
}
|
||||
sys.manager.hub.Logger().Warn("Session closed. Retrying...", "host", sys.Host, "port", sys.Port, "err", err)
|
||||
sys.closeSSHConnection()
|
||||
continue
|
||||
}
|
||||
|
||||
retry, opErr := func() (bool, error) {
|
||||
defer session.Close()
|
||||
return operation(session)
|
||||
}()
|
||||
|
||||
if opErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if retry {
|
||||
sys.closeSSHConnection()
|
||||
if attempt < retries {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return opErr
|
||||
}
|
||||
|
||||
return fmt.Errorf("ssh operation failed")
|
||||
}
|
||||
|
||||
// createSSHClient creates a new SSH client for the system
|
||||
func (s *System) createSSHClient() error {
|
||||
if s.manager.sshConfig == nil {
|
||||
if err := s.manager.createSSHClientConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
network := "tcp"
|
||||
host := s.Host
|
||||
if strings.HasPrefix(host, "/") {
|
||||
network = "unix"
|
||||
} else {
|
||||
host = net.JoinHostPort(host, s.Port)
|
||||
}
|
||||
var err error
|
||||
s.client, err = ssh.Dial(network, host, s.manager.sshConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.agentVersion, _ = extractAgentVersion(string(s.client.Conn.ServerVersion()))
|
||||
s.manager.resetFailedSmartFetchState(s.Id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSessionWithTimeout creates a new SSH session with a timeout to avoid hanging
|
||||
// in case of network issues
|
||||
func (sys *System) createSessionWithTimeout(timeout time.Duration) (*ssh.Session, error) {
|
||||
if sys.client == nil {
|
||||
return nil, fmt.Errorf("client not initialized")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(sys.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
sessionChan := make(chan *ssh.Session, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if session, err := sys.client.NewSession(); err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
sessionChan <- session
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case session := <-sessionChan:
|
||||
return session, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// closeSSHConnection closes the SSH connection but keeps the system in the manager
|
||||
func (sys *System) closeSSHConnection() {
|
||||
if sys.sshTransport != nil {
|
||||
sys.sshTransport.Close()
|
||||
}
|
||||
if sys.client != nil {
|
||||
sys.client.Close()
|
||||
sys.client = nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeWebSocketConnection closes the WebSocket connection but keeps the system in the manager
|
||||
// to allow updating via SSH. It will be removed if the WS connection is re-established.
|
||||
// The system will be set as down a few seconds later if the connection is not re-established.
|
||||
func (sys *System) closeWebSocketConnection() {
|
||||
if sys.WsConn != nil {
|
||||
sys.WsConn.Close(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// extractAgentVersion extracts the beszel version from SSH server version string
|
||||
func extractAgentVersion(versionString string) (semver.Version, error) {
|
||||
_, after, _ := strings.Cut(versionString, "_")
|
||||
return semver.Parse(after)
|
||||
}
|
||||
|
||||
// getJitter returns a channel that will be triggered after a random delay
|
||||
// between 51% and 95% of the interval.
|
||||
// This is used to stagger the initial WebSocket connections to prevent clustering.
|
||||
func getJitter() <-chan time.Time {
|
||||
minPercent := 51
|
||||
maxPercent := 95
|
||||
jitterRange := maxPercent - minPercent
|
||||
msDelay := (interval * minPercent / 100) + rand.Intn(interval*jitterRange/100)
|
||||
return time.After(time.Duration(msDelay) * time.Millisecond)
|
||||
}
|
||||
|
||||
// migrateDeprecatedFields moves values from deprecated fields to their new locations if the new
|
||||
// fields are not already populated. Deprecated fields and refs may be removed at least 30 days
|
||||
// and one minor version release after the release that includes the migration.
|
||||
//
|
||||
// This is run when processing incoming system data from agents, which may be on older versions.
|
||||
func migrateDeprecatedFields(cd *system.CombinedData, createDetails bool) {
|
||||
// migration added 0.19.0
|
||||
if cd.Stats.Bandwidth[0] == 0 && cd.Stats.Bandwidth[1] == 0 {
|
||||
cd.Stats.Bandwidth[0] = uint64(cd.Stats.NetworkSent * 1024 * 1024)
|
||||
cd.Stats.Bandwidth[1] = uint64(cd.Stats.NetworkRecv * 1024 * 1024)
|
||||
cd.Stats.NetworkSent, cd.Stats.NetworkRecv = 0, 0
|
||||
}
|
||||
// migration added 0.19.0
|
||||
if cd.Info.BandwidthBytes == 0 {
|
||||
cd.Info.BandwidthBytes = uint64(cd.Info.Bandwidth * 1024 * 1024)
|
||||
cd.Info.Bandwidth = 0
|
||||
}
|
||||
// migration added 0.19.0
|
||||
if cd.Stats.DiskIO[0] == 0 && cd.Stats.DiskIO[1] == 0 {
|
||||
cd.Stats.DiskIO[0] = uint64(cd.Stats.DiskReadPs * 1024 * 1024)
|
||||
cd.Stats.DiskIO[1] = uint64(cd.Stats.DiskWritePs * 1024 * 1024)
|
||||
cd.Stats.DiskReadPs, cd.Stats.DiskWritePs = 0, 0
|
||||
}
|
||||
// migration added 0.19.0 - Move deprecated Info fields to Details struct
|
||||
if cd.Details == nil && cd.Info.Hostname != "" {
|
||||
if createDetails {
|
||||
cd.Details = &system.Details{
|
||||
Hostname: cd.Info.Hostname,
|
||||
Kernel: cd.Info.KernelVersion,
|
||||
Cores: cd.Info.Cores,
|
||||
Threads: cd.Info.Threads,
|
||||
CpuModel: cd.Info.CpuModel,
|
||||
Podman: cd.Info.Podman,
|
||||
Os: cd.Info.Os,
|
||||
MemoryTotal: uint64(cd.Stats.Mem * 1024 * 1024 * 1024),
|
||||
}
|
||||
}
|
||||
// zero the deprecated fields to prevent saving them in systems.info DB json payload
|
||||
cd.Info.Hostname = ""
|
||||
cd.Info.KernelVersion = ""
|
||||
cd.Info.Cores = 0
|
||||
cd.Info.CpuModel = ""
|
||||
cd.Info.Podman = false
|
||||
cd.Info.Os = 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/hub/ws"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/system"
|
||||
"github.com/henrygd/beszel/internal/hub/expirymap"
|
||||
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
|
||||
"github.com/henrygd/beszel"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// System status constants
|
||||
const (
|
||||
up string = "up" // System is online and responding
|
||||
down string = "down" // System is offline or not responding
|
||||
paused string = "paused" // System monitoring is paused
|
||||
pending string = "pending" // System is waiting on initial connection result
|
||||
|
||||
// interval is the default update interval in milliseconds (60 seconds)
|
||||
interval int = 60_000
|
||||
// interval int = 10_000 // Debug interval for faster updates
|
||||
|
||||
// sessionTimeout is the maximum time to wait for SSH connections
|
||||
sessionTimeout = 4 * time.Second
|
||||
)
|
||||
|
||||
// errSystemExists is returned when attempting to add a system that already exists
|
||||
var errSystemExists = errors.New("system exists")
|
||||
|
||||
// SystemManager manages a collection of monitored systems and their connections.
|
||||
// It handles system lifecycle, status updates, and maintains both SSH and WebSocket connections.
|
||||
type SystemManager struct {
|
||||
hub hubLike // Hub interface for database and alert operations
|
||||
systems *store.Store[string, *System] // Thread-safe store of active systems
|
||||
sshConfig *ssh.ClientConfig // SSH client configuration for system connections
|
||||
smartFetchMap *expirymap.ExpiryMap[smartFetchState] // Stores last SMART fetch time/result; TTL is only for cleanup
|
||||
}
|
||||
|
||||
// hubLike defines the interface requirements for the hub dependency.
|
||||
// It extends core.App with system-specific functionality.
|
||||
type hubLike interface {
|
||||
core.App
|
||||
GetSSHKey(dataDir string) (ssh.Signer, error)
|
||||
HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error
|
||||
HandleStatusAlerts(status string, systemRecord *core.Record) error
|
||||
CancelPendingStatusAlerts(systemID string)
|
||||
}
|
||||
|
||||
// NewSystemManager creates a new SystemManager instance with the provided hub.
|
||||
// The hub must implement the hubLike interface to provide database and alert functionality.
|
||||
func NewSystemManager(hub hubLike) *SystemManager {
|
||||
return &SystemManager{
|
||||
systems: store.New(map[string]*System{}),
|
||||
hub: hub,
|
||||
smartFetchMap: expirymap.New[smartFetchState](time.Hour),
|
||||
}
|
||||
}
|
||||
|
||||
// GetSystem returns a system by ID from the store
|
||||
func (sm *SystemManager) GetSystem(systemID string) (*System, error) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("system not found")
|
||||
}
|
||||
return sys, nil
|
||||
}
|
||||
|
||||
// Initialize sets up the system manager by binding event hooks and starting existing systems.
|
||||
// It configures SSH client settings and begins monitoring all non-paused systems from the database.
|
||||
// Systems are started with staggered delays to prevent overwhelming the hub during startup.
|
||||
func (sm *SystemManager) Initialize() error {
|
||||
sm.bindEventHooks()
|
||||
|
||||
// Initialize SSH client configuration
|
||||
err := sm.createSSHClientConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load existing systems from database (excluding paused ones)
|
||||
var systems []*System
|
||||
err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems)
|
||||
if err != nil || len(systems) == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start systems in background with staggered timing
|
||||
go func() {
|
||||
// Calculate staggered delay between system starts (max 2 seconds per system)
|
||||
delta := interval / max(1, len(systems))
|
||||
delta = min(delta, 2_000)
|
||||
sleepTime := time.Duration(delta) * time.Millisecond
|
||||
|
||||
for _, system := range systems {
|
||||
time.Sleep(sleepTime)
|
||||
_ = sm.AddSystem(system)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// bindEventHooks registers event handlers for system and fingerprint record changes.
|
||||
// These hooks ensure the system manager stays synchronized with database changes.
|
||||
func (sm *SystemManager) bindEventHooks() {
|
||||
sm.hub.OnRecordCreate("systems").BindFunc(sm.onRecordCreate)
|
||||
sm.hub.OnRecordAfterCreateSuccess("systems").BindFunc(sm.onRecordAfterCreateSuccess)
|
||||
sm.hub.OnRecordUpdate("systems").BindFunc(sm.onRecordUpdate)
|
||||
sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess)
|
||||
sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess)
|
||||
sm.hub.OnRecordAfterUpdateSuccess("fingerprints").BindFunc(sm.onTokenRotated)
|
||||
sm.hub.OnRealtimeSubscribeRequest().BindFunc(sm.onRealtimeSubscribeRequest)
|
||||
sm.hub.OnRealtimeConnectRequest().BindFunc(sm.onRealtimeConnectRequest)
|
||||
}
|
||||
|
||||
// onTokenRotated handles fingerprint token rotation events.
|
||||
// When a system's authentication token is rotated, any existing WebSocket connection
|
||||
// must be closed to force re-authentication with the new token.
|
||||
func (sm *SystemManager) onTokenRotated(e *core.RecordEvent) error {
|
||||
systemID := e.Record.GetString("system")
|
||||
system, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return e.Next()
|
||||
}
|
||||
// No need to close connection if not connected via websocket
|
||||
if system.WsConn == nil {
|
||||
return e.Next()
|
||||
}
|
||||
system.setDown(nil)
|
||||
sm.RemoveSystem(systemID)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordCreate is called before a new system record is committed to the database.
|
||||
// It initializes the record with default values: empty info and pending status.
|
||||
func (sm *SystemManager) onRecordCreate(e *core.RecordEvent) error {
|
||||
e.Record.Set("info", system.Info{})
|
||||
e.Record.Set("status", pending)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordAfterCreateSuccess is called after a new system record is successfully created.
|
||||
// It adds the new system to the manager to begin monitoring.
|
||||
func (sm *SystemManager) onRecordAfterCreateSuccess(e *core.RecordEvent) error {
|
||||
if err := sm.AddRecord(e.Record, nil); err != nil {
|
||||
e.App.Logger().Error("Error adding record", "err", err)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordUpdate is called before a system record is updated in the database.
|
||||
// It clears system info when the status is changed to paused.
|
||||
func (sm *SystemManager) onRecordUpdate(e *core.RecordEvent) error {
|
||||
if e.Record.GetString("status") == paused {
|
||||
e.Record.Set("info", system.Info{})
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordAfterUpdateSuccess handles system record updates after they're committed to the database.
|
||||
// It manages system lifecycle based on status changes and triggers appropriate alerts.
|
||||
// Status transitions are handled as follows:
|
||||
// - paused: Closes SSH connection and deactivates alerts
|
||||
// - pending: Starts monitoring (reuses WebSocket if available)
|
||||
// - up: Triggers system alerts
|
||||
// - down: Triggers status change alerts
|
||||
func (sm *SystemManager) onRecordAfterUpdateSuccess(e *core.RecordEvent) error {
|
||||
newStatus := e.Record.GetString("status")
|
||||
prevStatus := pending
|
||||
system, ok := sm.systems.GetOk(e.Record.Id)
|
||||
if ok {
|
||||
prevStatus = system.Status
|
||||
system.Status = newStatus
|
||||
}
|
||||
|
||||
switch newStatus {
|
||||
case paused:
|
||||
if ok {
|
||||
// Pause monitoring but keep system in manager for potential resume
|
||||
system.closeSSHConnection()
|
||||
}
|
||||
_ = deactivateAlerts(e.App, e.Record.Id)
|
||||
sm.hub.CancelPendingStatusAlerts(e.Record.Id)
|
||||
return e.Next()
|
||||
case pending:
|
||||
// Resume monitoring, preferring existing WebSocket connection
|
||||
if ok && system.WsConn != nil {
|
||||
go system.update()
|
||||
return e.Next()
|
||||
}
|
||||
// Start new monitoring session
|
||||
if err := sm.AddRecord(e.Record, nil); err != nil {
|
||||
e.App.Logger().Error("Error adding record", "err", err)
|
||||
}
|
||||
_ = deactivateAlerts(e.App, e.Record.Id)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Handle systems not in manager
|
||||
if !ok {
|
||||
return sm.AddRecord(e.Record, nil)
|
||||
}
|
||||
|
||||
// Trigger system alerts when system comes online
|
||||
if newStatus == up {
|
||||
if err := sm.hub.HandleSystemAlerts(e.Record, system.data); err != nil {
|
||||
e.App.Logger().Error("Error handling system alerts", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger status change alerts for up/down transitions
|
||||
if (newStatus == down && prevStatus == up) || (newStatus == up && prevStatus == down) {
|
||||
if err := sm.hub.HandleStatusAlerts(newStatus, e.Record); err != nil {
|
||||
e.App.Logger().Error("Error handling status alerts", "err", err)
|
||||
}
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordAfterDeleteSuccess is called after a system record is successfully deleted.
|
||||
// It removes the system from the manager and cleans up all associated resources.
|
||||
func (sm *SystemManager) onRecordAfterDeleteSuccess(e *core.RecordEvent) error {
|
||||
sm.RemoveSystem(e.Record.Id)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// AddSystem adds a system to the manager and starts monitoring it.
|
||||
// It validates required fields, initializes the system context, and starts the update goroutine.
|
||||
// Returns error if a system with the same ID already exists.
|
||||
func (sm *SystemManager) AddSystem(sys *System) error {
|
||||
if sm.systems.Has(sys.Id) {
|
||||
return errSystemExists
|
||||
}
|
||||
if sys.Id == "" || sys.Host == "" {
|
||||
return errors.New("system missing required fields")
|
||||
}
|
||||
|
||||
// Initialize system for monitoring
|
||||
sys.manager = sm
|
||||
sys.ctx, sys.cancel = sys.getContext()
|
||||
sys.data = &system.CombinedData{}
|
||||
sm.systems.Set(sys.Id, sys)
|
||||
|
||||
// Start monitoring in background
|
||||
go sys.StartUpdater()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSystem removes a system from the manager and cleans up all associated resources.
|
||||
// It cancels the system's context, closes all connections, and removes it from the store.
|
||||
// Returns an error if the system is not found.
|
||||
func (sm *SystemManager) RemoveSystem(systemID string) error {
|
||||
system, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return errors.New("system not found")
|
||||
}
|
||||
|
||||
// Stop the update goroutine
|
||||
if system.cancel != nil {
|
||||
system.cancel()
|
||||
}
|
||||
|
||||
// Clean up all connections
|
||||
system.closeSSHConnection()
|
||||
system.closeWebSocketConnection()
|
||||
sm.systems.Remove(systemID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRecord creates a System instance from a database record and adds it to the manager.
|
||||
// If a system with the same ID already exists, it's removed first to ensure clean state.
|
||||
// If no system instance is provided, a new one is created.
|
||||
// This method is typically called when systems are created or their status changes to pending.
|
||||
func (sm *SystemManager) AddRecord(record *core.Record, system *System) (err error) {
|
||||
// Remove existing system to ensure clean state
|
||||
if sm.systems.Has(record.Id) {
|
||||
_ = sm.RemoveSystem(record.Id)
|
||||
}
|
||||
|
||||
// Create new system if none provided
|
||||
if system == nil {
|
||||
system = sm.NewSystem(record.Id)
|
||||
}
|
||||
|
||||
// Populate system from record
|
||||
system.Status = record.GetString("status")
|
||||
system.Host = record.GetString("host")
|
||||
system.Port = record.GetString("port")
|
||||
|
||||
return sm.AddSystem(system)
|
||||
}
|
||||
|
||||
// AddWebSocketSystem creates and adds a system with an established WebSocket connection.
|
||||
// This method is called when an agent connects via WebSocket with valid authentication.
|
||||
// The system is immediately added to monitoring with the provided connection and version info.
|
||||
func (sm *SystemManager) AddWebSocketSystem(systemId string, agentVersion semver.Version, wsConn *ws.WsConn) error {
|
||||
systemRecord, err := sm.hub.FindRecordById("systems", systemId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sm.resetFailedSmartFetchState(systemId)
|
||||
|
||||
system := sm.NewSystem(systemId)
|
||||
system.WsConn = wsConn
|
||||
system.agentVersion = agentVersion
|
||||
|
||||
if err := sm.AddRecord(systemRecord, system); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetFailedSmartFetchState clears only failed SMART cooldown entries so a fresh
|
||||
// agent reconnect retries SMART discovery immediately after configuration changes.
|
||||
func (sm *SystemManager) resetFailedSmartFetchState(systemID string) {
|
||||
state, ok := sm.smartFetchMap.GetOk(systemID)
|
||||
if ok && !state.Successful {
|
||||
sm.smartFetchMap.Remove(systemID)
|
||||
}
|
||||
}
|
||||
|
||||
// createSSHClientConfig initializes the SSH client configuration for connecting to an agent's server
|
||||
func (sm *SystemManager) createSSHClientConfig() error {
|
||||
privateKey, err := sm.hub.GetSSHKey("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm.sshConfig = &ssh.ClientConfig{
|
||||
User: "u",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(privateKey),
|
||||
},
|
||||
Config: ssh.Config{
|
||||
Ciphers: common.DefaultCiphers,
|
||||
KeyExchanges: common.DefaultKeyExchanges,
|
||||
MACs: common.DefaultMACs,
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
ClientVersion: fmt.Sprintf("SSH-2.0-%s_%s", beszel.AppName, beszel.Version),
|
||||
Timeout: sessionTimeout,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deactivateAlerts finds all triggered alerts for a system and sets them to inactive.
|
||||
// This is called when a system is paused or goes offline to prevent continued alerts.
|
||||
func deactivateAlerts(app core.App, systemID string) error {
|
||||
// Note: Direct SQL updates don't trigger SSE, so we use the PocketBase API
|
||||
// _, err := app.DB().NewQuery(fmt.Sprintf("UPDATE alerts SET triggered = false WHERE system = '%s'", systemID)).Execute()
|
||||
|
||||
alerts, err := app.FindRecordsByFilter("alerts", fmt.Sprintf("system = '%s' && triggered = 1", systemID), "", -1, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, alert := range alerts {
|
||||
alert.Set("triggered", false)
|
||||
if err := app.SaveNoValidate(alert); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
type subscriptionInfo struct {
|
||||
subscription string
|
||||
connectedClients uint8
|
||||
}
|
||||
|
||||
var (
|
||||
activeSubscriptions = make(map[string]*subscriptionInfo)
|
||||
workerRunning bool
|
||||
tickerStopChan chan struct{}
|
||||
realtimeMutex sync.Mutex
|
||||
)
|
||||
|
||||
// onRealtimeConnectRequest handles client connection events for realtime subscriptions.
|
||||
// It cleans up existing subscriptions when a client connects.
|
||||
func (sm *SystemManager) onRealtimeConnectRequest(e *core.RealtimeConnectRequestEvent) error {
|
||||
// after e.Next() is the client disconnection
|
||||
e.Next()
|
||||
subscriptions := e.Client.Subscriptions()
|
||||
for k := range subscriptions {
|
||||
sm.removeRealtimeSubscription(k, subscriptions[k])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// onRealtimeSubscribeRequest handles client subscription events for realtime metrics.
|
||||
// It tracks new subscriptions and unsubscriptions to manage the realtime worker lifecycle.
|
||||
func (sm *SystemManager) onRealtimeSubscribeRequest(e *core.RealtimeSubscribeRequestEvent) error {
|
||||
oldSubs := e.Client.Subscriptions()
|
||||
// after e.Next() is the result of the subscribe request
|
||||
err := e.Next()
|
||||
newSubs := e.Client.Subscriptions()
|
||||
|
||||
// handle new subscriptions
|
||||
for k, options := range newSubs {
|
||||
if _, ok := oldSubs[k]; !ok {
|
||||
if strings.HasPrefix(k, "rt_metrics") {
|
||||
systemId := options.Query["system"]
|
||||
if _, ok := activeSubscriptions[systemId]; !ok {
|
||||
activeSubscriptions[systemId] = &subscriptionInfo{
|
||||
subscription: k,
|
||||
}
|
||||
}
|
||||
activeSubscriptions[systemId].connectedClients += 1
|
||||
sm.onRealtimeSubscriptionAdded()
|
||||
}
|
||||
}
|
||||
}
|
||||
// handle unsubscriptions
|
||||
for k := range oldSubs {
|
||||
if _, ok := newSubs[k]; !ok {
|
||||
sm.removeRealtimeSubscription(k, oldSubs[k])
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// onRealtimeSubscriptionAdded initializes or starts the realtime worker when the first subscription is added.
|
||||
// It ensures only one worker runs at a time.
|
||||
func (sm *SystemManager) onRealtimeSubscriptionAdded() {
|
||||
realtimeMutex.Lock()
|
||||
defer realtimeMutex.Unlock()
|
||||
|
||||
// Start the worker if it's not already running
|
||||
if !workerRunning {
|
||||
workerRunning = true
|
||||
// Create a new stop channel for this worker instance
|
||||
tickerStopChan = make(chan struct{})
|
||||
go sm.startRealtimeWorker()
|
||||
}
|
||||
}
|
||||
|
||||
// checkSubscriptions stops the realtime worker when there are no active subscriptions.
|
||||
// This prevents unnecessary resource usage when no clients are listening for realtime data.
|
||||
func (sm *SystemManager) checkSubscriptions() {
|
||||
if !workerRunning || len(activeSubscriptions) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
realtimeMutex.Lock()
|
||||
defer realtimeMutex.Unlock()
|
||||
|
||||
// Signal the worker to stop
|
||||
if tickerStopChan != nil {
|
||||
select {
|
||||
case tickerStopChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Mark worker as stopped (will be reset when next subscription comes in)
|
||||
workerRunning = false
|
||||
}
|
||||
|
||||
// removeRealtimeSubscription removes a realtime subscription and checks if the worker should be stopped.
|
||||
// It only processes subscriptions with the "rt_metrics" prefix and triggers cleanup when subscriptions are removed.
|
||||
func (sm *SystemManager) removeRealtimeSubscription(subscription string, options subscriptions.SubscriptionOptions) {
|
||||
if strings.HasPrefix(subscription, "rt_metrics") {
|
||||
systemId := options.Query["system"]
|
||||
if info, ok := activeSubscriptions[systemId]; ok {
|
||||
info.connectedClients -= 1
|
||||
if info.connectedClients <= 0 {
|
||||
delete(activeSubscriptions, systemId)
|
||||
}
|
||||
}
|
||||
sm.checkSubscriptions()
|
||||
}
|
||||
}
|
||||
|
||||
// startRealtimeWorker runs the main loop for fetching realtime data from agents.
|
||||
// It continuously fetches system data and broadcasts it to subscribed clients via WebSocket.
|
||||
func (sm *SystemManager) startRealtimeWorker() {
|
||||
sm.fetchRealtimeDataAndNotify()
|
||||
tick := time.Tick(1 * time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tickerStopChan:
|
||||
return
|
||||
case <-tick:
|
||||
if len(activeSubscriptions) == 0 {
|
||||
return
|
||||
}
|
||||
sm.fetchRealtimeDataAndNotify()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchRealtimeDataAndNotify fetches realtime data for all active subscriptions and notifies the clients.
|
||||
func (sm *SystemManager) fetchRealtimeDataAndNotify() {
|
||||
for systemId, info := range activeSubscriptions {
|
||||
system, err := sm.GetSystem(systemId)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
data, err := system.fetchDataFromAgent(common.DataRequestOptions{CacheTimeMs: 1000})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
bytes, err := json.Marshal(data)
|
||||
if err == nil {
|
||||
notify(sm.hub, info.subscription, bytes)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// notify broadcasts realtime data to all clients subscribed to a specific subscription.
|
||||
// It iterates through all connected clients and sends the data only to those with matching subscriptions.
|
||||
func notify(app core.App, subscription string, data []byte) error {
|
||||
message := subscriptions.Message{
|
||||
Name: subscription,
|
||||
Data: data,
|
||||
}
|
||||
for _, client := range app.SubscriptionsBroker().Clients() {
|
||||
if !client.HasSubscription(subscription) {
|
||||
continue
|
||||
}
|
||||
client.Send(message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/smart"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
type smartFetchState struct {
|
||||
LastAttempt int64
|
||||
Successful bool
|
||||
}
|
||||
|
||||
// FetchAndSaveSmartDevices fetches SMART data from the agent and saves it to the database
|
||||
func (sys *System) FetchAndSaveSmartDevices() error {
|
||||
smartData, err := sys.FetchSmartDataFromAgent()
|
||||
if err != nil {
|
||||
sys.recordSmartFetchResult(err, 0)
|
||||
return err
|
||||
}
|
||||
err = sys.saveSmartDevices(smartData)
|
||||
sys.recordSmartFetchResult(err, len(smartData))
|
||||
return err
|
||||
}
|
||||
|
||||
// recordSmartFetchResult stores a cooldown entry for the SMART interval and marks
|
||||
// whether the last fetch produced any devices, so failed setup can retry on reconnect.
|
||||
func (sys *System) recordSmartFetchResult(err error, deviceCount int) {
|
||||
if sys.manager == nil {
|
||||
return
|
||||
}
|
||||
interval := sys.smartFetchInterval()
|
||||
success := err == nil && deviceCount > 0
|
||||
if sys.manager.hub != nil {
|
||||
sys.manager.hub.Logger().Info("SMART fetch result", "system", sys.Id, "success", success, "devices", deviceCount, "interval", interval.String(), "err", err)
|
||||
}
|
||||
sys.manager.smartFetchMap.Set(sys.Id, smartFetchState{LastAttempt: time.Now().UnixMilli(), Successful: success}, interval+time.Minute)
|
||||
}
|
||||
|
||||
// shouldFetchSmart returns true when there is no active SMART cooldown entry for this system.
|
||||
func (sys *System) shouldFetchSmart() bool {
|
||||
if sys.manager == nil {
|
||||
return true
|
||||
}
|
||||
state, ok := sys.manager.smartFetchMap.GetOk(sys.Id)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
return !time.UnixMilli(state.LastAttempt).Add(sys.smartFetchInterval()).After(time.Now())
|
||||
}
|
||||
|
||||
// smartFetchInterval returns the agent-provided SMART interval or the default when unset.
|
||||
func (sys *System) smartFetchInterval() time.Duration {
|
||||
if sys.smartInterval > 0 {
|
||||
return sys.smartInterval
|
||||
}
|
||||
return time.Hour
|
||||
}
|
||||
|
||||
// saveSmartDevices saves SMART device data to the smart_devices collection
|
||||
func (sys *System) saveSmartDevices(smartData map[string]smart.SmartData) error {
|
||||
if len(smartData) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
hub := sys.manager.hub
|
||||
collection, err := hub.FindCachedCollectionByNameOrId("smart_devices")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for deviceKey, device := range smartData {
|
||||
if err := sys.upsertSmartDeviceRecord(collection, deviceKey, device); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sys *System) upsertSmartDeviceRecord(collection *core.Collection, deviceKey string, device smart.SmartData) error {
|
||||
hub := sys.manager.hub
|
||||
recordID := makeStableHashId(sys.Id, deviceKey)
|
||||
|
||||
record, err := hub.FindRecordById(collection, recordID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
record = core.NewRecord(collection)
|
||||
record.Set("id", recordID)
|
||||
}
|
||||
|
||||
name := device.DiskName
|
||||
if name == "" {
|
||||
name = deviceKey
|
||||
}
|
||||
|
||||
powerOnHours, powerCycles := extractPowerMetrics(device.Attributes)
|
||||
record.Set("system", sys.Id)
|
||||
record.Set("name", name)
|
||||
record.Set("model", device.ModelName)
|
||||
record.Set("state", device.SmartStatus)
|
||||
record.Set("capacity", device.Capacity)
|
||||
record.Set("temp", device.Temperature)
|
||||
record.Set("firmware", device.FirmwareVersion)
|
||||
record.Set("serial", device.SerialNumber)
|
||||
record.Set("type", device.DiskType)
|
||||
record.Set("hours", powerOnHours)
|
||||
record.Set("cycles", powerCycles)
|
||||
record.Set("attributes", device.Attributes)
|
||||
|
||||
return hub.SaveNoValidate(record)
|
||||
}
|
||||
|
||||
// extractPowerMetrics extracts power on hours and power cycles from SMART attributes
|
||||
func extractPowerMetrics(attributes []*smart.SmartAttribute) (powerOnHours, powerCycles uint64) {
|
||||
for _, attr := range attributes {
|
||||
nameLower := strings.ToLower(attr.Name)
|
||||
if powerOnHours == 0 && (strings.Contains(nameLower, "poweronhours") || strings.Contains(nameLower, "power_on_hours")) {
|
||||
powerOnHours = attr.RawValue
|
||||
}
|
||||
if powerCycles == 0 && ((strings.Contains(nameLower, "power") && strings.Contains(nameLower, "cycle")) || strings.Contains(nameLower, "startstopcycles")) {
|
||||
powerCycles = attr.RawValue
|
||||
}
|
||||
if powerOnHours > 0 && powerCycles > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
//go:build testing
|
||||
|
||||
package systems
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/hub/expirymap"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRecordSmartFetchResult(t *testing.T) {
|
||||
sm := &SystemManager{smartFetchMap: expirymap.New[smartFetchState](time.Hour)}
|
||||
t.Cleanup(sm.smartFetchMap.StopCleaner)
|
||||
|
||||
sys := &System{
|
||||
Id: "system-1",
|
||||
manager: sm,
|
||||
smartInterval: time.Hour,
|
||||
}
|
||||
|
||||
// Successful fetch with devices
|
||||
sys.recordSmartFetchResult(nil, 5)
|
||||
state, ok := sm.smartFetchMap.GetOk(sys.Id)
|
||||
assert.True(t, ok, "expected smart fetch result to be stored")
|
||||
assert.True(t, state.Successful, "expected successful fetch state to be recorded")
|
||||
|
||||
// Failed fetch
|
||||
sys.recordSmartFetchResult(errors.New("failed"), 0)
|
||||
state, ok = sm.smartFetchMap.GetOk(sys.Id)
|
||||
assert.True(t, ok, "expected failed smart fetch state to be stored")
|
||||
assert.False(t, state.Successful, "expected failed smart fetch state to be marked unsuccessful")
|
||||
|
||||
// Successful fetch but no devices
|
||||
sys.recordSmartFetchResult(nil, 0)
|
||||
state, ok = sm.smartFetchMap.GetOk(sys.Id)
|
||||
assert.True(t, ok, "expected fetch with zero devices to be stored")
|
||||
assert.False(t, state.Successful, "expected fetch with zero devices to be marked unsuccessful")
|
||||
}
|
||||
|
||||
func TestShouldFetchSmart(t *testing.T) {
|
||||
sm := &SystemManager{smartFetchMap: expirymap.New[smartFetchState](time.Hour)}
|
||||
t.Cleanup(sm.smartFetchMap.StopCleaner)
|
||||
|
||||
sys := &System{
|
||||
Id: "system-1",
|
||||
manager: sm,
|
||||
smartInterval: time.Hour,
|
||||
}
|
||||
|
||||
assert.True(t, sys.shouldFetchSmart(), "expected initial smart fetch to be allowed")
|
||||
|
||||
sys.recordSmartFetchResult(errors.New("failed"), 0)
|
||||
assert.False(t, sys.shouldFetchSmart(), "expected smart fetch to be blocked while interval entry exists")
|
||||
|
||||
sm.smartFetchMap.Remove(sys.Id)
|
||||
assert.True(t, sys.shouldFetchSmart(), "expected smart fetch to be allowed after interval entry is cleared")
|
||||
}
|
||||
|
||||
func TestShouldFetchSmart_IgnoresExtendedTTLWhenFetchIsDue(t *testing.T) {
|
||||
sm := &SystemManager{smartFetchMap: expirymap.New[smartFetchState](time.Hour)}
|
||||
t.Cleanup(sm.smartFetchMap.StopCleaner)
|
||||
|
||||
sys := &System{
|
||||
Id: "system-1",
|
||||
manager: sm,
|
||||
smartInterval: time.Hour,
|
||||
}
|
||||
|
||||
sm.smartFetchMap.Set(sys.Id, smartFetchState{
|
||||
LastAttempt: time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||
Successful: true,
|
||||
}, 10*time.Minute)
|
||||
sm.smartFetchMap.UpdateExpiration(sys.Id, 3*time.Hour)
|
||||
|
||||
assert.True(t, sys.shouldFetchSmart(), "expected fetch time to take precedence over updated TTL")
|
||||
}
|
||||
|
||||
func TestResetFailedSmartFetchState(t *testing.T) {
|
||||
sm := &SystemManager{smartFetchMap: expirymap.New[smartFetchState](time.Hour)}
|
||||
t.Cleanup(sm.smartFetchMap.StopCleaner)
|
||||
|
||||
sm.smartFetchMap.Set("system-1", smartFetchState{LastAttempt: time.Now().UnixMilli(), Successful: false}, time.Hour)
|
||||
sm.resetFailedSmartFetchState("system-1")
|
||||
_, ok := sm.smartFetchMap.GetOk("system-1")
|
||||
assert.False(t, ok, "expected failed smart fetch state to be cleared on reconnect")
|
||||
|
||||
sm.smartFetchMap.Set("system-1", smartFetchState{LastAttempt: time.Now().UnixMilli(), Successful: true}, time.Hour)
|
||||
sm.resetFailedSmartFetchState("system-1")
|
||||
_, ok = sm.smartFetchMap.GetOk("system-1")
|
||||
assert.True(t, ok, "expected successful smart fetch state to be preserved")
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
//go:build testing
|
||||
|
||||
package systems
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetSystemdServiceId(t *testing.T) {
|
||||
t.Run("deterministic output", func(t *testing.T) {
|
||||
systemId := "sys-123"
|
||||
serviceName := "nginx.service"
|
||||
|
||||
// Call multiple times and ensure same result
|
||||
id1 := makeStableHashId(systemId, serviceName)
|
||||
id2 := makeStableHashId(systemId, serviceName)
|
||||
id3 := makeStableHashId(systemId, serviceName)
|
||||
|
||||
assert.Equal(t, id1, id2)
|
||||
assert.Equal(t, id2, id3)
|
||||
assert.NotEmpty(t, id1)
|
||||
})
|
||||
|
||||
t.Run("different inputs produce different ids", func(t *testing.T) {
|
||||
systemId1 := "sys-123"
|
||||
systemId2 := "sys-456"
|
||||
serviceName1 := "nginx.service"
|
||||
serviceName2 := "apache.service"
|
||||
|
||||
id1 := makeStableHashId(systemId1, serviceName1)
|
||||
id2 := makeStableHashId(systemId2, serviceName1)
|
||||
id3 := makeStableHashId(systemId1, serviceName2)
|
||||
id4 := makeStableHashId(systemId2, serviceName2)
|
||||
|
||||
// All IDs should be different
|
||||
assert.NotEqual(t, id1, id2)
|
||||
assert.NotEqual(t, id1, id3)
|
||||
assert.NotEqual(t, id1, id4)
|
||||
assert.NotEqual(t, id2, id3)
|
||||
assert.NotEqual(t, id2, id4)
|
||||
assert.NotEqual(t, id3, id4)
|
||||
})
|
||||
|
||||
t.Run("consistent length", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
systemId string
|
||||
serviceName string
|
||||
}{
|
||||
{"short", "short.service"},
|
||||
{"very-long-system-id-that-might-be-used-in-practice", "very-long-service-name.service"},
|
||||
{"", "empty-system.service"},
|
||||
{"empty-service", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
id := makeStableHashId(tc.systemId, tc.serviceName)
|
||||
// FNV-32 produces 8 hex characters
|
||||
assert.Len(t, id, 8, "ID should be 8 characters for systemId='%s', serviceName='%s'", tc.systemId, tc.serviceName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hexadecimal output", func(t *testing.T) {
|
||||
id := makeStableHashId("test-system", "test-service")
|
||||
assert.NotEmpty(t, id)
|
||||
|
||||
// Should only contain hexadecimal characters
|
||||
for _, char := range id {
|
||||
assert.True(t, (char >= '0' && char <= '9') || (char >= 'a' && char <= 'f'),
|
||||
"ID should only contain hexadecimal characters, got: %s", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
//go:build testing
|
||||
|
||||
package systems
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/system"
|
||||
)
|
||||
|
||||
func TestCombinedData_MigrateDeprecatedFields(t *testing.T) {
|
||||
t.Run("Migrate NetworkSent and NetworkRecv to Bandwidth", func(t *testing.T) {
|
||||
cd := &system.CombinedData{
|
||||
Stats: system.Stats{
|
||||
NetworkSent: 1.5, // 1.5 MB
|
||||
NetworkRecv: 2.5, // 2.5 MB
|
||||
},
|
||||
}
|
||||
migrateDeprecatedFields(cd, true)
|
||||
|
||||
expectedSent := uint64(1.5 * 1024 * 1024)
|
||||
expectedRecv := uint64(2.5 * 1024 * 1024)
|
||||
|
||||
if cd.Stats.Bandwidth[0] != expectedSent {
|
||||
t.Errorf("expected Bandwidth[0] %d, got %d", expectedSent, cd.Stats.Bandwidth[0])
|
||||
}
|
||||
if cd.Stats.Bandwidth[1] != expectedRecv {
|
||||
t.Errorf("expected Bandwidth[1] %d, got %d", expectedRecv, cd.Stats.Bandwidth[1])
|
||||
}
|
||||
if cd.Stats.NetworkSent != 0 || cd.Stats.NetworkRecv != 0 {
|
||||
t.Errorf("expected NetworkSent and NetworkRecv to be reset, got %f, %f", cd.Stats.NetworkSent, cd.Stats.NetworkRecv)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Migrate Info.Bandwidth to Info.BandwidthBytes", func(t *testing.T) {
|
||||
cd := &system.CombinedData{
|
||||
Info: system.Info{
|
||||
Bandwidth: 10.0, // 10 MB
|
||||
},
|
||||
}
|
||||
migrateDeprecatedFields(cd, true)
|
||||
|
||||
expected := uint64(10 * 1024 * 1024)
|
||||
if cd.Info.BandwidthBytes != expected {
|
||||
t.Errorf("expected BandwidthBytes %d, got %d", expected, cd.Info.BandwidthBytes)
|
||||
}
|
||||
if cd.Info.Bandwidth != 0 {
|
||||
t.Errorf("expected Info.Bandwidth to be reset, got %f", cd.Info.Bandwidth)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Migrate DiskReadPs and DiskWritePs to DiskIO", func(t *testing.T) {
|
||||
cd := &system.CombinedData{
|
||||
Stats: system.Stats{
|
||||
DiskReadPs: 3.0, // 3 MB
|
||||
DiskWritePs: 4.0, // 4 MB
|
||||
},
|
||||
}
|
||||
migrateDeprecatedFields(cd, true)
|
||||
|
||||
expectedRead := uint64(3 * 1024 * 1024)
|
||||
expectedWrite := uint64(4 * 1024 * 1024)
|
||||
|
||||
if cd.Stats.DiskIO[0] != expectedRead {
|
||||
t.Errorf("expected DiskIO[0] %d, got %d", expectedRead, cd.Stats.DiskIO[0])
|
||||
}
|
||||
if cd.Stats.DiskIO[1] != expectedWrite {
|
||||
t.Errorf("expected DiskIO[1] %d, got %d", expectedWrite, cd.Stats.DiskIO[1])
|
||||
}
|
||||
if cd.Stats.DiskReadPs != 0 || cd.Stats.DiskWritePs != 0 {
|
||||
t.Errorf("expected DiskReadPs and DiskWritePs to be reset, got %f, %f", cd.Stats.DiskReadPs, cd.Stats.DiskWritePs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Migrate Info fields to Details struct", func(t *testing.T) {
|
||||
cd := &system.CombinedData{
|
||||
Stats: system.Stats{
|
||||
Mem: 16.0, // 16 GB
|
||||
},
|
||||
Info: system.Info{
|
||||
Hostname: "test-host",
|
||||
KernelVersion: "6.8.0",
|
||||
Cores: 8,
|
||||
Threads: 16,
|
||||
CpuModel: "Intel i7",
|
||||
Podman: true,
|
||||
Os: system.Linux,
|
||||
},
|
||||
}
|
||||
migrateDeprecatedFields(cd, true)
|
||||
|
||||
if cd.Details == nil {
|
||||
t.Fatal("expected Details struct to be created")
|
||||
}
|
||||
if cd.Details.Hostname != "test-host" {
|
||||
t.Errorf("expected Hostname 'test-host', got '%s'", cd.Details.Hostname)
|
||||
}
|
||||
if cd.Details.Kernel != "6.8.0" {
|
||||
t.Errorf("expected Kernel '6.8.0', got '%s'", cd.Details.Kernel)
|
||||
}
|
||||
if cd.Details.Cores != 8 {
|
||||
t.Errorf("expected Cores 8, got %d", cd.Details.Cores)
|
||||
}
|
||||
if cd.Details.Threads != 16 {
|
||||
t.Errorf("expected Threads 16, got %d", cd.Details.Threads)
|
||||
}
|
||||
if cd.Details.CpuModel != "Intel i7" {
|
||||
t.Errorf("expected CpuModel 'Intel i7', got '%s'", cd.Details.CpuModel)
|
||||
}
|
||||
if cd.Details.Podman != true {
|
||||
t.Errorf("expected Podman true, got %v", cd.Details.Podman)
|
||||
}
|
||||
if cd.Details.Os != system.Linux {
|
||||
t.Errorf("expected Os Linux, got %d", cd.Details.Os)
|
||||
}
|
||||
expectedMem := uint64(16 * 1024 * 1024 * 1024)
|
||||
if cd.Details.MemoryTotal != expectedMem {
|
||||
t.Errorf("expected MemoryTotal %d, got %d", expectedMem, cd.Details.MemoryTotal)
|
||||
}
|
||||
|
||||
if cd.Info.Hostname != "" || cd.Info.KernelVersion != "" || cd.Info.Cores != 0 || cd.Info.CpuModel != "" || cd.Info.Podman != false || cd.Info.Os != 0 {
|
||||
t.Errorf("expected Info fields to be reset, got %+v", cd.Info)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Do not migrate if Details already exists", func(t *testing.T) {
|
||||
cd := &system.CombinedData{
|
||||
Details: &system.Details{Hostname: "existing-host"},
|
||||
Info: system.Info{
|
||||
Hostname: "deprecated-host",
|
||||
},
|
||||
}
|
||||
migrateDeprecatedFields(cd, true)
|
||||
|
||||
if cd.Details.Hostname != "existing-host" {
|
||||
t.Errorf("expected Hostname 'existing-host', got '%s'", cd.Details.Hostname)
|
||||
}
|
||||
if cd.Info.Hostname != "deprecated-host" {
|
||||
t.Errorf("expected Info.Hostname to remain 'deprecated-host', got '%s'", cd.Info.Hostname)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Do not create details if migrateDetails is false", func(t *testing.T) {
|
||||
cd := &system.CombinedData{
|
||||
Info: system.Info{
|
||||
Hostname: "deprecated-host",
|
||||
},
|
||||
}
|
||||
migrateDeprecatedFields(cd, false)
|
||||
|
||||
if cd.Details != nil {
|
||||
t.Fatal("expected Details struct to not be created")
|
||||
}
|
||||
|
||||
if cd.Info.Hostname != "" {
|
||||
t.Errorf("expected Info.Hostname to be reset, got '%s'", cd.Info.Hostname)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build !testing
|
||||
|
||||
package systems
|
||||
|
||||
// Background SMART fetching is enabled in production but disabled for tests (systems_test_helpers.go).
|
||||
//
|
||||
// The hub integration tests create/replace systems and clean up the test apps quickly.
|
||||
// Background SMART fetching can outlive teardown and crash in PocketBase internals (nil DB).
|
||||
func backgroundSmartFetchEnabled() bool { return true }
|
||||
@@ -0,0 +1,480 @@
|
||||
//go:build testing
|
||||
|
||||
package systems_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/internal/entities/container"
|
||||
"github.com/henrygd/beszel/internal/entities/system"
|
||||
"github.com/henrygd/beszel/internal/hub/systems"
|
||||
"github.com/henrygd/beszel/internal/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSystemManagerNew(t *testing.T) {
|
||||
hub, err := tests.NewTestHub(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer hub.Cleanup()
|
||||
sm := hub.GetSystemManager()
|
||||
|
||||
user, err := tests.CreateUser(hub, "test@test.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
sm.Initialize()
|
||||
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "it-was-coney-island",
|
||||
"host": "the-playground-of-the-world",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "pending", record.GetString("status"), "System status should be 'pending'")
|
||||
assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'")
|
||||
|
||||
// Verify the system host and port
|
||||
host, port := sm.GetSystemHostPort(record.Id)
|
||||
assert.Equal(t, record.GetString("host"), host, "System host should match")
|
||||
assert.Equal(t, record.GetString("port"), port, "System port should match")
|
||||
|
||||
time.Sleep(13 * time.Second)
|
||||
synctest.Wait()
|
||||
|
||||
assert.Equal(t, "pending", record.Fresh().GetString("status"), "System status should be 'pending'")
|
||||
// Verify the system was added by checking if it exists
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store")
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
synctest.Wait()
|
||||
|
||||
// system should be set to down after 15 seconds (no websocket connection)
|
||||
assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'")
|
||||
// make sure the system is down in the db
|
||||
record, err = hub.FindRecordById("systems", record.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "down", record.GetString("status"), "System status should be 'down'")
|
||||
|
||||
assert.Equal(t, 1, sm.GetSystemCount(), "System count should be 1")
|
||||
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, sm.GetSystemCount(), "System count should be 0")
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal")
|
||||
|
||||
// let's also make sure a system is removed from the store when the record is deleted
|
||||
record, err = tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "there-was-no-place-like-it",
|
||||
"host": "in-the-whole-world",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store after creation")
|
||||
|
||||
time.Sleep(8 * time.Second)
|
||||
synctest.Wait()
|
||||
assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'")
|
||||
|
||||
sm.SetSystemStatusInDB(record.Id, "up")
|
||||
time.Sleep(time.Second)
|
||||
synctest.Wait()
|
||||
assert.Equal(t, "up", sm.GetSystemStatusFromStore(record.Id), "System status should be 'up'")
|
||||
|
||||
// make sure the system switches to down after 11 seconds
|
||||
sm.RemoveSystem(record.Id)
|
||||
sm.AddRecord(record, nil)
|
||||
assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'")
|
||||
time.Sleep(12 * time.Second)
|
||||
synctest.Wait()
|
||||
assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'")
|
||||
|
||||
// sm.SetSystemStatusInDB(record.Id, "paused")
|
||||
// time.Sleep(time.Second)
|
||||
// synctest.Wait()
|
||||
// assert.Equal(t, "paused", sm.GetSystemStatusFromStore(record.Id), "System status should be 'paused'")
|
||||
|
||||
// delete the record
|
||||
err = hub.Delete(record)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after deletion")
|
||||
})
|
||||
|
||||
testOld(t, hub)
|
||||
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
synctest.Wait()
|
||||
|
||||
for _, systemId := range sm.GetAllSystemIDs() {
|
||||
err = sm.RemoveSystem(systemId)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sm.HasSystem(systemId), "System should not exist in the store after deletion")
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, sm.GetSystemCount(), "System count should be 0")
|
||||
|
||||
// TODO: test with websocket client
|
||||
})
|
||||
}
|
||||
|
||||
func testOld(t *testing.T, hub *tests.TestHub) {
|
||||
user, err := tests.CreateUser(hub, "test@testy.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
sm := hub.GetSystemManager()
|
||||
assert.NotNil(t, sm)
|
||||
|
||||
// error expected when creating a user with a duplicate email
|
||||
_, err = tests.CreateUser(hub, "test@test.com", "testtesttest")
|
||||
require.Error(t, err)
|
||||
|
||||
// Test collection existence. todo: move to hub package tests
|
||||
t.Run("CollectionExistence", func(t *testing.T) {
|
||||
// Verify that required collections exist
|
||||
systems, err := hub.FindCachedCollectionByNameOrId("systems")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, systems)
|
||||
|
||||
systemStats, err := hub.FindCachedCollectionByNameOrId("system_stats")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, systemStats)
|
||||
|
||||
containerStats, err := hub.FindCachedCollectionByNameOrId("container_stats")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, containerStats)
|
||||
})
|
||||
|
||||
t.Run("RemoveSystem", func(t *testing.T) {
|
||||
// Get the count before adding the system
|
||||
countBefore := sm.GetSystemCount()
|
||||
|
||||
// Create a test system record
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "i-even-got-lost-at-coney-island",
|
||||
"host": "but-they-found-me",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the system count increased
|
||||
countAfterAdd := sm.GetSystemCount()
|
||||
assert.Equal(t, countBefore+1, countAfterAdd, "System count should increase after adding a system via event hook")
|
||||
|
||||
// Verify the system exists
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store")
|
||||
|
||||
// Remove the system
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that the system count decreased
|
||||
countAfterRemove := sm.GetSystemCount()
|
||||
assert.Equal(t, countAfterAdd-1, countAfterRemove, "System count should decrease after removing a system")
|
||||
|
||||
// Verify the system no longer exists
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal")
|
||||
|
||||
// Verify the system is not in the list of all system IDs
|
||||
ids := sm.GetAllSystemIDs()
|
||||
assert.NotContains(t, ids, record.Id, "System ID should not be in the list of all system IDs after removal")
|
||||
|
||||
// Verify the system status is empty
|
||||
status := sm.GetSystemStatusFromStore(record.Id)
|
||||
assert.Equal(t, "", status, "System status should be empty after removal")
|
||||
|
||||
// Try to remove it again - should return an error since it's already removed
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("NewRecordPending", func(t *testing.T) {
|
||||
// Create a test system
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "and-you-know",
|
||||
"host": "i-feel-very-bad",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add the record to the system manager
|
||||
err = sm.AddRecord(record, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test filtering records by status - should be "pending" now
|
||||
filter := "status = 'pending'"
|
||||
pendingSystems, err := hub.FindRecordsByFilter("systems", filter, "-created", 0, 0, nil)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(pendingSystems), 1)
|
||||
})
|
||||
|
||||
t.Run("SystemStatusUpdate", func(t *testing.T) {
|
||||
// Create a test system record
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "we-used-to-sleep-on-the-beach",
|
||||
"host": "sleep-overnight-here",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add the record to the system manager
|
||||
err = sm.AddRecord(record, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test status changes
|
||||
initialStatus := sm.GetSystemStatusFromStore(record.Id)
|
||||
|
||||
// Set a new status
|
||||
sm.SetSystemStatusInDB(record.Id, "up")
|
||||
|
||||
// Verify status was updated
|
||||
newStatus := sm.GetSystemStatusFromStore(record.Id)
|
||||
assert.Equal(t, "up", newStatus, "System status should be updated to 'up'")
|
||||
assert.NotEqual(t, initialStatus, newStatus, "Status should have changed")
|
||||
|
||||
// Verify the database was updated
|
||||
updatedRecord, err := hub.FindRecordById("systems", record.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "up", updatedRecord.Get("status"), "Database status should match")
|
||||
})
|
||||
|
||||
t.Run("HandleSystemData", func(t *testing.T) {
|
||||
// Create a test system record
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "things-changed-you-know",
|
||||
"host": "they-dont-sleep-anymore-on-the-beach",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test system data
|
||||
testData := &system.CombinedData{
|
||||
Details: &system.Details{
|
||||
Hostname: "data-test.example.com",
|
||||
Kernel: "5.15.0-generic",
|
||||
Cores: 4,
|
||||
Threads: 8,
|
||||
CpuModel: "Test CPU",
|
||||
},
|
||||
Info: system.Info{
|
||||
Uptime: 3600,
|
||||
Cpu: 25.5,
|
||||
MemPct: 40.2,
|
||||
DiskPct: 60.0,
|
||||
Bandwidth: 100.0,
|
||||
AgentVersion: "1.0.0",
|
||||
},
|
||||
Stats: system.Stats{
|
||||
Cpu: 25.5,
|
||||
Mem: 16384.0,
|
||||
MemUsed: 6553.6,
|
||||
MemPct: 40.0,
|
||||
DiskTotal: 1024000.0,
|
||||
DiskUsed: 614400.0,
|
||||
DiskPct: 60.0,
|
||||
NetworkSent: 1024.0,
|
||||
NetworkRecv: 2048.0,
|
||||
},
|
||||
Containers: []*container.Stats{},
|
||||
}
|
||||
|
||||
// Test handling system data. todo: move to hub/alerts package tests
|
||||
err = hub.HandleSystemAlerts(record, testData)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ErrorHandling", func(t *testing.T) {
|
||||
// Try to add a non-existent record
|
||||
nonExistentId := "non_existent_id"
|
||||
err := sm.RemoveSystem(nonExistentId)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Try to add a system with invalid host
|
||||
system := &systems.System{
|
||||
Host: "",
|
||||
}
|
||||
err = sm.AddSystem(system)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentOperations", func(t *testing.T) {
|
||||
// Create a test system
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "jfkjahkfajs",
|
||||
"host": "localhost",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run concurrent operations
|
||||
const goroutines = 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := range goroutines {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Alternate between different operations
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
status := fmt.Sprintf("status-%d", i)
|
||||
sm.SetSystemStatusInDB(record.Id, status)
|
||||
case 1:
|
||||
_ = sm.GetSystemStatusFromStore(record.Id)
|
||||
case 2:
|
||||
_, _ = sm.GetSystemHostPort(record.Id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify system still exists and is in a valid state
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should still exist after concurrent operations")
|
||||
status := sm.GetSystemStatusFromStore(record.Id)
|
||||
assert.NotEmpty(t, status, "System should have a status after concurrent operations")
|
||||
})
|
||||
|
||||
t.Run("ContextCancellation", func(t *testing.T) {
|
||||
// Create a test system record
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "lkhsdfsjf",
|
||||
"host": "localhost",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the system exists in the store
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store")
|
||||
|
||||
// Store the original context and cancel function
|
||||
originalCtx, originalCancel, err := sm.GetSystemContextFromStore(record.Id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure the context is not nil
|
||||
assert.NotNil(t, originalCtx, "System context should not be nil")
|
||||
assert.NotNil(t, originalCancel, "System cancel function should not be nil")
|
||||
|
||||
// Cancel the context
|
||||
originalCancel()
|
||||
|
||||
// Wait a short time for cancellation to propagate
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify the context is done
|
||||
select {
|
||||
case <-originalCtx.Done():
|
||||
// Context was properly cancelled
|
||||
default:
|
||||
t.Fatal("Context was not cancelled")
|
||||
}
|
||||
|
||||
// Verify the system is still in the store (cancellation shouldn't remove it)
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should still exist after context cancellation")
|
||||
|
||||
// Explicitly remove the system
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.NoError(t, err, "RemoveSystem should succeed")
|
||||
|
||||
// Verify the system is removed
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should be removed after RemoveSystem")
|
||||
|
||||
// Try to remove it again - should return an error
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.Error(t, err, "RemoveSystem should fail for non-existent system")
|
||||
|
||||
// Add the system back
|
||||
err = sm.AddRecord(record, nil)
|
||||
require.NoError(t, err, "AddRecord should succeed")
|
||||
|
||||
// Verify the system is back in the store
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist after re-adding")
|
||||
|
||||
// Verify a new context was created
|
||||
newCtx, newCancel, err := sm.GetSystemContextFromStore(record.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, newCtx, "New system context should not be nil")
|
||||
assert.NotNil(t, newCancel, "New system cancel function should not be nil")
|
||||
assert.NotEqual(t, originalCtx, newCtx, "New context should be different from original")
|
||||
|
||||
// Clean up
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasUser(t *testing.T) {
|
||||
hub, err := tests.NewTestHub(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer hub.Cleanup()
|
||||
|
||||
sm := hub.GetSystemManager()
|
||||
err = sm.Initialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
user1, err := tests.CreateUser(hub, "user1@test.com", "password123")
|
||||
require.NoError(t, err)
|
||||
user2, err := tests.CreateUser(hub, "user2@test.com", "password123")
|
||||
require.NoError(t, err)
|
||||
|
||||
systemRecord, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "has-user-test",
|
||||
"host": "127.0.0.1",
|
||||
"port": "33914",
|
||||
"users": []string{user1.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sys, err := sm.GetSystemFromStore(systemRecord.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("user in list returns true", func(t *testing.T) {
|
||||
assert.True(t, sys.HasUser(hub, user1))
|
||||
})
|
||||
|
||||
t.Run("user not in list returns false", func(t *testing.T) {
|
||||
assert.False(t, sys.HasUser(hub, user2))
|
||||
})
|
||||
|
||||
t.Run("unknown user ID returns false", func(t *testing.T) {
|
||||
assert.False(t, sys.HasUser(hub, nil))
|
||||
})
|
||||
|
||||
t.Run("SHARE_ALL_SYSTEMS=true grants access to non-member", func(t *testing.T) {
|
||||
t.Setenv("SHARE_ALL_SYSTEMS", "true")
|
||||
assert.True(t, sys.HasUser(hub, user2))
|
||||
})
|
||||
|
||||
t.Run("BESZEL_HUB_SHARE_ALL_SYSTEMS=true grants access to non-member", func(t *testing.T) {
|
||||
t.Setenv("BESZEL_HUB_SHARE_ALL_SYSTEMS", "true")
|
||||
assert.True(t, sys.HasUser(hub, user2))
|
||||
})
|
||||
|
||||
t.Run("additional user works", func(t *testing.T) {
|
||||
assert.False(t, sys.HasUser(hub, user2))
|
||||
systemRecord.Set("users", []string{user1.Id, user2.Id})
|
||||
err = hub.Save(systemRecord)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, sys.HasUser(hub, user1))
|
||||
assert.True(t, sys.HasUser(hub, user2))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
//go:build testing
|
||||
|
||||
package systems
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
entities "github.com/henrygd/beszel/internal/entities/system"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// The hub integration tests create/replace systems and cleanup the test apps quickly.
|
||||
// Background SMART fetching can outlive teardown and crash in PocketBase internals (nil DB).
|
||||
//
|
||||
// We keep the explicit SMART refresh endpoint / method available, but disable
|
||||
// the automatic background fetch during tests.
|
||||
func backgroundSmartFetchEnabled() bool { return false }
|
||||
|
||||
// TESTING ONLY: GetSystemCount returns the number of systems in the store
|
||||
func (sm *SystemManager) GetSystemCount() int {
|
||||
return sm.systems.Length()
|
||||
}
|
||||
|
||||
// TESTING ONLY: HasSystem checks if a system with the given ID exists in the store
|
||||
func (sm *SystemManager) HasSystem(systemID string) bool {
|
||||
return sm.systems.Has(systemID)
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetSystemStatusFromStore returns the status of a system with the given ID
|
||||
// Returns an empty string if the system doesn't exist
|
||||
func (sm *SystemManager) GetSystemStatusFromStore(systemID string) string {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return sys.Status
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetSystemContextFromStore returns the context and cancel function for a system
|
||||
func (sm *SystemManager) GetSystemContextFromStore(systemID string) (context.Context, context.CancelFunc, error) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("no system")
|
||||
}
|
||||
return sys.ctx, sys.cancel, nil
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetSystemFromStore returns a store from the system
|
||||
func (sm *SystemManager) GetSystemFromStore(systemID string) (*System, error) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no system")
|
||||
}
|
||||
return sys, nil
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetAllSystemIDs returns a slice of all system IDs in the store
|
||||
func (sm *SystemManager) GetAllSystemIDs() []string {
|
||||
data := sm.systems.GetAll()
|
||||
ids := make([]string, 0, len(data))
|
||||
for id := range data {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetSystemData returns the combined data for a system with the given ID
|
||||
// Returns nil if the system doesn't exist
|
||||
// This method is intended for testing
|
||||
func (sm *SystemManager) GetSystemData(systemID string) *entities.CombinedData {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return sys.data
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetSystemHostPort returns the host and port for a system with the given ID
|
||||
// Returns empty strings if the system doesn't exist
|
||||
func (sm *SystemManager) GetSystemHostPort(systemID string) (string, string) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return "", ""
|
||||
}
|
||||
return sys.Host, sys.Port
|
||||
}
|
||||
|
||||
// TESTING ONLY: SetSystemStatusInDB sets the status of a system directly and updates the database record
|
||||
// This is intended for testing
|
||||
// Returns false if the system doesn't exist
|
||||
func (sm *SystemManager) SetSystemStatusInDB(systemID string, status string) bool {
|
||||
if !sm.HasSystem(systemID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Update the database record
|
||||
record, err := sm.hub.FindRecordById("systems", systemID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
record.Set("status", status)
|
||||
err = sm.hub.Save(record)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TESTING ONLY: RemoveAllSystems removes all systems from the store
|
||||
func (sm *SystemManager) RemoveAllSystems() {
|
||||
for _, system := range sm.systems.GetAll() {
|
||||
sm.RemoveSystem(system.Id)
|
||||
}
|
||||
sm.smartFetchMap.StopCleaner()
|
||||
}
|
||||
|
||||
func (s *System) StopUpdater() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func (s *System) CreateRecords(data *entities.CombinedData) (*core.Record, error) {
|
||||
s.data = data
|
||||
return s.createRecords(data)
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHTransport implements Transport over SSH connections.
|
||||
type SSHTransport struct {
|
||||
client *ssh.Client
|
||||
config *ssh.ClientConfig
|
||||
host string
|
||||
port string
|
||||
agentVersion semver.Version
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// SSHTransportConfig holds configuration for creating an SSH transport.
|
||||
type SSHTransportConfig struct {
|
||||
Host string
|
||||
Port string
|
||||
Config *ssh.ClientConfig
|
||||
AgentVersion semver.Version
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// NewSSHTransport creates a new SSH transport with the given configuration.
|
||||
func NewSSHTransport(cfg SSHTransportConfig) *SSHTransport {
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 4 * time.Second
|
||||
}
|
||||
return &SSHTransport{
|
||||
config: cfg.Config,
|
||||
host: cfg.Host,
|
||||
port: cfg.Port,
|
||||
agentVersion: cfg.AgentVersion,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// SetClient sets the SSH client for reuse across requests.
|
||||
func (t *SSHTransport) SetClient(client *ssh.Client) {
|
||||
t.client = client
|
||||
}
|
||||
|
||||
// SetAgentVersion sets the agent version (extracted from SSH handshake).
|
||||
func (t *SSHTransport) SetAgentVersion(version semver.Version) {
|
||||
t.agentVersion = version
|
||||
}
|
||||
|
||||
// GetClient returns the current SSH client (for connection management).
|
||||
func (t *SSHTransport) GetClient() *ssh.Client {
|
||||
return t.client
|
||||
}
|
||||
|
||||
// GetAgentVersion returns the agent version.
|
||||
func (t *SSHTransport) GetAgentVersion() semver.Version {
|
||||
return t.agentVersion
|
||||
}
|
||||
|
||||
// Request sends a request to the agent via SSH and unmarshals the response.
|
||||
func (t *SSHTransport) Request(ctx context.Context, action common.WebSocketAction, req any, dest any) error {
|
||||
if t.client == nil {
|
||||
if err := t.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
session, err := t.createSessionWithTimeout(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.Shell(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request
|
||||
hubReq := common.HubRequest[any]{Action: action, Data: req}
|
||||
if err := cbor.NewEncoder(stdin).Encode(hubReq); err != nil {
|
||||
return fmt.Errorf("failed to encode request: %w", err)
|
||||
}
|
||||
stdin.Close()
|
||||
|
||||
// Read response
|
||||
var resp common.AgentResponse
|
||||
if err := cbor.NewDecoder(stdout).Decode(&resp); err != nil {
|
||||
return fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if resp.Error != "" {
|
||||
return errors.New(resp.Error)
|
||||
}
|
||||
|
||||
if err := session.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalResponse(resp, action, dest)
|
||||
}
|
||||
|
||||
// IsConnected returns true if the SSH connection is active.
|
||||
func (t *SSHTransport) IsConnected() bool {
|
||||
return t.client != nil
|
||||
}
|
||||
|
||||
// Close terminates the SSH connection.
|
||||
func (t *SSHTransport) Close() {
|
||||
if t.client != nil {
|
||||
t.client.Close()
|
||||
t.client = nil
|
||||
}
|
||||
}
|
||||
|
||||
// connect establishes a new SSH connection.
|
||||
func (t *SSHTransport) connect() error {
|
||||
if t.config == nil {
|
||||
return errors.New("SSH config not set")
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
host := t.host
|
||||
if strings.HasPrefix(host, "/") {
|
||||
network = "unix"
|
||||
} else {
|
||||
host = net.JoinHostPort(host, t.port)
|
||||
}
|
||||
|
||||
client, err := ssh.Dial(network, host, t.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.client = client
|
||||
|
||||
// Extract agent version from server version string
|
||||
t.agentVersion, _ = extractAgentVersion(string(client.Conn.ServerVersion()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSessionWithTimeout creates a new SSH session with a timeout.
|
||||
func (t *SSHTransport) createSessionWithTimeout(ctx context.Context) (*ssh.Session, error) {
|
||||
if t.client == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
defer cancel()
|
||||
|
||||
sessionChan := make(chan *ssh.Session, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
session, err := t.client.NewSession()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
sessionChan <- session
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case session := <-sessionChan:
|
||||
return session, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, errors.New("timeout creating session")
|
||||
}
|
||||
}
|
||||
|
||||
// extractAgentVersion extracts the beszel version from SSH server version string.
|
||||
func extractAgentVersion(versionString string) (semver.Version, error) {
|
||||
_, after, _ := strings.Cut(versionString, "_")
|
||||
return semver.Parse(after)
|
||||
}
|
||||
|
||||
// RequestWithRetry sends a request with automatic retry on connection failures.
|
||||
func (t *SSHTransport) RequestWithRetry(ctx context.Context, action common.WebSocketAction, req any, dest any, retries int) error {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= retries; attempt++ {
|
||||
err := t.Request(ctx, action, req, dest)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Check if it's a connection error that warrants a retry
|
||||
if isConnectionError(err) && attempt < retries {
|
||||
t.Close()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// isConnectionError checks if an error indicates a connection problem.
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "connection") ||
|
||||
strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "closed") ||
|
||||
errors.Is(err, io.EOF)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
// Package transport provides a unified abstraction for hub-agent communication
|
||||
// over different transports (WebSocket, SSH).
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/henrygd/beszel/internal/entities/smart"
|
||||
"github.com/henrygd/beszel/internal/entities/system"
|
||||
"github.com/henrygd/beszel/internal/entities/systemd"
|
||||
)
|
||||
|
||||
// Transport defines the interface for hub-agent communication.
|
||||
// Both WebSocket and SSH transports implement this interface.
|
||||
type Transport interface {
|
||||
// Request sends a request to the agent and unmarshals the response into dest.
|
||||
// The dest parameter should be a pointer to the expected response type.
|
||||
Request(ctx context.Context, action common.WebSocketAction, req any, dest any) error
|
||||
// IsConnected returns true if the transport connection is active.
|
||||
IsConnected() bool
|
||||
// Close terminates the transport connection.
|
||||
Close()
|
||||
}
|
||||
|
||||
// UnmarshalResponse unmarshals an AgentResponse into the destination type.
|
||||
// It first checks the generic Data field (0.19+ agents), then falls back
|
||||
// to legacy typed fields for backward compatibility with 0.18.0 agents.
|
||||
func UnmarshalResponse(resp common.AgentResponse, action common.WebSocketAction, dest any) error {
|
||||
if dest == nil {
|
||||
return errors.New("nil destination")
|
||||
}
|
||||
// Try generic Data field first (0.19+)
|
||||
if len(resp.Data) > 0 {
|
||||
if err := cbor.Unmarshal(resp.Data, dest); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal generic response data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Fall back to legacy typed fields for older agents/hubs.
|
||||
return unmarshalLegacyResponse(resp, action, dest)
|
||||
}
|
||||
|
||||
// unmarshalLegacyResponse handles legacy responses that use typed fields.
|
||||
func unmarshalLegacyResponse(resp common.AgentResponse, action common.WebSocketAction, dest any) error {
|
||||
switch action {
|
||||
case common.GetData:
|
||||
d, ok := dest.(*system.CombinedData)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected dest type for GetData: %T", dest)
|
||||
}
|
||||
if resp.SystemData == nil {
|
||||
return errors.New("no system data in response")
|
||||
}
|
||||
*d = *resp.SystemData
|
||||
return nil
|
||||
case common.CheckFingerprint:
|
||||
d, ok := dest.(*common.FingerprintResponse)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected dest type for CheckFingerprint: %T", dest)
|
||||
}
|
||||
if resp.Fingerprint == nil {
|
||||
return errors.New("no fingerprint in response")
|
||||
}
|
||||
*d = *resp.Fingerprint
|
||||
return nil
|
||||
case common.GetContainerLogs:
|
||||
d, ok := dest.(*string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected dest type for GetContainerLogs: %T", dest)
|
||||
}
|
||||
if resp.String == nil {
|
||||
return errors.New("no logs in response")
|
||||
}
|
||||
*d = *resp.String
|
||||
return nil
|
||||
case common.GetContainerInfo:
|
||||
d, ok := dest.(*string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected dest type for GetContainerInfo: %T", dest)
|
||||
}
|
||||
if resp.String == nil {
|
||||
return errors.New("no info in response")
|
||||
}
|
||||
*d = *resp.String
|
||||
return nil
|
||||
case common.GetSmartData:
|
||||
d, ok := dest.(*map[string]smart.SmartData)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected dest type for GetSmartData: %T", dest)
|
||||
}
|
||||
if resp.SmartData == nil {
|
||||
return errors.New("no SMART data in response")
|
||||
}
|
||||
*d = resp.SmartData
|
||||
return nil
|
||||
case common.GetSystemdInfo:
|
||||
d, ok := dest.(*systemd.ServiceDetails)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected dest type for GetSystemdInfo: %T", dest)
|
||||
}
|
||||
if resp.ServiceInfo == nil {
|
||||
return errors.New("no systemd info in response")
|
||||
}
|
||||
*d = resp.ServiceInfo
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unsupported action: %d", action)
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/henrygd/beszel"
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/henrygd/beszel/internal/hub/ws"
|
||||
)
|
||||
|
||||
// ErrWebSocketNotConnected indicates a WebSocket transport is not currently connected.
|
||||
var ErrWebSocketNotConnected = errors.New("websocket not connected")
|
||||
|
||||
// WebSocketTransport implements Transport over WebSocket connections.
|
||||
type WebSocketTransport struct {
|
||||
wsConn *ws.WsConn
|
||||
}
|
||||
|
||||
// NewWebSocketTransport creates a new WebSocket transport wrapper.
|
||||
func NewWebSocketTransport(wsConn *ws.WsConn) *WebSocketTransport {
|
||||
return &WebSocketTransport{wsConn: wsConn}
|
||||
}
|
||||
|
||||
// Request sends a request to the agent via WebSocket and unmarshals the response.
|
||||
func (t *WebSocketTransport) Request(ctx context.Context, action common.WebSocketAction, req any, dest any) error {
|
||||
if !t.IsConnected() {
|
||||
return ErrWebSocketNotConnected
|
||||
}
|
||||
|
||||
pendingReq, err := t.wsConn.SendRequest(ctx, action, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
select {
|
||||
case message := <-pendingReq.ResponseCh:
|
||||
defer message.Close()
|
||||
defer pendingReq.Cancel()
|
||||
|
||||
// Legacy agents (< MinVersionAgentResponse) respond with a raw payload instead of an AgentResponse wrapper.
|
||||
if t.wsConn.AgentVersion().LT(beszel.MinVersionAgentResponse) {
|
||||
return cbor.Unmarshal(message.Data.Bytes(), dest)
|
||||
}
|
||||
|
||||
var agentResponse common.AgentResponse
|
||||
if err := cbor.Unmarshal(message.Data.Bytes(), &agentResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if agentResponse.Error != "" {
|
||||
return errors.New(agentResponse.Error)
|
||||
}
|
||||
|
||||
return UnmarshalResponse(agentResponse, action, dest)
|
||||
|
||||
case <-pendingReq.Context.Done():
|
||||
return pendingReq.Context.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns true if the WebSocket connection is active.
|
||||
func (t *WebSocketTransport) IsConnected() bool {
|
||||
return t.wsConn != nil && t.wsConn.IsConnected()
|
||||
}
|
||||
|
||||
// Close terminates the WebSocket connection.
|
||||
func (t *WebSocketTransport) Close() {
|
||||
if t.wsConn != nil {
|
||||
t.wsConn.Close(nil)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/henrygd/beszel/internal/ghupdate"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Update updates beszel to the latest version
|
||||
func Update(cmd *cobra.Command, _ []string) {
|
||||
dataDir := os.TempDir()
|
||||
|
||||
// set dataDir to ./beszel_data if it exists
|
||||
if _, err := os.Stat("./beszel_data"); err == nil {
|
||||
dataDir = "./beszel_data"
|
||||
}
|
||||
|
||||
// Check if china-mirrors flag is set
|
||||
useMirror, _ := cmd.Flags().GetBool("china-mirrors")
|
||||
|
||||
// Get the executable path before update
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
updated, err := ghupdate.Update(ghupdate.Config{
|
||||
ArchiveExecutable: "beszel",
|
||||
DataDir: dataDir,
|
||||
UseMirror: useMirror,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
|
||||
// make sure the file is executable
|
||||
if err := os.Chmod(exePath, 0755); err != nil {
|
||||
fmt.Printf("Warning: failed to set executable permissions: %v\n", err)
|
||||
}
|
||||
|
||||
// Fix SELinux context if necessary
|
||||
if err := ghupdate.HandleSELinuxContext(exePath); err != nil {
|
||||
ghupdate.ColorPrintf(ghupdate.ColorYellow, "Warning: SELinux context handling: %v", err)
|
||||
}
|
||||
|
||||
// Try to restart the service if it's running
|
||||
restartService()
|
||||
}
|
||||
|
||||
// restartService attempts to restart the beszel service
|
||||
func restartService() {
|
||||
// Check if we're running as a service by looking for systemd
|
||||
if _, err := exec.LookPath("systemctl"); err == nil {
|
||||
// Check if beszel service exists and is active
|
||||
cmd := exec.Command("systemctl", "is-active", "beszel.service")
|
||||
if err := cmd.Run(); err == nil {
|
||||
ghupdate.ColorPrint(ghupdate.ColorYellow, "Restarting beszel service...")
|
||||
restartCmd := exec.Command("systemctl", "restart", "beszel.service")
|
||||
if err := restartCmd.Run(); err != nil {
|
||||
ghupdate.ColorPrintf(ghupdate.ColorYellow, "Warning: Failed to restart service: %v\n", err)
|
||||
ghupdate.ColorPrint(ghupdate.ColorYellow, "Please restart the service manually: sudo systemctl restart beszel")
|
||||
} else {
|
||||
ghupdate.ColorPrint(ghupdate.ColorGreen, "Service restarted successfully")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for OpenRC (Alpine Linux)
|
||||
if _, err := exec.LookPath("rc-service"); err == nil {
|
||||
cmd := exec.Command("rc-service", "beszel", "status")
|
||||
if err := cmd.Run(); err == nil {
|
||||
ghupdate.ColorPrint(ghupdate.ColorYellow, "Restarting beszel service...")
|
||||
restartCmd := exec.Command("rc-service", "beszel", "restart")
|
||||
if err := restartCmd.Run(); err != nil {
|
||||
ghupdate.ColorPrintf(ghupdate.ColorYellow, "Warning: Failed to restart service: %v\n", err)
|
||||
ghupdate.ColorPrint(ghupdate.ColorYellow, "Please restart the service manually: sudo rc-service beszel restart")
|
||||
} else {
|
||||
ghupdate.ColorPrint(ghupdate.ColorGreen, "Service restarted successfully")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ghupdate.ColorPrint(ghupdate.ColorYellow, "Service restart not attempted. If running as a service, restart manually.")
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// Package utils provides utility functions for the hub.
|
||||
package utils
|
||||
|
||||
import "os"
|
||||
|
||||
// GetEnv retrieves an environment variable with a "BESZEL_HUB_" prefix, or falls back to the unprefixed key.
|
||||
func GetEnv(key string) (value string, exists bool) {
|
||||
if value, exists = os.LookupEnv("BESZEL_HUB_" + key); exists {
|
||||
return value, exists
|
||||
}
|
||||
return os.LookupEnv(key)
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/lxzan/gws"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// ResponseHandler defines interface for handling agent responses.
|
||||
// This is used by handleAgentRequest for legacy response handling.
|
||||
type ResponseHandler interface {
|
||||
Handle(agentResponse common.AgentResponse) error
|
||||
HandleLegacy(rawData []byte) error
|
||||
}
|
||||
|
||||
// BaseHandler provides a default implementation that can be embedded to make HandleLegacy optional
|
||||
type BaseHandler struct{}
|
||||
|
||||
func (h *BaseHandler) HandleLegacy(rawData []byte) error {
|
||||
return errors.New("legacy format not supported")
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Fingerprint handling (used for WebSocket authentication)
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// fingerprintHandler implements ResponseHandler for fingerprint requests
|
||||
type fingerprintHandler struct {
|
||||
result *common.FingerprintResponse
|
||||
}
|
||||
|
||||
func (h *fingerprintHandler) HandleLegacy(rawData []byte) error {
|
||||
return cbor.Unmarshal(rawData, h.result)
|
||||
}
|
||||
|
||||
func (h *fingerprintHandler) Handle(agentResponse common.AgentResponse) error {
|
||||
if agentResponse.Fingerprint != nil {
|
||||
*h.result = *agentResponse.Fingerprint
|
||||
return nil
|
||||
}
|
||||
return errors.New("no fingerprint data in response")
|
||||
}
|
||||
|
||||
// GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint.
|
||||
func (ws *WsConn) GetFingerprint(ctx context.Context, token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) {
|
||||
if !ws.IsConnected() {
|
||||
return common.FingerprintResponse{}, gws.ErrConnClosed
|
||||
}
|
||||
|
||||
challenge := []byte(token)
|
||||
signature, err := signer.Sign(nil, challenge)
|
||||
if err != nil {
|
||||
return common.FingerprintResponse{}, err
|
||||
}
|
||||
|
||||
req, err := ws.requestManager.SendRequest(ctx, common.CheckFingerprint, common.FingerprintRequest{
|
||||
Signature: signature.Blob,
|
||||
NeedSysInfo: needSysInfo,
|
||||
})
|
||||
if err != nil {
|
||||
return common.FingerprintResponse{}, err
|
||||
}
|
||||
|
||||
var result common.FingerprintResponse
|
||||
handler := &fingerprintHandler{result: &result}
|
||||
err = ws.handleAgentRequest(req, handler)
|
||||
return result, err
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
"github.com/lxzan/gws"
|
||||
)
|
||||
|
||||
// RequestID uniquely identifies a request
|
||||
type RequestID uint32
|
||||
|
||||
// PendingRequest tracks an in-flight request
|
||||
type PendingRequest struct {
|
||||
ID RequestID
|
||||
ResponseCh chan *gws.Message
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// RequestManager handles concurrent requests to an agent
|
||||
type RequestManager struct {
|
||||
sync.RWMutex
|
||||
conn *gws.Conn
|
||||
pendingReqs map[RequestID]*PendingRequest
|
||||
nextID atomic.Uint32
|
||||
}
|
||||
|
||||
// NewRequestManager creates a new request manager for a WebSocket connection
|
||||
func NewRequestManager(conn *gws.Conn) *RequestManager {
|
||||
rm := &RequestManager{
|
||||
conn: conn,
|
||||
pendingReqs: make(map[RequestID]*PendingRequest),
|
||||
}
|
||||
return rm
|
||||
}
|
||||
|
||||
// SendRequest sends a request and returns a channel for the response
|
||||
func (rm *RequestManager) SendRequest(ctx context.Context, action common.WebSocketAction, data any) (*PendingRequest, error) {
|
||||
reqID := RequestID(rm.nextID.Add(1))
|
||||
|
||||
// Respect any caller-provided deadline. If none is set, apply a reasonable default
|
||||
// so pending requests don't live forever if the agent never responds.
|
||||
reqCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if _, hasDeadline := ctx.Deadline(); hasDeadline {
|
||||
reqCtx, cancel = context.WithCancel(ctx)
|
||||
} else {
|
||||
reqCtx, cancel = context.WithTimeout(ctx, 5*time.Second)
|
||||
}
|
||||
|
||||
req := &PendingRequest{
|
||||
ID: reqID,
|
||||
ResponseCh: make(chan *gws.Message, 1),
|
||||
Context: reqCtx,
|
||||
Cancel: cancel,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
rm.Lock()
|
||||
rm.pendingReqs[reqID] = req
|
||||
rm.Unlock()
|
||||
|
||||
hubReq := common.HubRequest[any]{
|
||||
Id: (*uint32)(&reqID),
|
||||
Action: action,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// Send the request
|
||||
if err := rm.sendMessage(hubReq); err != nil {
|
||||
rm.cancelRequest(reqID)
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
// Start cleanup watcher for timeout/cancellation
|
||||
go rm.cleanupRequest(req)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// sendMessage encodes and sends a message over WebSocket
|
||||
func (rm *RequestManager) sendMessage(data any) error {
|
||||
if rm.conn == nil {
|
||||
return gws.ErrConnClosed
|
||||
}
|
||||
|
||||
bytes, err := cbor.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
return rm.conn.WriteMessage(gws.OpcodeBinary, bytes)
|
||||
}
|
||||
|
||||
// handleResponse processes a single response message
|
||||
func (rm *RequestManager) handleResponse(message *gws.Message) {
|
||||
var response common.AgentResponse
|
||||
if err := cbor.Unmarshal(message.Data.Bytes(), &response); err != nil {
|
||||
// Legacy response without ID - route to first pending request of any type
|
||||
rm.routeLegacyResponse(message)
|
||||
return
|
||||
}
|
||||
|
||||
if response.Id == nil {
|
||||
rm.routeLegacyResponse(message)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := RequestID(*response.Id)
|
||||
|
||||
rm.RLock()
|
||||
req, exists := rm.pendingReqs[reqID]
|
||||
rm.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Request not found (might have timed out) - close the message
|
||||
message.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case req.ResponseCh <- message:
|
||||
// Message successfully delivered - the receiver will close it
|
||||
rm.deleteRequest(reqID)
|
||||
case <-req.Context.Done():
|
||||
// Request was cancelled/timed out - close the message
|
||||
message.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// routeLegacyResponse handles responses that don't have request IDs (backwards compatibility)
|
||||
func (rm *RequestManager) routeLegacyResponse(message *gws.Message) {
|
||||
// Snapshot the oldest pending request without holding the lock during send
|
||||
rm.RLock()
|
||||
var oldestReq *PendingRequest
|
||||
for _, req := range rm.pendingReqs {
|
||||
if oldestReq == nil || req.CreatedAt.Before(oldestReq.CreatedAt) {
|
||||
oldestReq = req
|
||||
}
|
||||
}
|
||||
rm.RUnlock()
|
||||
|
||||
if oldestReq != nil {
|
||||
select {
|
||||
case oldestReq.ResponseCh <- message:
|
||||
// Message successfully delivered - the receiver will close it
|
||||
rm.deleteRequest(oldestReq.ID)
|
||||
case <-oldestReq.Context.Done():
|
||||
// Request was cancelled - close the message
|
||||
message.Close()
|
||||
}
|
||||
} else {
|
||||
// No pending requests - close the message
|
||||
message.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRequest handles request timeout and cleanup
|
||||
func (rm *RequestManager) cleanupRequest(req *PendingRequest) {
|
||||
<-req.Context.Done()
|
||||
rm.cancelRequest(req.ID)
|
||||
}
|
||||
|
||||
// cancelRequest removes a request and cancels its context
|
||||
func (rm *RequestManager) cancelRequest(reqID RequestID) {
|
||||
rm.Lock()
|
||||
defer rm.Unlock()
|
||||
|
||||
if req, exists := rm.pendingReqs[reqID]; exists {
|
||||
req.Cancel()
|
||||
delete(rm.pendingReqs, reqID)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteRequest removes a request from the pending map without cancelling its context.
|
||||
func (rm *RequestManager) deleteRequest(reqID RequestID) {
|
||||
rm.Lock()
|
||||
defer rm.Unlock()
|
||||
delete(rm.pendingReqs, reqID)
|
||||
}
|
||||
|
||||
// Close shuts down the request manager
|
||||
func (rm *RequestManager) Close() {
|
||||
rm.Lock()
|
||||
defer rm.Unlock()
|
||||
|
||||
// Cancel all pending requests
|
||||
for _, req := range rm.pendingReqs {
|
||||
req.Cancel()
|
||||
}
|
||||
rm.pendingReqs = make(map[RequestID]*PendingRequest)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
//go:build testing
|
||||
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestRequestManager_BasicFunctionality tests the request manager without mocking gws.Conn
|
||||
func TestRequestManager_BasicFunctionality(t *testing.T) {
|
||||
// We'll test the core logic without mocking the connection
|
||||
// since the gws.Conn interface is complex to mock properly
|
||||
|
||||
t.Run("request ID generation", func(t *testing.T) {
|
||||
// Test that request IDs are generated sequentially and uniquely
|
||||
rm := &RequestManager{}
|
||||
|
||||
// Simulate multiple ID generations
|
||||
id1 := rm.nextID.Add(1)
|
||||
id2 := rm.nextID.Add(1)
|
||||
id3 := rm.nextID.Add(1)
|
||||
|
||||
assert.NotEqual(t, id1, id2)
|
||||
assert.NotEqual(t, id2, id3)
|
||||
assert.Greater(t, id2, id1)
|
||||
assert.Greater(t, id3, id2)
|
||||
})
|
||||
|
||||
t.Run("pending request tracking", func(t *testing.T) {
|
||||
rm := &RequestManager{
|
||||
pendingReqs: make(map[RequestID]*PendingRequest),
|
||||
}
|
||||
|
||||
// Initially no pending requests
|
||||
assert.Equal(t, 0, rm.GetPendingCount())
|
||||
|
||||
// Add some fake pending requests
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req1 := &PendingRequest{
|
||||
ID: RequestID(1),
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
req2 := &PendingRequest{
|
||||
ID: RequestID(2),
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
|
||||
rm.pendingReqs[req1.ID] = req1
|
||||
rm.pendingReqs[req2.ID] = req2
|
||||
|
||||
assert.Equal(t, 2, rm.GetPendingCount())
|
||||
|
||||
// Remove one
|
||||
delete(rm.pendingReqs, req1.ID)
|
||||
assert.Equal(t, 1, rm.GetPendingCount())
|
||||
|
||||
// Remove all
|
||||
delete(rm.pendingReqs, req2.ID)
|
||||
assert.Equal(t, 0, rm.GetPendingCount())
|
||||
})
|
||||
|
||||
t.Run("context cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Wait for context to timeout
|
||||
<-ctx.Done()
|
||||
|
||||
// Verify context was cancelled
|
||||
assert.Equal(t, context.DeadlineExceeded, ctx.Err())
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"weak"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/henrygd/beszel"
|
||||
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/lxzan/gws"
|
||||
)
|
||||
|
||||
const (
|
||||
deadline = 70 * time.Second
|
||||
)
|
||||
|
||||
// Handler implements the WebSocket event handler for agent connections.
|
||||
type Handler struct {
|
||||
gws.BuiltinEventHandler
|
||||
}
|
||||
|
||||
// WsConn represents a WebSocket connection to an agent.
|
||||
type WsConn struct {
|
||||
conn *gws.Conn
|
||||
requestManager *RequestManager
|
||||
DownChan chan struct{}
|
||||
agentVersion semver.Version
|
||||
}
|
||||
|
||||
// FingerprintRecord is fingerprints collection record data in the hub
|
||||
type FingerprintRecord struct {
|
||||
Id string `db:"id"`
|
||||
SystemId string `db:"system"`
|
||||
Fingerprint string `db:"fingerprint"`
|
||||
Token string `db:"token"`
|
||||
}
|
||||
|
||||
var upgrader *gws.Upgrader
|
||||
|
||||
// GetUpgrader returns a singleton WebSocket upgrader instance.
|
||||
func GetUpgrader() *gws.Upgrader {
|
||||
if upgrader != nil {
|
||||
return upgrader
|
||||
}
|
||||
handler := &Handler{}
|
||||
upgrader = gws.NewUpgrader(handler, &gws.ServerOption{})
|
||||
return upgrader
|
||||
}
|
||||
|
||||
// NewWsConnection creates a new WebSocket connection wrapper with agent version.
|
||||
func NewWsConnection(conn *gws.Conn, agentVersion semver.Version) *WsConn {
|
||||
return &WsConn{
|
||||
conn: conn,
|
||||
requestManager: NewRequestManager(conn),
|
||||
DownChan: make(chan struct{}, 1),
|
||||
agentVersion: agentVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// OnOpen sets a deadline for the WebSocket connection and extracts agent version.
|
||||
func (h *Handler) OnOpen(conn *gws.Conn) {
|
||||
conn.SetDeadline(time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
// OnMessage routes incoming WebSocket messages to the request manager.
|
||||
func (h *Handler) OnMessage(conn *gws.Conn, message *gws.Message) {
|
||||
conn.SetDeadline(time.Now().Add(deadline))
|
||||
if message.Opcode != gws.OpcodeBinary || message.Data.Len() == 0 {
|
||||
return
|
||||
}
|
||||
wsConn, ok := conn.Session().Load("wsConn")
|
||||
if !ok {
|
||||
_ = conn.WriteClose(1000, nil)
|
||||
return
|
||||
}
|
||||
wsConn.(*WsConn).requestManager.handleResponse(message)
|
||||
}
|
||||
|
||||
// OnClose handles WebSocket connection closures and triggers system down status after delay.
|
||||
func (h *Handler) OnClose(conn *gws.Conn, err error) {
|
||||
wsConn, ok := conn.Session().Load("wsConn")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
wsConn.(*WsConn).conn = nil
|
||||
// wait 5 seconds to allow reconnection before setting system down
|
||||
// use a weak pointer to avoid keeping references if the system is removed
|
||||
go func(downChan weak.Pointer[chan struct{}]) {
|
||||
time.Sleep(5 * time.Second)
|
||||
downChanValue := downChan.Value()
|
||||
if downChanValue != nil {
|
||||
*downChanValue <- struct{}{}
|
||||
}
|
||||
}(weak.Make(&wsConn.(*WsConn).DownChan))
|
||||
}
|
||||
|
||||
// Close terminates the WebSocket connection gracefully.
|
||||
func (ws *WsConn) Close(msg []byte) {
|
||||
if ws.IsConnected() {
|
||||
ws.conn.WriteClose(1000, msg)
|
||||
}
|
||||
if ws.requestManager != nil {
|
||||
ws.requestManager.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Ping sends a ping frame to keep the connection alive.
|
||||
func (ws *WsConn) Ping() error {
|
||||
if ws.conn == nil {
|
||||
return gws.ErrConnClosed
|
||||
}
|
||||
ws.conn.SetDeadline(time.Now().Add(deadline))
|
||||
return ws.conn.WritePing(nil)
|
||||
}
|
||||
|
||||
// sendMessage encodes data to CBOR and sends it as a binary message to the agent.
|
||||
// This is kept for backwards compatibility but new actions should use RequestManager.
|
||||
func (ws *WsConn) sendMessage(data common.HubRequest[any]) error {
|
||||
if ws.conn == nil {
|
||||
return gws.ErrConnClosed
|
||||
}
|
||||
bytes, err := cbor.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.conn.WriteMessage(gws.OpcodeBinary, bytes)
|
||||
}
|
||||
|
||||
// handleAgentRequest processes a request to the agent, handling both legacy and new formats.
|
||||
func (ws *WsConn) handleAgentRequest(req *PendingRequest, handler ResponseHandler) error {
|
||||
// Wait for response
|
||||
select {
|
||||
case message := <-req.ResponseCh:
|
||||
defer message.Close()
|
||||
// Cancel request context to stop timeout watcher promptly
|
||||
defer req.Cancel()
|
||||
data := message.Data.Bytes()
|
||||
|
||||
// Legacy format - unmarshal directly
|
||||
if ws.agentVersion.LT(beszel.MinVersionAgentResponse) {
|
||||
return handler.HandleLegacy(data)
|
||||
}
|
||||
|
||||
// New format with AgentResponse wrapper
|
||||
var agentResponse common.AgentResponse
|
||||
if err := cbor.Unmarshal(data, &agentResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if agentResponse.Error != "" {
|
||||
return errors.New(agentResponse.Error)
|
||||
}
|
||||
return handler.Handle(agentResponse)
|
||||
|
||||
case <-req.Context.Done():
|
||||
return req.Context.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns true if the WebSocket connection is active.
|
||||
func (ws *WsConn) IsConnected() bool {
|
||||
return ws.conn != nil
|
||||
}
|
||||
|
||||
// AgentVersion returns the connected agent's version (as reported during handshake).
|
||||
func (ws *WsConn) AgentVersion() semver.Version {
|
||||
return ws.agentVersion
|
||||
}
|
||||
|
||||
// SendRequest sends a request to the agent and returns a pending request handle.
|
||||
// This is used by the transport layer to send requests.
|
||||
func (ws *WsConn) SendRequest(ctx context.Context, action common.WebSocketAction, data any) (*PendingRequest, error) {
|
||||
return ws.requestManager.SendRequest(ctx, action, data)
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
//go:build testing
|
||||
|
||||
package ws
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/henrygd/beszel/internal/common"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// TestGetUpgrader tests the singleton upgrader
|
||||
func TestGetUpgrader(t *testing.T) {
|
||||
// Reset the global upgrader to test singleton behavior
|
||||
upgrader = nil
|
||||
|
||||
// First call should create the upgrader
|
||||
upgrader1 := GetUpgrader()
|
||||
assert.NotNil(t, upgrader1, "Upgrader should not be nil")
|
||||
|
||||
// Second call should return the same instance
|
||||
upgrader2 := GetUpgrader()
|
||||
assert.Same(t, upgrader1, upgrader2, "Should return the same upgrader instance")
|
||||
|
||||
// Verify it's properly configured
|
||||
assert.NotNil(t, upgrader1, "Upgrader should be configured")
|
||||
}
|
||||
|
||||
// TestNewWsConnection tests WebSocket connection creation
|
||||
func TestNewWsConnection(t *testing.T) {
|
||||
// We can't easily mock gws.Conn, so we'll pass nil and test the structure
|
||||
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
|
||||
|
||||
assert.NotNil(t, wsConn, "WebSocket connection should not be nil")
|
||||
assert.Nil(t, wsConn.conn, "Connection should be nil as passed")
|
||||
assert.NotNil(t, wsConn.requestManager, "Request manager should be initialized")
|
||||
assert.NotNil(t, wsConn.DownChan, "Down channel should be initialized")
|
||||
assert.Equal(t, 1, cap(wsConn.DownChan), "Down channel should have capacity of 1")
|
||||
}
|
||||
|
||||
// TestWsConn_IsConnected tests the connection status check
|
||||
func TestWsConn_IsConnected(t *testing.T) {
|
||||
// Test with nil connection
|
||||
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
|
||||
assert.False(t, wsConn.IsConnected(), "Should not be connected when conn is nil")
|
||||
}
|
||||
|
||||
// TestWsConn_Close tests the connection closing with nil connection
|
||||
func TestWsConn_Close(t *testing.T) {
|
||||
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
|
||||
|
||||
// Should handle nil connection gracefully
|
||||
assert.NotPanics(t, func() {
|
||||
wsConn.Close([]byte("test message"))
|
||||
}, "Should not panic when closing nil connection")
|
||||
}
|
||||
|
||||
// TestWsConn_SendMessage_CBOR tests CBOR encoding in sendMessage
|
||||
func TestWsConn_SendMessage_CBOR(t *testing.T) {
|
||||
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
|
||||
|
||||
testData := common.HubRequest[any]{
|
||||
Action: common.GetData,
|
||||
Data: "test data",
|
||||
}
|
||||
|
||||
// This will fail because conn is nil, but we can test the CBOR encoding logic
|
||||
// by checking that the function properly encodes to CBOR before failing
|
||||
err := wsConn.sendMessage(testData)
|
||||
assert.Error(t, err, "Should error with nil connection")
|
||||
|
||||
// Test CBOR encoding separately
|
||||
bytes, err := cbor.Marshal(testData)
|
||||
assert.NoError(t, err, "Should encode to CBOR successfully")
|
||||
|
||||
// Verify we can decode it back
|
||||
var decodedData common.HubRequest[any]
|
||||
err = cbor.Unmarshal(bytes, &decodedData)
|
||||
assert.NoError(t, err, "Should decode from CBOR successfully")
|
||||
assert.Equal(t, testData.Action, decodedData.Action, "Action should match")
|
||||
}
|
||||
|
||||
// TestWsConn_GetFingerprint_SignatureGeneration tests signature creation logic
|
||||
func TestWsConn_GetFingerprint_SignatureGeneration(t *testing.T) {
|
||||
// Generate test key pair
|
||||
_, privKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := "test-token"
|
||||
|
||||
// This will timeout since conn is nil, but we can verify the signature logic
|
||||
// We can't test the full flow, but we can test that the signature is created properly
|
||||
challenge := []byte(token)
|
||||
signature, err := signer.Sign(nil, challenge)
|
||||
assert.NoError(t, err, "Should create signature successfully")
|
||||
assert.NotEmpty(t, signature.Blob, "Signature blob should not be empty")
|
||||
assert.Equal(t, signer.PublicKey().Type(), signature.Format, "Signature format should match key type")
|
||||
|
||||
// Test the fingerprint request structure
|
||||
fpRequest := common.FingerprintRequest{
|
||||
Signature: signature.Blob,
|
||||
NeedSysInfo: true,
|
||||
}
|
||||
|
||||
// Test CBOR encoding of fingerprint request
|
||||
fpData, err := cbor.Marshal(fpRequest)
|
||||
assert.NoError(t, err, "Should encode fingerprint request to CBOR")
|
||||
|
||||
var decodedFpRequest common.FingerprintRequest
|
||||
err = cbor.Unmarshal(fpData, &decodedFpRequest)
|
||||
assert.NoError(t, err, "Should decode fingerprint request from CBOR")
|
||||
assert.Equal(t, fpRequest.Signature, decodedFpRequest.Signature, "Signature should match")
|
||||
assert.Equal(t, fpRequest.NeedSysInfo, decodedFpRequest.NeedSysInfo, "NeedSysInfo should match")
|
||||
|
||||
// Test the full hub request structure
|
||||
hubRequest := common.HubRequest[any]{
|
||||
Action: common.CheckFingerprint,
|
||||
Data: fpRequest,
|
||||
}
|
||||
|
||||
hubData, err := cbor.Marshal(hubRequest)
|
||||
assert.NoError(t, err, "Should encode hub request to CBOR")
|
||||
|
||||
var decodedHubRequest common.HubRequest[cbor.RawMessage]
|
||||
err = cbor.Unmarshal(hubData, &decodedHubRequest)
|
||||
assert.NoError(t, err, "Should decode hub request from CBOR")
|
||||
assert.Equal(t, common.CheckFingerprint, decodedHubRequest.Action, "Action should be CheckFingerprint")
|
||||
}
|
||||
|
||||
// TestWsConn_RequestSystemData_RequestFormat tests system data request format
|
||||
func TestWsConn_RequestSystemData_RequestFormat(t *testing.T) {
|
||||
// Test the request format that would be sent
|
||||
request := common.HubRequest[any]{
|
||||
Action: common.GetData,
|
||||
}
|
||||
|
||||
// Test CBOR encoding
|
||||
data, err := cbor.Marshal(request)
|
||||
assert.NoError(t, err, "Should encode request to CBOR")
|
||||
|
||||
// Test decoding
|
||||
var decodedRequest common.HubRequest[any]
|
||||
err = cbor.Unmarshal(data, &decodedRequest)
|
||||
assert.NoError(t, err, "Should decode request from CBOR")
|
||||
assert.Equal(t, common.GetData, decodedRequest.Action, "Should have GetData action")
|
||||
}
|
||||
|
||||
// TestFingerprintRecord tests the FingerprintRecord struct
|
||||
func TestFingerprintRecord(t *testing.T) {
|
||||
record := FingerprintRecord{
|
||||
Id: "test-id",
|
||||
SystemId: "system-123",
|
||||
Fingerprint: "test-fingerprint",
|
||||
Token: "test-token",
|
||||
}
|
||||
|
||||
assert.Equal(t, "test-id", record.Id)
|
||||
assert.Equal(t, "system-123", record.SystemId)
|
||||
assert.Equal(t, "test-fingerprint", record.Fingerprint)
|
||||
assert.Equal(t, "test-token", record.Token)
|
||||
}
|
||||
|
||||
// TestDeadlineConstant tests that the deadline constant is reasonable
|
||||
func TestDeadlineConstant(t *testing.T) {
|
||||
assert.Equal(t, 70*time.Second, deadline, "Deadline should be 70 seconds")
|
||||
}
|
||||
|
||||
// TestCommonActions tests that the common actions are properly defined
|
||||
func TestCommonActions(t *testing.T) {
|
||||
// Test that the actions we use exist and have expected values
|
||||
assert.Equal(t, common.WebSocketAction(0), common.GetData, "GetData should be action 0")
|
||||
assert.Equal(t, common.WebSocketAction(1), common.CheckFingerprint, "CheckFingerprint should be action 1")
|
||||
assert.Equal(t, common.WebSocketAction(2), common.GetContainerLogs, "GetLogs should be action 2")
|
||||
}
|
||||
|
||||
func TestFingerprintHandler(t *testing.T) {
|
||||
var result common.FingerprintResponse
|
||||
h := &fingerprintHandler{result: &result}
|
||||
|
||||
resp := common.AgentResponse{Fingerprint: &common.FingerprintResponse{
|
||||
Fingerprint: "test-fingerprint",
|
||||
Hostname: "test-host",
|
||||
}}
|
||||
err := h.Handle(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-fingerprint", result.Fingerprint)
|
||||
assert.Equal(t, "test-host", result.Hostname)
|
||||
}
|
||||
|
||||
// TestHandler tests that we can create a Handler
|
||||
func TestHandler(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
assert.NotNil(t, handler, "Handler should be created successfully")
|
||||
|
||||
// The Handler embeds gws.BuiltinEventHandler, so it should have the embedded type
|
||||
assert.NotNil(t, handler.BuiltinEventHandler, "Should have embedded BuiltinEventHandler")
|
||||
}
|
||||
|
||||
// TestWsConnChannelBehavior tests channel behavior without WebSocket connections
|
||||
func TestWsConnChannelBehavior(t *testing.T) {
|
||||
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
|
||||
|
||||
// Test that channels are properly initialized and can be used
|
||||
select {
|
||||
case wsConn.DownChan <- struct{}{}:
|
||||
// Should be able to write to channel
|
||||
default:
|
||||
t.Error("Should be able to write to DownChan")
|
||||
}
|
||||
|
||||
// Test reading from DownChan
|
||||
select {
|
||||
case <-wsConn.DownChan:
|
||||
// Should be able to read from channel
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Error("Should be able to read from DownChan")
|
||||
}
|
||||
|
||||
// Request manager should have no pending requests initially
|
||||
assert.Equal(t, 0, wsConn.requestManager.GetPendingCount(), "Should have no pending requests initially")
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
//go:build testing
|
||||
|
||||
package ws
|
||||
|
||||
// GetPendingCount returns the number of pending requests (for monitoring)
|
||||
func (rm *RequestManager) GetPendingCount() int {
|
||||
rm.RLock()
|
||||
defer rm.RUnlock()
|
||||
return len(rm.pendingReqs)
|
||||
}
|
||||
Reference in New Issue
Block a user