Clawith/backend/tests/test_password_reset_and_not...

370 lines
11 KiB
Python

import contextlib
import uuid
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
from starlette.background import BackgroundTasks
from app.api import auth as auth_api
from app.api.notification import BroadcastRequest, broadcast_notification
from app.core.security import verify_password
from app.models.user import User
from app.schemas.schemas import ForgotPasswordRequest, ResetPasswordRequest
from app.services import password_reset_service, system_email_service
class DummyScalars:
def __init__(self, values):
self._values = list(values)
def all(self):
return list(self._values)
class DummyResult:
def __init__(self, value=None, values=None):
self._value = value
self._values = list(values or [])
def scalar_one_or_none(self):
return self._value
def scalars(self):
return DummyScalars(self._values)
class MockRedis:
def __init__(self, initial_data=None):
self._data = initial_data or {}
self.deleted = []
self.setex_calls = []
async def get(self, key):
return self._data.get(key)
async def delete(self, key):
self.deleted.append(key)
self._data.pop(key, None)
async def setex(self, key, ttl, value):
self.setex_calls.append((key, ttl, value))
self._data[key] = value
def pipeline(self, transaction=True):
return self
async def __aenter__(self):
return self
async def __aexit__(self, *_):
pass
async def execute(self):
pass
class RecordingDB:
def __init__(self, responses=None):
self.responses = list(responses or [])
self.executed = []
self.added = []
self.flushed = False
self.committed = False
async def execute(self, statement):
self.executed.append(statement)
if self.responses:
return self.responses.pop(0)
return DummyResult()
def add(self, obj):
self.added.append(obj)
async def flush(self):
self.flushed = True
async def commit(self):
self.committed = True
def make_user(**overrides):
values = {
"id": uuid.uuid4(),
"username": "alice",
"email": "alice@example.com",
"password_hash": "old-hash",
"display_name": "Alice",
"role": "member",
"tenant_id": uuid.uuid4(),
"is_active": True,
}
values.update(overrides)
return User(**values)
@pytest.mark.asyncio
async def test_create_password_reset_token_invalidates_older_tokens(monkeypatch):
monkeypatch.setattr(
password_reset_service,
"get_settings",
lambda: SimpleNamespace(PASSWORD_RESET_TOKEN_EXPIRE_MINUTES=15, PUBLIC_BASE_URL=""),
)
mock_redis = MockRedis(initial_data={"pwd_reset:user:user-id-123": "old-token-hash"})
async def fake_get_redis(): return mock_redis
monkeypatch.setattr(password_reset_service, "get_redis", fake_get_redis)
db = RecordingDB()
user_id = uuid.uuid4()
raw_token, expires_at = await password_reset_service.create_password_reset_token(user_id)
# Verify old token invalidation
assert "pwd_reset:token:old-token-hash" in mock_redis.deleted
# Verify new token storage
assert len(mock_redis.setex_calls) == 2
# Verify raw token is long
assert len(raw_token) >= 20
assert expires_at > datetime.now(timezone.utc)
@pytest.mark.asyncio
async def test_build_password_reset_url_uses_env_public_base_url(monkeypatch):
monkeypatch.setattr(
password_reset_service,
"get_settings",
lambda: SimpleNamespace(PASSWORD_RESET_TOKEN_EXPIRE_MINUTES=30, PUBLIC_BASE_URL="https://app.example.com/"),
)
db = RecordingDB([DummyResult(None)])
url = await password_reset_service.build_password_reset_url(db, "abc123")
assert url == "https://app.example.com/reset-password?token=abc123"
@pytest.mark.asyncio
async def test_consume_password_reset_token_works_correctly(monkeypatch):
user_id = uuid.uuid4()
raw_token = "raw-token"
token_hash = password_reset_service._hash_token(raw_token)
initial_data = {
f"pwd_reset:token:{token_hash}": str(user_id),
f"pwd_reset:user:{user_id}": token_hash,
}
mock_redis = MockRedis(initial_data=initial_data)
async def fake_get_redis(): return mock_redis
monkeypatch.setattr(password_reset_service, "get_redis", fake_get_redis)
db = RecordingDB()
result = await password_reset_service.consume_password_reset_token(raw_token)
assert result is not None
assert result["user_id"] == user_id
# Should be deleted after consumption
assert f"pwd_reset:token:{token_hash}" in mock_redis.deleted
assert f"pwd_reset:user:{user_id}" in mock_redis.deleted
@pytest.mark.asyncio
async def test_forgot_password_returns_generic_response_for_unknown_email():
db = RecordingDB([DummyResult(None)])
background_tasks = BackgroundTasks()
response = await auth_api.forgot_password(
ForgotPasswordRequest(email="missing@example.com"),
background_tasks,
db,
)
assert response == {
"ok": True,
"message": "If an account with that email exists, a password reset email has been sent.",
}
assert background_tasks.tasks == []
@pytest.mark.asyncio
async def test_forgot_password_queues_background_email(monkeypatch):
user = make_user()
db = RecordingDB([DummyResult(user)])
background_tasks = BackgroundTasks()
async def fake_create_password_reset_token(*_args, **_kwargs):
return "raw-token", datetime.now(timezone.utc) + timedelta(minutes=30)
async def fake_build_password_reset_url(*_args, **_kwargs):
return "https://app.example.com/reset-password?token=raw-token"
monkeypatch.setattr(password_reset_service, "create_password_reset_token", fake_create_password_reset_token)
monkeypatch.setattr(password_reset_service, "build_password_reset_url", fake_build_password_reset_url)
response = await auth_api.forgot_password(ForgotPasswordRequest(email=user.email), background_tasks, db)
assert response["ok"] is True
assert db.committed is True
assert len(background_tasks.tasks) == 1
def test_send_system_email_uses_configured_timeout(monkeypatch):
captured = {}
class DummySMTPSSL:
def __init__(self, host: str, port: int, context=None, timeout: int | None = None):
captured["host"] = host
captured["port"] = port
captured["timeout"] = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def login(self, username: str, password: str):
captured["username"] = username
captured["password"] = password
def sendmail(self, from_address: str, to_addresses: list[str], message: str):
captured["from"] = from_address
captured["to"] = to_addresses
captured["has_message"] = bool(message)
config = system_email_service.SystemEmailConfig(
from_address="bot@example.com",
from_name="Clawith",
smtp_host="smtp.example.com",
smtp_port=465,
smtp_username="bot@example.com",
smtp_password="secret",
smtp_ssl=True,
smtp_timeout_seconds=27,
)
monkeypatch.setattr(system_email_service.smtplib, "SMTP_SSL", DummySMTPSSL)
monkeypatch.setattr(system_email_service, "force_ipv4", lambda: contextlib.nullcontext())
system_email_service._send_email_with_config_sync(config, "alice@example.com", "subject", "body")
assert captured["timeout"] == 27
assert captured["to"] == ["alice@example.com"]
@pytest.mark.asyncio
async def test_reset_password_updates_user(monkeypatch):
user = make_user(password_hash=auth_api.hash_password("old-password"))
db = RecordingDB([DummyResult(user)])
async def fake_consume_password_reset_token(*_args, **_kwargs):
return {"user_id": user.id}
monkeypatch.setattr(password_reset_service, "consume_password_reset_token", fake_consume_password_reset_token)
response = await auth_api.reset_password(
ResetPasswordRequest(token="t" * 20, new_password="new-password"),
db,
)
assert response == {"ok": True}
assert verify_password("new-password", user.password_hash)
assert db.flushed is True
@pytest.mark.asyncio
async def test_broadcast_notification_rejects_missing_system_email_config(monkeypatch):
current_user = make_user(role="org_admin")
async def fake_resolve_email_config_async(db):
return None
monkeypatch.setattr(
"app.services.system_email_service.resolve_email_config_async",
fake_resolve_email_config_async,
)
with pytest.raises(HTTPException) as excinfo:
await broadcast_notification(
BroadcastRequest(title="Maintenance", body="Tonight", send_email=True),
background_tasks=BackgroundTasks(),
current_user=current_user,
db=RecordingDB(),
)
assert excinfo.value.status_code == 400
assert "System email is not configured" in excinfo.value.detail
@pytest.mark.asyncio
async def test_broadcast_notification_queues_email_delivery(monkeypatch):
current_user = make_user(role="org_admin")
target_user = make_user(email="bob@example.com", tenant_id=current_user.tenant_id)
db = RecordingDB([
DummyResult(values=[target_user]),
DummyResult(values=[]),
])
background_tasks = BackgroundTasks()
async def fake_resolve_email_config_async(db):
return system_email_service.SystemEmailConfig(
from_address="bot@example.com",
from_name="Clawith",
smtp_host="smtp.example.com",
smtp_port=465,
smtp_username="bot@example.com",
smtp_password="secret",
smtp_ssl=True,
smtp_timeout_seconds=15,
)
monkeypatch.setattr(
"app.services.system_email_service.resolve_email_config_async",
fake_resolve_email_config_async,
)
notifications = []
async def fake_send_notification(*_args, **kwargs):
notifications.append(kwargs)
monkeypatch.setattr("app.services.notification_service.send_notification", fake_send_notification)
response = await broadcast_notification(
BroadcastRequest(title="Maintenance", body="Tonight", send_email=True),
background_tasks=background_tasks,
current_user=current_user,
db=db,
)
assert response["ok"] is True
assert response["emails_sent"] == 1
assert db.committed is True
assert len(notifications) == 1
assert len(background_tasks.tasks) == 1
@pytest.mark.asyncio
async def test_deliver_broadcast_emails_continues_after_single_failure(monkeypatch):
from app.services.system_email_service import BroadcastEmailRecipient, deliver_broadcast_emails
delivered = []
async def fake_send_system_email(email: str, subject: str, body: str) -> None:
if email == "bad@example.com":
raise RuntimeError("smtp down")
delivered.append((email, subject, body))
monkeypatch.setattr("app.services.system_email_service.send_system_email", fake_send_system_email)
await deliver_broadcast_emails([
BroadcastEmailRecipient(email="bad@example.com", subject="s1", body="b1"),
BroadcastEmailRecipient(email="good@example.com", subject="s2", body="b2"),
])
assert delivered == [("good@example.com", "s2", "b2")]