feat(aio_sandbox): add extra_env parameter for thread_id injection in sandbox creation

This commit is contained in:
Titan 2026-04-17 18:38:02 +08:00
parent d337e46868
commit 77801c03ff
6 changed files with 167 additions and 6 deletions

View File

@ -514,7 +514,7 @@ class AioSandboxProvider(SandboxProvider):
# that is actively serving a thread. # that is actively serving a thread.
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit") logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None) info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None, extra_env={"THREAD_ID": thread_id} if thread_id else None)
# Wait for sandbox to be ready # Wait for sandbox to be ready
if not wait_for_sandbox_ready(info.sandbox_url, timeout=60): if not wait_for_sandbox_ready(info.sandbox_url, timeout=60):

View File

@ -44,7 +44,7 @@ class SandboxBackend(ABC):
""" """
@abstractmethod @abstractmethod
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None, extra_env: dict[str, str] | None = None) -> SandboxInfo:
"""Create/provision a new sandbox. """Create/provision a new sandbox.
Args: Args:
@ -52,6 +52,9 @@ class SandboxBackend(ABC):
sandbox_id: Deterministic sandbox identifier. sandbox_id: Deterministic sandbox identifier.
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples. extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
Ignored by backends that don't manage containers (e.g., remote). Ignored by backends that don't manage containers (e.g., remote).
extra_env: Additional environment variables to inject at runtime (e.g. THREAD_ID).
These are merged after static config env vars, so runtime values override same-key static values.
Ignored by backends that don't manage containers (e.g., remote).
Returns: Returns:
SandboxInfo with connection details. SandboxInfo with connection details.

View File

@ -110,7 +110,7 @@ class LocalContainerBackend(SandboxBackend):
# ── SandboxBackend interface ────────────────────────────────────────── # ── SandboxBackend interface ──────────────────────────────────────────
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None, extra_env: dict[str, str] | None = None) -> SandboxInfo:
"""Start a new container and return its connection info. """Start a new container and return its connection info.
Args: Args:
@ -137,7 +137,7 @@ class LocalContainerBackend(SandboxBackend):
for _attempt in range(10): for _attempt in range(10):
port = get_free_port(start_port=_next_start) port = get_free_port(start_port=_next_start)
try: try:
container_id = self._start_container(container_name, port, extra_mounts) container_id = self._start_container(container_name, port, extra_mounts, extra_env=extra_env)
break break
except RuntimeError as exc: except RuntimeError as exc:
release_port(port) release_port(port)
@ -229,6 +229,7 @@ class LocalContainerBackend(SandboxBackend):
container_name: str, container_name: str,
port: int, port: int,
extra_mounts: list[tuple[str, str, bool]] | None = None, extra_mounts: list[tuple[str, str, bool]] | None = None,
extra_env: dict[str, str] | None = None,
) -> str: ) -> str:
"""Start a new container. """Start a new container.
@ -260,9 +261,11 @@ class LocalContainerBackend(SandboxBackend):
] ]
) )
# Environment variables # Environment variables (static config first, runtime overrides last)
for key, value in self._environment.items(): for key, value in self._environment.items():
cmd.extend(["-e", f"{key}={value}"]) cmd.extend(["-e", f"{key}={value}"])
for key, value in (extra_env or {}).items():
cmd.extend(["-e", f"{key}={value}"])
# Config-level volume mounts # Config-level volume mounts
for mount in self._config_mounts: for mount in self._config_mounts:

View File

@ -60,6 +60,7 @@ class RemoteSandboxBackend(SandboxBackend):
thread_id: str, thread_id: str,
sandbox_id: str, sandbox_id: str,
extra_mounts: list[tuple[str, str, bool]] | None = None, extra_mounts: list[tuple[str, str, bool]] | None = None,
extra_env: dict[str, str] | None = None,
) -> SandboxInfo: ) -> SandboxInfo:
"""Create a sandbox Pod + Service via the provisioner. """Create a sandbox Pod + Service via the provisioner.

View File

@ -1,4 +1,6 @@
from deerflow.community.aio_sandbox.local_backend import _format_container_mount from unittest.mock import MagicMock, patch
from deerflow.community.aio_sandbox.local_backend import LocalContainerBackend, _format_container_mount
def test_format_container_mount_uses_mount_syntax_for_docker_windows_paths(): def test_format_container_mount_uses_mount_syntax_for_docker_windows_paths():
@ -26,3 +28,90 @@ def test_format_container_mount_keeps_volume_syntax_for_apple_container():
"-v", "-v",
"/host/path:/mnt/path:ro", "/host/path:/mnt/path:ro",
] ]
# ── extra_env injection ──────────────────────────────────────────────────────
def _make_backend(runtime: str = "docker") -> LocalContainerBackend:
"""Build a minimal LocalContainerBackend without real config."""
backend = LocalContainerBackend.__new__(LocalContainerBackend)
backend._runtime = runtime
backend._container_prefix = "test"
backend._environment = {}
backend._config_mounts = []
backend._base_port = 9000
backend._image = "test-image:latest"
return backend
def test_start_container_injects_extra_env(monkeypatch):
"""_start_container must append -e KEY=VALUE for each extra_env entry."""
backend = _make_backend()
captured: list[list[str]] = []
def fake_run(cmd, **_kwargs):
captured.append(list(cmd))
result = MagicMock()
result.returncode = 0
result.stdout = "fake-container-id\n"
return result
monkeypatch.setattr("deerflow.community.aio_sandbox.local_backend.subprocess.run", fake_run)
backend._start_container("c", 9000, extra_env={"THREAD_ID": "thread-abc", "FOO": "bar"})
cmd = captured[0]
assert "-e" in cmd
env_pairs = {cmd[i + 1] for i in range(len(cmd)) if cmd[i] == "-e"}
assert "THREAD_ID=thread-abc" in env_pairs
assert "FOO=bar" in env_pairs
def test_start_container_no_extra_env_does_not_inject(monkeypatch):
"""_start_container with no extra_env must not add unexpected -e flags."""
backend = _make_backend()
captured: list[list[str]] = []
def fake_run(cmd, **_kwargs):
captured.append(list(cmd))
result = MagicMock()
result.returncode = 0
result.stdout = "fake-container-id\n"
return result
monkeypatch.setattr("deerflow.community.aio_sandbox.local_backend.subprocess.run", fake_run)
backend._start_container("c", 9000)
cmd = captured[0]
env_pairs = {cmd[i + 1] for i in range(len(cmd)) if cmd[i] == "-e"}
assert all("THREAD_ID" not in pair for pair in env_pairs)
def test_start_container_extra_env_overrides_static_env(monkeypatch):
"""Runtime extra_env values must appear after static env, effectively overriding same-key entries."""
backend = _make_backend()
backend._environment = {"MY_VAR": "static"}
captured: list[list[str]] = []
def fake_run(cmd, **_kwargs):
captured.append(list(cmd))
result = MagicMock()
result.returncode = 0
result.stdout = "fake-container-id\n"
return result
monkeypatch.setattr("deerflow.community.aio_sandbox.local_backend.subprocess.run", fake_run)
backend._start_container("c", 9000, extra_env={"MY_VAR": "runtime"})
cmd = captured[0]
env_pairs = [cmd[i + 1] for i in range(len(cmd)) if cmd[i] == "-e"]
# Both entries should be present; the runtime one comes after, which Docker respects
assert "MY_VAR=static" in env_pairs
assert "MY_VAR=runtime" in env_pairs
assert env_pairs.index("MY_VAR=runtime") > env_pairs.index("MY_VAR=static")

View File

@ -134,3 +134,68 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
provider._discover_or_create_with_lock("thread-5", "sandbox-5") provider._discover_or_create_with_lock("thread-5", "sandbox-5")
assert unlock_calls == [] assert unlock_calls == []
# ── THREAD_ID env injection ──────────────────────────────────────────────────
def test_create_sandbox_passes_thread_id_as_extra_env(tmp_path, monkeypatch):
"""_create_sandbox must pass extra_env={'THREAD_ID': thread_id} to backend.create."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
monkeypatch.setattr(aio_mod, "get_paths", lambda: MagicMock())
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda self, tid: [])
provider = _make_provider(tmp_path)
provider._config = {"replicas": 100}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
fake_info = MagicMock()
fake_info.sandbox_url = "http://localhost:9999"
backend_mock = MagicMock()
backend_mock.create.return_value = fake_info
provider._backend = backend_mock
with patch.object(aio_mod, "wait_for_sandbox_ready", return_value=True):
provider._create_sandbox("thread-xyz", "sandbox-1")
backend_mock.create.assert_called_once_with(
"thread-xyz",
"sandbox-1",
extra_mounts=None,
extra_env={"THREAD_ID": "thread-xyz"},
)
def test_create_sandbox_no_thread_id_passes_no_extra_env(tmp_path, monkeypatch):
"""_create_sandbox with thread_id=None must not inject THREAD_ID."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
monkeypatch.setattr(aio_mod, "get_paths", lambda: MagicMock())
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda self, tid: [])
provider = _make_provider(tmp_path)
provider._config = {"replicas": 100}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
fake_info = MagicMock()
fake_info.sandbox_url = "http://localhost:9999"
backend_mock = MagicMock()
backend_mock.create.return_value = fake_info
provider._backend = backend_mock
with patch.object(aio_mod, "wait_for_sandbox_ready", return_value=True):
provider._create_sandbox(None, "sandbox-2")
backend_mock.create.assert_called_once_with(
None,
"sandbox-2",
extra_mounts=None,
extra_env=None,
)