import uuid from datetime import UTC, datetime from types import SimpleNamespace import pytest from app.api import chat_sessions as chat_sessions_api class DummyResult: def __init__(self, values=None, scalar_value=None): self._values = list(values or []) self._scalar_value = scalar_value def scalar_one_or_none(self): if self._values: return self._values[0] return self._scalar_value def scalars(self): return self def all(self): return list(self._values) def scalar(self): if self._scalar_value is not None: return self._scalar_value return self._values[0] if self._values else None class RecordingDB: def __init__(self, responses=None): self.responses = list(responses or []) self.added = [] self.committed = False self.refreshed = [] async def execute(self, _statement, _params=None): if not self.responses: raise AssertionError("unexpected execute() call") return self.responses.pop(0) def add(self, value): self.added.append(value) async def commit(self): self.committed = True async def refresh(self, value): self.refreshed.append(value) @pytest.mark.asyncio async def test_org_admin_can_list_all_sessions(monkeypatch): viewer_id = uuid.uuid4() agent_id = uuid.uuid4() owner_id = uuid.uuid4() now = datetime.now(UTC) current_user = SimpleNamespace(id=viewer_id, role="org_admin") agent = SimpleNamespace(id=agent_id, creator_id=uuid.uuid4()) session = SimpleNamespace( id=uuid.uuid4(), agent_id=agent_id, user_id=owner_id, source_channel="web", title="Customer follow-up", created_at=now, last_message_at=now, peer_agent_id=None, is_group=False, group_name=None, ) db = RecordingDB( responses=[ DummyResult([agent]), DummyResult([session]), DummyResult(scalar_value=3), DummyResult(scalar_value="Alice"), ] ) async def fake_check_agent_access(_db, _user, _agent_id): return agent, "use" monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access) sessions = await chat_sessions_api.list_sessions( agent_id=agent_id, scope="all", current_user=current_user, db=db, ) assert len(sessions) == 1 assert sessions[0].id == str(session.id) assert sessions[0].user_id == str(owner_id) assert sessions[0].username == "Alice" @pytest.mark.asyncio async def test_org_admin_can_view_other_users_session_messages(monkeypatch): viewer_id = uuid.uuid4() agent_id = uuid.uuid4() owner_id = uuid.uuid4() session_id = uuid.uuid4() now = datetime.now(UTC) current_user = SimpleNamespace(id=viewer_id, role="org_admin") session = SimpleNamespace( id=session_id, agent_id=agent_id, peer_agent_id=None, user_id=owner_id, source_channel="web", ) message = SimpleNamespace( role="user", content="hello", created_at=now, participant_id=None, ) db = RecordingDB( responses=[ DummyResult([session]), DummyResult([message]), ] ) async def fake_check_agent_access(_db, _user, _agent_id): return SimpleNamespace(id=agent_id), "use" monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access) messages = await chat_sessions_api.get_session_messages( agent_id=agent_id, session_id=session_id, current_user=current_user, db=db, ) assert messages == [ { "role": "user", "content": "hello", "created_at": now.isoformat(), } ] @pytest.mark.asyncio async def test_create_session_returns_web_session_shape(monkeypatch): user_id = uuid.uuid4() agent_id = uuid.uuid4() current_user = SimpleNamespace(id=user_id, role="member") db = RecordingDB() async def fake_check_agent_access(_db, _user, _agent_id): return SimpleNamespace(id=agent_id), "use" monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access) session = await chat_sessions_api.create_session( agent_id=agent_id, current_user=current_user, db=db, ) assert session.agent_id == str(agent_id) assert session.user_id == str(user_id) assert session.source_channel == "web" assert session.participant_type == "user" assert session.is_group is False assert db.committed is True assert len(db.added) == 1