mirror of
https://github.com/Dvorinka/Devour.git
synced 2026-06-04 04:23:02 +00:00
first commit
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
// Package ai provides AI integration for embeddings and context injection.
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Config holds AI configuration.
|
||||
type Config struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Model string `yaml:"model"`
|
||||
Dimensions int `yaml:"dimensions"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
BatchSize int `yaml:"batch_size"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Temperature float64 `yaml:"temperature"`
|
||||
}
|
||||
|
||||
// Client provides AI operations.
|
||||
type Client interface {
|
||||
// Embed generates embeddings for texts.
|
||||
Embed(ctx context.Context, texts []string) ([][]float32, error)
|
||||
|
||||
// QueryWithContext generates a response with context injection.
|
||||
QueryWithContext(ctx context.Context, query string, context []string) (string, error)
|
||||
}
|
||||
|
||||
// NewClient creates a new AI client based on provider.
|
||||
func NewClient(config *Config) Client {
|
||||
switch config.Provider {
|
||||
case "openai":
|
||||
return NewOpenAIClient(config)
|
||||
case "mock":
|
||||
return NewMockClient(config.Dimensions)
|
||||
default:
|
||||
return NewMockClient(1536)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
// Package ai provides AI integration for embeddings and context injection.
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAIClient implements AI operations using OpenAI API.
|
||||
type OpenAIClient struct {
|
||||
config *Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOpenAIClient creates a new OpenAI client.
|
||||
func NewOpenAIClient(config *Config) *OpenAIClient {
|
||||
apiKey := config.APIKey
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
baseURL := config.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
return &OpenAIClient{
|
||||
config: &Config{
|
||||
Provider: config.Provider,
|
||||
Model: config.Model,
|
||||
Dimensions: config.Dimensions,
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
BatchSize: config.BatchSize,
|
||||
Temperature: config.Temperature,
|
||||
},
|
||||
httpClient: &http.Client{Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// EmbeddingRequest represents an embedding API request.
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse represents an embedding API response.
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
Error *APIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// APIError represents an API error.
|
||||
type APIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Embed generates embeddings for texts.
|
||||
func (c *OpenAIClient) Embed(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
if c.config.APIKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not configured")
|
||||
}
|
||||
|
||||
model := c.config.Model
|
||||
if model == "" {
|
||||
model = "text-embedding-3-small"
|
||||
}
|
||||
|
||||
batchSize := c.config.BatchSize
|
||||
if batchSize == 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
|
||||
var allEmbeddings [][]float32
|
||||
|
||||
// Process in batches
|
||||
for i := 0; i < len(texts); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
|
||||
batch := texts[i:end]
|
||||
embeddings, err := c.embedBatch(ctx, model, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allEmbeddings = append(allEmbeddings, embeddings...)
|
||||
}
|
||||
|
||||
return allEmbeddings, nil
|
||||
}
|
||||
|
||||
// embedBatch processes a single batch of texts.
|
||||
func (c *OpenAIClient) embedBatch(ctx context.Context, model string, texts []string) ([][]float32, error) {
|
||||
req := EmbeddingRequest{
|
||||
Model: model,
|
||||
Input: texts,
|
||||
}
|
||||
|
||||
// Set dimensions if specified (for text-embedding-3 models)
|
||||
if c.config.Dimensions > 0 && strings.HasPrefix(model, "text-embedding-3") {
|
||||
req.Dimensions = c.config.Dimensions
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.config.BaseURL+"/embeddings", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var embeddingResp EmbeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if embeddingResp.Error != nil {
|
||||
return nil, embeddingResp.Error
|
||||
}
|
||||
|
||||
// Extract embeddings in order
|
||||
embeddings := make([][]float32, len(texts))
|
||||
for _, data := range embeddingResp.Data {
|
||||
embeddings[data.Index] = data.Embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// ChatRequest represents a chat completion request.
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a chat message.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatResponse represents a chat completion response.
|
||||
type ChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
Finish string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
Error *APIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// QueryWithContext generates a response with context injection.
|
||||
func (c *OpenAIClient) QueryWithContext(ctx context.Context, query string, contextDocs []string) (string, error) {
|
||||
if c.config.APIKey == "" {
|
||||
return "", fmt.Errorf("OpenAI API key not configured")
|
||||
}
|
||||
|
||||
model := c.config.Model
|
||||
if model == "" || strings.Contains(model, "embedding") {
|
||||
model = "gpt-4o-mini"
|
||||
}
|
||||
|
||||
// Build context
|
||||
contextText := strings.Join(contextDocs, "\n\n---\n\n")
|
||||
|
||||
// Build messages
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a helpful assistant that answers questions based on the provided context. " +
|
||||
"Use the context to provide accurate and relevant answers. " +
|
||||
"If the context doesn't contain relevant information, say so.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", contextText, query),
|
||||
},
|
||||
}
|
||||
|
||||
req := ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if c.config.Temperature > 0 {
|
||||
req.Temperature = c.config.Temperature
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.config.BaseURL+"/chat/completions", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var chatResp ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if chatResp.Error != nil {
|
||||
return "", chatResp.Error
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response generated")
|
||||
}
|
||||
|
||||
return chatResp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// MockClient implements AI operations without external API calls.
|
||||
type MockClient struct {
|
||||
dimensions int
|
||||
}
|
||||
|
||||
// NewMockClient creates a mock client for testing.
|
||||
func NewMockClient(dimensions int) *MockClient {
|
||||
return &MockClient{dimensions: dimensions}
|
||||
}
|
||||
|
||||
// Embed generates mock embeddings.
|
||||
func (c *MockClient) Embed(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
embeddings := make([][]float32, len(texts))
|
||||
for i := range texts {
|
||||
// Generate deterministic but varied embeddings
|
||||
embedding := make([]float32, c.dimensions)
|
||||
for j := range embedding {
|
||||
embedding[j] = float32(i*100+j) / float32(c.dimensions*100)
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
}
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// QueryWithContext returns a mock response.
|
||||
func (c *MockClient) QueryWithContext(ctx context.Context, query string, context []string) (string, error) {
|
||||
return "This is a mock response.", nil
|
||||
}
|
||||
Reference in New Issue
Block a user