Files
2026-02-24 12:10:13 +01:00

275 lines
7.1 KiB
Go

// Package server provides MCP server implementation.
package server
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
)
// Config holds server configuration.
type Config struct {
Mode string `yaml:"mode"`
Transport string `yaml:"transport"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Handler MethodHandler `yaml:"-"`
}
// MethodHandler executes a server method with raw params and returns result payload.
type MethodHandler func(ctx context.Context, method string, params json.RawMessage) (any, error)
// Server defines the MCP server interface.
type Server interface {
// Start begins listening for connections.
Start(ctx context.Context) error
// Stop gracefully shuts down the server.
Stop(ctx context.Context) error
}
// QueryRequest represents a search query.
type QueryRequest struct {
Query string `json:"query"`
Limit int `json:"limit,omitempty"`
Threshold float64 `json:"threshold,omitempty"`
}
// QueryResponse represents search results.
type QueryResponse struct {
Query string `json:"query"`
Results []Result `json:"results"`
Total int `json:"total"`
TookMs int64 `json:"took_ms"`
}
// Result represents a single search result.
type Result struct {
ID string `json:"id"`
DocumentID string `json:"document_id"`
Content string `json:"content"`
Score float64 `json:"score"`
Source string `json:"source"`
Metadata map[string]any `json:"metadata,omitempty"`
}
type rpcRequest struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type rpcResponse struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error *rpcError `json:"error,omitempty"`
}
type rpcError struct {
Code int `json:"code"`
Message string `json:"message"`
}
// NewServer creates a new MCP server.
func NewServer(config *Config) Server {
if strings.EqualFold(config.Mode, "remote") {
return NewHTTPServer(config)
}
return NewStdioServer(config)
}
// NewHTTPServer creates an HTTP-based MCP server.
func NewHTTPServer(config *Config) *HTTPServer {
return &HTTPServer{config: config}
}
// NewStdioServer creates a stdio-based MCP server.
func NewStdioServer(config *Config) *StdioServer {
return &StdioServer{config: config}
}
// HTTPServer implements Server for HTTP transport.
type HTTPServer struct {
config *Config
http *http.Server
mu sync.Mutex
}
func (s *HTTPServer) Start(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.config == nil {
return fmt.Errorf("server config is required")
}
if s.config.Handler == nil {
return fmt.Errorf("server handler is required")
}
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ok":true}`)
})
mux.HandleFunc("/rpc", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
defer r.Body.Close()
var req rpcRequest
if err := json.NewDecoder(io.LimitReader(r.Body, 2<<20)).Decode(&req); err != nil {
if writeErr := writeRPC(w, rpcResponse{JSONRPC: "2.0", Error: &rpcError{Code: -32700, Message: "parse error"}}); writeErr != nil {
wrapped := wrapTransportError("http", "encode parse-error response", writeErr)
log.Printf("%v", wrapped)
http.Error(w, wrapped.Error(), http.StatusInternalServerError)
}
return
}
resp := s.handleRPC(r.Context(), req)
if err := writeRPC(w, resp); err != nil {
wrapped := wrapTransportError("http", "encode rpc response", err)
log.Printf("%v", wrapped)
http.Error(w, wrapped.Error(), http.StatusInternalServerError)
}
})
host := s.config.Host
if host == "" {
host = "localhost"
}
port := s.config.Port
if port == 0 {
port = 8080
}
s.http = &http.Server{Addr: fmt.Sprintf("%s:%d", host, port), Handler: mux}
errCh := make(chan error, 1)
go func() {
errCh <- s.http.ListenAndServe()
}()
select {
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.http.Shutdown(shutdownCtx)
return wrapTransportError("http", "server context canceled", ctx.Err())
case err := <-errCh:
if err != nil && err != http.ErrServerClosed {
return wrapTransportError("http", "listen and serve", err)
}
return nil
}
}
func (s *HTTPServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.http == nil {
return nil
}
return s.http.Shutdown(ctx)
}
func (s *HTTPServer) handleRPC(ctx context.Context, req rpcRequest) rpcResponse {
return handleRPC(ctx, s.config.Handler, req)
}
// StdioServer implements Server for stdio transport.
type StdioServer struct {
config *Config
mu sync.Mutex
stop bool
}
func (s *StdioServer) Start(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.config == nil {
return fmt.Errorf("server config is required")
}
if s.config.Handler == nil {
return fmt.Errorf("server handler is required")
}
scanner := bufio.NewScanner(os.Stdin)
out := json.NewEncoder(os.Stdout)
for scanner.Scan() {
if ctx.Err() != nil || s.stop {
break
}
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var req rpcRequest
if err := json.Unmarshal([]byte(line), &req); err != nil {
if encodeErr := out.Encode(rpcResponse{JSONRPC: "2.0", Error: &rpcError{Code: -32700, Message: "parse error"}}); encodeErr != nil {
return wrapTransportError("stdio", "encode parse-error response", encodeErr)
}
continue
}
resp := handleRPC(ctx, s.config.Handler, req)
if err := out.Encode(resp); err != nil {
return wrapTransportError("stdio", "encode rpc response", err)
}
}
if err := scanner.Err(); err != nil {
return wrapTransportError("stdio", "scan stdin", err)
}
return nil
}
func (s *StdioServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
s.stop = true
return nil
}
func handleRPC(ctx context.Context, handler MethodHandler, req rpcRequest) rpcResponse {
if req.JSONRPC == "" {
req.JSONRPC = "2.0"
}
if req.Method == "" {
return rpcResponse{JSONRPC: "2.0", ID: req.ID, Error: &rpcError{Code: -32600, Message: "invalid request"}}
}
result, err := handler(ctx, req.Method, req.Params)
if err != nil {
return rpcResponse{JSONRPC: "2.0", ID: req.ID, Error: &rpcError{Code: -32000, Message: err.Error()}}
}
return rpcResponse{JSONRPC: "2.0", ID: req.ID, Result: result}
}
func writeRPC(w http.ResponseWriter, payload rpcResponse) error {
w.Header().Set("Content-Type", "application/json")
if payload.Error != nil {
w.WriteHeader(http.StatusBadRequest)
}
if err := json.NewEncoder(w).Encode(payload); err != nil {
return fmt.Errorf("encode rpc response: %w", err)
}
return nil
}
func wrapTransportError(transport, operation string, err error) error {
return fmt.Errorf("%s rpc %s failed: %w", transport, operation, err)
}