deerflow2/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py

241 lines
9.3 KiB
Python

"""SandboxAuditMiddleware - bash command security auditing."""
import json
import logging
import re
import shlex
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.agents.thread_state import ThreadState
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Command classification rules
# ---------------------------------------------------------------------------
# Each pattern is compiled once at import time.
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"), # rm -rf / /* ~ /home /root
re.compile(r"(curl|wget).+\|\s*(ba)?sh"), # curl|sh, wget|sh
re.compile(r"dd\s+if="),
re.compile(r"mkfs"),
re.compile(r"cat\s+/etc/shadow"),
re.compile(r">\s*/etc/"), # overwrite /etc/ files
]
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"chmod\s+777"), # overly permissive, but reversible
re.compile(r"pip\s+install"),
re.compile(r"pip3\s+install"),
re.compile(r"apt(-get)?\s+install"),
]
def _classify_command(command: str) -> str:
"""Return 'block', 'warn', or 'pass'."""
# Normalize for matching (collapse whitespace)
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(normalized):
return "block"
# Also try shlex-parsed tokens for high-risk detection
try:
tokens = shlex.split(command)
joined = " ".join(tokens)
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(joined):
return "block"
except ValueError:
# shlex.split fails on unclosed quotes — treat as suspicious
return "block"
for pattern in _MEDIUM_RISK_PATTERNS:
if pattern.search(normalized):
return "warn"
return "pass"
# ---------------------------------------------------------------------------
# Middleware
# ---------------------------------------------------------------------------
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
"""Bash command security auditing middleware.
For every ``bash`` tool call:
1. **Command classification**: regex + shlex analysis grades commands as
high-risk (block), medium-risk (warn), or safe (pass).
2. **Audit log**: every bash call is recorded as a structured JSON entry
via the standard logger (visible in langgraph.log).
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
the handler is not called and an error ``ToolMessage`` is returned so the
agent loop can continue gracefully.
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
normally; a warning is appended to the tool result so the LLM is aware.
"""
state_schema = ThreadState
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
runtime = request.runtime # ToolRuntime; may be None-like in tests
if runtime is None:
return None
ctx = getattr(runtime, "context", None) or {}
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
if thread_id is None:
cfg = getattr(runtime, "config", None) or {}
thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id
_AUDIT_COMMAND_LIMIT = 200
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
audited_command = command
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
record = {
"timestamp": datetime.now(UTC).isoformat(),
"thread_id": thread_id or "unknown",
"command": audited_command,
"verdict": verdict,
}
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
tool_call_id = str(request.tool_call.get("id") or "missing_id")
return ToolMessage(
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
tool_call_id=tool_call_id,
name="bash",
status="error",
)
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
"""Append a warning note to the tool result for medium-risk commands."""
if not isinstance(result, ToolMessage):
return result
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
if isinstance(result.content, list):
new_content = list(result.content) + [{"type": "text", "text": warning}]
else:
new_content = str(result.content) + warning
return ToolMessage(
content=new_content,
tool_call_id=result.tool_call_id,
name=result.name,
status=result.status,
)
# ------------------------------------------------------------------
# Input sanitisation
# ------------------------------------------------------------------
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
# Anything longer is almost certainly a payload injection or base64-encoded
# attack string.
_MAX_COMMAND_LENGTH = 10_000
def _validate_input(self, command: str) -> str | None:
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
if not command.strip():
return "empty command"
if len(command) > self._MAX_COMMAND_LENGTH:
return "command too long"
if "\x00" in command:
return "null byte detected"
return None
# ------------------------------------------------------------------
# Core logic (shared between sync and async paths)
# ------------------------------------------------------------------
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
"""
Returns (command, thread_id, verdict, reject_reason).
verdict is 'block', 'warn', or 'pass'.
reject_reason is non-None only for input sanitisation rejections.
"""
args = request.tool_call.get("args", {})
raw_command = args.get("command")
command = raw_command if isinstance(raw_command, str) else ""
thread_id = self._get_thread_id(request)
# ① input sanitisation — reject malformed input before regex analysis
reject_reason = self._validate_input(command)
if reject_reason:
self._write_audit(thread_id, command, "block", truncate=True)
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
return command, thread_id, "block", reject_reason
# ② classify command
verdict = _classify_command(command)
# ③ audit log
self._write_audit(thread_id, command, verdict)
if verdict == "block":
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
elif verdict == "warn":
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
return command, thread_id, verdict, None
# ------------------------------------------------------------------
# wrap_tool_call hooks
# ------------------------------------------------------------------
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
if request.tool_call.get("name") != "bash":
return handler(request)
command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block":
reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = handler(request)
if verdict == "warn":
result = self._append_warn_to_result(result, command)
return result
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
if request.tool_call.get("name") != "bash":
return await handler(request)
command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block":
reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = await handler(request)
if verdict == "warn":
result = self._append_warn_to_result(result, command)
return result