134 lines
4.6 KiB
Python
134 lines
4.6 KiB
Python
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
|
|
|
import threading
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture()
|
|
def sandbox():
|
|
"""Create an AioSandbox with a mocked client."""
|
|
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
|
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
|
|
|
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
|
return sb
|
|
|
|
|
|
class TestExecuteCommandSerialization:
|
|
"""Verify that concurrent exec_command calls are serialized."""
|
|
|
|
def test_lock_prevents_concurrent_execution(self, sandbox):
|
|
"""Concurrent threads should not overlap inside execute_command."""
|
|
call_log = []
|
|
barrier = threading.Barrier(3)
|
|
|
|
def slow_exec(command, **kwargs):
|
|
call_log.append(("enter", command))
|
|
import time
|
|
|
|
time.sleep(0.05)
|
|
call_log.append(("exit", command))
|
|
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
|
|
|
sandbox._client.shell.exec_command = slow_exec
|
|
|
|
def worker(cmd):
|
|
barrier.wait() # ensure all threads contend for the lock simultaneously
|
|
sandbox.execute_command(cmd)
|
|
|
|
threads = []
|
|
for i in range(3):
|
|
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Verify serialization: each "enter" should be followed by its own
|
|
# "exit" before the next "enter" (no interleaving).
|
|
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
|
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
|
assert len(enters) == 3
|
|
assert len(exits) == 3
|
|
for e_idx, x_idx in zip(enters, exits):
|
|
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
|
|
|
|
|
class TestErrorObservationRetry:
|
|
"""Verify ErrorObservation detection and fresh-session retry."""
|
|
|
|
def test_retry_on_error_observation(self, sandbox):
|
|
"""When output contains ErrorObservation, retry with a fresh session."""
|
|
call_count = 0
|
|
|
|
def mock_exec(command, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
result = sandbox.execute_command("echo hello")
|
|
assert result == "success"
|
|
assert call_count == 2
|
|
|
|
def test_retry_passes_fresh_session_id(self, sandbox):
|
|
"""The retry call should include a new session id kwarg."""
|
|
calls = []
|
|
|
|
def mock_exec(command, **kwargs):
|
|
calls.append(kwargs)
|
|
if len(calls) == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
sandbox.execute_command("test")
|
|
assert len(calls) == 2
|
|
assert "id" not in calls[0]
|
|
assert "id" in calls[1]
|
|
assert len(calls[1]["id"]) == 36 # UUID format
|
|
|
|
def test_no_retry_on_clean_output(self, sandbox):
|
|
"""Normal output should not trigger a retry."""
|
|
call_count = 0
|
|
|
|
def mock_exec(command, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
result = sandbox.execute_command("echo hello")
|
|
assert result == "all good"
|
|
assert call_count == 1
|
|
|
|
|
|
class TestListDirSerialization:
|
|
"""Verify that list_dir also acquires the lock."""
|
|
|
|
def test_list_dir_uses_lock(self, sandbox):
|
|
"""list_dir should hold the lock during execution."""
|
|
lock_was_held = []
|
|
|
|
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
|
|
|
def tracking_exec(command, **kwargs):
|
|
lock_was_held.append(sandbox._lock.locked())
|
|
return original_exec(command, **kwargs)
|
|
|
|
sandbox._client.shell.exec_command = tracking_exec
|
|
|
|
result = sandbox.list_dir("/test")
|
|
assert result == ["/a", "/b"]
|
|
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|