469 lines
17 KiB
Python
469 lines
17 KiB
Python
"""SSO (Single Sign-On) service for enterprise user authentication.
|
|
|
|
This module handles SSO-based login, user matching, and tenant association.
|
|
"""
|
|
|
|
import re
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import select, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.identity import IdentityProvider
|
|
from app.models.tenant import Tenant
|
|
from app.models.user import Identity, User
|
|
from app.services.platform_service import platform_service
|
|
|
|
|
|
class SSOService:
|
|
"""Service for handling SSO authentication flows."""
|
|
|
|
# Common email domain to tenant mapping hints
|
|
DOMAIN_TENANT_HINTS: dict[str, str] = {}
|
|
|
|
async def match_user_by_email(
|
|
self, db: AsyncSession, email: str, tenant_id: str | None = None
|
|
) -> User | None:
|
|
"""Find existing user by email address.
|
|
|
|
Args:
|
|
db: Database session
|
|
email: User email address
|
|
tenant_id: Optional tenant ID to scope the search
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
# 1. Try direct match via Identity join
|
|
query = (
|
|
select(User)
|
|
.join(User.identity)
|
|
.where(Identity.email == email)
|
|
)
|
|
if tenant_id:
|
|
query = query.where(User.tenant_id == tenant_id)
|
|
|
|
result = await db.execute(query)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user:
|
|
return user
|
|
|
|
# 2. If not found and tenant_id is provided, try to find an Identity
|
|
if email:
|
|
id_query = select(Identity).where(Identity.email == email)
|
|
id_result = await db.execute(id_query)
|
|
identity = id_result.scalar_one_or_none()
|
|
if identity:
|
|
# Find any user for this identity (representative)
|
|
u_res = await db.execute(select(User).where(User.identity_id == identity.id).limit(1))
|
|
return u_res.scalar_one_or_none()
|
|
|
|
return None
|
|
|
|
async def match_user_by_mobile(
|
|
self, db: AsyncSession, mobile: str, tenant_id: str | None = None
|
|
) -> User | None:
|
|
"""Find existing user by mobile phone number.
|
|
|
|
Args:
|
|
db: Database session
|
|
mobile: Mobile phone number
|
|
tenant_id: Optional tenant ID to scope the search
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
# Normalize mobile number
|
|
normalized_mobile = re.sub(r"[\s\-\+]", "", mobile)
|
|
if not normalized_mobile:
|
|
return None
|
|
|
|
# 1. Try direct match via Identity join
|
|
query = (
|
|
select(User)
|
|
.join(User.identity)
|
|
.where(Identity.phone == normalized_mobile)
|
|
)
|
|
if tenant_id:
|
|
query = query.where(User.tenant_id == tenant_id)
|
|
|
|
result = await db.execute(query)
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
return user
|
|
|
|
# 2. Try Identity match
|
|
id_query = select(Identity).where(Identity.phone == normalized_mobile)
|
|
id_result = await db.execute(id_query)
|
|
identity = id_result.scalar_one_or_none()
|
|
if identity:
|
|
u_res = await db.execute(select(User).where(User.identity_id == identity.id).limit(1))
|
|
return u_res.scalar_one_or_none()
|
|
|
|
return None
|
|
|
|
async def auto_associate_tenant(self, db: AsyncSession, email: str) -> str | None:
|
|
"""Detect tenant based on email domain.
|
|
|
|
Args:
|
|
db: Database session
|
|
email: User email address
|
|
|
|
Returns:
|
|
Tenant ID if found, None otherwise
|
|
"""
|
|
if not email or "@" not in email:
|
|
return None
|
|
|
|
domain = email.split("@")[1].lower()
|
|
|
|
# Check domain hints first
|
|
if domain in self.DOMAIN_TENANT_HINTS:
|
|
return self.DOMAIN_TENANT_HINTS[domain]
|
|
|
|
# Try to find tenant by custom domain
|
|
result = await db.execute(
|
|
select(Tenant).where(Tenant.sso_domain.ilike(f"%{domain}%"))
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
|
|
if tenant:
|
|
return str(tenant.id)
|
|
|
|
# Try to find tenant by matching tenant name
|
|
result = await db.execute(
|
|
select(Tenant).where(
|
|
Tenant.name.ilike(f"%{domain.split('.')[0]}%")
|
|
)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
|
|
if tenant:
|
|
return str(tenant.id)
|
|
|
|
return None
|
|
|
|
async def resolve_user_identity(
|
|
self, db: AsyncSession, provider_user_id: str, provider_type: str, tenant_id: str | None = None
|
|
) -> User | None:
|
|
"""Resolve user from external identity via OrgMember.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider_user_id: User ID in the external system (unionid or userid)
|
|
provider_type: Type of provider (feishu, dingtalk, etc.)
|
|
tenant_id: Optional tenant ID to scope the provider search
|
|
|
|
Returns:
|
|
User if found via OrgMember, None otherwise
|
|
"""
|
|
from app.models.org import OrgMember
|
|
|
|
# Get provider
|
|
query = select(IdentityProvider).where(IdentityProvider.provider_type == provider_type)
|
|
if tenant_id:
|
|
query = query.where(IdentityProvider.tenant_id == tenant_id)
|
|
|
|
result = await db.execute(query)
|
|
provider = result.scalar_one_or_none()
|
|
|
|
if not provider:
|
|
return None
|
|
|
|
# Find OrgMember by unionid, external_id, or open_id
|
|
# For Feishu/DingTalk we often use unionid, for WeCom we use external_id (userid)
|
|
member_query = select(OrgMember).where(
|
|
OrgMember.provider_id == provider.id,
|
|
OrgMember.status == "active",
|
|
or_(
|
|
OrgMember.unionid == provider_user_id,
|
|
OrgMember.external_id == provider_user_id,
|
|
OrgMember.open_id == provider_user_id
|
|
)
|
|
)
|
|
member_result = await db.execute(member_query)
|
|
member = member_result.scalar_one_or_none()
|
|
|
|
if not member or not member.user_id:
|
|
return None
|
|
|
|
# Get user
|
|
from sqlalchemy.orm import selectinload
|
|
user_result = await db.execute(
|
|
select(User).where(User.id == member.user_id).options(selectinload(User.identity))
|
|
)
|
|
return user_result.scalar_one_or_none()
|
|
|
|
async def link_identity(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
provider_type: str,
|
|
provider_user_id: str,
|
|
identity_data: dict[str, Any] | None = None,
|
|
tenant_id: str | None = None,
|
|
) -> Any:
|
|
"""Link an external identity to an existing user via OrgMember.
|
|
|
|
When an OrgMember already exists (e.g. from org-sync), this also
|
|
enriches its profile fields with fresh SSO data so placeholder
|
|
records become fully hydrated over time.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID to link to
|
|
provider_type: Type of provider
|
|
provider_user_id: User ID in the external system
|
|
identity_data: Raw data from the provider (ExternalUserInfo.raw_data);
|
|
used for passive profile enrichment.
|
|
tenant_id: Optional tenant ID for provider lookup
|
|
|
|
Returns:
|
|
The linked OrgMember
|
|
"""
|
|
from app.models.org import OrgMember
|
|
|
|
# Get or create provider
|
|
query = select(IdentityProvider).where(
|
|
IdentityProvider.provider_type == provider_type,
|
|
IdentityProvider.tenant_id == tenant_id
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
provider = result.scalar_one_or_none()
|
|
|
|
if not provider:
|
|
raise ValueError(f"Provider {provider_type} not found for tenant {tenant_id}")
|
|
|
|
uid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
# Extract the raw open_id from identity_data (raw provider response).
|
|
# For Feishu: raw_data has 'open_id' and 'union_id' as separate fields.
|
|
# For DingTalk: raw_data has 'openId' and 'unionId'.
|
|
# Storing open_id separately prevents duplicate user creation when the
|
|
# lookup key alternates between open_id and union_id across SSO sessions.
|
|
raw_open_id = None
|
|
if identity_data:
|
|
raw_open_id = (
|
|
identity_data.get("open_id") # Feishu
|
|
or identity_data.get("openId") # DingTalk
|
|
)
|
|
|
|
# Check if OrgMember already exists for this provider user.
|
|
# Search across unionid, external_id, and open_id to handle the case where
|
|
# the lookup key differs between sync (uses user_id/employee_id as external_id)
|
|
# and SSO (uses union_id or open_id as provider_user_id).
|
|
conditions = [
|
|
OrgMember.unionid == provider_user_id,
|
|
OrgMember.external_id == provider_user_id,
|
|
OrgMember.open_id == provider_user_id,
|
|
]
|
|
if raw_open_id and raw_open_id != provider_user_id:
|
|
# Also search by the actual open_id from raw data, in case the member
|
|
# was created with open_id as its primary key (e.g. from a previous SSO login)
|
|
conditions.append(OrgMember.open_id == raw_open_id)
|
|
conditions.append(OrgMember.external_id == raw_open_id)
|
|
|
|
member_query = select(OrgMember).where(
|
|
OrgMember.provider_id == provider.id,
|
|
OrgMember.status == "active",
|
|
or_(*conditions)
|
|
)
|
|
member_result = await db.execute(member_query)
|
|
member = member_result.scalar_one_or_none()
|
|
|
|
if member:
|
|
# Always link user
|
|
member.user_id = uid
|
|
|
|
# Fill in open_id if not already set — prevents future lookup misses
|
|
if raw_open_id and not member.open_id:
|
|
member.open_id = raw_open_id
|
|
|
|
# Passive identity enrichment: update profile fields from SSO data.
|
|
# OrgMember records created by org-sync may have placeholder values
|
|
# (e.g. name=userid, no avatar/email). We fill them in here so they
|
|
# become accurate after the user's first SSO login, without needing
|
|
# IP-whitelisted batch calls.
|
|
if identity_data:
|
|
incoming_name = (
|
|
identity_data.get("name")
|
|
or identity_data.get("display_name")
|
|
)
|
|
# Only overwrite name if the current value looks like a placeholder
|
|
# (e.g. was set to the raw userid during degraded org sync)
|
|
is_placeholder_name = (
|
|
not member.name
|
|
or member.name == member.external_id
|
|
or member.name == provider_user_id
|
|
or member.name.startswith(f"{provider_type.capitalize()} User")
|
|
)
|
|
if incoming_name and is_placeholder_name:
|
|
member.name = incoming_name
|
|
|
|
incoming_email = identity_data.get("email") or identity_data.get("biz_mail")
|
|
if incoming_email and not member.email:
|
|
member.email = incoming_email
|
|
|
|
incoming_avatar = identity_data.get("avatar")
|
|
if incoming_avatar and not member.avatar_url:
|
|
member.avatar_url = incoming_avatar
|
|
|
|
incoming_mobile = identity_data.get("mobile")
|
|
if incoming_mobile and not member.phone:
|
|
member.phone = incoming_mobile
|
|
|
|
else:
|
|
# Create a shell OrgMember if not synced yet.
|
|
# This handles organizations that skip org-sync and rely purely on SSO.
|
|
member_name = (
|
|
(identity_data.get("name") or identity_data.get("display_name"))
|
|
if identity_data else None
|
|
)
|
|
member = OrgMember(
|
|
name=member_name or f"{provider_type.capitalize()} User {provider_user_id[:8]}",
|
|
email=(identity_data.get("email") or identity_data.get("biz_mail")) if identity_data else None,
|
|
avatar_url=identity_data.get("avatar") if identity_data else None,
|
|
phone=identity_data.get("mobile") if identity_data else None,
|
|
provider_id=provider.id,
|
|
user_id=uid,
|
|
tenant_id=tenant_id,
|
|
# For Feishu/DingTalk: external_id stores union_id (cross-app stable).
|
|
# open_id is stored separately so it can also be matched on next login.
|
|
external_id=provider_user_id,
|
|
unionid=provider_user_id if provider_type != "wecom" else None,
|
|
# Explicitly store the raw open_id so future SSO lookups can match on it
|
|
# even if the lookup key is union_id (and vice versa).
|
|
open_id=raw_open_id,
|
|
)
|
|
db.add(member)
|
|
|
|
await db.flush()
|
|
return member
|
|
|
|
async def unlink_identity(
|
|
self, db: AsyncSession, user_id: str, provider_type: str, tenant_id: str | None = None
|
|
) -> bool:
|
|
"""Unlink an external identity (OrgMember) from a user.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID
|
|
provider_type: Type of provider to unlink
|
|
tenant_id: Optional tenant ID
|
|
|
|
Returns:
|
|
True if unlinked, False if not found
|
|
"""
|
|
from app.models.org import OrgMember
|
|
|
|
# Get provider
|
|
query = select(IdentityProvider).where(IdentityProvider.provider_type == provider_type)
|
|
if tenant_id:
|
|
query = query.where(IdentityProvider.tenant_id == tenant_id)
|
|
|
|
result = await db.execute(query)
|
|
provider = result.scalar_one_or_none()
|
|
|
|
if not provider:
|
|
return False
|
|
|
|
# Find OrgMember
|
|
mid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
|
|
member_result = await db.execute(
|
|
select(OrgMember).where(
|
|
OrgMember.user_id == mid,
|
|
OrgMember.provider_id == provider.id,
|
|
)
|
|
)
|
|
member = member_result.scalar_one_or_none()
|
|
|
|
if not member:
|
|
return False
|
|
|
|
member.user_id = None
|
|
await db.flush()
|
|
|
|
return True
|
|
|
|
async def check_duplicate_identity(
|
|
self, db: AsyncSession, provider_type: str, provider_user_id: str, tenant_id: str | None = None
|
|
) -> User | None:
|
|
"""Check if an external identity is already linked to another user.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider_type: Type of provider
|
|
provider_user_id: User ID in the external system
|
|
tenant_id: Optional tenant ID
|
|
|
|
Returns:
|
|
Existing user if identity is already linked, None otherwise
|
|
"""
|
|
return await self.resolve_user_identity(db, provider_user_id, provider_type, tenant_id)
|
|
|
|
async def validate_sso_enablement(self, db: AsyncSession, tenant_id: uuid.UUID) -> bool:
|
|
"""Check if SSO can be enabled for this tenant under IP restrictions.
|
|
|
|
Only checks when THIS tenant doesn't have SSO enabled yet.
|
|
If tenant already has sso_enabled=True, allows without checking.
|
|
|
|
Returns True if allowed, False if another tenant already has SSO enabled on an IP base.
|
|
"""
|
|
# First check if this tenant already has SSO enabled
|
|
tenant_result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
|
tenant = tenant_result.scalar_one_or_none()
|
|
if tenant and tenant.sso_enabled:
|
|
# Already has SSO enabled, can freely toggle providers
|
|
return True
|
|
|
|
# This tenant doesn't have SSO enabled yet, check IP restriction
|
|
base_url = await platform_service.get_public_base_url(db)
|
|
|
|
# Parse host
|
|
parts = base_url.split("://")
|
|
if len(parts) < 2:
|
|
return True # Conservative default
|
|
|
|
host = parts[1].split(":")[0].split("/")[0]
|
|
|
|
if not platform_service.is_ip_address(host):
|
|
return True
|
|
|
|
# IP Address: only ONE tenant in the whole system can have SSO enabled.
|
|
# Check if any *other* tenant has an active SSO-enabled provider.
|
|
query = select(IdentityProvider).where(
|
|
IdentityProvider.sso_login_enabled == True,
|
|
IdentityProvider.is_active == True,
|
|
IdentityProvider.tenant_id != tenant_id,
|
|
)
|
|
result = await db.execute(query)
|
|
other_providers = result.scalars().all()
|
|
|
|
if other_providers:
|
|
# Collect conflicting tenant names
|
|
conflict_names = []
|
|
for other_provider in other_providers:
|
|
tenant_query = await db.execute(select(Tenant).where(Tenant.id == other_provider.tenant_id))
|
|
conflict_tenant = tenant_query.scalar_one_or_none()
|
|
name = conflict_tenant.name if conflict_tenant else str(other_provider.tenant_id)
|
|
conflict_names.append(f"'{name}'")
|
|
conflict_str = ", ".join(conflict_names)
|
|
logger.warning(f"[SSO] IP conflict: tenant_id={tenant_id} cannot enable SSO, other tenants already have SSO enabled on IP base: {conflict_str}")
|
|
return len(other_providers) == 0
|
|
|
|
def add_domain_hint(self, domain: str, tenant_id: str):
|
|
"""Add a domain to tenant mapping hint.
|
|
|
|
Args:
|
|
domain: Email domain (e.g., "company.com")
|
|
tenant_id: Associated tenant ID
|
|
"""
|
|
self.DOMAIN_TENANT_HINTS[domain.lower()] = tenant_id
|
|
|
|
|
|
# Global SSO service instance
|
|
sso_service = SSOService() |