deerflow2/src/rag/vikingdb_knowledge_base.py

319 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import asyncio
import hashlib
import hmac
import json
import os
import urllib.parse
from datetime import datetime
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class VikingDBKnowledgeBaseProvider(Retriever):
"""
VikingDBKnowledgeBaseProvider is a provider that uses VikingDB Knowledge base API to retrieve documents.
"""
api_url: str
api_ak: str
api_sk: str
retrieval_size: int = 10
region: str = "cn-north-1"
service: str = "air"
def __init__(self):
api_url = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_URL")
if not api_url:
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_URL is not set")
self.api_url = api_url
api_ak = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_AK")
if not api_ak:
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_AK is not set")
self.api_ak = api_ak
api_sk = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_SK")
if not api_sk:
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_SK is not set")
self.api_sk = api_sk
retrieval_size = os.getenv("VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE")
if retrieval_size:
self.retrieval_size = int(retrieval_size)
# 设置region如果需要可以从环境变量获取
region = os.getenv("VIKINGDB_KNOWLEDGE_BASE_REGION", "cn-north-1")
self.region = region
def _hmac_sha256(self, key: bytes, content: str) -> bytes:
return hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()
def _hash_sha256(self, data: bytes) -> bytes:
return hashlib.sha256(data).digest()
def _get_signed_key(
self, secret_key: str, date: str, region: str, service: str
) -> bytes:
k_date = self._hmac_sha256(secret_key.encode("utf-8"), date)
k_region = self._hmac_sha256(k_date, region)
k_service = self._hmac_sha256(k_region, service)
k_signing = self._hmac_sha256(k_service, "request")
return k_signing
def _create_canonical_request(
self, method: str, path: str, query_params: dict, headers: dict, payload: bytes
) -> str:
canonical_method = method.upper()
canonical_uri = path if path else "/"
if query_params:
encoded_params = []
for key in sorted(query_params.keys()):
value = query_params[key]
encoded_key = urllib.parse.quote(str(key), safe="")
encoded_value = urllib.parse.quote(str(value), safe="")
encoded_params.append(f"{encoded_key}={encoded_value}")
canonical_query_string = "&".join(encoded_params)
else:
canonical_query_string = ""
canonical_headers_list = []
signed_headers_list = []
for header_name in sorted(headers.keys(), key=str.lower):
header_name_lower = header_name.lower()
header_value = str(headers[header_name]).strip()
canonical_headers_list.append(f"{header_name_lower}:{header_value}")
signed_headers_list.append(header_name_lower)
canonical_headers = "\n".join(canonical_headers_list) + "\n"
signed_headers = ";".join(signed_headers_list)
payload_hash = self._hash_sha256(payload).hex()
canonical_request = "\n".join(
[
canonical_method,
canonical_uri,
canonical_query_string,
canonical_headers,
signed_headers,
payload_hash,
]
)
return canonical_request, signed_headers
def _create_signature(
self, method: str, path: str, query_params: dict, headers: dict, payload: bytes
) -> str:
now = datetime.utcnow()
date_stamp = now.strftime("%Y%m%dT%H%M%SZ")
auth_date = date_stamp[:8]
headers["X-Date"] = date_stamp
headers["Host"] = self.api_url.replace("https://", "").replace("http://", "")
headers["X-Content-Sha256"] = self._hash_sha256(payload).hex()
headers["Content-Type"] = "application/json"
canonical_request, signed_headers = self._create_canonical_request(
method, path, query_params, headers, payload
)
algorithm = "HMAC-SHA256"
credential_scope = f"{auth_date}/{self.region}/{self.service}/request"
canonical_request_hash = self._hash_sha256(
canonical_request.encode("utf-8")
).hex()
string_to_sign = "\n".join(
[algorithm, date_stamp, credential_scope, canonical_request_hash]
)
signing_key = self._get_signed_key(
self.api_sk, auth_date, self.region, self.service
)
signature = hmac.new(
signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
).hexdigest()
authorization = (
f"{algorithm} "
f"Credential={self.api_ak}/{credential_scope}, "
f"SignedHeaders={signed_headers}, "
f"Signature={signature}"
)
headers["Authorization"] = authorization
return headers
def _make_signed_request(
self, method: str, path: str, params: dict = None, data: dict = None
):
if data is None:
payload = b""
else:
payload = json.dumps(data).encode("utf-8")
if params is None:
params = {}
url = f"https://{self.api_url}{path}"
headers = {}
signed_headers = self._create_signature(method, path, params, headers, payload)
try:
response = requests.request(
method=method,
url=url,
headers=signed_headers,
params=params,
data=payload if payload else None,
timeout=30,
)
return response
except Exception as e:
raise ValueError(f"Request failed: {e}")
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Query relevant documents from the knowledge base
"""
if not resources:
return []
all_documents = {}
for resource in resources:
resource_id, document_id = parse_uri(resource.uri)
request_params = {
"resource_id": resource_id,
"query": query,
"limit": self.retrieval_size,
"dense_weight": 0.5,
"pre_processing": {
"need_instruction": True,
"rewrite": False,
"return_token_usage": True,
},
"post_processing": {
"rerank_switch": True,
"chunk_diffusion_count": 0,
"chunk_group": True,
"get_attachment_link": True,
},
}
if document_id:
doc_filter = {"op": "must", "field": "doc_id", "conds": [document_id]}
query_param = {"doc_filter": doc_filter}
request_params["query_param"] = query_param
path = "/api/knowledge/collection/search_knowledge"
# 使用新的签名请求方法
response = self._make_signed_request(
method="POST", path=path, data=request_params
)
try:
response_data = response.json()
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {e}")
if response_data["code"] != 0:
raise ValueError(
f"Failed to query documents from resource: {response_data['message']}"
)
rsp_data = response_data.get("data", {})
if "result_list" not in rsp_data:
continue
result_list = rsp_data["result_list"]
for item in result_list:
doc_info = item.get("doc_info", {})
doc_id = doc_info.get("doc_id")
if not doc_id:
continue
if doc_id not in all_documents:
all_documents[doc_id] = Document(
id=doc_id, title=doc_info.get("doc_name"), chunks=[]
)
chunk = Chunk(
content=item.get("content", ""), similarity=item.get("score", 0.0)
)
all_documents[doc_id].chunks.append(chunk)
return list(all_documents.values())
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def list_resources(self, query: str | None = None) -> list[Resource]:
"""
List resources (knowledge bases) from the knowledge base service
"""
path = "/api/knowledge/collection/list"
response = self._make_signed_request(method="POST", path=path)
try:
response_data = response.json()
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {e}")
if response_data["code"] != 0:
raise Exception(f"Failed to list resources: {response_data['message']}")
resources = []
rsp_data = response_data.get("data", {})
collection_list = rsp_data.get("collection_list", [])
for item in collection_list:
collection_name = item.get("collection_name", "")
description = item.get("description", "")
if query and query.lower() not in collection_name.lower():
continue
resource_id = item.get("resource_id", "")
resource = Resource(
uri=f"rag://dataset/{resource_id}",
title=collection_name,
description=description,
)
resources.append(resource)
return resources
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment