paper-burner/js/chatbot/agents/embedding-client.js

313 lines
11 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// js/chatbot/agents/embedding-client.js
// 统一的 Embedding API 客户端,支持 OpenAI 格式的各种服务
(function(window) {
'use strict';
/**
* Embedding API 配置
* 支持的服务:
* - OpenAI: text-embedding-3-small, text-embedding-3-large
* - BGE-M3: BAAI/bge-m3 (通过兼容接口)
* - Jina AI: jina-embeddings-v2-base-zh (多语言)
* - 本地部署: 任何 OpenAI 兼容的服务
*/
function EmbeddingClient() {
this.config = this.loadConfig();
this.cache = new Map(); // 内存缓存
}
// 简单延时
EmbeddingClient.prototype._delay = function(ms) { return new Promise(resolve => setTimeout(resolve, ms)); };
// 是否应该重试(包含 401/403/429/408/5xx
EmbeddingClient.prototype._shouldRetry = function(status) {
if (status === 401 || status === 403) return true;
if (status === 429 || status === 408) return true;
if (status >= 500 && status <= 599) return true;
return false;
};
// 具备指数退避 + 抖动的重试封装
EmbeddingClient.prototype._fetchWithRetry = async function(url, options = {}, retryOpts = {}) {
const {
maxRetries = 3,
baseDelay = 500,
maxDelay = 4000,
} = retryOpts;
let lastError = null;
for (let attempt = 0; attempt <= maxRetries; attempt++) {
try {
const res = await fetch(url, options);
if (res.ok) return res;
if (!this._shouldRetry(res.status) || attempt === maxRetries) {
return res; // 交给上层解析/抛错
}
const jitter = Math.floor(Math.random() * 250);
const delay = Math.min(maxDelay, baseDelay * Math.pow(2, attempt)) + jitter;
await this._delay(delay);
} catch (err) {
lastError = err;
if (attempt === maxRetries) throw err; // 网络错误且用尽重试
const jitter = Math.floor(Math.random() * 250);
const delay = Math.min(maxDelay, baseDelay * Math.pow(2, attempt)) + jitter;
await this._delay(delay);
}
}
if (lastError) throw lastError;
return fetch(url, options);
};
EmbeddingClient.prototype.loadConfig = function() {
try {
const saved = localStorage.getItem('embeddingConfig');
if (saved) return JSON.parse(saved);
} catch (e) {
console.warn('[EmbeddingClient] 加载配置失败:', e);
}
return {
provider: 'openai', // openai | jina | custom
apiKey: '',
endpoint: 'https://api.openai.com/v1/embeddings',
model: 'text-embedding-3-small',
dimensions: 1536,
maxBatchSize: 2048,
concurrency: 5,
enabled: false
};
};
EmbeddingClient.prototype.saveConfig = function(config) {
this.config = Object.assign({}, this.config, config);
try {
localStorage.setItem('embeddingConfig', JSON.stringify(this.config));
} catch (e) {
console.error('[EmbeddingClient] 保存配置失败:', e);
}
};
/**
* 获取文本的向量表示
*/
EmbeddingClient.prototype.embed = async function(input) {
if (!this.config.enabled || !this.config.apiKey) {
throw new Error('Embedding API 未配置或未启用');
}
const isBatch = Array.isArray(input);
const texts = isBatch ? input : [input];
// 检查缓存
const cachedResults = [];
const uncachedTexts = [];
const uncachedIndices = [];
texts.forEach((text, idx) => {
const cacheKey = this.getCacheKey(text);
if (this.cache.has(cacheKey)) {
cachedResults[idx] = this.cache.get(cacheKey);
} else {
uncachedTexts.push(text);
uncachedIndices.push(idx);
}
});
// 如果全部命中缓存
if (uncachedTexts.length === 0) {
return isBatch ? cachedResults : cachedResults[0];
}
// 调用API
const requestBody = {
model: this.config.model,
input: uncachedTexts
};
// 根据服务商添加特定参数
const provider = this.config.provider || 'openai';
if (provider === 'openai') {
// OpenAI 支持 encoding_format 和 dimensions
requestBody.encoding_format = 'float';
// 对于支持降维的模型(如 OpenAI text-embedding-3-*
if (this.config.dimensions && this.config.dimensions < 1536) {
requestBody.dimensions = this.config.dimensions;
}
} else if (provider === 'alibaba') {
// 阿里云百炼支持 dimensions
if (this.config.dimensions) {
requestBody.dimensions = this.config.dimensions;
}
}
// Jina AI 和其他服务商不需要额外参数
try {
const response = await this._fetchWithRetry(this.config.endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${this.config.apiKey}`
},
body: JSON.stringify(requestBody)
}, { maxRetries: 3, baseDelay: 600, maxDelay: 5000 });
if (!response.ok) {
const errText = await response.text();
const err = new Error(`Embedding API 错误 (${response.status}): ${errText}`);
err.status = response.status;
// 401/403 等鉴权问题认定为不可重试
err.retryable = this._shouldRetry(response.status);
throw err;
}
const data = await response.json();
// OpenAI 格式响应: { data: [{ embedding: [...] }], usage: {...} }
const embeddings = data.data.map(item => item.embedding);
// 缓存结果
uncachedTexts.forEach((text, idx) => {
const cacheKey = this.getCacheKey(text);
this.cache.set(cacheKey, embeddings[idx]);
cachedResults[uncachedIndices[idx]] = embeddings[idx];
});
console.log(`[EmbeddingClient] 成功生成 ${embeddings.length} 个向量使用token: ${(data && data.usage && data.usage.total_tokens) || '未知'}`);
return isBatch ? cachedResults : cachedResults[0];
} catch (error) {
console.error('[EmbeddingClient] 调用API失败:', error);
const e = new Error(error.message || 'Embedding 调用失败');
e.status = error.status;
e.retryable = error.retryable;
throw e;
}
};
EmbeddingClient.prototype.getCacheKey = function(text) {
// 简单的哈希函数
let hash = 0;
for (let i = 0; i < text.length; i++) {
const char = text.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash; // Convert to 32bit integer
}
return `${this.config.model}_${hash}`;
};
/**
* 批量生成向量(自动分批 + 并发处理)
* @param {string[]} texts - 文本数组
* @param {Object} options - 选项
* @param {Function} options.onProgress - 进度回调 (current, total, message)
* @returns {Promise<number[][]>} 向量数组
*/
EmbeddingClient.prototype.batchEmbed = async function(texts, options = {}) {
const onProgress = (options && options.onProgress) ? options.onProgress : null;
const batches = [];
let currentBatch = [];
let currentTokens = 0;
for (const text of texts) {
// 粗略估算 token 数中文1字≈1.5token英文1词≈1token
const estimatedTokens = Math.ceil(text.length * 1.5);
if (currentTokens + estimatedTokens > this.config.maxBatchSize && currentBatch.length > 0) {
batches.push(currentBatch);
currentBatch = [text];
currentTokens = estimatedTokens;
} else {
currentBatch.push(text);
currentTokens += estimatedTokens;
}
}
if (currentBatch.length > 0) {
batches.push(currentBatch);
}
const concurrency = Math.max(1, Math.min(this.config.concurrency || 5, 50));
console.log(`[EmbeddingClient] 分为 ${batches.length} 批次,并发数: ${concurrency}`);
// 并发处理批次
const results = new Array(batches.length);
let nextIndex = 0;
let completedCount = 0;
let abortAll = false; // 硬错误(如 401/403时中止
async function processNext(self) {
const i = nextIndex++;
if (i >= batches.length) return;
if (abortAll) return; // 已经判定为硬错误,停止排队
console.log(`[EmbeddingClient] 处理批次 ${i + 1}/${batches.length}`);
try {
results[i] = await self.embed(batches[i]);
} catch (err) {
// 对于 401/403停止继续调度新的批次但保留已在飞的任务返回部分结果
if (err && (err.status === 401 || err.status === 403)) {
abortAll = true;
results[i] = new Array(batches[i].length).fill(null);
if (onProgress && typeof onProgress === 'function') {
onProgress(completedCount, batches.length, `鉴权失败 (${err.status}),停止新的批次,保留部分结果`);
}
// 不抛出,让其余并发任务自然结束,返回部分结果
} else {
// 其他错误(网络/429/5xx在 _fetchWithRetry 已重试,此处标记该批失败并继续
console.warn('[EmbeddingClient] 批次失败,已跳过:', (err && err.message) || err);
results[i] = new Array(batches[i].length).fill(null);
}
completedCount++;
if (onProgress && typeof onProgress === 'function') {
onProgress(completedCount, batches.length, `正在生成向量 ${completedCount}/${batches.length}`);
}
if (!abortAll) return processNext(self);
}
// 启动并发worker
const workers = [];
for (let i = 0; i < Math.min(concurrency, batches.length); i++) {
workers.push(processNext(this));
}
// 使用 Promise.allSettled 确保单个 worker 抛错不影响清理
const settled = await Promise.allSettled(workers);
const rejected = settled.find(r => r.status === 'rejected');
if (rejected) {
throw rejected.reason;
}
// 合并结果
const allEmbeddings = [];
for (const batchResult of results) {
// 允许 batchResult 为空(理论上不应),做兜底
if (Array.isArray(batchResult)) {
allEmbeddings.push(...batchResult);
}
}
return allEmbeddings;
};
/** 清空缓存 */
EmbeddingClient.prototype.clearCache = function() {
this.cache.clear();
console.log('[EmbeddingClient] 缓存已清空');
};
// 导出全局实例
window.EmbeddingClient = new EmbeddingClient();
console.log('[EmbeddingClient] Embedding客户端已加载');
}
})(window);