Clawith/backend/app/api/enterprise.py

1654 lines
64 KiB
Python

"""Enterprise management API routes: LLM pool, enterprise info, approvals, audit logs."""
import uuid
import logging
logger = logging.getLogger(__name__)
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request
from pydantic import BaseModel
from sqlalchemy import select, func, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import get_current_admin, get_current_user, require_role
from app.database import get_db
from app.models.org import OrgDepartment, OrgMember
from app.models.identity import IdentityProvider
from app.models.user import User
from app.models.agent import Agent
from app.models.llm import LLMModel
from app.models.audit import AuditLog, ApprovalRequest, EnterpriseInfo
from app.schemas.schemas import (
ApprovalAction, ApprovalRequestOut, AuditLogOut, EnterpriseInfoOut,
EnterpriseInfoUpdate, LLMModelCreate, LLMModelOut, LLMModelUpdate,
IdentityProviderOut, UserInviteRequest
)
from app.services.autonomy_service import autonomy_service
from app.services.enterprise_sync import enterprise_sync_service
from app.services.llm_utils import get_provider_manifest
from app.services.platform_service import platform_service
from app.services.sso_service import sso_service
router = APIRouter(prefix="/enterprise", tags=["enterprise"])
# ─── Public: Check Email Exists ────────────────────────
class CheckEmailRequest(BaseModel):
email: str
@router.post("/check-email-exists")
async def check_email_exists(
data: CheckEmailRequest,
db: AsyncSession = Depends(get_db),
):
"""Public endpoint — check if an email address is already registered on this platform.
Used by the invitation flow to decide whether to show the login or register form.
Only returns a boolean; does not expose any user data.
"""
from app.models.user import Identity
result = await db.execute(
select(Identity).where(Identity.email == data.email.strip().lower())
)
exists = result.scalar_one_or_none() is not None
return {"exists": exists}
@router.get("/llm-providers")
async def list_llm_providers(
current_user: User = Depends(get_current_user),
):
"""List supported LLM providers and capabilities from registry."""
return get_provider_manifest()
class LLMTestRequest(BaseModel):
provider: str
model: str
api_key: str | None = None
base_url: str | None = None
model_id: str | None = None # existing model ID to use stored API key
@router.post("/llm-test")
async def test_llm_model(
data: LLMTestRequest,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Test an LLM model configuration by making a simple API call."""
import time
from app.services.llm_client import create_llm_client
# Resolve API key: use provided key, or look up from stored model
api_key = data.api_key if data.api_key and not data.api_key.startswith('****') else None
if not api_key and data.model_id:
result = await db.execute(select(LLMModel).where(LLMModel.id == data.model_id))
existing = result.scalar_one_or_none()
if existing:
api_key = existing.api_key_encrypted
if not api_key:
return {"success": False, "latency_ms": 0, "error": "API Key is required"}
start = time.time()
try:
client = create_llm_client(
provider=data.provider,
model=data.model,
api_key=api_key,
base_url=data.base_url or None,
)
# Simple test: ask model to say "ok"
from app.services.llm_client import LLMMessage
response = await client.complete(
messages=[LLMMessage(role="user", content="Say 'ok' and nothing else.")],
max_tokens=16,
)
latency_ms = int((time.time() - start) * 1000)
reply = (response.content or "")[:100] if response else ""
return {"success": True, "latency_ms": latency_ms, "reply": reply}
except Exception as e:
latency_ms = int((time.time() - start) * 1000)
return {"success": False, "latency_ms": latency_ms, "error": str(e)[:500]}
@router.get("/llm-models", response_model=list[LLMModelOut])
async def list_llm_models(
tenant_id: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List LLM models scoped to the selected tenant."""
# Authorization: non-platform admins can only see their own tenant's models
if tenant_id and current_user.role != "platform_admin":
if str(current_user.tenant_id) != tenant_id:
raise HTTPException(status_code=403, detail="Cannot access other tenant's models")
tid = tenant_id or str(current_user.tenant_id) if current_user.tenant_id else None
query = select(LLMModel).order_by(LLMModel.created_at.desc())
if tid:
query = query.where(LLMModel.tenant_id == uuid.UUID(tid))
result = await db.execute(query)
models = []
for m in result.scalars().all():
out = LLMModelOut.model_validate(m)
# Mask API key: show last 4 chars
key = m.api_key_encrypted or ""
out.api_key_masked = f"****{key[-4:]}" if len(key) > 4 else "****"
models.append(out)
return models
@router.post("/llm-models", response_model=LLMModelOut, status_code=status.HTTP_201_CREATED)
async def add_llm_model(
data: LLMModelCreate,
tenant_id: str | None = None,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Add a new LLM model to the tenant's pool (admin)."""
tid = tenant_id or (str(current_user.tenant_id) if current_user.tenant_id else None)
model = LLMModel(
provider=data.provider,
model=data.model,
api_key_encrypted=data.api_key, # TODO: encrypt
base_url=data.base_url,
label=data.label,
temperature=data.temperature,
max_tokens_per_day=data.max_tokens_per_day,
enabled=data.enabled,
supports_vision=data.supports_vision,
max_output_tokens=data.max_output_tokens,
request_timeout=data.request_timeout,
tenant_id=uuid.UUID(tid) if tid else None,
)
db.add(model)
await db.flush()
return LLMModelOut.model_validate(model)
@router.delete("/llm-models/{model_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_llm_model(
model_id: uuid.UUID,
force: bool = False,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Remove an LLM model from the pool."""
result = await db.execute(select(LLMModel).where(LLMModel.id == model_id))
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# Check if any agents reference this model
from sqlalchemy import or_
ref_result = await db.execute(
select(Agent.name).where(
or_(Agent.primary_model_id == model_id, Agent.fallback_model_id == model_id)
)
)
agent_names = [row[0] for row in ref_result.all()]
if agent_names and not force:
raise HTTPException(
status_code=409,
detail={
"message": f"This model is used by {len(agent_names)} agent(s)",
"agents": agent_names,
},
)
# Nullify FK references in agents before deleting
if agent_names:
await db.execute(
update(Agent).where(Agent.primary_model_id == model_id).values(primary_model_id=None)
)
await db.execute(
update(Agent).where(Agent.fallback_model_id == model_id).values(fallback_model_id=None)
)
await db.delete(model)
await db.commit()
@router.put("/llm-models/{model_id}", response_model=LLMModelOut)
async def update_llm_model(
model_id: uuid.UUID,
data: LLMModelUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Update an existing LLM model in the pool (admin)."""
result = await db.execute(select(LLMModel).where(LLMModel.id == model_id))
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
try:
if data.provider:
model.provider = data.provider
if data.model:
model.model = data.model
if data.label is not None:
model.label = data.label
if hasattr(data, 'base_url') and data.base_url is not None:
model.base_url = data.base_url
if data.api_key and data.api_key.strip() and not data.api_key.startswith('****'): # Skip masked values
model.api_key_encrypted = data.api_key.strip()
if data.temperature is not None:
model.temperature = data.temperature
if data.max_tokens_per_day is not None:
model.max_tokens_per_day = data.max_tokens_per_day
if data.enabled is not None:
model.enabled = data.enabled
if hasattr(data, 'supports_vision') and data.supports_vision is not None:
model.supports_vision = data.supports_vision
if hasattr(data, 'max_output_tokens') and data.max_output_tokens is not None:
model.max_output_tokens = data.max_output_tokens
if hasattr(data, 'request_timeout') and data.request_timeout is not None:
model.request_timeout = data.request_timeout
await db.commit()
await db.refresh(model)
return LLMModelOut.model_validate(model)
except SQLAlchemyError as e:
await db.rollback()
raise HTTPException(status_code=500, detail="Failed to update model")
# ─── Enterprise Info ────────────────────────────────────
@router.get("/info", response_model=list[EnterpriseInfoOut])
async def list_enterprise_info(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List all enterprise information entries."""
result = await db.execute(select(EnterpriseInfo).order_by(EnterpriseInfo.info_type))
return [EnterpriseInfoOut.model_validate(e) for e in result.scalars().all()]
@router.put("/info/{info_type}", response_model=EnterpriseInfoOut)
async def update_enterprise_info(
info_type: str,
data: EnterpriseInfoUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Create or update enterprise information. Triggers sync to agents."""
info = await enterprise_sync_service.update_enterprise_info(
db, info_type, data.content, data.visible_roles, current_user.id
)
# Sync to all running agents
await enterprise_sync_service.sync_to_all_agents(db)
return EnterpriseInfoOut.model_validate(info)
# ─── Approvals ──────────────────────────────────────────
@router.get("/approvals", response_model=list[ApprovalRequestOut])
async def list_approvals(
tenant_id: str | None = None,
status_filter: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List approval requests scoped to a tenant."""
query = select(ApprovalRequest)
# Scope by tenant: only show approvals for agents belonging to this tenant
tid = tenant_id or (str(current_user.tenant_id) if current_user.tenant_id else None)
if tid:
tenant_agent_ids = select(Agent.id).where(Agent.tenant_id == tid)
query = query.where(ApprovalRequest.agent_id.in_(tenant_agent_ids))
# Non-admins further restricted to their own agents
if current_user.role != "platform_admin":
query = query.where(ApprovalRequest.agent_id.in_(
select(Agent.id).where(Agent.creator_id == current_user.id)
))
if status_filter:
query = query.where(ApprovalRequest.status == status_filter)
query = query.order_by(ApprovalRequest.created_at.desc())
result = await db.execute(query)
approvals = result.scalars().all()
# Batch-load agent names
agent_ids_set = {a.agent_id for a in approvals}
agent_names: dict[uuid.UUID, str] = {}
if agent_ids_set:
agents_r = await db.execute(select(Agent.id, Agent.name).where(Agent.id.in_(agent_ids_set)))
agent_names = {row.id: row.name for row in agents_r.all()}
out = []
for a in approvals:
d = ApprovalRequestOut.model_validate(a)
d.agent_name = agent_names.get(a.agent_id)
out.append(d)
return out
@router.post("/approvals/{approval_id}/resolve", response_model=ApprovalRequestOut)
async def resolve_approval(
approval_id: uuid.UUID,
data: ApprovalAction,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Approve or reject a pending approval request."""
try:
approval = await autonomy_service.resolve_approval(
db, approval_id, current_user, data.action
)
return ApprovalRequestOut.model_validate(approval)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# ─── Audit Logs ─────────────────────────────────────────
@router.get("/audit-logs", response_model=list[AuditLogOut])
async def list_audit_logs(
agent_id: uuid.UUID | None = None,
tenant_id: str | None = None,
limit: int = 50,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""List audit logs scoped to a tenant (admin only)."""
query = select(AuditLog).order_by(AuditLog.created_at.desc()).limit(limit)
# Scope by tenant: only show logs for agents belonging to this tenant
tid = tenant_id or (str(current_user.tenant_id) if current_user.tenant_id else None)
if tid:
tenant_agent_ids = select(Agent.id).where(Agent.tenant_id == tid)
query = query.where(AuditLog.agent_id.in_(tenant_agent_ids))
if agent_id:
query = query.where(AuditLog.agent_id == agent_id)
result = await db.execute(query)
return [AuditLogOut.model_validate(log) for log in result.scalars().all()]
# ─── Dashboard Stats ────────────────────────────────────
@router.get("/stats")
async def get_enterprise_stats(
tenant_id: str | None = None,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Get enterprise dashboard statistics, optionally scoped to a tenant."""
# Determine which tenant to filter by
tid = tenant_id
if tid and isinstance(tid, str):
tid = uuid.UUID(tid)
elif not tid:
tid = current_user.tenant_id
# Base queries
agent_q = select(func.count(Agent.id))
user_q = select(func.count(User.id)).where(User.is_active == True)
approval_q = select(func.count(ApprovalRequest.id))
if tid:
agent_q = agent_q.where(Agent.tenant_id == tid)
user_q = user_q.where(User.tenant_id == tid)
# For approvals, we only see requests for agents in this tenant
approval_q = approval_q.where(ApprovalRequest.agent_id.in_(
select(Agent.id).where(Agent.tenant_id == tid)
))
total_agents = await db.execute(agent_q)
running_agents = await db.execute(
agent_q.where(Agent.status == "running")
)
total_users = await db.execute(user_q)
pending_approvals = await db.execute(
approval_q.where(ApprovalRequest.status == "pending")
)
return {
"total_agents": total_agents.scalar() or 0,
"running_agents": running_agents.scalar() or 0,
"total_users": total_users.scalar() or 0,
"pending_approvals": pending_approvals.scalar() or 0,
}
# ─── Tenant Quota Settings ──────────────────────────────
from app.models.tenant import Tenant
class TenantQuotaUpdate(BaseModel):
default_message_limit: int | None = None
default_message_period: str | None = None
default_max_agents: int | None = None
default_agent_ttl_hours: int | None = None
default_max_llm_calls_per_day: int | None = None
min_heartbeat_interval_minutes: int | None = None
default_max_triggers: int | None = None
min_poll_interval_floor: int | None = None
max_webhook_rate_ceiling: int | None = None
@router.get("/tenant-quotas")
async def get_tenant_quotas(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get tenant quota defaults and heartbeat settings."""
if not current_user.tenant_id:
return {}
result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
return {}
return {
"default_message_limit": tenant.default_message_limit,
"default_message_period": tenant.default_message_period,
"default_max_agents": tenant.default_max_agents,
"default_agent_ttl_hours": tenant.default_agent_ttl_hours,
"default_max_llm_calls_per_day": tenant.default_max_llm_calls_per_day,
"min_heartbeat_interval_minutes": tenant.min_heartbeat_interval_minutes,
"default_max_triggers": tenant.default_max_triggers,
"min_poll_interval_floor": tenant.min_poll_interval_floor,
"max_webhook_rate_ceiling": tenant.max_webhook_rate_ceiling,
}
@router.patch("/tenant-quotas")
async def update_tenant_quotas(
data: TenantQuotaUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Update tenant quota defaults (admin only). Enforces heartbeat floor on existing agents."""
if not current_user.tenant_id:
raise HTTPException(status_code=400, detail="No tenant assigned")
result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
if data.default_message_limit is not None:
tenant.default_message_limit = data.default_message_limit
if data.default_message_period is not None:
tenant.default_message_period = data.default_message_period
if data.default_max_agents is not None:
tenant.default_max_agents = data.default_max_agents
if data.default_agent_ttl_hours is not None:
tenant.default_agent_ttl_hours = data.default_agent_ttl_hours
if data.default_max_llm_calls_per_day is not None:
tenant.default_max_llm_calls_per_day = data.default_max_llm_calls_per_day
# Handle heartbeat floor — enforce on existing agents
adjusted_count = 0
if data.min_heartbeat_interval_minutes is not None:
tenant.min_heartbeat_interval_minutes = data.min_heartbeat_interval_minutes
from app.services.quota_guard import enforce_heartbeat_floor
adjusted_count = await enforce_heartbeat_floor(
tenant.id, floor=data.min_heartbeat_interval_minutes, db=db
)
# Handle trigger limit fields
if data.default_max_triggers is not None:
tenant.default_max_triggers = data.default_max_triggers
if data.min_poll_interval_floor is not None:
tenant.min_poll_interval_floor = data.min_poll_interval_floor
if data.max_webhook_rate_ceiling is not None:
tenant.max_webhook_rate_ceiling = data.max_webhook_rate_ceiling
await db.commit()
return {
"message": "Tenant quotas updated",
"heartbeat_agents_adjusted": adjusted_count,
}
# ── System Email: Test & Templates ──────────────────────
class TestEmailRequest(BaseModel):
email: str
@router.post("/system-email/test")
async def send_test_email_endpoint(
data: TestEmailRequest,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Send a test email to verify SMTP configuration (admin only)."""
from app.services.system_email_service import send_test_email
try:
await send_test_email(data.email, db=db)
return {"success": True, "message": f"Test email sent to {data.email}"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/email-templates")
async def get_email_templates_endpoint(
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Get email templates (current values + available variables per scenario)."""
from app.services.system_email_service import (
get_email_templates,
EMAIL_TEMPLATE_VARIABLES,
DEFAULT_EMAIL_TEMPLATES,
)
templates = await get_email_templates(db=db)
return {
"templates": templates,
"variables": EMAIL_TEMPLATE_VARIABLES,
"defaults": DEFAULT_EMAIL_TEMPLATES,
}
class EmailTemplatesUpdate(BaseModel):
templates: dict
@router.put("/email-templates")
async def update_email_templates_endpoint(
data: EmailTemplatesUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Save email templates (admin only)."""
from app.services.system_email_service import EMAIL_TEMPLATE_VARIABLES
# Validate that only known scenario keys are provided
for key in data.templates:
if key not in EMAIL_TEMPLATE_VARIABLES:
raise HTTPException(
status_code=400,
detail=f"Unknown email template scenario: {key}"
)
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == "email_templates")
)
setting = result.scalar_one_or_none()
if setting:
setting.value = data.templates
else:
setting = SystemSetting(key="email_templates", value=data.templates)
db.add(setting)
await db.commit()
return {"success": True, "message": "Email templates saved"}
# ─── System Settings ───────────────────────────────────
from app.models.system_settings import SystemSetting
class SettingUpdate(BaseModel):
value: dict
@router.get("/system-settings/notification_bar/public")
async def get_notification_bar_public(
db: AsyncSession = Depends(get_db),
):
"""Public (no auth) endpoint to read the notification bar config."""
result = await db.execute(
select(SystemSetting).where(SystemSetting.key == "notification_bar")
)
setting = result.scalar_one_or_none()
if not setting or not setting.value:
return {"enabled": False, "text": ""}
return {
"enabled": setting.value.get("enabled", False),
"text": setting.value.get("text", ""),
}
@router.get("/system-settings/{key}")
async def get_system_setting(
key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get a system setting by key."""
result = await db.execute(select(SystemSetting).where(SystemSetting.key == key))
setting = result.scalar_one_or_none()
if not setting:
return {"key": key, "value": {}}
return {"key": setting.key, "value": setting.value, "updated_at": setting.updated_at.isoformat() if setting.updated_at else None}
@router.put("/system-settings/{key}")
async def update_system_setting(
key: str,
data: SettingUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Create or update a system setting."""
# Platform-level settings (e.g. PUBLIC_BASE_URL) require platform_admin
if key == "platform" and current_user.role != "platform_admin":
raise HTTPException(status_code=403, detail="Only platform admin can modify platform settings")
result = await db.execute(select(SystemSetting).where(SystemSetting.key == key))
setting = result.scalar_one_or_none()
if setting:
setting.value = data.value
else:
setting = SystemSetting(key=key, value=data.value)
db.add(setting)
await db.commit()
# When public_base_url changes, regenerate sso_domain for all SSO-enabled tenants
if key == "platform" and data.value.get("public_base_url"):
await _regenerate_all_sso_domains(db)
return {"key": setting.key, "value": setting.value}
# ─── SSO Derived State Helper ───────────────────────────
async def _sync_tenant_sso_state(db: AsyncSession, tenant_id: uuid.UUID):
"""Recompute tenant.sso_enabled based on channel-level sso_login_enabled flags.
When any identity provider has sso_login_enabled=True, the tenant's
sso_enabled is set to True and sso_domain is auto-assigned if empty.
When all providers have sso_login_enabled=False, sso_enabled becomes False
but sso_domain is preserved for potential re-enablement.
Raises HTTPException(400) if IP mode and another tenant already owns the sso_domain.
"""
from app.models.tenant import Tenant
count_result = await db.execute(
select(func.count(IdentityProvider.id)).where(
IdentityProvider.tenant_id == tenant_id,
IdentityProvider.sso_login_enabled == True,
IdentityProvider.is_active == True,
)
)
active_sso_count = count_result.scalar() or 0
tenant_result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = tenant_result.scalar_one_or_none()
if not tenant:
return
tenant.sso_enabled = active_sso_count > 0
# Auto-assign subdomain on first SSO enablement based on Platform rules
if tenant.sso_enabled and not tenant.sso_domain:
sso_base = await platform_service.get_tenant_sso_base_url(db, tenant)
host = sso_base.split("://")[-1].split(":")[0].split("/")[0]
is_ip = platform_service.is_ip_address(host)
if is_ip:
# IP mode: first clear ALL other tenants' sso_domain, then set for this tenant
# (unique constraint - only one tenant can hold the IP domain)
await db.execute(
update(Tenant)
.where(Tenant.id != tenant_id)
.values(sso_domain=None, sso_enabled=False)
)
logger.info(f"[SSO] IP mode: cleared sso_domain for all other tenants, setting for tenant_id={tenant_id}")
tenant.sso_domain = sso_base
await db.commit()
async def _regenerate_all_sso_domains(db: AsyncSession):
"""Regenerate sso_domain for ALL tenants when public_base_url changes.
- Domain mode: every tenant gets {slug}.{domain}, regardless of SSO status.
- IP mode: only ONE tenant can hold the IP domain (unique constraint).
The first SSO-enabled tenant keeps it; all others get sso_domain=None.
If no SSO-enabled tenant exists, the first tenant in the list gets it.
"""
base_url = await platform_service.get_public_base_url(db)
host = base_url.split("://")[-1].split(":")[0].split("/")[0]
is_ip = platform_service.is_ip_address(host)
# Fetch all tenants; put SSO-enabled ones first so they win the IP slot
all_tenants_result = await db.execute(
select(Tenant).order_by(Tenant.sso_enabled.desc(), Tenant.created_at.asc())
)
tenants = all_tenants_result.scalars().all()
for i, tenant in enumerate(tenants):
if is_ip:
# IP mode: only one tenant can have SSO domain
if i == 0:
sso_base = await platform_service.get_tenant_sso_base_url(db, tenant)
tenant.sso_domain = sso_base
else:
tenant.sso_domain = None
else:
# Domain mode: each tenant gets their own subdomain
sso_base = await platform_service.get_tenant_sso_base_url(db, tenant)
tenant.sso_domain = sso_base
logger.info(f"[SSO regen] tenant={tenant.slug} sso_domain={tenant.sso_domain}")
if tenants:
await db.commit()
# ─── Identity Providers ─────────────────────────────────
@router.get("/identity-providers", response_model=list[IdentityProviderOut])
async def list_identity_providers(
tenant_id: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List identity providers configured for the tenant."""
# Authorization: non-platform admins can only see their own tenant's providers
if tenant_id and current_user.role != "platform_admin":
if str(current_user.tenant_id) != tenant_id:
raise HTTPException(status_code=403, detail="Cannot access other tenant's providers")
query = select(IdentityProvider).order_by(IdentityProvider.created_at.desc())
tid = tenant_id or (str(current_user.tenant_id) if current_user.tenant_id else None)
# Require tenant context
if not tid:
if current_user.role == "platform_admin":
# Admin without tenant_id filter sees all
pass
else:
raise HTTPException(status_code=400, detail="tenant_id is required for identity providers")
else:
import uuid as _uuid
query = query.where(IdentityProvider.tenant_id == _uuid.UUID(tid))
result = await db.execute(query)
providers = []
for p in result.scalars().all():
data = IdentityProviderOut.model_validate(p).model_dump()
data["last_synced_at"] = (p.config or {}).get("last_synced_at")
providers.append(data)
return providers
class IdentityProviderCreate(BaseModel):
provider_type: str
name: str
is_active: bool = True
sso_login_enabled: bool = False
config: dict = {}
tenant_id: uuid.UUID | None = None
class OAuth2Config(BaseModel):
"""OAuth2 provider configuration with friendly field names."""
app_id: str | None = None # Alias for client_id
app_secret: str | None = None # Alias for client_secret
authorize_url: str | None = None # OAuth2 authorize endpoint
token_url: str | None = None # OAuth2 token endpoint
user_info_url: str | None = None # OAuth2 user info endpoint
scope: str | None = "openid profile email"
def to_config_dict(self) -> dict:
"""Convert to config dict with both naming conventions for compatibility."""
config = {}
if self.app_id:
config["app_id"] = self.app_id
config["client_id"] = self.app_id
if self.app_secret:
config["app_secret"] = self.app_secret
config["client_secret"] = self.app_secret
if self.authorize_url:
config["authorize_url"] = self.authorize_url
if self.token_url:
config["token_url"] = self.token_url
if self.user_info_url:
config["user_info_url"] = self.user_info_url
if self.scope:
config["scope"] = self.scope
return config
@classmethod
def from_config_dict(cls, config: dict) -> "OAuth2Config":
"""Create from config dict, supporting both naming conventions."""
return cls(
app_id=config.get("app_id") or config.get("client_id"),
app_secret=config.get("app_secret") or config.get("client_secret"),
authorize_url=config.get("authorize_url"),
token_url=config.get("token_url"),
user_info_url=config.get("user_info_url"),
scope=config.get("scope"),
)
class IdentityProviderOAuth2Create(BaseModel):
"""Simplified OAuth2 provider creation with dedicated fields."""
provider_type: str = "oauth2"
name: str
is_active: bool = True
app_id: str
app_secret: str
authorize_url: str
token_url: str
user_info_url: str
scope: str | None = "openid profile email"
tenant_id: uuid.UUID | None = None
def normalize_oauth2_config(config: dict) -> dict:
"""Normalize OAuth2 config to use both naming conventions for compatibility."""
if "app_id" in config or "app_secret" in config or "authorize_url" in config:
# Mix of naming conventions - normalize
normalized = {}
if "app_id" in config:
normalized["app_id"] = config["app_id"]
normalized["client_id"] = config["app_id"]
elif "client_id" in config:
normalized["app_id"] = config["client_id"]
normalized["client_id"] = config["client_id"]
if "app_secret" in config:
normalized["app_secret"] = config["app_secret"]
normalized["client_secret"] = config["app_secret"]
elif "client_secret" in config:
normalized["app_secret"] = config["client_secret"]
normalized["client_secret"] = config["client_secret"]
# Copy URLs if present
for key in ["authorize_url", "token_url", "user_info_url", "scope"]:
if key in config:
normalized[key] = config[key]
return normalized
return config
def validate_provider_config(provider_type: str, config: dict):
"""Validate identity provider config. Specific field checks are handled by the frontend."""
if not isinstance(config, dict):
raise HTTPException(status_code=422, detail="Configuration must be a JSON object")
return
@router.post("/identity-providers", response_model=IdentityProviderOut)
async def create_identity_provider(
data: IdentityProviderCreate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Create a new identity provider (Admin only)."""
# Validate config
validate_provider_config(data.provider_type, data.config)
# Validate and determine tenant_id
tid = data.tenant_id
if current_user.role == "platform_admin":
# Platform admins can use any tenant_id (including None for global providers)
pass
else:
# Non-platform admins: use request tenant_id if provided, else fall back to user's tenant
if tid is None:
tid = current_user.tenant_id
elif str(tid) != str(current_user.tenant_id):
# Validate they can only manage their own tenant
raise HTTPException(status_code=403, detail="Can only create providers for your own tenant")
if not tid:
raise HTTPException(status_code=400, detail="tenant_id is required to create an identity provider")
if data.sso_login_enabled:
if not await sso_service.validate_sso_enablement(db, tid):
raise HTTPException(
status_code=400,
detail="IP address does not support multi-tenant SSO. Another tenant already has SSO enabled."
)
provider = IdentityProvider(
provider_type=data.provider_type,
name=data.name,
is_active=data.is_active,
sso_login_enabled=data.sso_login_enabled,
config=data.config,
tenant_id=tid
)
db.add(provider)
await db.commit()
await db.refresh(provider)
return IdentityProviderOut.model_validate(provider)
@router.post("/identity-providers/oauth2", response_model=IdentityProviderOut)
async def create_oauth2_provider(
data: IdentityProviderOAuth2Create,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Create a new OAuth2 identity provider with simplified fields (app_id, app_secret, authorize_url, etc.)."""
# Convert to config dict
oauth_config = OAuth2Config(
app_id=data.app_id,
app_secret=data.app_secret,
authorize_url=data.authorize_url,
token_url=data.token_url,
user_info_url=data.user_info_url,
scope=data.scope,
)
config = oauth_config.to_config_dict()
# Validate
validate_provider_config("oauth2", config)
# Validate and determine tenant_id
tid = data.tenant_id
if current_user.role == "platform_admin":
# Platform admins can use any tenant_id (including None for global providers)
pass
else:
# Non-platform admins: use request tenant_id if provided, else fall back to user's tenant
if tid is None:
tid = current_user.tenant_id
elif str(tid) != str(current_user.tenant_id):
# Validate they can only manage their own tenant
raise HTTPException(status_code=403, detail="Can only create providers for your own tenant")
if not tid:
raise HTTPException(status_code=400, detail="tenant_id is required to create an identity provider")
provider = IdentityProvider(
provider_type="oauth2",
name=data.name,
is_active=data.is_active,
config=config,
tenant_id=tid
)
db.add(provider)
await db.commit()
await db.refresh(provider)
return IdentityProviderOut.model_validate(provider)
class OAuth2ConfigUpdate(BaseModel):
"""OAuth2 provider configuration update with dedicated fields."""
name: str | None = None
is_active: bool | None = None
app_id: str | None = None
app_secret: str | None = None # Set to None to keep existing, empty to clear
authorize_url: str | None = None
token_url: str | None = None
user_info_url: str | None = None
scope: str | None = None
@router.patch("/identity-providers/{provider_id}/oauth2", response_model=IdentityProviderOut)
async def update_oauth2_provider(
provider_id: uuid.UUID,
data: OAuth2ConfigUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Update an OAuth2 identity provider with simplified fields."""
result = await db.execute(select(IdentityProvider).where(IdentityProvider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if provider.provider_type != "oauth2":
raise HTTPException(status_code=400, detail="Provider is not an OAuth2 provider")
if current_user.role != "platform_admin" and provider.tenant_id != current_user.tenant_id:
raise HTTPException(status_code=403, detail="Not authorized to update this provider")
# Update name and is_active
if data.name is not None:
provider.name = data.name
if data.is_active is not None:
provider.is_active = data.is_active
# Update config fields
if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope]):
current_config = provider.config.copy()
if data.app_id is not None:
current_config["app_id"] = data.app_id
current_config["client_id"] = data.app_id
if data.app_secret is not None:
# Only update if explicitly set (not None) - allows clearing
if data.app_secret:
current_config["app_secret"] = data.app_secret
current_config["client_secret"] = data.app_secret
else:
current_config.pop("app_secret", None)
current_config.pop("client_secret", None)
if data.authorize_url is not None:
current_config["authorize_url"] = data.authorize_url
if data.token_url is not None:
current_config["token_url"] = data.token_url
if data.user_info_url is not None:
current_config["user_info_url"] = data.user_info_url
if data.scope is not None:
current_config["scope"] = data.scope
# Validate the updated config
validate_provider_config("oauth2", current_config)
provider.config = current_config
await db.commit()
await db.refresh(provider)
return IdentityProviderOut.model_validate(provider)
class IdentityProviderUpdate(BaseModel):
name: str | None = None
is_active: bool | None = None
sso_login_enabled: bool | None = None
config: dict | None = None
@router.put("/identity-providers/{provider_id}", response_model=IdentityProviderOut)
async def update_identity_provider(
provider_id: uuid.UUID,
data: IdentityProviderUpdate,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Update an existing identity provider."""
result = await db.execute(select(IdentityProvider).where(IdentityProvider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if current_user.role != "platform_admin" and provider.tenant_id != current_user.tenant_id:
raise HTTPException(status_code=403, detail="Not authorized to update this provider")
if data.name is not None:
provider.name = data.name
if data.is_active is not None:
provider.is_active = data.is_active
if data.sso_login_enabled is not None:
if data.sso_login_enabled is True and not provider.sso_login_enabled:
# Pre-check IP restriction before writing anything
if not await sso_service.validate_sso_enablement(db, provider.tenant_id):
raise HTTPException(
status_code=400,
detail="IP address does not support multi-tenant SSO. Another tenant already has SSO enabled."
)
provider.sso_login_enabled = data.sso_login_enabled
if data.config is not None:
# Merge config
new_config = provider.config.copy()
new_config.update(data.config)
# Validate merged config
validate_provider_config(provider.provider_type, new_config)
provider.config = new_config
await db.commit()
await db.refresh(provider)
# Recompute tenant.sso_enabled derived state whenever sso_login_enabled changes
sso_domain = None
if data.sso_login_enabled is not None and provider.tenant_id:
await _sync_tenant_sso_state(db, provider.tenant_id)
from app.models.tenant import Tenant
tenant_result = await db.execute(select(Tenant).where(Tenant.id == provider.tenant_id))
t = tenant_result.scalar_one_or_none()
if t:
sso_domain = t.sso_domain
out = IdentityProviderOut.model_validate(provider)
out.sso_domain = sso_domain
return out
@router.delete("/identity-providers/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_identity_provider(
provider_id: uuid.UUID,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Delete an identity provider."""
result = await db.execute(select(IdentityProvider).where(IdentityProvider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if current_user.role != "platform_admin" and provider.tenant_id != current_user.tenant_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this provider")
try:
# Nullify references in synced org data before deleting the provider
from sqlalchemy import update
await db.execute(
update(OrgMember).where(OrgMember.provider_id == provider_id).values(provider_id=None)
)
await db.execute(
update(OrgDepartment).where(OrgDepartment.provider_id == provider_id).values(provider_id=None)
)
await db.delete(provider)
await db.commit()
except SQLAlchemyError as e:
await db.rollback()
logger.error(f"Failed to delete identity provider {provider_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to delete identity provider due to database constraints")
# ─── Org Structure ──────────────────────────────────────
from app.models.org import OrgDepartment, OrgMember
@router.get("/org/departments")
async def list_org_departments(
tenant_id: str | None = None,
provider_id: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List all departments, optionally filtered by tenant or provider."""
# Tenant isolation rules:
# 1. If tenant_id param is explicitly provided:
# - non-platform-admins: must match their own tenant_id
# - platform_admin with a tenant in token: must match that tenant
# - platform_admin without a tenant (global view): any tenant allowed
# 2. If tenant_id param is NOT provided:
# - auto-scope to current_user.tenant_id when it is set (applies to ALL roles)
# - only a platform_admin with NO tenant_id in token can query unrestricted
effective_tenant_id = str(current_user.tenant_id) if current_user.tenant_id else None
is_global_admin = (current_user.role == "platform_admin" and not effective_tenant_id)
if tenant_id:
# Validate requested tenant against user context
if not is_global_admin and effective_tenant_id and effective_tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Cannot access other tenant's data")
else:
# Auto-scope: use the user's own tenant when available
tenant_id = effective_tenant_id # None only for true global admin
query = select(OrgDepartment, IdentityProvider.name.label("provider_name"), IdentityProvider.provider_type).outerjoin(
IdentityProvider, OrgDepartment.provider_id == IdentityProvider.id
).where(OrgDepartment.status == "active")
if tenant_id:
query = query.where(OrgDepartment.tenant_id == uuid.UUID(tenant_id))
if provider_id:
query = query.where(OrgDepartment.provider_id == uuid.UUID(provider_id))
result = await db.execute(query.order_by(OrgDepartment.name))
rows = result.all()
# Calculate total members for this scope (for the "All" entry in frontend)
total_q = select(func.count(OrgMember.id)).where(OrgMember.status == "active")
if tenant_id:
total_q = total_q.where(OrgMember.tenant_id == uuid.UUID(tenant_id))
if provider_id:
total_q = total_q.where(OrgMember.provider_id == uuid.UUID(provider_id))
total_result = await db.execute(total_q)
total_member = total_result.scalar() or 0
return {
"items": [
{
"id": str(d.id),
"external_id": d.external_id,
"provider_id": str(d.provider_id) if d.provider_id else None,
"provider_name": provider_name if d.provider_id else None,
"provider_type": provider_type if d.provider_id else None,
"name": d.name,
"parent_id": str(d.parent_id) if d.parent_id else None,
"path": d.path,
"member_count": d.member_count,
}
for d, provider_name, provider_type in rows
],
"total_member": total_member,
}
from sqlalchemy import or_
@router.get("/org/members")
async def list_org_members(
department_id: str | None = None,
search: str | None = None,
tenant_id: str | None = None,
provider_id: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List org members, optionally filtered by department, search, tenant, or provider."""
# Tenant isolation rules:
# 1. If tenant_id param is explicitly provided:
# - non-platform-admins: must match their own tenant_id
# - platform_admin with a tenant in token: must match that tenant
# - platform_admin without a tenant (global view): any tenant allowed
# 2. If tenant_id param is NOT provided:
# - auto-scope to current_user.tenant_id when it is set (applies to ALL roles)
# - only a platform_admin with NO tenant_id in token can query unrestricted
effective_tenant_id = str(current_user.tenant_id) if current_user.tenant_id else None
is_global_admin = (current_user.role == "platform_admin" and not effective_tenant_id)
if tenant_id:
# Validate requested tenant against user context
if not is_global_admin and effective_tenant_id and effective_tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Cannot access other tenant's data")
else:
# Auto-scope: use the user's own tenant when available
tenant_id = effective_tenant_id # None only for true global admin
query = select(OrgMember, IdentityProvider.name.label("provider_name"), IdentityProvider.provider_type).outerjoin(
IdentityProvider, OrgMember.provider_id == IdentityProvider.id
).where(OrgMember.status == "active")
if tenant_id:
query = query.where(OrgMember.tenant_id == uuid.UUID(tenant_id))
if department_id:
# Get the department to find its path and then include all sub-departments
dept_result = await db.execute(select(OrgDepartment).where(OrgDepartment.id == uuid.UUID(department_id)))
target_dept = dept_result.scalar_one_or_none()
if target_dept:
# Build sub-department query: the selected dept itself, plus any dept whose path
# starts with its path followed by a "/" (i.e., all descendants).
sub_dept_conditions = [OrgDepartment.id == target_dept.id]
if target_dept.path:
# Use SQL LIKE to find all descendants based on path prefix
sub_dept_conditions.append(OrgDepartment.path.like(f"{target_dept.path}/%"))
sub_depts_query = select(OrgDepartment.id).where(or_(*sub_dept_conditions))
sub_dept_ids_result = await db.execute(sub_depts_query)
sub_dept_ids = [row[0] for row in sub_dept_ids_result.all()]
query = query.where(OrgMember.department_id.in_(sub_dept_ids))
else:
# Fallback: exact match
query = query.where(OrgMember.department_id == uuid.UUID(department_id))
if provider_id:
query = query.where(OrgMember.provider_id == uuid.UUID(provider_id))
if search:
query = query.where(
or_(
OrgMember.name.ilike(f"%{search}%"),
OrgMember.name_translit_full.ilike(f"%{search}%"),
OrgMember.name_translit_initial.ilike(f"%{search}%"),
)
)
query = query.order_by(OrgMember.name).limit(100)
result = await db.execute(query)
rows = result.all()
return [
{
"id": str(m.id),
"name": m.name,
"email": m.email,
"title": m.title,
"department_path": m.department_path,
"avatar_url": m.avatar_url,
"external_id": m.external_id,
"provider_id": str(m.provider_id) if m.provider_id else None,
"provider_name": provider_name if m.provider_id else None,
"provider_type": provider_type if m.provider_id else None,
}
for m, provider_name, provider_type in rows
]
@router.post("/org/sync")
async def trigger_org_sync(
provider_id: str | None = None,
current_user: User = Depends(get_current_admin),
db: AsyncSession = Depends(get_db),
):
"""Manually trigger org structure sync from a specific identity provider."""
from app.services.org_sync_service import org_sync_service
if not provider_id:
raise HTTPException(status_code=400, detail="provider_id is required")
try:
pid = uuid.UUID(provider_id)
except Exception:
raise HTTPException(status_code=400, detail="Invalid provider_id")
result = await db.execute(select(IdentityProvider).where(IdentityProvider.id == pid))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if not provider.tenant_id:
raise HTTPException(status_code=400, detail="Provider must be bound to a tenant")
if current_user.role != "platform_admin" and provider.tenant_id != current_user.tenant_id:
raise HTTPException(status_code=403, detail="Cannot sync other tenant's provider")
return await org_sync_service.sync_provider(db, provider_id)
@router.get("/org/wecom-verify/{provider_id}")
async def wecom_org_sync_verify(
provider_id: uuid.UUID,
msg_signature: str = "",
timestamp: str = "",
nonce: str = "",
echostr: str = "",
db: AsyncSession = Depends(get_db),
):
"""Handle WeCom receive-message-server URL verification for the org sync app.
WeCom sends a GET request with msg_signature, timestamp, nonce, echostr when
the admin first saves the receive message server URL in the app settings.
This endpoint decrypts and returns the echostr to complete the handshake.
After this verification succeeds, the WeCom app's trusted IP whitelist becomes
configurable, which is the prerequisite for using App-level credentials (AgentID +
Secret) that have full contact read permission.
Configure URL in WeCom: {BASE_URL}/api/enterprise/org/wecom-verify/{provider_id}
Required provider config keys (set via Clawith WeCom config page):
- verify_token: the Token string set in both WeCom and Clawith
- verify_aes_key: the EncodingAESKey provided by WeCom (43 chars, base64url)
"""
from fastapi.responses import Response as _Response
from app.api.wecom import _decrypt_msg, _verify_signature
result = await db.execute(select(IdentityProvider).where(IdentityProvider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
return _Response(status_code=404)
config = provider.config or {}
token = config.get("verify_token", "")
aes_key = config.get("verify_aes_key", "")
if not token or not aes_key:
logger.warning(
f"[WeCom Verify] Provider {provider_id} is missing verify_token or verify_aes_key in config. "
"Please configure them in the WeCom provider settings."
)
return _Response(status_code=400)
# Verify signature to authenticate the request from WeCom
expected_sig = _verify_signature(token, timestamp, nonce, echostr)
if expected_sig != msg_signature:
logger.warning(f"[WeCom Verify] Signature mismatch for provider {provider_id}")
return _Response(status_code=403)
# Decrypt echostr and return plaintext (WeCom confirms URL ownership)
try:
decrypted, _ = _decrypt_msg(aes_key, echostr)
logger.info(f"[WeCom Verify] Successfully verified org sync callback for provider {provider_id}")
return _Response(content=decrypted, media_type="text/plain")
except Exception as e:
logger.error(f"[WeCom Verify] Failed to decrypt echostr for provider {provider_id}: {e}")
return _Response(status_code=500)
@router.get("/org/wecom-callback/{token}", include_in_schema=False)
async def wecom_callback_verify_universal(
token: str,
aes_key: str = "",
msg_signature: str = "",
timestamp: str = "",
nonce: str = "",
echostr: str = "",
):
"""Universal WeCom callback URL verification endpoint (no database lookup required).
Used to unlock the 企业可信IP configuration in the WeCom admin console.
Unlike the provider-based endpoint, this accepts the verify_token in the URL
path and the EncodingAESKey as a query parameter, so any tenant can use the
publicly accessible server (e.g. try.clawith.ai) regardless of which server
the WeCom provider is actually configured on.
URL format to configure in WeCom App → 接收消息服务器URL:
https://{public_host}/api/enterprise/org/wecom-callback/{verify_token}?aes_key={encoding_aes_key}
WeCom will append msg_signature, timestamp, nonce, echostr to this URL automatically.
Once WeCom verifies this URL, the app's 企业可信IP whitelist becomes configurable and
the user can add their API server IPs to allow App-level user/get calls.
"""
from fastapi.responses import Response as _Response
from app.api.wecom import _decrypt_msg, _verify_signature
if not token:
return _Response(status_code=400, content="verify_token is required in URL path")
if not aes_key:
logger.warning("[WeCom Callback] Missing aes_key query param in universal callback URL")
return _Response(status_code=400, content="aes_key query param is required")
# Verify signature to authenticate the request as coming from WeCom servers
expected_sig = _verify_signature(token, timestamp, nonce, echostr)
if expected_sig != msg_signature:
logger.warning(
f"[WeCom Callback] Signature mismatch: token={token[:8]}... "
f"expected={expected_sig[:16]}... got={msg_signature[:16]}..."
)
return _Response(status_code=403)
# Decrypt echostr and return plaintext to complete WeCom URL verification
try:
decrypted, _ = _decrypt_msg(aes_key, echostr)
logger.info(f"[WeCom Callback] Universal callback verified successfully for token={token[:8]}...")
return _Response(content=decrypted, media_type="text/plain")
except Exception as e:
logger.error(f"[WeCom Callback] Failed to decrypt echostr: {e}")
return _Response(status_code=500)
# ─── Invitation Codes ───────────────────────────────────
from app.models.invitation_code import InvitationCode
class InvitationCodeCreate(BaseModel):
count: int = 1 # how many codes to generate
max_uses: int = 1 # max registrations per code
def _require_tenant_admin(current_user: User) -> None:
"""Check that the user is org_admin or platform_admin with a tenant."""
if current_user.role not in ("platform_admin", "org_admin"):
raise HTTPException(status_code=403, detail="Requires admin privileges")
if not current_user.tenant_id:
raise HTTPException(status_code=400, detail="No company assigned")
@router.post("/invitation-codes")
async def create_invitation_codes(
data: InvitationCodeCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Batch-create invitation codes for the current user's company."""
_require_tenant_admin(current_user)
import random
import string
codes_created = []
for _ in range(min(data.count, 100)): # cap at 100 per batch
code_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
code = InvitationCode(
code=code_str,
tenant_id=current_user.tenant_id,
max_uses=data.max_uses,
created_by=current_user.id,
)
db.add(code)
codes_created.append(code_str)
await db.commit()
return {"created": len(codes_created), "codes": codes_created}
@router.post("/invite-users")
async def invite_users(
request: Request,
data: UserInviteRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Batch-invite users via email to the current user's company."""
_require_tenant_admin(current_user)
if not data.emails:
raise HTTPException(status_code=400, detail="No emails provided")
import random
import string
from app.services.system_email_service import send_company_invitation_email
from app.services.platform_service import platform_service
from app.models.tenant import Tenant
tenant_result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id))
tenant = tenant_result.scalar_one_or_none()
if not tenant:
raise HTTPException(status_code=404, detail="Company not found")
base_url = await platform_service.get_public_base_url(db, request=request)
invited_count = 0
codes = []
for email in data.emails:
email = email.lower().strip()
if not email:
continue
code_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
code = InvitationCode(
code=code_str,
tenant_id=current_user.tenant_id,
max_uses=1,
created_by=current_user.id,
)
db.add(code)
codes.append(code)
invite_url = f"{base_url}/login?code={code_str}&email={email}"
inviter_name = current_user.display_name or current_user.username
# Use background task to send email
background_tasks.add_task(
send_company_invitation_email,
to=email,
inviter_name=inviter_name,
company_name=tenant.name,
invite_url=invite_url,
)
invited_count += 1
if invited_count > 0:
await db.commit()
return {"invited": invited_count, "message": "Invitations sent successfully"}
@router.get("/invitation-codes")
async def list_invitation_codes(
page: int = 1,
page_size: int = 20,
search: str = "",
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List invitation codes for the current user's company."""
_require_tenant_admin(current_user)
from sqlalchemy import func as sqla_func
base_filter = InvitationCode.tenant_id == current_user.tenant_id
stmt = select(InvitationCode).where(base_filter)
count_stmt = select(sqla_func.count()).select_from(InvitationCode).where(base_filter)
if search:
stmt = stmt.where(InvitationCode.code.ilike(f"%{search}%"))
count_stmt = count_stmt.where(InvitationCode.code.ilike(f"%{search}%"))
total_result = await db.execute(count_stmt)
total = total_result.scalar() or 0
offset = (max(page, 1) - 1) * page_size
result = await db.execute(
stmt.order_by(InvitationCode.created_at.desc()).offset(offset).limit(page_size)
)
codes = result.scalars().all()
return {
"items": [
{
"id": str(c.id),
"code": c.code,
"max_uses": c.max_uses,
"used_count": c.used_count,
"is_active": c.is_active,
"created_at": c.created_at.isoformat() if c.created_at else None,
}
for c in codes
],
"total": total,
"page": page,
"page_size": page_size,
}
@router.get("/invitation-codes/export")
async def export_invitation_codes_csv(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Export invitation codes for the current user's company as CSV."""
_require_tenant_admin(current_user)
import csv
import io
from fastapi.responses import StreamingResponse
result = await db.execute(
select(InvitationCode)
.where(InvitationCode.tenant_id == current_user.tenant_id)
.order_by(InvitationCode.created_at.asc())
)
codes = result.scalars().all()
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["Code", "Max Uses", "Used Count", "Active", "Created At"])
for c in codes:
writer.writerow([
c.code,
c.max_uses,
c.used_count,
"Yes" if c.is_active else "No",
c.created_at.strftime("%Y-%m-%d %H:%M:%S") if c.created_at else "",
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=invitation_codes.csv"},
)
@router.delete("/invitation-codes/{code_id}")
async def deactivate_invitation_code(
code_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Deactivate an invitation code (must belong to current user's company)."""
_require_tenant_admin(current_user)
import uuid as _uuid
result = await db.execute(
select(InvitationCode).where(
InvitationCode.id == _uuid.UUID(code_id),
InvitationCode.tenant_id == current_user.tenant_id,
)
)
code = result.scalar_one_or_none()
if not code:
raise HTTPException(status_code=404, detail="Code not found")
code.is_active = False
await db.commit()
return {"status": "deactivated"}