Clawith/backend/tests/test_chat_sessions_api.py

178 lines
4.7 KiB
Python

import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from app.api import chat_sessions as chat_sessions_api
class DummyResult:
def __init__(self, values=None, scalar_value=None):
self._values = list(values or [])
self._scalar_value = scalar_value
def scalar_one_or_none(self):
if self._values:
return self._values[0]
return self._scalar_value
def scalars(self):
return self
def all(self):
return list(self._values)
def scalar(self):
if self._scalar_value is not None:
return self._scalar_value
return self._values[0] if self._values else None
class RecordingDB:
def __init__(self, responses=None):
self.responses = list(responses or [])
self.added = []
self.committed = False
self.refreshed = []
async def execute(self, _statement, _params=None):
if not self.responses:
raise AssertionError("unexpected execute() call")
return self.responses.pop(0)
def add(self, value):
self.added.append(value)
async def commit(self):
self.committed = True
async def refresh(self, value):
self.refreshed.append(value)
@pytest.mark.asyncio
async def test_org_admin_can_list_all_sessions(monkeypatch):
viewer_id = uuid.uuid4()
agent_id = uuid.uuid4()
owner_id = uuid.uuid4()
now = datetime.now(UTC)
current_user = SimpleNamespace(id=viewer_id, role="org_admin")
agent = SimpleNamespace(id=agent_id, creator_id=uuid.uuid4())
session = SimpleNamespace(
id=uuid.uuid4(),
agent_id=agent_id,
user_id=owner_id,
source_channel="web",
title="Customer follow-up",
created_at=now,
last_message_at=now,
peer_agent_id=None,
is_group=False,
group_name=None,
)
db = RecordingDB(
responses=[
DummyResult([agent]),
DummyResult([session]),
DummyResult(scalar_value=3),
DummyResult(scalar_value="Alice"),
]
)
async def fake_check_agent_access(_db, _user, _agent_id):
return agent, "use"
monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access)
sessions = await chat_sessions_api.list_sessions(
agent_id=agent_id,
scope="all",
current_user=current_user,
db=db,
)
assert len(sessions) == 1
assert sessions[0].id == str(session.id)
assert sessions[0].user_id == str(owner_id)
assert sessions[0].username == "Alice"
@pytest.mark.asyncio
async def test_org_admin_can_view_other_users_session_messages(monkeypatch):
viewer_id = uuid.uuid4()
agent_id = uuid.uuid4()
owner_id = uuid.uuid4()
session_id = uuid.uuid4()
now = datetime.now(UTC)
current_user = SimpleNamespace(id=viewer_id, role="org_admin")
session = SimpleNamespace(
id=session_id,
agent_id=agent_id,
peer_agent_id=None,
user_id=owner_id,
source_channel="web",
)
message = SimpleNamespace(
role="user",
content="hello",
created_at=now,
participant_id=None,
)
db = RecordingDB(
responses=[
DummyResult([session]),
DummyResult([message]),
]
)
async def fake_check_agent_access(_db, _user, _agent_id):
return SimpleNamespace(id=agent_id), "use"
monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access)
messages = await chat_sessions_api.get_session_messages(
agent_id=agent_id,
session_id=session_id,
current_user=current_user,
db=db,
)
assert messages == [
{
"role": "user",
"content": "hello",
"created_at": now.isoformat(),
}
]
@pytest.mark.asyncio
async def test_create_session_returns_web_session_shape(monkeypatch):
user_id = uuid.uuid4()
agent_id = uuid.uuid4()
current_user = SimpleNamespace(id=user_id, role="member")
db = RecordingDB()
async def fake_check_agent_access(_db, _user, _agent_id):
return SimpleNamespace(id=agent_id), "use"
monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access)
session = await chat_sessions_api.create_session(
agent_id=agent_id,
current_user=current_user,
db=db,
)
assert session.agent_id == str(agent_id)
assert session.user_id == str(user_id)
assert session.source_channel == "web"
assert session.participant_type == "user"
assert session.is_group is False
assert db.committed is True
assert len(db.added) == 1