fix(oauth): Harden Claude OAuth cache-control handling (#1583)
This commit is contained in:
parent
fc7de7fffe
commit
5ceb19f6f6
|
|
@ -27,6 +27,7 @@ from typing import Any
|
||||||
import anthropic
|
import anthropic
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -56,8 +57,8 @@ class ClaudeChatModel(ChatAnthropic):
|
||||||
prompt_cache_size: int = 3
|
prompt_cache_size: int = 3
|
||||||
auto_thinking_budget: bool = True
|
auto_thinking_budget: bool = True
|
||||||
retry_max_attempts: int = MAX_RETRIES
|
retry_max_attempts: int = MAX_RETRIES
|
||||||
_is_oauth: bool = False
|
_is_oauth: bool = PrivateAttr(default=False)
|
||||||
_oauth_access_token: str = ""
|
_oauth_access_token: str = PrivateAttr(default="")
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
|
@ -244,6 +245,39 @@ class ClaudeChatModel(ChatAnthropic):
|
||||||
max_tokens = payload.get("max_tokens", 8192)
|
max_tokens = payload.get("max_tokens", 8192)
|
||||||
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
|
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_cache_control(payload: dict) -> None:
|
||||||
|
"""Remove cache_control markers before OAuth requests reach Anthropic."""
|
||||||
|
for section in ("system", "messages"):
|
||||||
|
items = payload.get(section)
|
||||||
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item.pop("cache_control", None)
|
||||||
|
content = item.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict):
|
||||||
|
block.pop("cache_control", None)
|
||||||
|
|
||||||
|
tools = payload.get("tools")
|
||||||
|
if isinstance(tools, list):
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, dict):
|
||||||
|
tool.pop("cache_control", None)
|
||||||
|
|
||||||
|
def _create(self, payload: dict) -> Any:
|
||||||
|
if self._is_oauth:
|
||||||
|
self._strip_cache_control(payload)
|
||||||
|
return super()._create(payload)
|
||||||
|
|
||||||
|
async def _acreate(self, payload: dict) -> Any:
|
||||||
|
if self._is_oauth:
|
||||||
|
self._strip_cache_control(payload)
|
||||||
|
return await super()._acreate(payload)
|
||||||
|
|
||||||
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
||||||
"""Override with OAuth patching and retry logic."""
|
"""Override with OAuth patching and retry logic."""
|
||||||
if self._is_oauth:
|
if self._is_oauth:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
"""Tests for ClaudeChatModel._apply_oauth_billing."""
|
"""Tests for ClaudeChatModel._apply_oauth_billing."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -108,3 +110,45 @@ def test_metadata_non_dict_replaced_with_dict(model):
|
||||||
model._apply_oauth_billing(payload)
|
model._apply_oauth_billing(payload)
|
||||||
assert isinstance(payload["metadata"], dict)
|
assert isinstance(payload["metadata"], dict)
|
||||||
assert "user_id" in payload["metadata"]
|
assert "user_id" in payload["metadata"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_create_strips_cache_control_from_oauth_payload(model):
|
||||||
|
payload = {
|
||||||
|
"system": [{"type": "text", "text": "sys", "cache_control": {"type": "ephemeral"}}],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "hi", "cache_control": {"type": "ephemeral"}}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [{"name": "demo", "input_schema": {"type": "object"}, "cache_control": {"type": "ephemeral"}}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch.object(model._client.messages, "create", return_value=object()) as create:
|
||||||
|
model._create(payload)
|
||||||
|
|
||||||
|
sent_payload = create.call_args.kwargs
|
||||||
|
assert "cache_control" not in sent_payload["system"][0]
|
||||||
|
assert "cache_control" not in sent_payload["messages"][0]["content"][0]
|
||||||
|
assert "cache_control" not in sent_payload["tools"][0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_create_strips_cache_control_from_oauth_payload(model):
|
||||||
|
payload = {
|
||||||
|
"system": [{"type": "text", "text": "sys", "cache_control": {"type": "ephemeral"}}],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "hi", "cache_control": {"type": "ephemeral"}}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [{"name": "demo", "input_schema": {"type": "object"}, "cache_control": {"type": "ephemeral"}}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch.object(model._async_client.messages, "create", new=mock.AsyncMock(return_value=object())) as create:
|
||||||
|
asyncio.run(model._acreate(payload))
|
||||||
|
|
||||||
|
sent_payload = create.call_args.kwargs
|
||||||
|
assert "cache_control" not in sent_payload["system"][0]
|
||||||
|
assert "cache_control" not in sent_payload["messages"][0]["content"][0]
|
||||||
|
assert "cache_control" not in sent_payload["tools"][0]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue