ai-chat-ui/server/adapters/base.py

136 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
适配器基类定义
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class ModelInfo:
"""模型信息"""
id: str
name: str
description: str
max_tokens: int = 4096
provider: str = "unknown"
# 能力标志
supports_thinking: bool = False # 是否支持深度思考
supports_web_search: bool = False # 是否支持在线搜索
supports_vision: bool = False # 是否支持图片识别
supports_files: bool = False # 是否支持文件附件PDF、DOCX等
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"maxTokens": self.max_tokens,
"provider": self.provider,
"supports_thinking": self.supports_thinking,
"supports_web_Search": self.supports_web_search,
"supports_vision": self.supports_vision,
"supports_files": self.supports_files,
}
@dataclass
class ChatCompletionRequest:
"""OpenAI 格式的聊天请求"""
model: str
messages: List[Dict[str, Any]]
stream: bool = True
temperature: float = 0.7
max_tokens: int = 2000
files: Optional[List[str]] = None
deep_search: bool = False
web_search: bool = False
deep_thinking: bool = False
# 原始请求体(保留额外字段)
extra: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatCompletionRequest":
"""从字典创建请求对象"""
# 提取已知字段
known_fields = {
"model",
"messages",
"stream",
"temperature",
"max_tokens",
"files",
"deepSearch",
"webSearch",
"deepThinking",
}
extra = {k: v for k, v in data.items() if k not in known_fields}
return cls(
model=data.get("model", "glm-4-flash"),
messages=data.get("messages", []),
stream=data.get("stream", True),
temperature=data.get("temperature", 0.7),
max_tokens=data.get("max_tokens", data.get("maxTokens", 2000)),
files=data.get("files"),
deep_search=data.get("deepSearch", False),
web_search=data.get("webSearch", False),
deep_thinking=data.get("deepThinking", False),
extra=extra,
)
class BaseAdapter(ABC):
"""
LLM 平台适配器基类
所有平台适配器需继承此类并实现抽象方法
"""
@property
@abstractmethod
def provider_name(self) -> str:
"""返回平台名称(如 'glm', 'dashscope', 'openai'"""
pass
@abstractmethod
async def chat(self, request: ChatCompletionRequest):
"""
处理聊天请求
Args:
request: OpenAI 格式的聊天请求
Returns:
流式响应返回 StreamingResponse
非流式返回 JSONResponse 或 dict
"""
pass
@abstractmethod
def list_models(self) -> List[ModelInfo]:
"""
返回该平台支持的模型列表
Returns:
ModelInfo 对象列表
"""
pass
def is_available(self) -> bool:
"""
检查该适配器是否可用API Key 是否配置)
默认实现:检查环境变量中的 API Key
"""
return True
def get_models_response(self) -> Dict[str, Any]:
"""返回 OpenAI 格式的模型列表响应"""
models = self.list_models()
return {
"object": "list",
"data": [m.to_dict() for m in models],
}