feat(aio_sandbox): add extra_env parameter for thread_id injection in sandbox creation
This commit is contained in:
parent
d337e46868
commit
77801c03ff
|
|
@ -514,7 +514,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||||
# that is actively serving a thread.
|
# that is actively serving a thread.
|
||||||
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
|
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
|
||||||
|
|
||||||
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None)
|
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None, extra_env={"THREAD_ID": thread_id} if thread_id else None)
|
||||||
|
|
||||||
# Wait for sandbox to be ready
|
# Wait for sandbox to be ready
|
||||||
if not wait_for_sandbox_ready(info.sandbox_url, timeout=60):
|
if not wait_for_sandbox_ready(info.sandbox_url, timeout=60):
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class SandboxBackend(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None, extra_env: dict[str, str] | None = None) -> SandboxInfo:
|
||||||
"""Create/provision a new sandbox.
|
"""Create/provision a new sandbox.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -52,6 +52,9 @@ class SandboxBackend(ABC):
|
||||||
sandbox_id: Deterministic sandbox identifier.
|
sandbox_id: Deterministic sandbox identifier.
|
||||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||||
Ignored by backends that don't manage containers (e.g., remote).
|
Ignored by backends that don't manage containers (e.g., remote).
|
||||||
|
extra_env: Additional environment variables to inject at runtime (e.g. THREAD_ID).
|
||||||
|
These are merged after static config env vars, so runtime values override same-key static values.
|
||||||
|
Ignored by backends that don't manage containers (e.g., remote).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SandboxInfo with connection details.
|
SandboxInfo with connection details.
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ class LocalContainerBackend(SandboxBackend):
|
||||||
|
|
||||||
# ── SandboxBackend interface ──────────────────────────────────────────
|
# ── SandboxBackend interface ──────────────────────────────────────────
|
||||||
|
|
||||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None, extra_env: dict[str, str] | None = None) -> SandboxInfo:
|
||||||
"""Start a new container and return its connection info.
|
"""Start a new container and return its connection info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -137,7 +137,7 @@ class LocalContainerBackend(SandboxBackend):
|
||||||
for _attempt in range(10):
|
for _attempt in range(10):
|
||||||
port = get_free_port(start_port=_next_start)
|
port = get_free_port(start_port=_next_start)
|
||||||
try:
|
try:
|
||||||
container_id = self._start_container(container_name, port, extra_mounts)
|
container_id = self._start_container(container_name, port, extra_mounts, extra_env=extra_env)
|
||||||
break
|
break
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
release_port(port)
|
release_port(port)
|
||||||
|
|
@ -229,6 +229,7 @@ class LocalContainerBackend(SandboxBackend):
|
||||||
container_name: str,
|
container_name: str,
|
||||||
port: int,
|
port: int,
|
||||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||||
|
extra_env: dict[str, str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Start a new container.
|
"""Start a new container.
|
||||||
|
|
||||||
|
|
@ -260,9 +261,11 @@ class LocalContainerBackend(SandboxBackend):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Environment variables
|
# Environment variables (static config first, runtime overrides last)
|
||||||
for key, value in self._environment.items():
|
for key, value in self._environment.items():
|
||||||
cmd.extend(["-e", f"{key}={value}"])
|
cmd.extend(["-e", f"{key}={value}"])
|
||||||
|
for key, value in (extra_env or {}).items():
|
||||||
|
cmd.extend(["-e", f"{key}={value}"])
|
||||||
|
|
||||||
# Config-level volume mounts
|
# Config-level volume mounts
|
||||||
for mount in self._config_mounts:
|
for mount in self._config_mounts:
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ class RemoteSandboxBackend(SandboxBackend):
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
sandbox_id: str,
|
sandbox_id: str,
|
||||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||||
|
extra_env: dict[str, str] | None = None,
|
||||||
) -> SandboxInfo:
|
) -> SandboxInfo:
|
||||||
"""Create a sandbox Pod + Service via the provisioner.
|
"""Create a sandbox Pod + Service via the provisioner.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from deerflow.community.aio_sandbox.local_backend import _format_container_mount
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from deerflow.community.aio_sandbox.local_backend import LocalContainerBackend, _format_container_mount
|
||||||
|
|
||||||
|
|
||||||
def test_format_container_mount_uses_mount_syntax_for_docker_windows_paths():
|
def test_format_container_mount_uses_mount_syntax_for_docker_windows_paths():
|
||||||
|
|
@ -26,3 +28,90 @@ def test_format_container_mount_keeps_volume_syntax_for_apple_container():
|
||||||
"-v",
|
"-v",
|
||||||
"/host/path:/mnt/path:ro",
|
"/host/path:/mnt/path:ro",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── extra_env injection ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_backend(runtime: str = "docker") -> LocalContainerBackend:
|
||||||
|
"""Build a minimal LocalContainerBackend without real config."""
|
||||||
|
backend = LocalContainerBackend.__new__(LocalContainerBackend)
|
||||||
|
backend._runtime = runtime
|
||||||
|
backend._container_prefix = "test"
|
||||||
|
backend._environment = {}
|
||||||
|
backend._config_mounts = []
|
||||||
|
backend._base_port = 9000
|
||||||
|
backend._image = "test-image:latest"
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_container_injects_extra_env(monkeypatch):
|
||||||
|
"""_start_container must append -e KEY=VALUE for each extra_env entry."""
|
||||||
|
backend = _make_backend()
|
||||||
|
|
||||||
|
captured: list[list[str]] = []
|
||||||
|
|
||||||
|
def fake_run(cmd, **_kwargs):
|
||||||
|
captured.append(list(cmd))
|
||||||
|
result = MagicMock()
|
||||||
|
result.returncode = 0
|
||||||
|
result.stdout = "fake-container-id\n"
|
||||||
|
return result
|
||||||
|
|
||||||
|
monkeypatch.setattr("deerflow.community.aio_sandbox.local_backend.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
backend._start_container("c", 9000, extra_env={"THREAD_ID": "thread-abc", "FOO": "bar"})
|
||||||
|
|
||||||
|
cmd = captured[0]
|
||||||
|
assert "-e" in cmd
|
||||||
|
env_pairs = {cmd[i + 1] for i in range(len(cmd)) if cmd[i] == "-e"}
|
||||||
|
assert "THREAD_ID=thread-abc" in env_pairs
|
||||||
|
assert "FOO=bar" in env_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_container_no_extra_env_does_not_inject(monkeypatch):
|
||||||
|
"""_start_container with no extra_env must not add unexpected -e flags."""
|
||||||
|
backend = _make_backend()
|
||||||
|
|
||||||
|
captured: list[list[str]] = []
|
||||||
|
|
||||||
|
def fake_run(cmd, **_kwargs):
|
||||||
|
captured.append(list(cmd))
|
||||||
|
result = MagicMock()
|
||||||
|
result.returncode = 0
|
||||||
|
result.stdout = "fake-container-id\n"
|
||||||
|
return result
|
||||||
|
|
||||||
|
monkeypatch.setattr("deerflow.community.aio_sandbox.local_backend.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
backend._start_container("c", 9000)
|
||||||
|
|
||||||
|
cmd = captured[0]
|
||||||
|
env_pairs = {cmd[i + 1] for i in range(len(cmd)) if cmd[i] == "-e"}
|
||||||
|
assert all("THREAD_ID" not in pair for pair in env_pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_container_extra_env_overrides_static_env(monkeypatch):
|
||||||
|
"""Runtime extra_env values must appear after static env, effectively overriding same-key entries."""
|
||||||
|
backend = _make_backend()
|
||||||
|
backend._environment = {"MY_VAR": "static"}
|
||||||
|
|
||||||
|
captured: list[list[str]] = []
|
||||||
|
|
||||||
|
def fake_run(cmd, **_kwargs):
|
||||||
|
captured.append(list(cmd))
|
||||||
|
result = MagicMock()
|
||||||
|
result.returncode = 0
|
||||||
|
result.stdout = "fake-container-id\n"
|
||||||
|
return result
|
||||||
|
|
||||||
|
monkeypatch.setattr("deerflow.community.aio_sandbox.local_backend.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
backend._start_container("c", 9000, extra_env={"MY_VAR": "runtime"})
|
||||||
|
|
||||||
|
cmd = captured[0]
|
||||||
|
env_pairs = [cmd[i + 1] for i in range(len(cmd)) if cmd[i] == "-e"]
|
||||||
|
# Both entries should be present; the runtime one comes after, which Docker respects
|
||||||
|
assert "MY_VAR=static" in env_pairs
|
||||||
|
assert "MY_VAR=runtime" in env_pairs
|
||||||
|
assert env_pairs.index("MY_VAR=runtime") > env_pairs.index("MY_VAR=static")
|
||||||
|
|
|
||||||
|
|
@ -134,3 +134,68 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
|
||||||
provider._discover_or_create_with_lock("thread-5", "sandbox-5")
|
provider._discover_or_create_with_lock("thread-5", "sandbox-5")
|
||||||
|
|
||||||
assert unlock_calls == []
|
assert unlock_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── THREAD_ID env injection ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_sandbox_passes_thread_id_as_extra_env(tmp_path, monkeypatch):
|
||||||
|
"""_create_sandbox must pass extra_env={'THREAD_ID': thread_id} to backend.create."""
|
||||||
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: MagicMock())
|
||||||
|
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda self, tid: [])
|
||||||
|
|
||||||
|
provider = _make_provider(tmp_path)
|
||||||
|
provider._config = {"replicas": 100}
|
||||||
|
provider._warm_pool = {}
|
||||||
|
provider._sandbox_infos = {}
|
||||||
|
provider._thread_sandboxes = {}
|
||||||
|
provider._thread_locks = {}
|
||||||
|
provider._last_activity = {}
|
||||||
|
|
||||||
|
fake_info = MagicMock()
|
||||||
|
fake_info.sandbox_url = "http://localhost:9999"
|
||||||
|
backend_mock = MagicMock()
|
||||||
|
backend_mock.create.return_value = fake_info
|
||||||
|
provider._backend = backend_mock
|
||||||
|
|
||||||
|
with patch.object(aio_mod, "wait_for_sandbox_ready", return_value=True):
|
||||||
|
provider._create_sandbox("thread-xyz", "sandbox-1")
|
||||||
|
|
||||||
|
backend_mock.create.assert_called_once_with(
|
||||||
|
"thread-xyz",
|
||||||
|
"sandbox-1",
|
||||||
|
extra_mounts=None,
|
||||||
|
extra_env={"THREAD_ID": "thread-xyz"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_sandbox_no_thread_id_passes_no_extra_env(tmp_path, monkeypatch):
|
||||||
|
"""_create_sandbox with thread_id=None must not inject THREAD_ID."""
|
||||||
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: MagicMock())
|
||||||
|
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda self, tid: [])
|
||||||
|
|
||||||
|
provider = _make_provider(tmp_path)
|
||||||
|
provider._config = {"replicas": 100}
|
||||||
|
provider._warm_pool = {}
|
||||||
|
provider._sandbox_infos = {}
|
||||||
|
provider._thread_sandboxes = {}
|
||||||
|
provider._thread_locks = {}
|
||||||
|
provider._last_activity = {}
|
||||||
|
|
||||||
|
fake_info = MagicMock()
|
||||||
|
fake_info.sandbox_url = "http://localhost:9999"
|
||||||
|
backend_mock = MagicMock()
|
||||||
|
backend_mock.create.return_value = fake_info
|
||||||
|
provider._backend = backend_mock
|
||||||
|
|
||||||
|
with patch.object(aio_mod, "wait_for_sandbox_ready", return_value=True):
|
||||||
|
provider._create_sandbox(None, "sandbox-2")
|
||||||
|
|
||||||
|
backend_mock.create.assert_called_once_with(
|
||||||
|
None,
|
||||||
|
"sandbox-2",
|
||||||
|
extra_mounts=None,
|
||||||
|
extra_env=None,
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue