366 lines
16 KiB
Python
366 lines
16 KiB
Python
"""MCP (Model Context Protocol) Client — connects to external MCP servers.
|
|
|
|
Supports two transport modes:
|
|
1. Streamable HTTP (modern) — single URL, POST JSON-RPC, response as JSON or SSE
|
|
2. SSE Transport (legacy but widely used) — GET /sse for event stream, POST /messages for requests
|
|
|
|
Transport is auto-detected: tries Streamable HTTP first, falls back to SSE.
|
|
Reference: https://modelcontextprotocol.io/docs
|
|
"""
|
|
|
|
import httpx
|
|
import json
|
|
import asyncio
|
|
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
|
|
|
|
from loguru import logger
|
|
|
|
|
|
class MCPClient:
|
|
"""Client for connecting to MCP servers via Streamable HTTP or SSE transport.
|
|
|
|
Auto-detects the transport mode on first request.
|
|
"""
|
|
|
|
def __init__(self, server_url: str, api_key: str | None = None):
|
|
# Extract apiKey from URL query params and move to Authorization header
|
|
parsed = urlparse(server_url)
|
|
qs = parse_qs(parsed.query, keep_blank_values=True)
|
|
|
|
self.api_key = api_key
|
|
if not self.api_key and "apiKey" in qs:
|
|
self.api_key = qs.pop("apiKey")[0]
|
|
|
|
# Rebuild URL without apiKey in query string
|
|
remaining_qs = urlencode({k: v[0] for k, v in qs.items()}) if qs else ""
|
|
self.server_url = urlunparse(parsed._replace(query=remaining_qs)).rstrip("/")
|
|
|
|
# Transport state
|
|
self._transport: str | None = None # "streamable" or "sse"
|
|
self._session_id: str | None = None
|
|
self._sse_messages_url: str | None = None # POST endpoint for SSE transport
|
|
|
|
def _headers(self) -> dict:
|
|
"""Build request headers with proper MCP and auth headers."""
|
|
h = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json, text/event-stream",
|
|
}
|
|
if self.api_key:
|
|
h["Authorization"] = f"Bearer {self.api_key}"
|
|
if self._session_id:
|
|
h["Mcp-Session-Id"] = self._session_id
|
|
return h
|
|
|
|
def _parse_response(self, resp: httpx.Response) -> dict:
|
|
"""Parse response — handles both JSON and SSE (text/event-stream) formats."""
|
|
content_type = resp.headers.get("content-type", "")
|
|
|
|
# Save session ID if the server returns one
|
|
session_id = resp.headers.get("mcp-session-id")
|
|
if session_id:
|
|
self._session_id = session_id
|
|
|
|
if "text/event-stream" in content_type:
|
|
return self._parse_sse_response(resp.text)
|
|
else:
|
|
return resp.json()
|
|
|
|
def _parse_sse_response(self, text: str) -> dict:
|
|
"""Extract the last JSON-RPC result from an SSE stream."""
|
|
last_data = None
|
|
for line in text.splitlines():
|
|
if line.startswith("data:"):
|
|
raw = line[5:].strip()
|
|
if raw and raw != "[DONE]":
|
|
try:
|
|
last_data = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
if last_data is None:
|
|
raise Exception("No valid JSON found in SSE response")
|
|
return last_data
|
|
|
|
# ── Streamable HTTP Transport ────────────────────────────────
|
|
|
|
async def _streamable_initialize(self, client: httpx.AsyncClient) -> None:
|
|
"""Send MCP initialize + initialized handshake (Streamable HTTP)."""
|
|
try:
|
|
resp = await client.post(
|
|
self.server_url,
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 0,
|
|
"method": "initialize",
|
|
"params": {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {},
|
|
"clientInfo": {"name": "clawith", "version": "1.0"},
|
|
},
|
|
},
|
|
headers=self._headers(),
|
|
)
|
|
if resp.status_code == 200:
|
|
self._parse_response(resp) # captures Mcp-Session-Id if present
|
|
# Send initialized notification (required by MCP spec before other requests)
|
|
await client.post(
|
|
self.server_url,
|
|
json={"jsonrpc": "2.0", "method": "notifications/initialized"},
|
|
headers=self._headers(),
|
|
)
|
|
except Exception:
|
|
pass # initialization failure is non-fatal — server may be stateless
|
|
|
|
async def _streamable_request(self, method: str, params: dict | None = None) -> dict:
|
|
"""Send a JSON-RPC request via Streamable HTTP transport."""
|
|
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
|
if not self._session_id:
|
|
await self._streamable_initialize(client)
|
|
|
|
body: dict = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params or {}}
|
|
|
|
resp = await client.post(self.server_url, json=body, headers=self._headers())
|
|
if resp.status_code not in (200, 201):
|
|
raise Exception(f"HTTP {resp.status_code}")
|
|
return self._parse_response(resp)
|
|
|
|
# ── SSE Transport ────────────────────────────────────────────
|
|
|
|
async def _sse_connect(self) -> str:
|
|
"""Connect to SSE endpoint (GET /sse) and extract the messages URL.
|
|
|
|
Returns the full POST URL for sending JSON-RPC messages.
|
|
"""
|
|
# Determine SSE URL: if server_url ends with /sse use it directly,
|
|
# otherwise append /sse
|
|
sse_url = self.server_url if self.server_url.endswith("/sse") else f"{self.server_url}/sse"
|
|
parsed = urlparse(sse_url)
|
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
|
|
headers = {"Accept": "text/event-stream"}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
messages_url = None
|
|
|
|
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
|
async with client.stream("GET", sse_url, headers=headers) as resp:
|
|
if resp.status_code != 200:
|
|
raise Exception(f"SSE connect failed: HTTP {resp.status_code}")
|
|
|
|
# Read SSE events until we get the endpoint event
|
|
async for line in resp.aiter_lines():
|
|
line = line.strip()
|
|
if line.startswith("event:"):
|
|
event_type = line[6:].strip()
|
|
elif line.startswith("data:"):
|
|
data = line[5:].strip()
|
|
if event_type == "endpoint" and data:
|
|
# data is typically a relative URL like /messages?sessionId=xxx
|
|
if data.startswith("http"):
|
|
messages_url = data
|
|
else:
|
|
messages_url = base_url + data
|
|
break
|
|
elif line == "":
|
|
# Empty line = end of SSE event block
|
|
pass
|
|
|
|
if not messages_url:
|
|
raise Exception("SSE endpoint did not return a messages URL")
|
|
|
|
return messages_url
|
|
|
|
async def _sse_request(self, method: str, params: dict | None = None) -> dict:
|
|
"""Send a JSON-RPC request via SSE transport.
|
|
|
|
Opens a fresh SSE connection each call to get the messages endpoint,
|
|
sends the JSON-RPC request, then reads responses from the SSE stream.
|
|
"""
|
|
# Connect to SSE to get the messages endpoint
|
|
sse_url = self.server_url if self.server_url.endswith("/sse") else f"{self.server_url}/sse"
|
|
parsed = urlparse(sse_url)
|
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
|
|
headers_sse = {"Accept": "text/event-stream"}
|
|
headers_post = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
|
|
if self.api_key:
|
|
headers_sse["Authorization"] = f"Bearer {self.api_key}"
|
|
headers_post["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
body: dict = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params or {}}
|
|
|
|
timeout = 60 if method == "tools/call" else 30
|
|
|
|
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
|
# Open the SSE stream
|
|
async with client.stream("GET", sse_url, headers=headers_sse) as sse_resp:
|
|
if sse_resp.status_code != 200:
|
|
raise Exception(f"SSE connect failed: HTTP {sse_resp.status_code}")
|
|
|
|
messages_url = None
|
|
event_type = ""
|
|
|
|
# Phase 1: Read until we get the endpoint event
|
|
line_iter = sse_resp.aiter_lines()
|
|
async for line in line_iter:
|
|
line = line.strip()
|
|
if line.startswith("event:"):
|
|
event_type = line[6:].strip()
|
|
elif line.startswith("data:"):
|
|
data = line[5:].strip()
|
|
if event_type == "endpoint" and data:
|
|
if data.startswith("http"):
|
|
messages_url = data
|
|
else:
|
|
messages_url = base_url + data
|
|
break
|
|
|
|
if not messages_url:
|
|
raise Exception("SSE endpoint did not return a messages URL")
|
|
|
|
# Phase 2: MCP handshake — initialize + initialized notification
|
|
init_body = {
|
|
"jsonrpc": "2.0", "id": 0, "method": "initialize",
|
|
"params": {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {},
|
|
"clientInfo": {"name": "clawith", "version": "1.0"},
|
|
},
|
|
}
|
|
await client.post(messages_url, json=init_body, headers=headers_post)
|
|
# Send initialized notification (required before other requests)
|
|
await client.post(
|
|
messages_url,
|
|
json={"jsonrpc": "2.0", "method": "notifications/initialized"},
|
|
headers=headers_post,
|
|
)
|
|
|
|
# Send the actual request
|
|
post_resp = await client.post(messages_url, json=body, headers=headers_post)
|
|
|
|
# Phase 3: Read the response — either from POST response or from SSE stream
|
|
if post_resp.status_code == 200:
|
|
ct = post_resp.headers.get("content-type", "")
|
|
if "application/json" in ct:
|
|
return post_resp.json()
|
|
|
|
# Read response from SSE stream
|
|
result = None
|
|
async for line in line_iter:
|
|
line = line.strip()
|
|
if line.startswith("event:"):
|
|
event_type = line[6:].strip()
|
|
elif line.startswith("data:"):
|
|
data = line[5:].strip()
|
|
if event_type == "message" and data:
|
|
try:
|
|
parsed_data = json.loads(data)
|
|
# Match our request ID
|
|
if isinstance(parsed_data, dict) and parsed_data.get("id") in (0, 1):
|
|
result = parsed_data
|
|
if parsed_data.get("id") == 1:
|
|
break # Got our actual request response
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
if result is None:
|
|
raise Exception("No response received from SSE transport")
|
|
return result
|
|
|
|
# ── Auto-detect Transport ────────────────────────────────────
|
|
|
|
async def _detect_and_request(self, method: str, params: dict | None = None) -> dict:
|
|
"""Auto-detect transport and send request.
|
|
|
|
Strategy: If transport is already known, use it directly.
|
|
Otherwise try Streamable HTTP first, fall back to SSE.
|
|
"""
|
|
if self._transport == "sse":
|
|
return await self._sse_request(method, params)
|
|
if self._transport == "streamable":
|
|
return await self._streamable_request(method, params)
|
|
|
|
# Auto-detect: try Streamable HTTP first
|
|
try:
|
|
result = await self._streamable_request(method, params)
|
|
self._transport = "streamable"
|
|
return result
|
|
except Exception as streamable_err:
|
|
logger.info(f"[MCPClient] Streamable HTTP failed ({streamable_err}), trying SSE transport...")
|
|
|
|
# Fallback to SSE
|
|
try:
|
|
result = await self._sse_request(method, params)
|
|
self._transport = "sse"
|
|
return result
|
|
except Exception as sse_err:
|
|
raise Exception(
|
|
f"Both transports failed. "
|
|
f"Streamable HTTP: {streamable_err}; "
|
|
f"SSE: {sse_err}"
|
|
)
|
|
|
|
# ── Public API ───────────────────────────────────────────────
|
|
|
|
async def list_tools(self) -> list[dict]:
|
|
"""Fetch available tools from the MCP server."""
|
|
try:
|
|
data = await self._detect_and_request("tools/list")
|
|
|
|
if "error" in data:
|
|
err = data["error"]
|
|
msg = err.get("message", str(err)) if isinstance(err, dict) else str(err)
|
|
raise Exception(f"MCP error: {msg}")
|
|
|
|
result = data.get("result", {})
|
|
tools = result.get("tools", []) if isinstance(result, dict) else []
|
|
return [
|
|
{
|
|
"name": t.get("name", ""),
|
|
"description": t.get("description", ""),
|
|
"inputSchema": t.get("inputSchema", {}),
|
|
}
|
|
for t in tools
|
|
]
|
|
except httpx.HTTPError as e:
|
|
raise Exception(f"Connection failed: {str(e)[:200]}")
|
|
|
|
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
|
"""Execute a tool on the MCP server."""
|
|
try:
|
|
data = await self._detect_and_request(
|
|
"tools/call",
|
|
{"name": tool_name, "arguments": arguments},
|
|
)
|
|
|
|
if "error" in data:
|
|
err = data["error"]
|
|
msg = err.get("message", str(err)) if isinstance(err, dict) else str(err)
|
|
return f"❌ MCP tool execution error: {msg[:200]}"
|
|
|
|
result = data.get("result", {})
|
|
if isinstance(result, str):
|
|
return result
|
|
|
|
# MCP returns content as list of content blocks
|
|
content_blocks = result.get("content", []) if isinstance(result, dict) else []
|
|
texts = []
|
|
for block in content_blocks:
|
|
if isinstance(block, str):
|
|
texts.append(block)
|
|
elif isinstance(block, dict):
|
|
if block.get("type") == "text":
|
|
texts.append(block.get("text", ""))
|
|
elif block.get("type") == "image":
|
|
texts.append(f"[Image: {block.get('mimeType', 'image')}]")
|
|
else:
|
|
texts.append(str(block))
|
|
else:
|
|
texts.append(str(block))
|
|
|
|
return "\n".join(texts) if texts else str(result)
|
|
|
|
except httpx.HTTPError as e:
|
|
return f"❌ MCP connection failed: {str(e)[:200]}"
|