Clawith/backend/app/services/quota_guard.py

249 lines
9.0 KiB
Python

"""Usage quota guard — check and enforce usage limits."""
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select, func as sa_func
from app.database import async_session
class QuotaExceeded(Exception):
"""Raised when a quota limit is reached."""
def __init__(self, message: str, quota_type: str = "generic"):
self.message = message
self.quota_type = quota_type
super().__init__(message)
class AgentExpired(Exception):
"""Raised when an agent has expired."""
def __init__(self, agent_name: str = ""):
self.message = f"Agent '{agent_name}' has expired and is no longer available."
super().__init__(self.message)
# ── Conversation quota ──────────────────────────────────────────────
async def check_conversation_quota(user_id: uuid.UUID) -> None:
"""Check if user has remaining conversation quota. Raises QuotaExceeded if not."""
from app.models.user import User
async with async_session() as db:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
return
# Admin users are exempt
if user.role in ("platform_admin", "org_admin"):
return
# Check period reset
now = datetime.now(timezone.utc)
if user.quota_message_period != "permanent" and user.quota_period_start:
period_duration = _get_period_duration(user.quota_message_period)
if now - user.quota_period_start >= period_duration:
# Period expired — reset counter
user.quota_messages_used = 0
user.quota_period_start = now
await db.commit()
if user.quota_messages_used >= user.quota_message_limit:
raise QuotaExceeded(
f"Message quota exceeded ({user.quota_messages_used}/{user.quota_message_limit}). "
f"Period: {user.quota_message_period}.",
quota_type="conversation",
)
async def increment_conversation_usage(user_id: uuid.UUID) -> None:
"""Increment conversation usage counter for a user."""
from app.models.user import User
async with async_session() as db:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
return
if user.role in ("platform_admin", "org_admin"):
return
now = datetime.now(timezone.utc)
# Initialize period start if needed
if user.quota_message_period != "permanent" and not user.quota_period_start:
user.quota_period_start = now
user.quota_messages_used += 1
await db.commit()
# ── Agent expiry ────────────────────────────────────────────────────
async def check_agent_expired(agent_id: uuid.UUID) -> None:
"""Check if agent has expired. If so, mark it and raise AgentExpired."""
from app.models.agent import Agent
async with async_session() as db:
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
return
if agent.is_expired:
raise AgentExpired(agent.name)
now = datetime.now(timezone.utc)
if agent.expires_at and now >= agent.expires_at:
agent.is_expired = True
agent.status = "stopped"
agent.heartbeat_enabled = False
await db.commit()
raise AgentExpired(agent.name)
async def get_agent_expiry_reply(agent_name: str) -> str:
"""Return a message for when an expired agent is contacted."""
return f"I'm sorry, but I ({agent_name}) am currently unavailable. My service period has ended. Please contact the platform administrator for assistance."
# ── Agent LLM call quota ───────────────────────────────────────────
async def check_agent_llm_quota(agent_id: uuid.UUID) -> None:
"""Check if agent has remaining daily LLM calls."""
from app.models.agent import Agent
async with async_session() as db:
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
return
now = datetime.now(timezone.utc)
# Daily reset
if agent.llm_calls_reset_at and now.date() > agent.llm_calls_reset_at.date():
agent.llm_calls_today = 0
agent.llm_calls_reset_at = now
await db.commit()
if agent.llm_calls_today >= agent.max_llm_calls_per_day:
raise QuotaExceeded(
f"Agent '{agent.name}' has reached daily LLM call limit "
f"({agent.llm_calls_today}/{agent.max_llm_calls_per_day}).",
quota_type="agent_llm",
)
async def increment_agent_llm_usage(agent_id: uuid.UUID) -> None:
"""Increment agent's daily LLM call counter."""
from app.models.agent import Agent
async with async_session() as db:
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
return
now = datetime.now(timezone.utc)
if not agent.llm_calls_reset_at or now.date() > agent.llm_calls_reset_at.date():
agent.llm_calls_today = 1
agent.llm_calls_reset_at = now
else:
agent.llm_calls_today += 1
await db.commit()
# ── Agent creation quota ───────────────────────────────────────────
async def check_agent_creation_quota(user_id: uuid.UUID) -> None:
"""Check if user can create more agents."""
from app.models.user import User
from app.models.agent import Agent
async with async_session() as db:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
return
if user.role in ("platform_admin", "org_admin"):
return
# Count user's non-expired agents
count_result = await db.execute(
select(sa_func.count()).select_from(Agent).where(
Agent.creator_id == user_id,
Agent.is_expired == False,
)
)
current_count = count_result.scalar() or 0
if current_count >= user.quota_max_agents:
raise QuotaExceeded(
f"Agent creation limit reached ({current_count}/{user.quota_max_agents}).",
quota_type="max_agents",
)
# ── Heartbeat floor enforcement ────────────────────────────────────
async def enforce_heartbeat_floor(tenant_id: uuid.UUID, floor: int | None = None, db=None) -> int:
"""Enforce heartbeat floor on all agents in the tenant.
Args:
tenant_id: The tenant to enforce for.
floor: The minimum interval in minutes. If None, reads from tenant.
db: Optional existing database session to reuse (avoids session isolation bugs).
Returns number of agents adjusted.
"""
from app.models.agent import Agent
from app.models.tenant import Tenant
async def _enforce(session, floor_val):
# If floor not provided, read from tenant
if floor_val is None:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
return 0
floor_val = tenant.min_heartbeat_interval_minutes
# Find agents with interval below floor
agents_result = await session.execute(
select(Agent).where(
Agent.tenant_id == tenant_id,
Agent.heartbeat_interval_minutes < floor_val,
)
)
agents = agents_result.scalars().all()
for agent in agents:
agent.heartbeat_interval_minutes = floor_val
if agents:
await session.commit()
return len(agents)
if db is not None:
return await _enforce(db, floor)
else:
async with async_session() as new_db:
return await _enforce(new_db, floor)
# ── Helper ─────────────────────────────────────────────────────────
def _get_period_duration(period: str) -> timedelta:
"""Convert period string to timedelta."""
mapping = {
"daily": timedelta(days=1),
"weekly": timedelta(weeks=1),
"monthly": timedelta(days=30),
}
return mapping.get(period, timedelta(days=36500)) # permanent = ~100 years