121 lines
2.8 KiB
Python
121 lines
2.8 KiB
Python
"""
|
||
适配器注册表
|
||
根据模型名称路由到对应的平台适配器
|
||
"""
|
||
|
||
import os
|
||
from typing import Dict, Optional, Type
|
||
|
||
from .base import BaseAdapter
|
||
|
||
# 模型前缀到平台名称的映射
|
||
MODEL_PREFIX_MAP = {
|
||
# 智谱 GLM
|
||
"glm-": "glm",
|
||
# 阿里云百炼(Qwen 系列)
|
||
"qwen-": "dashscope",
|
||
# OpenAI
|
||
"gpt-": "openai",
|
||
"o1-": "openai",
|
||
"o3-": "openai",
|
||
# Deepseek
|
||
"deepseek-": "deepseek",
|
||
}
|
||
|
||
# 已注册的适配器实例
|
||
_adapters: Dict[str, BaseAdapter] = {}
|
||
|
||
# 已注册的适配器类
|
||
_adapter_classes: Dict[str, Type[BaseAdapter]] = {}
|
||
|
||
|
||
def register_adapter(name: str, adapter_class: Type[BaseAdapter]):
|
||
"""
|
||
注册适配器类
|
||
|
||
Args:
|
||
name: 平台名称(如 'glm', 'dashscope', 'openai')
|
||
adapter_class: 适配器类
|
||
"""
|
||
_adapter_classes[name] = adapter_class
|
||
|
||
|
||
def get_provider_from_model(model: str) -> str:
|
||
"""
|
||
根据模型名称判断所属平台
|
||
|
||
Args:
|
||
model: 模型 ID(如 'glm-4-flash', 'qwen-turbo', 'gpt-4')
|
||
|
||
Returns:
|
||
平台名称(如 'glm', 'dashscope', 'openai')
|
||
"""
|
||
model_lower = model.lower()
|
||
|
||
# 优先精确匹配
|
||
exact_matches = {
|
||
# GLM 精确模型名
|
||
"glm-4": "glm",
|
||
"glm-4v": "glm",
|
||
# Deepseek
|
||
"deepseek-chat": "deepseek",
|
||
"deepseek-reasoner": "deepseek",
|
||
}
|
||
if model_lower in exact_matches:
|
||
return exact_matches[model_lower]
|
||
|
||
# 前缀匹配
|
||
for prefix, provider in MODEL_PREFIX_MAP.items():
|
||
if model_lower.startswith(prefix):
|
||
return provider
|
||
|
||
# 默认使用环境变量或 GLM
|
||
return os.getenv("DEFAULT_PROVIDER", "glm")
|
||
|
||
|
||
def get_adapter(provider: str) -> Optional[BaseAdapter]:
|
||
"""
|
||
获取适配器实例(懒加载)
|
||
|
||
Args:
|
||
provider: 平台名称
|
||
|
||
Returns:
|
||
适配器实例,如果平台未注册则返回 None
|
||
"""
|
||
if provider in _adapters:
|
||
return _adapters[provider]
|
||
|
||
# 懒加载:首次使用时实例化
|
||
if provider in _adapter_classes:
|
||
adapter_class = _adapter_classes[provider]
|
||
adapter = adapter_class()
|
||
_adapters[provider] = adapter
|
||
return adapter
|
||
|
||
return None
|
||
|
||
|
||
def get_all_adapters() -> Dict[str, BaseAdapter]:
|
||
"""
|
||
获取所有已注册的适配器实例
|
||
"""
|
||
result = {}
|
||
for name, adapter_class in _adapter_classes.items():
|
||
if name not in _adapters:
|
||
_adapters[name] = adapter_class()
|
||
result[name] = _adapters[name]
|
||
return result
|
||
|
||
|
||
def get_available_providers() -> list:
|
||
"""
|
||
获取所有可用的平台列表
|
||
"""
|
||
providers = []
|
||
for name, adapter_class in _adapter_classes.items():
|
||
adapter = get_adapter(name)
|
||
if adapter and adapter.is_available():
|
||
providers.append(name)
|
||
return providers
|