178 lines
4.7 KiB
Python
178 lines
4.7 KiB
Python
import uuid
|
|
from datetime import UTC, datetime
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from app.api import chat_sessions as chat_sessions_api
|
|
|
|
|
|
class DummyResult:
|
|
def __init__(self, values=None, scalar_value=None):
|
|
self._values = list(values or [])
|
|
self._scalar_value = scalar_value
|
|
|
|
def scalar_one_or_none(self):
|
|
if self._values:
|
|
return self._values[0]
|
|
return self._scalar_value
|
|
|
|
def scalars(self):
|
|
return self
|
|
|
|
def all(self):
|
|
return list(self._values)
|
|
|
|
def scalar(self):
|
|
if self._scalar_value is not None:
|
|
return self._scalar_value
|
|
return self._values[0] if self._values else None
|
|
|
|
|
|
class RecordingDB:
|
|
def __init__(self, responses=None):
|
|
self.responses = list(responses or [])
|
|
self.added = []
|
|
self.committed = False
|
|
self.refreshed = []
|
|
|
|
async def execute(self, _statement, _params=None):
|
|
if not self.responses:
|
|
raise AssertionError("unexpected execute() call")
|
|
return self.responses.pop(0)
|
|
|
|
def add(self, value):
|
|
self.added.append(value)
|
|
|
|
async def commit(self):
|
|
self.committed = True
|
|
|
|
async def refresh(self, value):
|
|
self.refreshed.append(value)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_org_admin_can_list_all_sessions(monkeypatch):
|
|
viewer_id = uuid.uuid4()
|
|
agent_id = uuid.uuid4()
|
|
owner_id = uuid.uuid4()
|
|
now = datetime.now(UTC)
|
|
|
|
current_user = SimpleNamespace(id=viewer_id, role="org_admin")
|
|
agent = SimpleNamespace(id=agent_id, creator_id=uuid.uuid4())
|
|
session = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
agent_id=agent_id,
|
|
user_id=owner_id,
|
|
source_channel="web",
|
|
title="Customer follow-up",
|
|
created_at=now,
|
|
last_message_at=now,
|
|
peer_agent_id=None,
|
|
is_group=False,
|
|
group_name=None,
|
|
)
|
|
db = RecordingDB(
|
|
responses=[
|
|
DummyResult([agent]),
|
|
DummyResult([session]),
|
|
DummyResult(scalar_value=3),
|
|
DummyResult(scalar_value="Alice"),
|
|
]
|
|
)
|
|
|
|
async def fake_check_agent_access(_db, _user, _agent_id):
|
|
return agent, "use"
|
|
|
|
monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access)
|
|
|
|
sessions = await chat_sessions_api.list_sessions(
|
|
agent_id=agent_id,
|
|
scope="all",
|
|
current_user=current_user,
|
|
db=db,
|
|
)
|
|
|
|
assert len(sessions) == 1
|
|
assert sessions[0].id == str(session.id)
|
|
assert sessions[0].user_id == str(owner_id)
|
|
assert sessions[0].username == "Alice"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_org_admin_can_view_other_users_session_messages(monkeypatch):
|
|
viewer_id = uuid.uuid4()
|
|
agent_id = uuid.uuid4()
|
|
owner_id = uuid.uuid4()
|
|
session_id = uuid.uuid4()
|
|
now = datetime.now(UTC)
|
|
|
|
current_user = SimpleNamespace(id=viewer_id, role="org_admin")
|
|
session = SimpleNamespace(
|
|
id=session_id,
|
|
agent_id=agent_id,
|
|
peer_agent_id=None,
|
|
user_id=owner_id,
|
|
source_channel="web",
|
|
)
|
|
message = SimpleNamespace(
|
|
role="user",
|
|
content="hello",
|
|
created_at=now,
|
|
participant_id=None,
|
|
)
|
|
db = RecordingDB(
|
|
responses=[
|
|
DummyResult([session]),
|
|
DummyResult([message]),
|
|
]
|
|
)
|
|
|
|
async def fake_check_agent_access(_db, _user, _agent_id):
|
|
return SimpleNamespace(id=agent_id), "use"
|
|
|
|
monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access)
|
|
|
|
messages = await chat_sessions_api.get_session_messages(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
current_user=current_user,
|
|
db=db,
|
|
)
|
|
|
|
assert messages == [
|
|
{
|
|
"role": "user",
|
|
"content": "hello",
|
|
"created_at": now.isoformat(),
|
|
}
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_session_returns_web_session_shape(monkeypatch):
|
|
user_id = uuid.uuid4()
|
|
agent_id = uuid.uuid4()
|
|
|
|
current_user = SimpleNamespace(id=user_id, role="member")
|
|
db = RecordingDB()
|
|
|
|
async def fake_check_agent_access(_db, _user, _agent_id):
|
|
return SimpleNamespace(id=agent_id), "use"
|
|
|
|
monkeypatch.setattr(chat_sessions_api, "check_agent_access", fake_check_agent_access)
|
|
|
|
session = await chat_sessions_api.create_session(
|
|
agent_id=agent_id,
|
|
current_user=current_user,
|
|
db=db,
|
|
)
|
|
|
|
assert session.agent_id == str(agent_id)
|
|
assert session.user_id == str(user_id)
|
|
assert session.source_channel == "web"
|
|
assert session.participant_type == "user"
|
|
assert session.is_group is False
|
|
assert db.committed is True
|
|
assert len(db.added) == 1
|