From 00589e841b82125a51c3b094c21b337352555e70 Mon Sep 17 00:00:00 2001 From: ER Hapal Date: Tue, 21 Oct 2025 12:51:07 +1300 Subject: [PATCH] feat: add support for custom embedding providers Add flexible embedding provider configuration to support both OpenAI and custom embedding endpoints (e.g., LiteLLM, local models). This enables users to use alternative embedding services while maintaining backward compatibility with OpenAI. Signed-off-by: ER Hapal --- README.md | 45 ++++++-- database.ts | 126 +++++++++++---------- doc2vec.ts | 253 ++++++++++++++++++++---------------------- embedding-factory.ts | 49 ++++++++ embedding-provider.ts | 174 +++++++++++++++++++++++++++++ mcp/package-lock.json | 48 +++++++- mcp/package.json | 3 +- mcp/src/index.ts | 128 ++++++++++++--------- package-lock.json | 6 +- package.json | 2 +- 10 files changed, 580 insertions(+), 254 deletions(-) create mode 100644 embedding-factory.ts create mode 100644 embedding-provider.ts diff --git a/README.md b/README.md index da49910..f5d53c9 100644 --- a/README.md +++ b/README.md @@ -68,9 +68,21 @@ Configuration is managed through two files: ```dotenv # .env - # Required: Your OpenAI API Key + # Required: Your OpenAI API Key (used for both OpenAI and custom providers) OPENAI_API_KEY="sk-..." + # Optional: Embedding provider (defaults to "openai") + PROVIDER="openai" # or "custom" + + # Optional: Custom embedding model (defaults based on provider) + EMBEDDING_MODEL="text-embedding-3-large" # or your preferred model + + # Required if using custom provider: Custom endpoint URL + CUSTOM_ENDPOINT="http://localhost:8000/v1/embeddings" + + # Optional: Vector size of custom embedding model + EMBEDDING_VECTOR_SIZE=1024 + # Required for GitHub sources GITHUB_PERSONAL_ACCESS_TOKEN="ghp_..." @@ -84,19 +96,25 @@ Configuration is managed through two files: 2. **`config.yaml` file:** This file defines the sources to process and how to handle them. Create a `config.yaml` file (or use a different name and pass it as an argument). + **Embedding Provider Configuration:** + + Embedding providers are now configured via environment variables: + - `OPENAI_API_KEY`: API key used for both providers + - `PROVIDER`: Set to "openai" (default) or "custom" + - `EMBEDDING_MODEL`: Model to use (default: "text-embedding-3-large") + - `EMBEDDING_VECTOR_SIZE`: Vector size of the custom embedding model (default: 3072) + - `CUSTOM_ENDPOINT`: Required when using custom provider (e.g., "http://localhost:8000/v1/embeddings") + **Structure:** * `sources`: An array of source configurations. * `type`: Either `'website'`, `'github'`, `'local_directory'`, or `'zendesk'` - For websites (`type: 'website'`): * `url`: The starting URL for crawling the documentation site. * `sitemap_url`: (Optional) URL to the site's XML sitemap for discovering additional pages not linked in navigation. - For GitHub repositories (`type: 'github'`): * `repo`: Repository name in the format `'owner/repo'` (e.g., `'istio/istio'`). * `start_date`: (Optional) Starting date to fetch issues from (e.g., `'2025-01-01'`). - For local directories (`type: 'local_directory'`): * `path`: Path to the local directory to process. * `include_extensions`: (Optional) Array of file extensions to include (e.g., `['.md', '.txt', '.pdf']`). Defaults to `['.md', '.txt', '.html', '.htm', '.pdf']`. @@ -104,7 +122,6 @@ Configuration is managed through two files: * `recursive`: (Optional) Whether to traverse subdirectories (defaults to `true`). * `url_rewrite_prefix` (Optional) URL prefix to rewrite `file://` URLs (e.g., `https://mydomain.com`) * `encoding`: (Optional) File encoding to use (defaults to `'utf8'`). Note: PDF files are processed as binary and this setting doesn't apply to them. - For Zendesk (`type: 'zendesk'`): * `zendesk_subdomain`: Your Zendesk subdomain (e.g., `'mycompany'` for mycompany.zendesk.com). * `email`: Your Zendesk admin email address. @@ -131,6 +148,21 @@ Configuration is managed through two files: **Example (`config.yaml`):** ```yaml + # Example with OpenAI embedding provider (default) + embedding_config: + provider: "openai" + openai: + api_key_env: "OPENAI_API_KEY" + + # Example with custom embedding provider (LiteLLM) + # embedding_config: + # provider: "custom" + # custom: + # endpoint: "http://localhost:8000/v1/embeddings" + # model: "text-embedding-ada-002" + # api_key_env: "LITELLM_API_KEY" + # timeout: 30000 + sources: # Website source example - type: 'website' @@ -155,7 +187,6 @@ Configuration is managed through two files: type: 'sqlite' params: db_path: './istio-issues.db' - # Local directory source example - type: 'local_directory' product_name: 'project-docs' @@ -168,7 +199,6 @@ Configuration is managed through two files: type: 'sqlite' params: db_path: './project-docs.db' - # Zendesk example - type: 'zendesk' product_name: 'MyCompany' @@ -186,7 +216,6 @@ Configuration is managed through two files: type: 'sqlite' params: db_path: './zendesk-kb.db' - # Qdrant example - type: 'website' product_name: 'Istio' diff --git a/database.ts b/database.ts index eb9f810..e4eb997 100644 --- a/database.ts +++ b/database.ts @@ -5,11 +5,11 @@ import * as sqliteVec from "sqlite-vec"; import { QdrantClient } from '@qdrant/js-client-rest'; import { Logger } from './logger'; import { Utils } from './utils'; -import { - SourceConfig, - DatabaseConnection, - SqliteDB, - QdrantDB, +import { + SourceConfig, + DatabaseConnection, + SqliteDB, + QdrantDB, DocumentChunk, SqliteDatabaseParams, QdrantDatabaseParams, @@ -17,23 +17,36 @@ import { } from './types'; export class DatabaseManager { + private static getEmbeddingSize(): number { + const envSize = process.env.EMBEDDING_VECTOR_SIZE; + if (envSize) { + const parsed = parseInt(envSize, 10); + if (isNaN(parsed) || parsed <= 0) { + throw new Error(`Invalid EMBEDDING_VECTOR_SIZE: ${envSize}. Must be a positive integer.`); + } + return parsed; + } + return 3072; // Default value + } + static async initDatabase(config: SourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('database'); const dbConfig = config.database_config; - + if (dbConfig.type === 'sqlite') { const params = dbConfig.params as SqliteDatabaseParams; const dbPath = params.db_path || path.join(process.cwd(), `${config.product_name.replace(/\s+/g, '_')}-${config.version}.db`); - + logger.info(`Opening SQLite database at ${dbPath}`); - + const db = new BetterSqlite3(dbPath, { allowExtension: true } as any); sqliteVec.load(db); - logger.debug(`Creating vec_items table if it doesn't exist`); + const embeddingSize = this.getEmbeddingSize(); + logger.debug(`Creating vec_items table if it doesn't exist with embedding size: ${embeddingSize}`); db.exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0( - embedding FLOAT[3072], + embedding FLOAT[${embeddingSize}], product_name TEXT, version TEXT, heading_hierarchy TEXT, @@ -51,7 +64,7 @@ export class DatabaseManager { const qdrantUrl = params.qdrant_url || 'http://localhost:6333'; const qdrantPort = params.qdrant_port || 443; const collectionName = params.collection_name || `${config.product_name.toLowerCase().replace(/\s+/g, '_')}_${config.version}`; - + logger.info(`Connecting to Qdrant at ${qdrantUrl}:${qdrantPort}, collection: ${collectionName}`); const qdrantClient = new QdrantClient({ url: qdrantUrl, apiKey: process.env.QDRANT_API_KEY, port: qdrantPort }); @@ -72,16 +85,17 @@ export class DatabaseManager { const collectionExists = collections.collections.some( (collection: any) => collection.name === collectionName ); - + if (collectionExists) { logger.info(`Collection ${collectionName} already exists`); return; } - - logger.info(`Creating new collection ${collectionName}`); + + const embeddingSize = this.getEmbeddingSize(); + logger.info(`Creating new collection ${collectionName} with embedding size: ${embeddingSize}`); await qdrantClient.createCollection(collectionName, { vectors: { - size: 3072, + size: embeddingSize, distance: "Cosine", }, }); @@ -90,9 +104,9 @@ export class DatabaseManager { if (error instanceof Error) { const errorMsg = error.message.toLowerCase(); const errorString = JSON.stringify(error).toLowerCase(); - + if ( - errorMsg.includes("already exists") || + errorMsg.includes("already exists") || errorString.includes("already exists") || (error as any)?.status === 409 || errorString.includes("conflict") @@ -101,7 +115,7 @@ export class DatabaseManager { return; } } - + logger.error(`Error creating Qdrant collection:`, error); logger.warn(`Continuing with existing collection...`); } @@ -127,12 +141,12 @@ export class DatabaseManager { static async getLastRunDate(dbConnection: DatabaseConnection, repo: string, defaultDate: string, logger: Logger): Promise { const metadataKey = `last_run_${repo.replace('/', '_')}`; - + try { if (dbConnection.type === 'sqlite') { const stmt = dbConnection.db.prepare('SELECT value FROM vec_metadata WHERE key = ?'); const result = stmt.get(metadataKey) as { value: string } | undefined; - + if (result) { logger.info(`Retrieved last run date for ${repo}: ${result.value}`); return result.value; @@ -141,7 +155,7 @@ export class DatabaseManager { // Generate a UUID for this repo's metadata const metadataUUID = Utils.generateMetadataUUID(repo); logger.debug(`Looking up metadata with UUID: ${metadataUUID}`); - + try { // Try to retrieve the metadata point for this repo const response = await dbConnection.client.retrieve(dbConnection.collectionName, { @@ -149,7 +163,7 @@ export class DatabaseManager { with_payload: true, with_vector: false }); - + if (response.length > 0 && response[0].payload?.metadata_value) { const lastRunDate = response[0].payload.metadata_value as string; logger.info(`Retrieved last run date for ${repo}: ${lastRunDate}`); @@ -162,14 +176,14 @@ export class DatabaseManager { } catch (error) { logger.warn(`Error retrieving last run date:`, error); } - + logger.info(`No saved run date found for ${repo}, using default: ${defaultDate}`); return defaultDate; } static async updateLastRunDate(dbConnection: DatabaseConnection, repo: string, logger: Logger): Promise { const now = new Date().toISOString(); - + try { if (dbConnection.type === 'sqlite') { const metadataKey = `last_run_${repo.replace('/', '_')}`; @@ -183,13 +197,13 @@ export class DatabaseManager { // Generate UUID for this repo's metadata const metadataUUID = Utils.generateMetadataUUID(repo); const metadataKey = `last_run_${repo.replace('/', '_')}`; - + logger.debug(`Using UUID: ${metadataUUID} for metadata`); - + // Generate a dummy embedding (all zeros) - const dummyEmbeddingSize = 3072; // Same size as your content embeddings + const dummyEmbeddingSize = this.getEmbeddingSize(); // Same size as your content embeddings const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0); - + // Create a point with special metadata payload const metadataPoint = { id: metadataUUID, @@ -204,12 +218,12 @@ export class DatabaseManager { url: 'metadata://' + repo } }; - + await dbConnection.client.upsert(dbConnection.collectionName, { wait: true, points: [metadataPoint] }); - + logger.info(`Updated last run date for ${repo} to ${now}`); } } catch (error) { @@ -236,7 +250,7 @@ export class DatabaseManager { static insertVectorsSQLite(db: Database, chunk: DocumentChunk, embedding: number[], logger: Logger, chunkHash?: string) { const { insertStmt, updateStmt } = this.prepareSQLiteStatements(db); const hash = chunkHash || Utils.generateHash(chunk.content); - + const transaction = db.transaction(() => { const params = [ new Float32Array(embedding), @@ -272,9 +286,9 @@ export class DatabaseManager { } catch (e) { pointId = crypto.randomUUID(); } - + const hash = chunkHash || Utils.generateHash(chunk.content); - + const pointItem = { id: pointId, vector: embedding, @@ -376,9 +390,9 @@ export class DatabaseManager { } static removeObsoleteFilesSQLite( - db: Database, - processedFiles: Set, - pathConfig: { path: string; url_rewrite_prefix?: string } | string, + db: Database, + processedFiles: Set, + pathConfig: { path: string; url_rewrite_prefix?: string } | string, logger: Logger ) { const getChunksForPathStmt = db.prepare(` @@ -386,10 +400,10 @@ export class DatabaseManager { WHERE url LIKE ? || '%' `); const deleteChunkStmt = db.prepare(`DELETE FROM vec_items WHERE chunk_id = ?`); - + // Determine if we're using URL rewriting or direct file paths const isRewriteMode = typeof pathConfig === 'object' && pathConfig.url_rewrite_prefix; - + // Set up the URL prefix for searching let urlPrefix: string; if (isRewriteMode) { @@ -402,19 +416,19 @@ export class DatabaseManager { const cleanedDirPrefix = dirPrefix.replace(/^\.\/+/, ''); urlPrefix = `file://${cleanedDirPrefix}`; } - + logger.debug(`Searching for chunks with URL prefix: ${urlPrefix}`); const existingChunks = getChunksForPathStmt.all(urlPrefix) as { chunk_id: string; url: string }[]; let deletedCount = 0; - + const transaction = db.transaction(() => { for (const { chunk_id, url } of existingChunks) { // Skip if it's not from our URL prefix (safety check) if (!url.startsWith(urlPrefix)) continue; - + let filePath: string; let shouldDelete = false; - + if (isRewriteMode) { // URL rewrite mode: extract relative path and construct full file path const config = pathConfig as { path: string; url_rewrite_prefix?: string }; @@ -426,7 +440,7 @@ export class DatabaseManager { filePath = url.substring(7); // Remove 'file://' prefix shouldDelete = !processedFiles.has(filePath); } - + if (shouldDelete) { logger.debug(`Deleting obsolete chunk from SQLite: ${chunk_id.substring(0, 8)}... (File not processed: ${filePath})`); deleteChunkStmt.run(chunk_id); @@ -435,21 +449,21 @@ export class DatabaseManager { } }); transaction(); - + logger.info(`Deleted ${deletedCount} obsolete chunks from SQLite for URL prefix ${urlPrefix}`); } static async removeObsoleteFilesQdrant( - db: QdrantDB, - processedFiles: Set, - pathConfig: { path: string; url_rewrite_prefix?: string } | string, + db: QdrantDB, + processedFiles: Set, + pathConfig: { path: string; url_rewrite_prefix?: string } | string, logger: Logger ) { const { client, collectionName } = db; try { // Determine if we're using URL rewriting or direct file paths const isRewriteMode = typeof pathConfig === 'object' && pathConfig.url_rewrite_prefix; - + // Set up the URL prefix for searching let urlPrefix: string; if (isRewriteMode) { @@ -462,7 +476,7 @@ export class DatabaseManager { const cleanedDirPrefix = dirPrefix.replace(/^\.\/+/, ''); urlPrefix = `file://${cleanedDirPrefix}`; } - + logger.debug(`Checking for obsolete chunks with URL prefix: ${urlPrefix}`); const response = await client.scroll(collectionName, { limit: 10000, @@ -487,7 +501,7 @@ export class DatabaseManager { ] } }); - + const obsoletePointIds = response.points .filter((point: any) => { const url = point.payload?.url; @@ -495,13 +509,13 @@ export class DatabaseManager { if (point.payload?.is_metadata === true) { return false; } - + if (!url || !url.startsWith(urlPrefix)) { return false; } - + let filePath: string; - + if (isRewriteMode) { // URL rewrite mode: extract relative path and construct full file path const config = pathConfig as { path: string; url_rewrite_prefix?: string }; @@ -511,11 +525,11 @@ export class DatabaseManager { // Direct file path mode: remove file:// prefix to match with processedFiles filePath = url.startsWith('file://') ? url.substring(7) : ''; } - + return filePath && !processedFiles.has(filePath); }) .map((point: any) => point.id); - + if (obsoletePointIds.length > 0) { await client.delete(collectionName, { points: obsoletePointIds, @@ -528,4 +542,4 @@ export class DatabaseManager { logger.error(`Error removing obsolete chunks from Qdrant:`, error); } } -} \ No newline at end of file +} diff --git a/doc2vec.ts b/doc2vec.ts index 4a9d8bc..09996a5 100644 --- a/doc2vec.ts +++ b/doc2vec.ts @@ -6,17 +6,18 @@ import * as yaml from 'js-yaml'; import * as fs from 'fs'; import * as path from 'path'; import { Buffer } from 'buffer'; -import { OpenAI } from "openai"; import * as dotenv from "dotenv"; import { Logger, LogLevel } from './logger'; import { Utils } from './utils'; import { DatabaseManager } from './database'; import { ContentProcessor } from './content-processor'; -import { - Config, - SourceConfig, - GithubSourceConfig, - WebsiteSourceConfig, +import { EmbeddingProviderFactory } from './embedding-factory'; +import { EmbeddingProvider } from './embedding-provider'; +import { + Config, + SourceConfig, + GithubSourceConfig, + WebsiteSourceConfig, LocalDirectorySourceConfig, ZendeskSourceConfig, DatabaseConnection, @@ -29,7 +30,7 @@ dotenv.config(); class Doc2Vec { private config: Config; - private openai: OpenAI; + private embeddingProvider: EmbeddingProvider; private contentProcessor: ContentProcessor; private logger: Logger; @@ -40,10 +41,14 @@ class Doc2Vec { useColor: true, prettyPrint: true }); - + this.logger.info('Initializing Doc2Vec'); this.config = this.loadConfig(configPath); - this.openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); + + // Initialize embedding provider based on environment variables + this.embeddingProvider = EmbeddingProviderFactory.createProvider(this.logger); + this.logger.info(`Using embedding provider: ${this.embeddingProvider.getProviderName()}`); + this.contentProcessor = new ContentProcessor(this.logger); } @@ -51,9 +56,9 @@ class Doc2Vec { try { const logger = this.logger.child('config'); logger.info(`Loading configuration from ${configPath}`); - + let configFile = fs.readFileSync(configPath, 'utf8'); - + // Substitute environment variables in the format ${VAR_NAME} configFile = configFile.replace(/\$\{([^}]+)\}/g, (match, varName) => { const envValue = process.env[varName]; @@ -64,9 +69,9 @@ class Doc2Vec { logger.debug(`Substituted ${match} with environment variable value`); return envValue; }); - + let config = yaml.load(configFile) as any; - + const typedConfig = config as Config; logger.info(`Configuration loaded successfully, found ${typedConfig.sources.length} sources`); return typedConfig; @@ -78,12 +83,12 @@ class Doc2Vec { public async run(): Promise { this.logger.section('PROCESSING SOURCES'); - + for (const sourceConfig of this.config.sources) { const sourceLogger = this.logger.child(`source:${sourceConfig.product_name}`); - + sourceLogger.info(`Processing ${sourceConfig.type} source for ${sourceConfig.product_name}@${sourceConfig.version}`); - + if (sourceConfig.type === 'github') { await this.processGithubRepo(sourceConfig, sourceLogger); } else if (sourceConfig.type === 'website') { @@ -96,17 +101,17 @@ class Doc2Vec { sourceLogger.error(`Unknown source type: ${(sourceConfig as any).type}`); } } - + this.logger.section('PROCESSING COMPLETE'); } private async fetchAndProcessGitHubIssues(repo: string, sourceConfig: GithubSourceConfig, dbConnection: DatabaseConnection, logger: Logger): Promise { const [owner, repoName] = repo.split('/'); const GITHUB_API_URL = `https://api.github.com/repos/${owner}/${repoName}/issues`; - + // Initialize metadata storage if needed await DatabaseManager.initDatabaseMetadata(dbConnection, logger); - + // Get the last run date from the database const startDate = sourceConfig.start_date || '2025-01-01'; const lastRunDate = await DatabaseManager.getLastRunDate(dbConnection, repo, `${startDate}T00:00:00Z`, logger); @@ -193,31 +198,31 @@ class Doc2Vec { const processIssue = async (issue: any): Promise => { const issueNumber = issue.number; const url = `https://github.com/${repo}/issues/${issueNumber}`; - + logger.info(`Processing issue #${issueNumber}`); - + // Generate markdown for the issue const markdown = await generateMarkdownForIssue(issue); - + // Chunk the markdown content const issueConfig = { ...sourceConfig, product_name: sourceConfig.product_name || repo, max_size: sourceConfig.max_size || Infinity }; - + const chunks = await this.contentProcessor.chunkMarkdown(markdown, issueConfig, url); logger.info(`Issue #${issueNumber}: Created ${chunks.length} chunks`); - + // Process and store each chunk immediately for (const chunk of chunks) { const chunkHash = Utils.generateHash(chunk.content); const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; @@ -252,7 +257,7 @@ class Doc2Vec { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; } - + const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length) { await DatabaseManager.storeChunkInQdrant(dbConnection, chunk, embeddings[0], chunkHash); @@ -279,36 +284,36 @@ class Doc2Vec { // Update the last run date in the database after processing all issues await DatabaseManager.updateLastRunDate(dbConnection, repo, logger); - + logger.info(`Successfully processed ${issues.length} issues`); } private async processGithubRepo(config: GithubSourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for GitHub repo: ${config.repo}`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); - + // Initialize metadata storage await DatabaseManager.initDatabaseMetadata(dbConnection, logger); - + logger.section('GITHUB ISSUES'); - + // Process GitHub issues await this.fetchAndProcessGitHubIssues(config.repo, config, dbConnection, logger); - + logger.info(`Finished processing GitHub repo: ${config.repo}`); } private async processWebsite(config: WebsiteSourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for website: ${config.url}`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); const validChunkIds: Set = new Set(); const visitedUrls: Set = new Set(); const urlPrefix = Utils.getUrlPrefix(config.url); - + logger.section('CRAWL AND EMBEDDING'); const crawlResult = await this.contentProcessor.crawlWebsite(config.url, config, async (url, content) => { @@ -399,7 +404,7 @@ class Doc2Vec { logger.info(`Found ${validChunkIds.size} valid chunks across processed pages for ${config.url}`); logger.section('CLEANUP'); - + if (crawlResult.hasNetworkErrors) { logger.warn('Skipping cleanup due to network errors encountered during crawling. This prevents removal of valid chunks when the site is temporarily unreachable.'); } else { @@ -418,28 +423,28 @@ class Doc2Vec { private async processLocalDirectory(config: LocalDirectorySourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for local directory: ${config.path}`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); const validChunkIds: Set = new Set(); const processedFiles: Set = new Set(); - + logger.section('FILE SCANNING AND EMBEDDING'); - + await this.contentProcessor.processDirectory( - config.path, - config, + config.path, + config, async (filePath, content) => { processedFiles.add(filePath); - + logger.info(`Processing content from ${filePath} (${content.length} chars)`); try { // Generate URL based on configuration let fileUrl: string; - + if (config.url_rewrite_prefix) { // Replace local path with URL prefix const relativePath = path.relative(config.path, filePath).replace(/\\/g, '/'); - + // If relativePath starts with '..', it means the file is outside the base directory if (relativePath.startsWith('..')) { // For files outside the configured path, use the default file:// scheme @@ -448,10 +453,10 @@ class Doc2Vec { } else { // For files inside the configured path, rewrite the URL // Handle trailing slashes in the URL prefix to avoid double slashes - const prefix = config.url_rewrite_prefix.endsWith('/') - ? config.url_rewrite_prefix.slice(0, -1) + const prefix = config.url_rewrite_prefix.endsWith('/') + ? config.url_rewrite_prefix.slice(0, -1) : config.url_rewrite_prefix; - + fileUrl = `${prefix}/${relativePath}`; logger.debug(`URL rewritten: ${filePath} -> ${fileUrl}`); } @@ -459,26 +464,26 @@ class Doc2Vec { // Use default file:// URL fileUrl = `file://${filePath}`; } - + const chunks = await this.contentProcessor.chunkMarkdown(content, config, fileUrl); logger.info(`Created ${chunks.length} chunks`); - + if (chunks.length > 0) { const chunkProgress = logger.progress(`Embedding chunks for ${filePath}`, chunks.length); - + for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; validChunkIds.add(chunk.metadata.chunk_id); - + const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + let needsEmbedding = true; const chunkHash = Utils.generateHash(chunk.content); - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { needsEmbedding = false; chunkProgress.update(1, `Skipping unchanged chunk ${chunkId}`); @@ -495,13 +500,13 @@ class Doc2Vec { } catch (e) { pointId = crypto.randomUUID(); } - + const existingPoints = await dbConnection.client.retrieve(dbConnection.collectionName, { ids: [pointId], with_payload: true, with_vector: false, }); - + if (existingPoints.length > 0 && existingPoints[0].payload && existingPoints[0].payload.hash === chunkHash) { needsEmbedding = false; chunkProgress.update(1, `Skipping unchanged chunk ${chunkId}`); @@ -511,7 +516,7 @@ class Doc2Vec { logger.error(`Error checking existing point in Qdrant:`, error); } } - + if (needsEmbedding) { const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length > 0) { @@ -529,16 +534,16 @@ class Doc2Vec { } } } - + chunkProgress.complete(); } } catch (error) { logger.error(`Error during chunking or embedding for ${filePath}:`, error); } - }, + }, logger ); - + logger.section('CLEANUP'); if (dbConnection.type === 'sqlite') { logger.info(`Running SQLite cleanup for local directory ${config.path}`); @@ -547,43 +552,43 @@ class Doc2Vec { logger.info(`Running Qdrant cleanup for local directory ${config.path} in collection ${dbConnection.collectionName}`); await DatabaseManager.removeObsoleteFilesQdrant(dbConnection, processedFiles, config, logger); } - + logger.info(`Finished processing local directory: ${config.path}`); } private async processZendesk(config: ZendeskSourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for Zendesk: ${config.zendesk_subdomain}.zendesk.com`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); - + // Initialize metadata storage await DatabaseManager.initDatabaseMetadata(dbConnection, logger); - + const fetchTickets = config.fetch_tickets !== false; // default true const fetchArticles = config.fetch_articles !== false; // default true - + if (fetchTickets) { logger.section('ZENDESK TICKETS'); await this.fetchAndProcessZendeskTickets(config, dbConnection, logger); } - + if (fetchArticles) { logger.section('ZENDESK ARTICLES'); await this.fetchAndProcessZendeskArticles(config, dbConnection, logger); } - + logger.info(`Finished processing Zendesk: ${config.zendesk_subdomain}.zendesk.com`); } private async fetchAndProcessZendeskTickets(config: ZendeskSourceConfig, dbConnection: DatabaseConnection, logger: Logger): Promise { const baseUrl = `https://${config.zendesk_subdomain}.zendesk.com/api/v2`; const auth = Buffer.from(`${config.email}/token:${config.api_token}`).toString('base64'); - + // Get the last run date from the database const startDate = config.start_date || `${new Date().getFullYear()}-01-01`; const lastRunDate = await DatabaseManager.getLastRunDate(dbConnection, `zendesk_tickets_${config.zendesk_subdomain}`, `${startDate}T00:00:00Z`, logger); - + const fetchWithRetry = async (url: string, retries = 3): Promise => { for (let attempt = 0; attempt < retries; attempt++) { try { @@ -593,14 +598,14 @@ class Doc2Vec { 'Content-Type': 'application/json', }, }); - + if (response.status === 429) { const retryAfter = parseInt(response.headers['retry-after'] || '60'); logger.warn(`Rate limited, waiting ${retryAfter}s before retry`); await new Promise(res => setTimeout(res, retryAfter * 1000)); continue; } - + return response.data; } catch (error: any) { logger.error(`Zendesk API error (attempt ${attempt + 1}):`, error.message); @@ -619,26 +624,26 @@ class Doc2Vec { md += `- **Assignee:** ${ticket.assignee_id || 'Unassigned'}\n`; md += `- **Created:** ${new Date(ticket.created_at).toDateString()}\n`; md += `- **Updated:** ${new Date(ticket.updated_at).toDateString()}\n`; - + if (ticket.tags && ticket.tags.length > 0) { md += `- **Tags:** ${ticket.tags.map((tag: string) => `\`${tag}\``).join(', ')}\n`; } - + // Handle ticket description const description = ticket.description || ''; const cleanDescription = description || '_No description._'; md += `\n## Description\n\n${cleanDescription}\n\n`; - + if (comments && comments.length > 0) { md += `## Comments\n\n`; for (const comment of comments) { if (comment.public) { md += `### ${comment.author_id} - ${new Date(comment.created_at).toDateString()}\n\n`; - + // Handle comment body const rawBody = comment.plain_body || comment.html_body || comment.body || ''; const commentBody = rawBody.replace(/ /g, " ") || '_No content._'; - + md += `${commentBody}\n\n---\n\n`; } } @@ -652,36 +657,36 @@ class Doc2Vec { const processTicket = async (ticket: any): Promise => { const ticketId = ticket.id; const url = `https://${config.zendesk_subdomain}.zendesk.com/agent/tickets/${ticketId}`; - + logger.info(`Processing ticket #${ticketId}`); - + // Fetch ticket comments const commentsUrl = `${baseUrl}/tickets/${ticketId}/comments.json`; const commentsData = await fetchWithRetry(commentsUrl); const comments = commentsData?.comments || []; - + // Generate markdown for the ticket const markdown = generateMarkdownForTicket(ticket, comments); - + // Chunk the markdown content const ticketConfig = { ...config, product_name: config.product_name || `zendesk_${config.zendesk_subdomain}`, max_size: config.max_size || Infinity }; - + const chunks = await this.contentProcessor.chunkMarkdown(markdown, ticketConfig, url); logger.info(`Ticket #${ticketId}: Created ${chunks.length} chunks`); - + // Process and store each chunk for (const chunk of chunks) { const chunkHash = Utils.generateHash(chunk.content); const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; @@ -716,7 +721,7 @@ class Doc2Vec { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; } - + const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length) { await DatabaseManager.storeChunkInQdrant(dbConnection, chunk, embeddings[0], chunkHash); @@ -732,27 +737,27 @@ class Doc2Vec { }; logger.info(`Fetching Zendesk tickets updated since ${lastRunDate}`); - + // Build query parameters const statusFilter = config.ticket_status || ['new', 'open', 'pending', 'hold', 'solved']; const query = `updated>${lastRunDate.split('T')[0]} status:${statusFilter.join(',status:')}`; - + let nextPage = `${baseUrl}/search.json?query=${encodeURIComponent(query)}&sort_by=updated_at&sort_order=asc`; let totalTickets = 0; - + while (nextPage) { const data = await fetchWithRetry(nextPage); const tickets = data.results || []; - + logger.info(`Processing batch of ${tickets.length} tickets`); - + for (const ticket of tickets) { await processTicket(ticket); totalTickets++; } - + nextPage = data.next_page; - + if (nextPage) { logger.debug(`Fetching next page: ${nextPage}`); // Rate limiting: wait between requests @@ -762,18 +767,18 @@ class Doc2Vec { // Update the last run date in the database await DatabaseManager.updateLastRunDate(dbConnection, `zendesk_tickets_${config.zendesk_subdomain}`, logger); - + logger.info(`Successfully processed ${totalTickets} tickets`); } private async fetchAndProcessZendeskArticles(config: ZendeskSourceConfig, dbConnection: DatabaseConnection, logger: Logger): Promise { const baseUrl = `https://${config.zendesk_subdomain}.zendesk.com/api/v2/help_center`; const auth = Buffer.from(`${config.email}/token:${config.api_token}`).toString('base64'); - + // Get the start date for filtering const startDate = config.start_date || `${new Date().getFullYear()}-01-01`; const startDateObj = new Date(startDate); - + const fetchWithRetry = async (url: string, retries = 3): Promise => { for (let attempt = 0; attempt < retries; attempt++) { try { @@ -783,14 +788,14 @@ class Doc2Vec { 'Content-Type': 'application/json', }, }); - + if (response.status === 429) { const retryAfter = parseInt(response.headers['retry-after'] || '60'); logger.warn(`Rate limited, waiting ${retryAfter}s before retry`); await new Promise(res => setTimeout(res, retryAfter * 1000)); continue; } - + return response.data; } catch (error: any) { logger.error(`Zendesk API error (attempt ${attempt + 1}):`, error.message); @@ -808,11 +813,11 @@ class Doc2Vec { md += `- **Updated:** ${new Date(article.updated_at).toDateString()}\n`; md += `- **Vote Sum:** ${article.vote_sum || 0}\n`; md += `- **Vote Count:** ${article.vote_count || 0}\n`; - + if (article.label_names && article.label_names.length > 0) { md += `- **Labels:** ${article.label_names.map((label: string) => `\`${label}\``).join(', ')}\n`; } - + // Handle article content - convert HTML to markdown const articleBody = article.body || ''; let cleanContent = '_No content._'; @@ -825,7 +830,7 @@ class Doc2Vec { cleanContent = articleBody; } } - + md += `\n## Content\n\n${cleanContent}\n`; return md; @@ -834,31 +839,31 @@ class Doc2Vec { const processArticle = async (article: any): Promise => { const articleId = article.id; const url = article.html_url || `https://${config.zendesk_subdomain}.zendesk.com/hc/articles/${articleId}`; - + logger.info(`Processing article #${articleId}: ${article.title}`); - + // Generate markdown for the article const markdown = generateMarkdownForArticle(article); - + // Chunk the markdown content const articleConfig = { ...config, product_name: config.product_name || `zendesk_${config.zendesk_subdomain}`, max_size: config.max_size || Infinity }; - + const chunks = await this.contentProcessor.chunkMarkdown(markdown, articleConfig, url); logger.info(`Article #${articleId}: Created ${chunks.length} chunks`); - + // Process and store each chunk (similar to ticket processing) for (const chunk of chunks) { const chunkHash = Utils.generateHash(chunk.content); const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; @@ -893,7 +898,7 @@ class Doc2Vec { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; } - + const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length) { await DatabaseManager.storeChunkInQdrant(dbConnection, chunk, embeddings[0], chunkHash); @@ -909,20 +914,20 @@ class Doc2Vec { }; logger.info(`Fetching Zendesk help center articles updated since ${startDate}`); - + let nextPage = `${baseUrl}/articles.json`; let totalArticles = 0; let processedArticles = 0; - + while (nextPage) { const data = await fetchWithRetry(nextPage); const articles = data.articles || []; - + logger.info(`Processing batch of ${articles.length} articles`); - + for (const article of articles) { totalArticles++; - + // Check if article was updated since the start date const updatedAt = new Date(article.updated_at); if (updatedAt >= startDateObj) { @@ -932,33 +937,21 @@ class Doc2Vec { logger.debug(`Skipping article #${article.id} (updated ${article.updated_at}, before ${startDate})`); } } - + nextPage = data.next_page; - + if (nextPage) { logger.debug(`Fetching next page: ${nextPage}`); // Rate limiting: wait between requests await new Promise(res => setTimeout(res, 1000)); } } - + logger.info(`Successfully processed ${processedArticles} of ${totalArticles} articles (filtered by date >= ${startDate})`); } private async createEmbeddings(texts: string[]): Promise { - const logger = this.logger.child('embeddings'); - try { - logger.debug(`Creating embeddings for ${texts.length} texts`); - const response = await this.openai.embeddings.create({ - model: "text-embedding-3-large", - input: texts, - }); - logger.debug(`Successfully created ${response.data.length} embeddings`); - return response.data.map(d => d.embedding); - } catch (error) { - logger.error('Failed to create embeddings:', error); - return []; - } + return await this.embeddingProvider.createEmbeddings(texts); } } @@ -970,4 +963,4 @@ if (require.main === module) { } const doc2Vec = new Doc2Vec(configPath); doc2Vec.run().catch(console.error); -} \ No newline at end of file +} diff --git a/embedding-factory.ts b/embedding-factory.ts new file mode 100644 index 0000000..9599e8e --- /dev/null +++ b/embedding-factory.ts @@ -0,0 +1,49 @@ +import { Logger } from './logger'; +import { + EmbeddingProvider, + EmbeddingConfig, + OpenAIEmbeddingProvider, + CustomEmbeddingProvider +} from './embedding-provider'; + +/** + * Factory class for creating embedding providers + */ +export class EmbeddingProviderFactory { + /** + * Creates an embedding provider based on environment variables + * @param logger Logger instance + * @returns Configured embedding provider + */ + static createProvider(logger: Logger): EmbeddingProvider { + const factoryLogger = logger.child('embedding-factory'); + + // Get provider from PROVIDER environment variable, default to openai + const provider = (process.env.PROVIDER || 'openai').toLowerCase(); + + factoryLogger.info(`Creating embedding provider: ${provider}`); + + switch (provider) { + case 'openai': + return new OpenAIEmbeddingProvider(logger); + + case 'custom': + const endpoint = process.env.CUSTOM_ENDPOINT; + if (!endpoint) { + throw new Error('CUSTOM_ENDPOINT environment variable is required when using custom provider'); + } + + // Validate endpoint URL format + try { + new URL(endpoint); + } catch (error) { + throw new Error(`Invalid custom embedding endpoint URL: ${endpoint}`); + } + + return new CustomEmbeddingProvider(endpoint, logger); + + default: + throw new Error(`Unknown embedding provider: ${provider}. Must be 'openai' or 'custom'`); + } + } +} diff --git a/embedding-provider.ts b/embedding-provider.ts new file mode 100644 index 0000000..851bc62 --- /dev/null +++ b/embedding-provider.ts @@ -0,0 +1,174 @@ +import axios from 'axios'; +import { OpenAI } from 'openai'; +import { Logger } from './logger'; + +/** + * Abstract interface for embedding providers + */ +export interface EmbeddingProvider { + createEmbeddings(texts: string[]): Promise; + getProviderName(): string; +} + +/** + * Configuration for embedding providers + */ +export interface EmbeddingConfig { + provider: 'openai' | 'custom'; + endpoint?: string; // For custom provider + model?: string; // Model to use + timeout?: number; // Timeout for custom provider +} + +/** + * OpenAI embedding provider implementation + */ +export class OpenAIEmbeddingProvider implements EmbeddingProvider { + private openai: OpenAI; + private logger: Logger; + private model: string; + + constructor(logger: Logger) { + this.logger = logger.child('openai-embeddings'); + + const apiKey = process.env.OPENAI_API_KEY; + if (!apiKey) { + throw new Error('OpenAI API key not found in environment variable: OPENAI_API_KEY'); + } + + // Use EMBEDDING_MODEL env var or default to text-embedding-3-large + this.model = process.env.EMBEDDING_MODEL || 'text-embedding-3-large'; + + this.openai = new OpenAI({ apiKey }); + this.logger.info(`Initialized OpenAI embedding provider with model: ${this.model}`); + } + + async createEmbeddings(texts: string[]): Promise { + const maxRetries = 3; + const baseDelay = 1000; // 1 second + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + this.logger.debug(`Creating embeddings for ${texts.length} texts (attempt ${attempt}/${maxRetries})`); + + const response = await this.openai.embeddings.create({ + model: this.model, + input: texts, + }); + + this.logger.debug(`Successfully created ${response.data.length} embeddings`); + return response.data.map(d => d.embedding); + + } catch (error: any) { + this.logger.warn(`OpenAI embedding attempt ${attempt} failed:`, error.message); + + if (attempt === maxRetries) { + this.logger.error('All OpenAI embedding attempts failed'); + throw error; + } + + // Exponential backoff + const delay = baseDelay * Math.pow(2, attempt - 1); + this.logger.debug(`Retrying in ${delay}ms...`); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + return []; + } + + getProviderName(): string { + return 'openai'; + } +} + +/** + * Custom endpoint embedding provider implementation (OpenAI-compatible) + */ +export class CustomEmbeddingProvider implements EmbeddingProvider { + private endpoint: string; + private model: string; + private apiKey?: string; + private timeout: number; + private logger: Logger; + + constructor(endpoint: string, logger: Logger) { + this.logger = logger.child('custom-embeddings'); + + this.endpoint = endpoint; + this.timeout = 30000; // 30 seconds default + + // Use OPENAI_API_KEY for authentication (same as OpenAI provider) + this.apiKey = process.env.OPENAI_API_KEY; + if (!this.apiKey) { + throw new Error('OpenAI API key not found in environment variable: OPENAI_API_KEY'); + } + + // Use EMBEDDING_MODEL env var or default to text-embedding-ada-002 for custom + this.model = process.env.EMBEDDING_MODEL || 'text-embedding-3-large'; + + this.logger.info(`Initialized custom embedding provider: ${this.endpoint} with model: ${this.model}`); + } + + async createEmbeddings(texts: string[]): Promise { + const maxRetries = 3; + const baseDelay = 1000; // 1 second + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + this.logger.debug(`Creating embeddings for ${texts.length} texts (attempt ${attempt}/${maxRetries})`); + + const headers: Record = { + 'Content-Type': 'application/json', + }; + + if (this.apiKey) { + headers['Authorization'] = `Bearer ${this.apiKey}`; + } + + const requestBody = { + model: this.model, + input: texts, + }; + + const response = await axios.post(this.endpoint, requestBody, { + headers, + timeout: this.timeout, + }); + + if (!response.data || !response.data.data) { + throw new Error('Invalid response format from custom embedding endpoint'); + } + + const embeddings = response.data.data.map((item: any) => { + if (!item.embedding || !Array.isArray(item.embedding)) { + throw new Error('Invalid embedding format in response'); + } + return item.embedding; + }); + + this.logger.debug(`Successfully created ${embeddings.length} embeddings`); + return embeddings; + + } catch (error: any) { + this.logger.warn(`Custom embedding attempt ${attempt} failed:`, error.message); + + if (attempt === maxRetries) { + this.logger.error('All custom embedding attempts failed'); + throw error; + } + + // Exponential backoff + const delay = baseDelay * Math.pow(2, attempt - 1); + this.logger.debug(`Retrying in ${delay}ms...`); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + return []; + } + + getProviderName(): string { + return 'custom'; + } +} diff --git a/mcp/package-lock.json b/mcp/package-lock.json index 3f496f6..c63824c 100644 --- a/mcp/package-lock.json +++ b/mcp/package-lock.json @@ -1,6 +1,6 @@ { "name": "sqlite-vec-mcp-server", - "version": "1.0.0", + "version": "1.1.0", "lockfileVersion": 3, "requires": true, "packages": { @@ -13,6 +13,7 @@ "@azure/openai": "^2.0.0", "@google/generative-ai": "^0.24.1", "@modelcontextprotocol/sdk": "^1.12.1", + "axios": "^1.12.2", "better-sqlite3": "^11.8.1", "dotenv": "^16.4.7", "express": "^5.1.0", @@ -424,6 +425,17 @@ "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==" }, + "node_modules/axios": { + "version": "1.12.2", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz", + "integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.4", + "proxy-from-env": "^1.1.0" + } + }, "node_modules/base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", @@ -972,14 +984,36 @@ "node": ">= 0.8" } }, + "node_modules/follow-redirects": { + "version": "1.15.11", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz", + "integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, "node_modules/form-data": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz", - "integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", + "license": "MIT", "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", "mime-types": "^2.1.12" }, "engines": { @@ -1671,6 +1705,12 @@ "node": ">= 0.10" } }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", + "license": "MIT" + }, "node_modules/pump": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.2.tgz", diff --git a/mcp/package.json b/mcp/package.json index 55a8ad7..f02170d 100644 --- a/mcp/package.json +++ b/mcp/package.json @@ -1,6 +1,6 @@ { "name": "sqlite-vec-mcp-server", - "version": "1.0.0", + "version": "1.1.0", "description": "MCP Server for querying documentation with sqlite-vec", "main": "build/index.js", "type": "module", @@ -28,6 +28,7 @@ "@azure/openai": "^2.0.0", "@google/generative-ai": "^0.24.1", "@modelcontextprotocol/sdk": "^1.12.1", + "axios": "^1.12.2", "better-sqlite3": "^11.8.1", "dotenv": "^16.4.7", "express": "^5.1.0", diff --git a/mcp/src/index.ts b/mcp/src/index.ts index ecc12df..1a441a1 100644 --- a/mcp/src/index.ts +++ b/mcp/src/index.ts @@ -15,6 +15,7 @@ import * as sqliteVec from "sqlite-vec"; import Database, { Database as DatabaseType } from "better-sqlite3"; import { OpenAI } from 'openai'; import { GoogleGenerativeAI } from '@google/generative-ai'; +import axios from 'axios'; import path from 'path'; import { fileURLToPath } from 'url'; import fs from 'fs'; // Import fs for checking file existence @@ -26,7 +27,7 @@ const __dirname = path.dirname(__filename); // Provider configuration // Note: Anthropic does not provide an embeddings API, only text generation -// Supported providers: 'openai', 'azure', 'gemini' +// Supported providers: 'openai', 'azure', 'gemini', 'custom' const embeddingProvider = process.env.EMBEDDING_PROVIDER || 'openai'; // OpenAI configuration @@ -43,6 +44,9 @@ const azureDeploymentName = process.env.AZURE_OPENAI_DEPLOYMENT_NAME || 'text-em const geminiApiKey = process.env.GEMINI_API_KEY; const geminiModel = process.env.GEMINI_MODEL || 'gemini-embedding-001'; +// Custom endpoint configuration +const customEndpoint = process.env.CUSTOM_ENDPOINT; + const dbDir = process.env.SQLITE_DB_DIR || __dirname; // Default to current dir if not set if (!fs.existsSync(dbDir)) { @@ -71,8 +75,14 @@ if (strictMode) { process.exit(1); } break; + case 'custom': + if (!customEndpoint || !openAIApiKey) { + console.error("Error: CUSTOM_ENDPOINT and OPENAI_API_KEY environment variables are required for custom provider."); + process.exit(1); + } + break; default: - console.error(`Error: Unknown embedding provider '${embeddingProvider}'. Supported providers: openai, azure, gemini`); + console.error(`Error: Unknown embedding provider '${embeddingProvider}'. Supported providers: openai, azure, gemini, custom`); console.error("Note: Anthropic does not provide an embeddings API, only text generation models."); process.exit(1); } @@ -104,14 +114,14 @@ async function createEmbeddings(text: string): Promise { } return response.data[0].embedding; } - + case 'azure': { - const azure = new AzureOpenAI({ - apiKey: azureApiKey, - endpoint: azureEndpoint, - deployment: azureDeploymentName, - apiVersion: azureApiVersion, - }); + const azure = new AzureOpenAI({ + apiKey: azureApiKey, + endpoint: azureEndpoint, + deployment: azureDeploymentName, + apiVersion: azureApiVersion, + }); const response = await azure.embeddings.create({ model: azureDeploymentName, // Use deployment name for Azure @@ -122,7 +132,7 @@ async function createEmbeddings(text: string): Promise { } return response.data[0].embedding; } - + case 'gemini': { const genAI = new GoogleGenerativeAI(geminiApiKey!); const model = genAI.getGenerativeModel({ model: geminiModel }); @@ -132,8 +142,24 @@ async function createEmbeddings(text: string): Promise { } return result.embedding.values; } + + case 'custom': { + const response = await axios.post(`${customEndpoint}/embeddings`, { + model: openAIModel, + input: text, + }, { + headers: { + 'Authorization': `Bearer ${openAIApiKey}`, + 'Content-Type': 'application/json', + }, + }); + if (!response.data?.data?.[0]?.embedding) { + throw new Error("Failed to get embedding from custom endpoint response."); + } + return response.data.data[0].embedding; + } default: - throw new Error(`Unsupported embedding provider: ${embeddingProvider}. Supported providers: openai, azure, gemini`); + throw new Error(`Unsupported embedding provider: ${embeddingProvider}. Supported providers: openai, azure, gemini, custom`); } } catch (error) { @@ -161,30 +187,30 @@ function queryCollection(queryEmbedding: number[], filter: { product_name: strin distance FROM vec_items WHERE embedding MATCH @query_embedding`; - + if (filter.product_name) query += ` AND product_name = @product_name`; if (filter.version) query += ` AND version = @version`; - + query += ` ORDER BY distance LIMIT @top_k;`; - + const stmt = db.prepare(query); console.error(`[DB ${dbPath}] Query prepared. Executing...`); const startTime = Date.now(); const rows = stmt.all({ - query_embedding: new Float32Array(queryEmbedding), - product_name: filter.product_name, - version: filter.version, - top_k: topK, + query_embedding: new Float32Array(queryEmbedding), + product_name: filter.product_name, + version: filter.version, + top_k: topK, }); const duration = Date.now() - startTime; console.error(`[DB ${dbPath}] Query executed in ${duration}ms. Found ${rows.length} rows.`); - + rows.forEach((row: any) => { - delete row.embedding; + delete row.embedding; }) - + return rows as QueryResult[]; } catch (error) { console.error(`Error querying collection in ${dbPath}:`, error); @@ -224,9 +250,9 @@ const queryDocumentationToolHandler = async ({ queryText, productName, version, const results = await queryDocumentation(queryText, productName, version, limit); if (results.length === 0) { - return { - content: [{ type: "text" as const, text: `No relevant documentation found for "${queryText}" in product "${productName}" ${version ? `(version ${version})` : ''}.` }], - }; + return { + content: [{ type: "text" as const, text: `No relevant documentation found for "${queryText}" in product "${productName}" ${version ? `(version ${version})` : ''}.` }], + }; } const formattedResults = results.map((r, index) => @@ -270,12 +296,12 @@ server.tool( async function main() { const transport_type = process.env.TRANSPORT_TYPE || 'http'; let webserver: any = null; // Store server reference for proper shutdown - + // Common graceful shutdown handler const createGracefulShutdownHandler = (transportCleanup: () => Promise) => { return async (signal: string) => { console.error(`Received ${signal}, initiating graceful shutdown...`); - + const shutdownTimeout = parseInt(process.env.SHUTDOWN_TIMEOUT || '5000', 10); const forceExitTimeout = setTimeout(() => { console.error(`Shutdown timeout (${shutdownTimeout}ms) exceeded, force exiting...`); @@ -311,32 +337,32 @@ async function main() { } }; }; - + if (transport_type === 'stdio') { // Stdio transport for direct communication console.error("Starting MCP server with stdio transport..."); const transport = new StdioServerTransport(); await server.connect(transport); console.error("MCP server connected via stdio."); - + // Add shutdown handler for stdio transport const shutdownHandler = createGracefulShutdownHandler(async () => { console.error('Closing stdio transport...'); // StdioServerTransport doesn't have a close method, but we can clean up the connection // The transport will be cleaned up when the process exits }); - + process.on('SIGTERM', () => shutdownHandler('SIGTERM')); process.on('SIGINT', () => shutdownHandler('SIGINT')); - + } else if (transport_type === 'sse') { // SSE transport for backward compatibility console.error("Starting MCP server with SSE transport..."); - + const app = express(); - + // Storage for SSE transports by session ID - const sseTransports: {[sessionId: string]: SSEServerTransport} = {}; + const sseTransports: { [sessionId: string]: SSEServerTransport } = {}; app.get("/sse", async (_: Request, res: Response) => { console.error('Received SSE connection request'); @@ -370,18 +396,18 @@ async function main() { console.error(`MCP server is running on port ${PORT} with SSE transport`); console.error(`Connect to: http://localhost:${PORT}/sse`); }); - + webserver.keepAliveTimeout = 3000; - + // Keep the process alive webserver.on('error', (error: any) => { console.error('HTTP server error:', error); }); - + // Handle server shutdown with proper SIGTERM/SIGINT support const shutdownHandler = createGracefulShutdownHandler(async () => { console.error('Closing SSE transports...'); - + // Close all active SSE transports for (const [sessionId, transport] of Object.entries(sseTransports)) { try { @@ -393,19 +419,19 @@ async function main() { } } }); - + process.on('SIGTERM', () => shutdownHandler('SIGTERM')); process.on('SIGINT', () => shutdownHandler('SIGINT')); - + } else if (transport_type === 'http') { // Streamable HTTP transport for web-based communication console.error("Starting MCP server with HTTP transport..."); - + const app = express(); - + const transports: Map = new Map(); const servers: Map = new Map(); - + // Handle POST requests for MCP initialization and method calls app.post('/mcp', async (req: Request, res: Response) => { console.error('Received MCP POST request'); @@ -563,20 +589,20 @@ async function main() { app.get("/health", (_: Request, res: Response) => { res.status(200).send("OK"); }); - + const PORT = process.env.PORT || 3001; webserver = app.listen(PORT, () => { console.error(`MCP server is running on port ${PORT} with HTTP transport`); console.error(`Connect to: http://localhost:${PORT}/mcp`); }); - + webserver.keepAliveTimeout = 3000; - + // Keep the process alive webserver.on('error', (error: any) => { console.error('HTTP server error:', error); }); - + // Handle server shutdown with proper SIGTERM/SIGINT support and timeout const shutdownHandler = createGracefulShutdownHandler(async () => { console.error('Closing HTTP transports and servers...'); @@ -585,17 +611,17 @@ async function main() { const transportClosePromises = Array.from(transports.entries()).map(async ([sessionId, transport]) => { try { console.error(`Closing transport and server for session ${sessionId}`); - + // Add timeout to individual transport close operations const closeTimeout = new Promise((_, reject) => { setTimeout(() => reject(new Error(`Transport close timeout for ${sessionId}`)), 2000); }); - + await Promise.race([ transport.close(), closeTimeout ]); - + transports.delete(sessionId); servers.delete(sessionId); console.error(`Transport and server closed for session ${sessionId}`); @@ -611,10 +637,10 @@ async function main() { await Promise.allSettled(transportClosePromises); console.error('All transports and servers cleanup completed'); }); - + process.on('SIGTERM', () => shutdownHandler('SIGTERM')); process.on('SIGINT', () => shutdownHandler('SIGINT')); - + } else { console.error(`Unknown transport type: ${transport_type}. Use 'stdio', 'sse', or 'http'.`); process.exit(1); diff --git a/package-lock.json b/package-lock.json index 4afa738..d4235d7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "doc2vec", - "version": "1.1.1", - "lockfileVersion": 3, + "version": "1.4.0", + "lockfileVersion": 4, "requires": true, "packages": { "": { "name": "doc2vec", - "version": "1.1.1", + "version": "1.3.0", "license": "ISC", "dependencies": { "@mozilla/readability": "^0.4.4", diff --git a/package.json b/package.json index 58cffbb..1e82c11 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "doc2vec", - "version": "1.3.0", + "version": "1.4.0", "type": "commonjs", "description": "", "main": "dist/doc2vec.js",