ai-chat-ui/server/adapters/registry.py

121 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
适配器注册表
根据模型名称路由到对应的平台适配器
"""
import os
from typing import Dict, Optional, Type
from .base import BaseAdapter
# 模型前缀到平台名称的映射
MODEL_PREFIX_MAP = {
# 智谱 GLM
"glm": "glm",
# 阿里云百炼Qwen 系列)
"qwen": "dashscope",
# OpenAI
"gpt": "openai",
"o1": "openai",
"o3": "openai",
# Deepseek
"deepseek": "deepseek",
}
# 已注册的适配器实例
_adapters: Dict[str, BaseAdapter] = {}
# 已注册的适配器类
_adapter_classes: Dict[str, Type[BaseAdapter]] = {}
def register_adapter(name: str, adapter_class: Type[BaseAdapter]):
"""
注册适配器类
Args:
name: 平台名称(如 'glm', 'dashscope', 'openai'
adapter_class: 适配器类
"""
_adapter_classes[name] = adapter_class
def get_provider_from_model(model: str) -> str:
"""
根据模型名称判断所属平台
Args:
model: 模型 ID'glm-4-flash', 'qwen-turbo', 'gpt-4'
Returns:
平台名称(如 'glm', 'dashscope', 'openai'
"""
model_lower = model.lower()
# 优先精确匹配
exact_matches = {
# GLM 精确模型名
"glm-4": "glm",
"glm-4v": "glm",
# Deepseek
"deepseek-chat": "deepseek",
"deepseek-reasoner": "deepseek",
}
if model_lower in exact_matches:
return exact_matches[model_lower]
# 前缀匹配
for prefix, provider in MODEL_PREFIX_MAP.items():
if model_lower.startswith(prefix):
return provider
# 默认使用环境变量或 GLM
return os.getenv("DEFAULT_PROVIDER", "glm")
def get_adapter(provider: str) -> Optional[BaseAdapter]:
"""
获取适配器实例(懒加载)
Args:
provider: 平台名称
Returns:
适配器实例,如果平台未注册则返回 None
"""
if provider in _adapters:
return _adapters[provider]
# 懒加载:首次使用时实例化
if provider in _adapter_classes:
adapter_class = _adapter_classes[provider]
adapter = adapter_class()
_adapters[provider] = adapter
return adapter
return None
def get_all_adapters() -> Dict[str, BaseAdapter]:
"""
获取所有已注册的适配器实例
"""
result = {}
for name, adapter_class in _adapter_classes.items():
if name not in _adapters:
_adapters[name] = adapter_class()
result[name] = _adapters[name]
return result
def get_available_providers() -> list:
"""
获取所有可用的平台列表
"""
providers = []
for name, adapter_class in _adapter_classes.items():
adapter = get_adapter(name)
if adapter and adapter.is_available():
providers.append(name)
return providers