from __future__ import annotations import asyncio from types import SimpleNamespace import pytest from langchain_core.messages import AIMessage from langgraph.errors import GraphBubbleUp from deerflow.agents.middlewares.llm_error_handling_middleware import ( LLMErrorHandlingMiddleware, ) class FakeError(Exception): def __init__( self, message: str, *, status_code: int | None = None, code: str | None = None, headers: dict[str, str] | None = None, body: dict | None = None, ) -> None: super().__init__(message) self.status_code = status_code self.code = code self.body = body self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware: middleware = LLMErrorHandlingMiddleware() for key, value in attrs.items(): setattr(middleware, key, value) return middleware def test_async_model_call_retries_busy_provider_then_succeeds( monkeypatch: pytest.MonkeyPatch, ) -> None: middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25) attempts = 0 waits: list[float] = [] events: list[dict] = [] async def fake_sleep(delay: float) -> None: waits.append(delay) def fake_writer(): return events.append async def handler(_request) -> AIMessage: nonlocal attempts attempts += 1 if attempts < 3: raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)") return AIMessage(content="ok") monkeypatch.setattr("asyncio.sleep", fake_sleep) monkeypatch.setattr( "langgraph.config.get_stream_writer", fake_writer, ) result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) assert isinstance(result, AIMessage) assert result.content == "ok" assert attempts == 3 assert waits == [0.025, 0.025] assert [event["type"] for event in events] == ["llm_retry", "llm_retry"] def test_async_model_call_returns_user_message_for_quota_errors() -> None: middleware = _build_middleware(retry_max_attempts=3) async def handler(_request) -> AIMessage: raise FakeError( "insufficient_quota: account balance is empty", status_code=429, code="insufficient_quota", ) result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) assert isinstance(result, AIMessage) assert "out of quota" in str(result.content) def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None: middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10) waits: list[float] = [] attempts = 0 def fake_sleep(delay: float) -> None: waits.append(delay) def handler(_request) -> AIMessage: nonlocal attempts attempts += 1 if attempts == 1: raise FakeError( "server busy", status_code=503, headers={"Retry-After": "2"}, ) return AIMessage(content="ok") monkeypatch.setattr("time.sleep", fake_sleep) result = middleware.wrap_model_call(SimpleNamespace(), handler) assert isinstance(result, AIMessage) assert result.content == "ok" assert waits == [2.0] def test_sync_model_call_propagates_graph_bubble_up() -> None: middleware = _build_middleware() def handler(_request) -> AIMessage: raise GraphBubbleUp() with pytest.raises(GraphBubbleUp): middleware.wrap_model_call(SimpleNamespace(), handler) def test_async_model_call_propagates_graph_bubble_up() -> None: middleware = _build_middleware() async def handler(_request) -> AIMessage: raise GraphBubbleUp() with pytest.raises(GraphBubbleUp): asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))