Initial commit: Beszel fork with Domain Locker integration

This commit is contained in:
Tomas Dvorak
2026-04-21 15:39:43 +02:00
commit 363d708e91
440 changed files with 160889 additions and 0 deletions
+72
View File
@@ -0,0 +1,72 @@
package ws
import (
"context"
"errors"
"github.com/fxamacker/cbor/v2"
"github.com/henrygd/beszel/internal/common"
"github.com/lxzan/gws"
"golang.org/x/crypto/ssh"
)
// ResponseHandler defines interface for handling agent responses.
// This is used by handleAgentRequest for legacy response handling.
type ResponseHandler interface {
Handle(agentResponse common.AgentResponse) error
HandleLegacy(rawData []byte) error
}
// BaseHandler provides a default implementation that can be embedded to make HandleLegacy optional
type BaseHandler struct{}
func (h *BaseHandler) HandleLegacy(rawData []byte) error {
return errors.New("legacy format not supported")
}
////////////////////////////////////////////////////////////////////////////
// Fingerprint handling (used for WebSocket authentication)
////////////////////////////////////////////////////////////////////////////
// fingerprintHandler implements ResponseHandler for fingerprint requests
type fingerprintHandler struct {
result *common.FingerprintResponse
}
func (h *fingerprintHandler) HandleLegacy(rawData []byte) error {
return cbor.Unmarshal(rawData, h.result)
}
func (h *fingerprintHandler) Handle(agentResponse common.AgentResponse) error {
if agentResponse.Fingerprint != nil {
*h.result = *agentResponse.Fingerprint
return nil
}
return errors.New("no fingerprint data in response")
}
// GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint.
func (ws *WsConn) GetFingerprint(ctx context.Context, token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) {
if !ws.IsConnected() {
return common.FingerprintResponse{}, gws.ErrConnClosed
}
challenge := []byte(token)
signature, err := signer.Sign(nil, challenge)
if err != nil {
return common.FingerprintResponse{}, err
}
req, err := ws.requestManager.SendRequest(ctx, common.CheckFingerprint, common.FingerprintRequest{
Signature: signature.Blob,
NeedSysInfo: needSysInfo,
})
if err != nil {
return common.FingerprintResponse{}, err
}
var result common.FingerprintResponse
handler := &fingerprintHandler{result: &result}
err = ws.handleAgentRequest(req, handler)
return result, err
}
+199
View File
@@ -0,0 +1,199 @@
package ws
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/henrygd/beszel/internal/common"
"github.com/lxzan/gws"
)
// RequestID uniquely identifies a request
type RequestID uint32
// PendingRequest tracks an in-flight request
type PendingRequest struct {
ID RequestID
ResponseCh chan *gws.Message
Context context.Context
Cancel context.CancelFunc
CreatedAt time.Time
}
// RequestManager handles concurrent requests to an agent
type RequestManager struct {
sync.RWMutex
conn *gws.Conn
pendingReqs map[RequestID]*PendingRequest
nextID atomic.Uint32
}
// NewRequestManager creates a new request manager for a WebSocket connection
func NewRequestManager(conn *gws.Conn) *RequestManager {
rm := &RequestManager{
conn: conn,
pendingReqs: make(map[RequestID]*PendingRequest),
}
return rm
}
// SendRequest sends a request and returns a channel for the response
func (rm *RequestManager) SendRequest(ctx context.Context, action common.WebSocketAction, data any) (*PendingRequest, error) {
reqID := RequestID(rm.nextID.Add(1))
// Respect any caller-provided deadline. If none is set, apply a reasonable default
// so pending requests don't live forever if the agent never responds.
reqCtx := ctx
var cancel context.CancelFunc
if _, hasDeadline := ctx.Deadline(); hasDeadline {
reqCtx, cancel = context.WithCancel(ctx)
} else {
reqCtx, cancel = context.WithTimeout(ctx, 5*time.Second)
}
req := &PendingRequest{
ID: reqID,
ResponseCh: make(chan *gws.Message, 1),
Context: reqCtx,
Cancel: cancel,
CreatedAt: time.Now(),
}
rm.Lock()
rm.pendingReqs[reqID] = req
rm.Unlock()
hubReq := common.HubRequest[any]{
Id: (*uint32)(&reqID),
Action: action,
Data: data,
}
// Send the request
if err := rm.sendMessage(hubReq); err != nil {
rm.cancelRequest(reqID)
return nil, fmt.Errorf("failed to send request: %w", err)
}
// Start cleanup watcher for timeout/cancellation
go rm.cleanupRequest(req)
return req, nil
}
// sendMessage encodes and sends a message over WebSocket
func (rm *RequestManager) sendMessage(data any) error {
if rm.conn == nil {
return gws.ErrConnClosed
}
bytes, err := cbor.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
return rm.conn.WriteMessage(gws.OpcodeBinary, bytes)
}
// handleResponse processes a single response message
func (rm *RequestManager) handleResponse(message *gws.Message) {
var response common.AgentResponse
if err := cbor.Unmarshal(message.Data.Bytes(), &response); err != nil {
// Legacy response without ID - route to first pending request of any type
rm.routeLegacyResponse(message)
return
}
if response.Id == nil {
rm.routeLegacyResponse(message)
return
}
reqID := RequestID(*response.Id)
rm.RLock()
req, exists := rm.pendingReqs[reqID]
rm.RUnlock()
if !exists {
// Request not found (might have timed out) - close the message
message.Close()
return
}
select {
case req.ResponseCh <- message:
// Message successfully delivered - the receiver will close it
rm.deleteRequest(reqID)
case <-req.Context.Done():
// Request was cancelled/timed out - close the message
message.Close()
}
}
// routeLegacyResponse handles responses that don't have request IDs (backwards compatibility)
func (rm *RequestManager) routeLegacyResponse(message *gws.Message) {
// Snapshot the oldest pending request without holding the lock during send
rm.RLock()
var oldestReq *PendingRequest
for _, req := range rm.pendingReqs {
if oldestReq == nil || req.CreatedAt.Before(oldestReq.CreatedAt) {
oldestReq = req
}
}
rm.RUnlock()
if oldestReq != nil {
select {
case oldestReq.ResponseCh <- message:
// Message successfully delivered - the receiver will close it
rm.deleteRequest(oldestReq.ID)
case <-oldestReq.Context.Done():
// Request was cancelled - close the message
message.Close()
}
} else {
// No pending requests - close the message
message.Close()
}
}
// cleanupRequest handles request timeout and cleanup
func (rm *RequestManager) cleanupRequest(req *PendingRequest) {
<-req.Context.Done()
rm.cancelRequest(req.ID)
}
// cancelRequest removes a request and cancels its context
func (rm *RequestManager) cancelRequest(reqID RequestID) {
rm.Lock()
defer rm.Unlock()
if req, exists := rm.pendingReqs[reqID]; exists {
req.Cancel()
delete(rm.pendingReqs, reqID)
}
}
// deleteRequest removes a request from the pending map without cancelling its context.
func (rm *RequestManager) deleteRequest(reqID RequestID) {
rm.Lock()
defer rm.Unlock()
delete(rm.pendingReqs, reqID)
}
// Close shuts down the request manager
func (rm *RequestManager) Close() {
rm.Lock()
defer rm.Unlock()
// Cancel all pending requests
for _, req := range rm.pendingReqs {
req.Cancel()
}
rm.pendingReqs = make(map[RequestID]*PendingRequest)
}
+80
View File
@@ -0,0 +1,80 @@
//go:build testing
package ws
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestRequestManager_BasicFunctionality tests the request manager without mocking gws.Conn
func TestRequestManager_BasicFunctionality(t *testing.T) {
// We'll test the core logic without mocking the connection
// since the gws.Conn interface is complex to mock properly
t.Run("request ID generation", func(t *testing.T) {
// Test that request IDs are generated sequentially and uniquely
rm := &RequestManager{}
// Simulate multiple ID generations
id1 := rm.nextID.Add(1)
id2 := rm.nextID.Add(1)
id3 := rm.nextID.Add(1)
assert.NotEqual(t, id1, id2)
assert.NotEqual(t, id2, id3)
assert.Greater(t, id2, id1)
assert.Greater(t, id3, id2)
})
t.Run("pending request tracking", func(t *testing.T) {
rm := &RequestManager{
pendingReqs: make(map[RequestID]*PendingRequest),
}
// Initially no pending requests
assert.Equal(t, 0, rm.GetPendingCount())
// Add some fake pending requests
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req1 := &PendingRequest{
ID: RequestID(1),
Context: ctx,
Cancel: cancel,
}
req2 := &PendingRequest{
ID: RequestID(2),
Context: ctx,
Cancel: cancel,
}
rm.pendingReqs[req1.ID] = req1
rm.pendingReqs[req2.ID] = req2
assert.Equal(t, 2, rm.GetPendingCount())
// Remove one
delete(rm.pendingReqs, req1.ID)
assert.Equal(t, 1, rm.GetPendingCount())
// Remove all
delete(rm.pendingReqs, req2.ID)
assert.Equal(t, 0, rm.GetPendingCount())
})
t.Run("context cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
// Wait for context to timeout
<-ctx.Done()
// Verify context was cancelled
assert.Equal(t, context.DeadlineExceeded, ctx.Err())
})
}
+178
View File
@@ -0,0 +1,178 @@
package ws
import (
"context"
"errors"
"time"
"weak"
"github.com/blang/semver"
"github.com/henrygd/beszel"
"github.com/henrygd/beszel/internal/common"
"github.com/fxamacker/cbor/v2"
"github.com/lxzan/gws"
)
const (
deadline = 70 * time.Second
)
// Handler implements the WebSocket event handler for agent connections.
type Handler struct {
gws.BuiltinEventHandler
}
// WsConn represents a WebSocket connection to an agent.
type WsConn struct {
conn *gws.Conn
requestManager *RequestManager
DownChan chan struct{}
agentVersion semver.Version
}
// FingerprintRecord is fingerprints collection record data in the hub
type FingerprintRecord struct {
Id string `db:"id"`
SystemId string `db:"system"`
Fingerprint string `db:"fingerprint"`
Token string `db:"token"`
}
var upgrader *gws.Upgrader
// GetUpgrader returns a singleton WebSocket upgrader instance.
func GetUpgrader() *gws.Upgrader {
if upgrader != nil {
return upgrader
}
handler := &Handler{}
upgrader = gws.NewUpgrader(handler, &gws.ServerOption{})
return upgrader
}
// NewWsConnection creates a new WebSocket connection wrapper with agent version.
func NewWsConnection(conn *gws.Conn, agentVersion semver.Version) *WsConn {
return &WsConn{
conn: conn,
requestManager: NewRequestManager(conn),
DownChan: make(chan struct{}, 1),
agentVersion: agentVersion,
}
}
// OnOpen sets a deadline for the WebSocket connection and extracts agent version.
func (h *Handler) OnOpen(conn *gws.Conn) {
conn.SetDeadline(time.Now().Add(deadline))
}
// OnMessage routes incoming WebSocket messages to the request manager.
func (h *Handler) OnMessage(conn *gws.Conn, message *gws.Message) {
conn.SetDeadline(time.Now().Add(deadline))
if message.Opcode != gws.OpcodeBinary || message.Data.Len() == 0 {
return
}
wsConn, ok := conn.Session().Load("wsConn")
if !ok {
_ = conn.WriteClose(1000, nil)
return
}
wsConn.(*WsConn).requestManager.handleResponse(message)
}
// OnClose handles WebSocket connection closures and triggers system down status after delay.
func (h *Handler) OnClose(conn *gws.Conn, err error) {
wsConn, ok := conn.Session().Load("wsConn")
if !ok {
return
}
wsConn.(*WsConn).conn = nil
// wait 5 seconds to allow reconnection before setting system down
// use a weak pointer to avoid keeping references if the system is removed
go func(downChan weak.Pointer[chan struct{}]) {
time.Sleep(5 * time.Second)
downChanValue := downChan.Value()
if downChanValue != nil {
*downChanValue <- struct{}{}
}
}(weak.Make(&wsConn.(*WsConn).DownChan))
}
// Close terminates the WebSocket connection gracefully.
func (ws *WsConn) Close(msg []byte) {
if ws.IsConnected() {
ws.conn.WriteClose(1000, msg)
}
if ws.requestManager != nil {
ws.requestManager.Close()
}
}
// Ping sends a ping frame to keep the connection alive.
func (ws *WsConn) Ping() error {
if ws.conn == nil {
return gws.ErrConnClosed
}
ws.conn.SetDeadline(time.Now().Add(deadline))
return ws.conn.WritePing(nil)
}
// sendMessage encodes data to CBOR and sends it as a binary message to the agent.
// This is kept for backwards compatibility but new actions should use RequestManager.
func (ws *WsConn) sendMessage(data common.HubRequest[any]) error {
if ws.conn == nil {
return gws.ErrConnClosed
}
bytes, err := cbor.Marshal(data)
if err != nil {
return err
}
return ws.conn.WriteMessage(gws.OpcodeBinary, bytes)
}
// handleAgentRequest processes a request to the agent, handling both legacy and new formats.
func (ws *WsConn) handleAgentRequest(req *PendingRequest, handler ResponseHandler) error {
// Wait for response
select {
case message := <-req.ResponseCh:
defer message.Close()
// Cancel request context to stop timeout watcher promptly
defer req.Cancel()
data := message.Data.Bytes()
// Legacy format - unmarshal directly
if ws.agentVersion.LT(beszel.MinVersionAgentResponse) {
return handler.HandleLegacy(data)
}
// New format with AgentResponse wrapper
var agentResponse common.AgentResponse
if err := cbor.Unmarshal(data, &agentResponse); err != nil {
return err
}
if agentResponse.Error != "" {
return errors.New(agentResponse.Error)
}
return handler.Handle(agentResponse)
case <-req.Context.Done():
return req.Context.Err()
}
}
// IsConnected returns true if the WebSocket connection is active.
func (ws *WsConn) IsConnected() bool {
return ws.conn != nil
}
// AgentVersion returns the connected agent's version (as reported during handshake).
func (ws *WsConn) AgentVersion() semver.Version {
return ws.agentVersion
}
// SendRequest sends a request to the agent and returns a pending request handle.
// This is used by the transport layer to send requests.
func (ws *WsConn) SendRequest(ctx context.Context, action common.WebSocketAction, data any) (*PendingRequest, error) {
return ws.requestManager.SendRequest(ctx, action, data)
}
+231
View File
@@ -0,0 +1,231 @@
//go:build testing
package ws
import (
"crypto/ed25519"
"testing"
"time"
"github.com/blang/semver"
"github.com/henrygd/beszel/internal/common"
"github.com/fxamacker/cbor/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
// TestGetUpgrader tests the singleton upgrader
func TestGetUpgrader(t *testing.T) {
// Reset the global upgrader to test singleton behavior
upgrader = nil
// First call should create the upgrader
upgrader1 := GetUpgrader()
assert.NotNil(t, upgrader1, "Upgrader should not be nil")
// Second call should return the same instance
upgrader2 := GetUpgrader()
assert.Same(t, upgrader1, upgrader2, "Should return the same upgrader instance")
// Verify it's properly configured
assert.NotNil(t, upgrader1, "Upgrader should be configured")
}
// TestNewWsConnection tests WebSocket connection creation
func TestNewWsConnection(t *testing.T) {
// We can't easily mock gws.Conn, so we'll pass nil and test the structure
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
assert.NotNil(t, wsConn, "WebSocket connection should not be nil")
assert.Nil(t, wsConn.conn, "Connection should be nil as passed")
assert.NotNil(t, wsConn.requestManager, "Request manager should be initialized")
assert.NotNil(t, wsConn.DownChan, "Down channel should be initialized")
assert.Equal(t, 1, cap(wsConn.DownChan), "Down channel should have capacity of 1")
}
// TestWsConn_IsConnected tests the connection status check
func TestWsConn_IsConnected(t *testing.T) {
// Test with nil connection
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
assert.False(t, wsConn.IsConnected(), "Should not be connected when conn is nil")
}
// TestWsConn_Close tests the connection closing with nil connection
func TestWsConn_Close(t *testing.T) {
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
// Should handle nil connection gracefully
assert.NotPanics(t, func() {
wsConn.Close([]byte("test message"))
}, "Should not panic when closing nil connection")
}
// TestWsConn_SendMessage_CBOR tests CBOR encoding in sendMessage
func TestWsConn_SendMessage_CBOR(t *testing.T) {
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
testData := common.HubRequest[any]{
Action: common.GetData,
Data: "test data",
}
// This will fail because conn is nil, but we can test the CBOR encoding logic
// by checking that the function properly encodes to CBOR before failing
err := wsConn.sendMessage(testData)
assert.Error(t, err, "Should error with nil connection")
// Test CBOR encoding separately
bytes, err := cbor.Marshal(testData)
assert.NoError(t, err, "Should encode to CBOR successfully")
// Verify we can decode it back
var decodedData common.HubRequest[any]
err = cbor.Unmarshal(bytes, &decodedData)
assert.NoError(t, err, "Should decode from CBOR successfully")
assert.Equal(t, testData.Action, decodedData.Action, "Action should match")
}
// TestWsConn_GetFingerprint_SignatureGeneration tests signature creation logic
func TestWsConn_GetFingerprint_SignatureGeneration(t *testing.T) {
// Generate test key pair
_, privKey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(privKey)
require.NoError(t, err)
token := "test-token"
// This will timeout since conn is nil, but we can verify the signature logic
// We can't test the full flow, but we can test that the signature is created properly
challenge := []byte(token)
signature, err := signer.Sign(nil, challenge)
assert.NoError(t, err, "Should create signature successfully")
assert.NotEmpty(t, signature.Blob, "Signature blob should not be empty")
assert.Equal(t, signer.PublicKey().Type(), signature.Format, "Signature format should match key type")
// Test the fingerprint request structure
fpRequest := common.FingerprintRequest{
Signature: signature.Blob,
NeedSysInfo: true,
}
// Test CBOR encoding of fingerprint request
fpData, err := cbor.Marshal(fpRequest)
assert.NoError(t, err, "Should encode fingerprint request to CBOR")
var decodedFpRequest common.FingerprintRequest
err = cbor.Unmarshal(fpData, &decodedFpRequest)
assert.NoError(t, err, "Should decode fingerprint request from CBOR")
assert.Equal(t, fpRequest.Signature, decodedFpRequest.Signature, "Signature should match")
assert.Equal(t, fpRequest.NeedSysInfo, decodedFpRequest.NeedSysInfo, "NeedSysInfo should match")
// Test the full hub request structure
hubRequest := common.HubRequest[any]{
Action: common.CheckFingerprint,
Data: fpRequest,
}
hubData, err := cbor.Marshal(hubRequest)
assert.NoError(t, err, "Should encode hub request to CBOR")
var decodedHubRequest common.HubRequest[cbor.RawMessage]
err = cbor.Unmarshal(hubData, &decodedHubRequest)
assert.NoError(t, err, "Should decode hub request from CBOR")
assert.Equal(t, common.CheckFingerprint, decodedHubRequest.Action, "Action should be CheckFingerprint")
}
// TestWsConn_RequestSystemData_RequestFormat tests system data request format
func TestWsConn_RequestSystemData_RequestFormat(t *testing.T) {
// Test the request format that would be sent
request := common.HubRequest[any]{
Action: common.GetData,
}
// Test CBOR encoding
data, err := cbor.Marshal(request)
assert.NoError(t, err, "Should encode request to CBOR")
// Test decoding
var decodedRequest common.HubRequest[any]
err = cbor.Unmarshal(data, &decodedRequest)
assert.NoError(t, err, "Should decode request from CBOR")
assert.Equal(t, common.GetData, decodedRequest.Action, "Should have GetData action")
}
// TestFingerprintRecord tests the FingerprintRecord struct
func TestFingerprintRecord(t *testing.T) {
record := FingerprintRecord{
Id: "test-id",
SystemId: "system-123",
Fingerprint: "test-fingerprint",
Token: "test-token",
}
assert.Equal(t, "test-id", record.Id)
assert.Equal(t, "system-123", record.SystemId)
assert.Equal(t, "test-fingerprint", record.Fingerprint)
assert.Equal(t, "test-token", record.Token)
}
// TestDeadlineConstant tests that the deadline constant is reasonable
func TestDeadlineConstant(t *testing.T) {
assert.Equal(t, 70*time.Second, deadline, "Deadline should be 70 seconds")
}
// TestCommonActions tests that the common actions are properly defined
func TestCommonActions(t *testing.T) {
// Test that the actions we use exist and have expected values
assert.Equal(t, common.WebSocketAction(0), common.GetData, "GetData should be action 0")
assert.Equal(t, common.WebSocketAction(1), common.CheckFingerprint, "CheckFingerprint should be action 1")
assert.Equal(t, common.WebSocketAction(2), common.GetContainerLogs, "GetLogs should be action 2")
}
func TestFingerprintHandler(t *testing.T) {
var result common.FingerprintResponse
h := &fingerprintHandler{result: &result}
resp := common.AgentResponse{Fingerprint: &common.FingerprintResponse{
Fingerprint: "test-fingerprint",
Hostname: "test-host",
}}
err := h.Handle(resp)
assert.NoError(t, err)
assert.Equal(t, "test-fingerprint", result.Fingerprint)
assert.Equal(t, "test-host", result.Hostname)
}
// TestHandler tests that we can create a Handler
func TestHandler(t *testing.T) {
handler := &Handler{}
assert.NotNil(t, handler, "Handler should be created successfully")
// The Handler embeds gws.BuiltinEventHandler, so it should have the embedded type
assert.NotNil(t, handler.BuiltinEventHandler, "Should have embedded BuiltinEventHandler")
}
// TestWsConnChannelBehavior tests channel behavior without WebSocket connections
func TestWsConnChannelBehavior(t *testing.T) {
wsConn := NewWsConnection(nil, semver.MustParse("0.12.10"))
// Test that channels are properly initialized and can be used
select {
case wsConn.DownChan <- struct{}{}:
// Should be able to write to channel
default:
t.Error("Should be able to write to DownChan")
}
// Test reading from DownChan
select {
case <-wsConn.DownChan:
// Should be able to read from channel
case <-time.After(10 * time.Millisecond):
t.Error("Should be able to read from DownChan")
}
// Request manager should have no pending requests initially
assert.Equal(t, 0, wsConn.requestManager.GetPendingCount(), "Should have no pending requests initially")
}
+10
View File
@@ -0,0 +1,10 @@
//go:build testing
package ws
// GetPendingCount returns the number of pending requests (for monitoring)
func (rm *RequestManager) GetPendingCount() int {
rm.RLock()
defer rm.RUnlock()
return len(rm.pendingReqs)
}