241 lines
9.3 KiB
Python
241 lines
9.3 KiB
Python
"""SandboxAuditMiddleware - bash command security auditing."""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import shlex
|
|
from collections.abc import Awaitable, Callable
|
|
from datetime import UTC, datetime
|
|
from typing import override
|
|
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
|
from langgraph.types import Command
|
|
|
|
from deerflow.agents.thread_state import ThreadState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Command classification rules
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Each pattern is compiled once at import time.
|
|
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
|
|
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"), # rm -rf / /* ~ /home /root
|
|
re.compile(r"(curl|wget).+\|\s*(ba)?sh"), # curl|sh, wget|sh
|
|
re.compile(r"dd\s+if="),
|
|
re.compile(r"mkfs"),
|
|
re.compile(r"cat\s+/etc/shadow"),
|
|
re.compile(r">\s*/etc/"), # overwrite /etc/ files
|
|
]
|
|
|
|
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
|
|
re.compile(r"chmod\s+777"), # overly permissive, but reversible
|
|
re.compile(r"pip\s+install"),
|
|
re.compile(r"pip3\s+install"),
|
|
re.compile(r"apt(-get)?\s+install"),
|
|
]
|
|
|
|
|
|
def _classify_command(command: str) -> str:
|
|
"""Return 'block', 'warn', or 'pass'."""
|
|
# Normalize for matching (collapse whitespace)
|
|
normalized = " ".join(command.split())
|
|
|
|
for pattern in _HIGH_RISK_PATTERNS:
|
|
if pattern.search(normalized):
|
|
return "block"
|
|
|
|
# Also try shlex-parsed tokens for high-risk detection
|
|
try:
|
|
tokens = shlex.split(command)
|
|
joined = " ".join(tokens)
|
|
for pattern in _HIGH_RISK_PATTERNS:
|
|
if pattern.search(joined):
|
|
return "block"
|
|
except ValueError:
|
|
# shlex.split fails on unclosed quotes — treat as suspicious
|
|
return "block"
|
|
|
|
for pattern in _MEDIUM_RISK_PATTERNS:
|
|
if pattern.search(normalized):
|
|
return "warn"
|
|
|
|
return "pass"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Middleware
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
|
"""Bash command security auditing middleware.
|
|
|
|
For every ``bash`` tool call:
|
|
1. **Command classification**: regex + shlex analysis grades commands as
|
|
high-risk (block), medium-risk (warn), or safe (pass).
|
|
2. **Audit log**: every bash call is recorded as a structured JSON entry
|
|
via the standard logger (visible in langgraph.log).
|
|
|
|
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
|
|
the handler is not called and an error ``ToolMessage`` is returned so the
|
|
agent loop can continue gracefully.
|
|
|
|
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
|
|
normally; a warning is appended to the tool result so the LLM is aware.
|
|
"""
|
|
|
|
state_schema = ThreadState
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
|
runtime = request.runtime # ToolRuntime; may be None-like in tests
|
|
if runtime is None:
|
|
return None
|
|
ctx = getattr(runtime, "context", None) or {}
|
|
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
|
|
if thread_id is None:
|
|
cfg = getattr(runtime, "config", None) or {}
|
|
thread_id = cfg.get("configurable", {}).get("thread_id")
|
|
return thread_id
|
|
|
|
_AUDIT_COMMAND_LIMIT = 200
|
|
|
|
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
|
|
audited_command = command
|
|
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
|
|
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
|
|
record = {
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"thread_id": thread_id or "unknown",
|
|
"command": audited_command,
|
|
"verdict": verdict,
|
|
}
|
|
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
|
|
|
|
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
|
|
tool_call_id = str(request.tool_call.get("id") or "missing_id")
|
|
return ToolMessage(
|
|
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
|
|
tool_call_id=tool_call_id,
|
|
name="bash",
|
|
status="error",
|
|
)
|
|
|
|
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
|
|
"""Append a warning note to the tool result for medium-risk commands."""
|
|
if not isinstance(result, ToolMessage):
|
|
return result
|
|
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
|
|
if isinstance(result.content, list):
|
|
new_content = list(result.content) + [{"type": "text", "text": warning}]
|
|
else:
|
|
new_content = str(result.content) + warning
|
|
return ToolMessage(
|
|
content=new_content,
|
|
tool_call_id=result.tool_call_id,
|
|
name=result.name,
|
|
status=result.status,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Input sanitisation
|
|
# ------------------------------------------------------------------
|
|
|
|
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
|
|
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
|
|
# Anything longer is almost certainly a payload injection or base64-encoded
|
|
# attack string.
|
|
_MAX_COMMAND_LENGTH = 10_000
|
|
|
|
def _validate_input(self, command: str) -> str | None:
|
|
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
|
|
if not command.strip():
|
|
return "empty command"
|
|
if len(command) > self._MAX_COMMAND_LENGTH:
|
|
return "command too long"
|
|
if "\x00" in command:
|
|
return "null byte detected"
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Core logic (shared between sync and async paths)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
|
|
"""
|
|
Returns (command, thread_id, verdict, reject_reason).
|
|
verdict is 'block', 'warn', or 'pass'.
|
|
reject_reason is non-None only for input sanitisation rejections.
|
|
"""
|
|
args = request.tool_call.get("args", {})
|
|
raw_command = args.get("command")
|
|
command = raw_command if isinstance(raw_command, str) else ""
|
|
thread_id = self._get_thread_id(request)
|
|
|
|
# ① input sanitisation — reject malformed input before regex analysis
|
|
reject_reason = self._validate_input(command)
|
|
if reject_reason:
|
|
self._write_audit(thread_id, command, "block", truncate=True)
|
|
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
|
|
return command, thread_id, "block", reject_reason
|
|
|
|
# ② classify command
|
|
verdict = _classify_command(command)
|
|
|
|
# ③ audit log
|
|
self._write_audit(thread_id, command, verdict)
|
|
|
|
if verdict == "block":
|
|
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
|
|
elif verdict == "warn":
|
|
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
|
|
|
|
return command, thread_id, verdict, None
|
|
|
|
# ------------------------------------------------------------------
|
|
# wrap_tool_call hooks
|
|
# ------------------------------------------------------------------
|
|
|
|
@override
|
|
def wrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
) -> ToolMessage | Command:
|
|
if request.tool_call.get("name") != "bash":
|
|
return handler(request)
|
|
|
|
command, _, verdict, reject_reason = self._pre_process(request)
|
|
if verdict == "block":
|
|
reason = reject_reason or "security violation detected"
|
|
return self._build_block_message(request, reason)
|
|
result = handler(request)
|
|
if verdict == "warn":
|
|
result = self._append_warn_to_result(result, command)
|
|
return result
|
|
|
|
@override
|
|
async def awrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
) -> ToolMessage | Command:
|
|
if request.tool_call.get("name") != "bash":
|
|
return await handler(request)
|
|
|
|
command, _, verdict, reject_reason = self._pre_process(request)
|
|
if verdict == "block":
|
|
reason = reject_reason or "security violation detected"
|
|
return self._build_block_message(request, reason)
|
|
result = await handler(request)
|
|
if verdict == "warn":
|
|
result = self._append_warn_to_result(result, command)
|
|
return result
|