mirror of
https://github.com/Dvorinka/Devour.git
synced 2026-06-04 12:33:04 +00:00
342 lines
8.9 KiB
Go
342 lines
8.9 KiB
Go
// Package ai provides AI integration for embeddings and context injection.
|
|
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// OpenAIClient implements AI operations using OpenAI API.
|
|
type OpenAIClient struct {
|
|
config *Config
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewOpenAIClient creates a new OpenAI client.
|
|
func NewOpenAIClient(config *Config) *OpenAIClient {
|
|
apiKey := config.APIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("OPENAI_API_KEY")
|
|
}
|
|
|
|
baseURL := config.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = "https://api.openai.com/v1"
|
|
}
|
|
|
|
return &OpenAIClient{
|
|
config: &Config{
|
|
Provider: config.Provider,
|
|
Model: config.Model,
|
|
Dimensions: config.Dimensions,
|
|
APIKey: apiKey,
|
|
BaseURL: baseURL,
|
|
BatchSize: config.BatchSize,
|
|
Temperature: config.Temperature,
|
|
},
|
|
httpClient: &http.Client{Timeout: 60 * time.Second},
|
|
}
|
|
}
|
|
|
|
// EmbeddingRequest represents an embedding API request.
|
|
type EmbeddingRequest struct {
|
|
Model string `json:"model"`
|
|
Input []string `json:"input"`
|
|
Dimensions int `json:"dimensions,omitempty"`
|
|
}
|
|
|
|
// EmbeddingResponse represents an embedding API response.
|
|
type EmbeddingResponse struct {
|
|
Object string `json:"object"`
|
|
Data []struct {
|
|
Object string `json:"object"`
|
|
Index int `json:"index"`
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
Error *APIError `json:"error,omitempty"`
|
|
}
|
|
|
|
// APIError represents an API error.
|
|
type APIError struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Code string `json:"code"`
|
|
}
|
|
|
|
func (e *APIError) Error() string {
|
|
return e.Message
|
|
}
|
|
|
|
const maxHTTPErrorBodyBytes = 2048
|
|
|
|
// Embed generates embeddings for texts.
|
|
func (c *OpenAIClient) Embed(ctx context.Context, texts []string) ([][]float32, error) {
|
|
if c.config.APIKey == "" {
|
|
return nil, fmt.Errorf("OpenAI API key not configured")
|
|
}
|
|
|
|
model := c.config.Model
|
|
if model == "" {
|
|
model = "text-embedding-3-small"
|
|
}
|
|
|
|
batchSize := c.config.BatchSize
|
|
if batchSize == 0 {
|
|
batchSize = 100
|
|
}
|
|
|
|
var allEmbeddings [][]float32
|
|
|
|
// Process in batches
|
|
for i := 0; i < len(texts); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(texts) {
|
|
end = len(texts)
|
|
}
|
|
|
|
batch := texts[i:end]
|
|
embeddings, err := c.embedBatch(ctx, model, batch)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
allEmbeddings = append(allEmbeddings, embeddings...)
|
|
}
|
|
|
|
return allEmbeddings, nil
|
|
}
|
|
|
|
// embedBatch processes a single batch of texts.
|
|
func (c *OpenAIClient) embedBatch(ctx context.Context, model string, texts []string) ([][]float32, error) {
|
|
req := EmbeddingRequest{
|
|
Model: model,
|
|
Input: texts,
|
|
}
|
|
|
|
// Set dimensions if specified (for text-embedding-3 models)
|
|
if c.config.Dimensions > 0 && strings.HasPrefix(model, "text-embedding-3") {
|
|
req.Dimensions = c.config.Dimensions
|
|
}
|
|
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.config.BaseURL+"/embeddings", strings.NewReader(string(body)))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
|
|
resp, err := c.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
return nil, formatHTTPStatusError("embeddings", resp)
|
|
}
|
|
|
|
var embeddingResp EmbeddingResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
if embeddingResp.Error != nil {
|
|
return nil, embeddingResp.Error
|
|
}
|
|
|
|
// Extract embeddings in order
|
|
embeddings := make([][]float32, len(texts))
|
|
for _, data := range embeddingResp.Data {
|
|
embeddings[data.Index] = data.Embedding
|
|
}
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
// ChatRequest represents a chat completion request.
|
|
type ChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
}
|
|
|
|
// ChatMessage represents a chat message.
|
|
type ChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// ChatResponse represents a chat completion response.
|
|
type ChatResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Message ChatMessage `json:"message"`
|
|
Finish string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
Error *APIError `json:"error,omitempty"`
|
|
}
|
|
|
|
// QueryWithContext generates a response with context injection.
|
|
func (c *OpenAIClient) QueryWithContext(ctx context.Context, query string, contextDocs []string) (string, error) {
|
|
if c.config.APIKey == "" {
|
|
return "", fmt.Errorf("OpenAI API key not configured")
|
|
}
|
|
|
|
model := c.config.Model
|
|
if model == "" || strings.Contains(model, "embedding") {
|
|
model = "gpt-4o-mini"
|
|
}
|
|
|
|
// Build context
|
|
contextText := strings.Join(contextDocs, "\n\n---\n\n")
|
|
|
|
// Build messages
|
|
messages := []ChatMessage{
|
|
{
|
|
Role: "system",
|
|
Content: "You are a helpful assistant that answers questions based on the provided context. " +
|
|
"Use the context to provide accurate and relevant answers. " +
|
|
"If the context doesn't contain relevant information, say so.",
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", contextText, query),
|
|
},
|
|
}
|
|
|
|
req := ChatRequest{
|
|
Model: model,
|
|
Messages: messages,
|
|
}
|
|
|
|
if c.config.Temperature > 0 {
|
|
req.Temperature = c.config.Temperature
|
|
}
|
|
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.config.BaseURL+"/chat/completions", strings.NewReader(string(body)))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
|
|
resp, err := c.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return "", fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
return "", formatHTTPStatusError("chat/completions", resp)
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
return "", fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
if chatResp.Error != nil {
|
|
return "", chatResp.Error
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
return "", fmt.Errorf("no response generated")
|
|
}
|
|
|
|
return chatResp.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
// MockClient implements AI operations without external API calls.
|
|
type MockClient struct {
|
|
dimensions int
|
|
}
|
|
|
|
// NewMockClient creates a mock client for testing.
|
|
func NewMockClient(dimensions int) *MockClient {
|
|
return &MockClient{dimensions: dimensions}
|
|
}
|
|
|
|
// Embed generates mock embeddings.
|
|
func (c *MockClient) Embed(ctx context.Context, texts []string) ([][]float32, error) {
|
|
embeddings := make([][]float32, len(texts))
|
|
for i := range texts {
|
|
// Generate deterministic but varied embeddings
|
|
embedding := make([]float32, c.dimensions)
|
|
for j := range embedding {
|
|
embedding[j] = float32(i*100+j) / float32(c.dimensions*100)
|
|
}
|
|
embeddings[i] = embedding
|
|
}
|
|
return embeddings, nil
|
|
}
|
|
|
|
// QueryWithContext returns a mock response.
|
|
func (c *MockClient) QueryWithContext(ctx context.Context, query string, context []string) (string, error) {
|
|
return "This is a mock response.", nil
|
|
}
|
|
|
|
func formatHTTPStatusError(endpoint string, resp *http.Response) error {
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPErrorBodyBytes))
|
|
if err != nil {
|
|
return fmt.Errorf("openai %s returned status %d (%s) and body read failed: %w", endpoint, resp.StatusCode, http.StatusText(resp.StatusCode), err)
|
|
}
|
|
|
|
return fmt.Errorf(
|
|
"openai %s returned status %d (%s): %s",
|
|
endpoint,
|
|
resp.StatusCode,
|
|
http.StatusText(resp.StatusCode),
|
|
extractHTTPErrorMessage(body),
|
|
)
|
|
}
|
|
|
|
func extractHTTPErrorMessage(body []byte) string {
|
|
trimmed := bytes.TrimSpace(body)
|
|
if len(trimmed) == 0 {
|
|
return "<empty body>"
|
|
}
|
|
|
|
var payload struct {
|
|
Error *APIError `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(trimmed, &payload); err == nil && payload.Error != nil && strings.TrimSpace(payload.Error.Message) != "" {
|
|
return strings.TrimSpace(payload.Error.Message)
|
|
}
|
|
|
|
return string(trimmed)
|
|
}
|