// Package ai provides AI integration for embeddings and context injection. package ai import ( "context" "encoding/json" "fmt" "net/http" "os" "strings" "time" ) // OpenAIClient implements AI operations using OpenAI API. type OpenAIClient struct { config *Config httpClient *http.Client } // NewOpenAIClient creates a new OpenAI client. func NewOpenAIClient(config *Config) *OpenAIClient { apiKey := config.APIKey if apiKey == "" { apiKey = os.Getenv("OPENAI_API_KEY") } baseURL := config.BaseURL if baseURL == "" { baseURL = "https://api.openai.com/v1" } return &OpenAIClient{ config: &Config{ Provider: config.Provider, Model: config.Model, Dimensions: config.Dimensions, APIKey: apiKey, BaseURL: baseURL, BatchSize: config.BatchSize, Temperature: config.Temperature, }, httpClient: &http.Client{Timeout: 60 * time.Second}, } } // EmbeddingRequest represents an embedding API request. type EmbeddingRequest struct { Model string `json:"model"` Input []string `json:"input"` Dimensions int `json:"dimensions,omitempty"` } // EmbeddingResponse represents an embedding API response. type EmbeddingResponse struct { Object string `json:"object"` Data []struct { Object string `json:"object"` Index int `json:"index"` Embedding []float32 `json:"embedding"` } `json:"data"` Model string `json:"model"` Usage struct { PromptTokens int `json:"prompt_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` Error *APIError `json:"error,omitempty"` } // APIError represents an API error. type APIError struct { Message string `json:"message"` Type string `json:"type"` Code string `json:"code"` } func (e *APIError) Error() string { return e.Message } // Embed generates embeddings for texts. func (c *OpenAIClient) Embed(ctx context.Context, texts []string) ([][]float32, error) { if c.config.APIKey == "" { return nil, fmt.Errorf("OpenAI API key not configured") } model := c.config.Model if model == "" { model = "text-embedding-3-small" } batchSize := c.config.BatchSize if batchSize == 0 { batchSize = 100 } var allEmbeddings [][]float32 // Process in batches for i := 0; i < len(texts); i += batchSize { end := i + batchSize if end > len(texts) { end = len(texts) } batch := texts[i:end] embeddings, err := c.embedBatch(ctx, model, batch) if err != nil { return nil, err } allEmbeddings = append(allEmbeddings, embeddings...) } return allEmbeddings, nil } // embedBatch processes a single batch of texts. func (c *OpenAIClient) embedBatch(ctx context.Context, model string, texts []string) ([][]float32, error) { req := EmbeddingRequest{ Model: model, Input: texts, } // Set dimensions if specified (for text-embedding-3 models) if c.config.Dimensions > 0 && strings.HasPrefix(model, "text-embedding-3") { req.Dimensions = c.config.Dimensions } body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } httpReq, err := http.NewRequestWithContext(ctx, "POST", c.config.BaseURL+"/embeddings", strings.NewReader(string(body))) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey) resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() var embeddingResp EmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } if embeddingResp.Error != nil { return nil, embeddingResp.Error } // Extract embeddings in order embeddings := make([][]float32, len(texts)) for _, data := range embeddingResp.Data { embeddings[data.Index] = data.Embedding } return embeddings, nil } // ChatRequest represents a chat completion request. type ChatRequest struct { Model string `json:"model"` Messages []ChatMessage `json:"messages"` Temperature float64 `json:"temperature,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` } // ChatMessage represents a chat message. type ChatMessage struct { Role string `json:"role"` Content string `json:"content"` } // ChatResponse represents a chat completion response. type ChatResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []struct { Index int `json:"index"` Message ChatMessage `json:"message"` Finish string `json:"finish_reason"` } `json:"choices"` Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` Error *APIError `json:"error,omitempty"` } // QueryWithContext generates a response with context injection. func (c *OpenAIClient) QueryWithContext(ctx context.Context, query string, contextDocs []string) (string, error) { if c.config.APIKey == "" { return "", fmt.Errorf("OpenAI API key not configured") } model := c.config.Model if model == "" || strings.Contains(model, "embedding") { model = "gpt-4o-mini" } // Build context contextText := strings.Join(contextDocs, "\n\n---\n\n") // Build messages messages := []ChatMessage{ { Role: "system", Content: "You are a helpful assistant that answers questions based on the provided context. " + "Use the context to provide accurate and relevant answers. " + "If the context doesn't contain relevant information, say so.", }, { Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", contextText, query), }, } req := ChatRequest{ Model: model, Messages: messages, } if c.config.Temperature > 0 { req.Temperature = c.config.Temperature } body, err := json.Marshal(req) if err != nil { return "", fmt.Errorf("failed to marshal request: %w", err) } httpReq, err := http.NewRequestWithContext(ctx, "POST", c.config.BaseURL+"/chat/completions", strings.NewReader(string(body))) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey) resp, err := c.httpClient.Do(httpReq) if err != nil { return "", fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() var chatResp ChatResponse if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { return "", fmt.Errorf("failed to decode response: %w", err) } if chatResp.Error != nil { return "", chatResp.Error } if len(chatResp.Choices) == 0 { return "", fmt.Errorf("no response generated") } return chatResp.Choices[0].Message.Content, nil } // MockClient implements AI operations without external API calls. type MockClient struct { dimensions int } // NewMockClient creates a mock client for testing. func NewMockClient(dimensions int) *MockClient { return &MockClient{dimensions: dimensions} } // Embed generates mock embeddings. func (c *MockClient) Embed(ctx context.Context, texts []string) ([][]float32, error) { embeddings := make([][]float32, len(texts)) for i := range texts { // Generate deterministic but varied embeddings embedding := make([]float32, c.dimensions) for j := range embedding { embedding[j] = float32(i*100+j) / float32(c.dimensions*100) } embeddings[i] = embedding } return embeddings, nil } // QueryWithContext returns a mock response. func (c *MockClient) QueryWithContext(ctx context.Context, query string, context []string) (string, error) { return "This is a mock response.", nil }