136 lines
3.7 KiB
Python
136 lines
3.7 KiB
Python
"""
|
||
适配器基类定义
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
|
||
@dataclass
|
||
class ModelInfo:
|
||
"""模型信息"""
|
||
|
||
id: str
|
||
name: str
|
||
description: str
|
||
max_tokens: int = 4096
|
||
provider: str = "unknown"
|
||
# 能力标志
|
||
supports_thinking: bool = False # 是否支持深度思考
|
||
supports_web_search: bool = False # 是否支持在线搜索
|
||
supports_vision: bool = False # 是否支持图片识别
|
||
supports_files: bool = False # 是否支持文件附件(PDF、DOCX等)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"id": self.id,
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"maxTokens": self.max_tokens,
|
||
"provider": self.provider,
|
||
"supports_thinking": self.supports_thinking,
|
||
"supports_web_search": self.supports_web_search,
|
||
"supports_vision": self.supports_vision,
|
||
"supports_files": self.supports_files,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ChatCompletionRequest:
|
||
"""OpenAI 格式的聊天请求"""
|
||
|
||
model: str
|
||
messages: List[Dict[str, Any]]
|
||
stream: bool = True
|
||
temperature: float = 0.7
|
||
max_tokens: int = 2000
|
||
files: Optional[List[str]] = None
|
||
deep_search: bool = False
|
||
web_search: bool = False
|
||
deep_thinking: bool = False
|
||
# 原始请求体(保留额外字段)
|
||
extra: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> "ChatCompletionRequest":
|
||
"""从字典创建请求对象"""
|
||
# 提取已知字段
|
||
known_fields = {
|
||
"model",
|
||
"messages",
|
||
"stream",
|
||
"temperature",
|
||
"max_tokens",
|
||
"files",
|
||
"deepSearch",
|
||
"webSearch",
|
||
"deepThinking",
|
||
}
|
||
extra = {k: v for k, v in data.items() if k not in known_fields}
|
||
|
||
return cls(
|
||
model=data.get("model", "glm-4-flash"),
|
||
messages=data.get("messages", []),
|
||
stream=data.get("stream", True),
|
||
temperature=data.get("temperature", 0.7),
|
||
max_tokens=data.get("max_tokens", data.get("maxTokens", 2000)),
|
||
files=data.get("files"),
|
||
deep_search=data.get("deepSearch", False),
|
||
web_search=data.get("webSearch", False),
|
||
deep_thinking=data.get("deepThinking", False),
|
||
extra=extra,
|
||
)
|
||
|
||
|
||
class BaseAdapter(ABC):
|
||
"""
|
||
LLM 平台适配器基类
|
||
所有平台适配器需继承此类并实现抽象方法
|
||
"""
|
||
|
||
@property
|
||
@abstractmethod
|
||
def provider_name(self) -> str:
|
||
"""返回平台名称(如 'glm', 'dashscope', 'openai')"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def chat(self, request: ChatCompletionRequest):
|
||
"""
|
||
处理聊天请求
|
||
|
||
Args:
|
||
request: OpenAI 格式的聊天请求
|
||
|
||
Returns:
|
||
流式响应返回 StreamingResponse
|
||
非流式返回 JSONResponse 或 dict
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def list_models(self) -> List[ModelInfo]:
|
||
"""
|
||
返回该平台支持的模型列表
|
||
|
||
Returns:
|
||
ModelInfo 对象列表
|
||
"""
|
||
pass
|
||
|
||
def is_available(self) -> bool:
|
||
"""
|
||
检查该适配器是否可用(API Key 是否配置)
|
||
默认实现:检查环境变量中的 API Key
|
||
"""
|
||
return True
|
||
|
||
def get_models_response(self) -> Dict[str, Any]:
|
||
"""返回 OpenAI 格式的模型列表响应"""
|
||
models = self.list_models()
|
||
return {
|
||
"object": "list",
|
||
"data": [m.to_dict() for m in models],
|
||
}
|