84 lines
2.2 KiB
Python
84 lines
2.2 KiB
Python
"""
|
||
GLM 适配器测试脚本
|
||
测试 GLMAdapter 的流式和非流式调用,包括联网搜索功能
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# Add project root to sys.path
|
||
root_dir = Path(__file__).parent.parent
|
||
sys.path.insert(0, str(root_dir))
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
from adapters.glm_adapter import GLMAdapter
|
||
from adapters.base import ChatCompletionRequest
|
||
|
||
load_dotenv()
|
||
|
||
|
||
async def test_stream():
|
||
"""测试流式调用(联网搜索)"""
|
||
adapter = GLMAdapter()
|
||
|
||
if not adapter.is_available():
|
||
print("错误:未配置 ZHIPU_API_KEY 或 GLM_API_KEY")
|
||
return
|
||
|
||
request = ChatCompletionRequest(
|
||
model="glm-4.6v",
|
||
messages=[{"role": "user", "content": "今天北京天气怎样?"}],
|
||
stream=True,
|
||
temperature=0.7,
|
||
max_tokens=1024,
|
||
web_search=True,
|
||
)
|
||
|
||
print("Testing stream with web_search...")
|
||
response = await adapter.chat(request)
|
||
|
||
# 流式响应是 StreamingResponse,需要手动读取
|
||
async for chunk in response.body_iterator:
|
||
# body_iterator 已经返回字符串
|
||
print(chunk, end="")
|
||
|
||
|
||
async def test_sync():
|
||
"""测试非流式调用(联网搜索)"""
|
||
adapter = GLMAdapter()
|
||
|
||
if not adapter.is_available():
|
||
print("错误:未配置 ZHIPU_API_KEY 或 GLM_API_KEY")
|
||
return
|
||
|
||
request = ChatCompletionRequest(
|
||
model="glm-4-flash",
|
||
messages=[{"role": "user", "content": "今天几号?武汉天气怎样?"}],
|
||
stream=False,
|
||
temperature=0.7,
|
||
max_tokens=1024,
|
||
web_search=True,
|
||
)
|
||
|
||
print("Testing sync with web_search...")
|
||
response = await adapter.chat(request)
|
||
|
||
# 非流式响应返回 JSONResponse
|
||
if hasattr(response, "body"):
|
||
import json
|
||
|
||
data = json.loads(response.body)
|
||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||
print(f"Response: {content}")
|
||
else:
|
||
print(f"Response: {response}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 运行流式测试
|
||
# asyncio.run(test_stream())
|
||
# 运行非流式测试
|
||
asyncio.run(test_sync()) |