Clawith/backend/app/services/auth_registry.py

223 lines
6.4 KiB
Python

"""Authentication provider registry and factory.
This module provides a centralized way to manage and instantiate auth providers.
"""
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.identity import IdentityProvider
from app.services.auth_provider import (
PROVIDER_CLASSES,
BaseAuthProvider,
DingTalkAuthProvider,
FeishuAuthProvider,
MicrosoftTeamsAuthProvider,
WeComAuthProvider,
)
class AuthProviderRegistry:
"""Registry for managing authentication provider instances.
This class provides a factory method to create provider instances
and caches them for reuse.
"""
def __init__(self):
self._cache: dict[str, BaseAuthProvider] = {}
async def get_provider(
self, db: AsyncSession, provider_type: str, tenant_id: str | None = None
) -> BaseAuthProvider | None:
"""Get or create an authentication provider instance.
Args:
db: Database session
provider_type: The type of provider (feishu, dingtalk, etc.)
tenant_id: Optional tenant ID for tenant-specific providers
Returns:
Provider instance or None if provider type is not supported
"""
# Check cache first
cache_key = f"{provider_type}:{tenant_id or 'global'}"
if cache_key in self._cache:
return self._cache[cache_key]
# Try to get provider config from database
query = select(IdentityProvider).where(
IdentityProvider.provider_type == provider_type,
IdentityProvider.is_active == True,
IdentityProvider.tenant_id == tenant_id
)
result = await db.execute(query)
provider_model = result.scalar_one_or_none()
# Create provider instance
provider = self._create_provider(provider_type, provider_model)
if provider:
self._cache[cache_key] = provider
return provider
def _create_provider(
self, provider_type: str, provider_model: IdentityProvider | None
) -> BaseAuthProvider | None:
"""Create a provider instance based on type.
Args:
provider_type: The type of provider
provider_model: Optional IdentityProvider model from database
Returns:
Provider instance or None
"""
provider_class = PROVIDER_CLASSES.get(provider_type)
if not provider_class:
return None
config = provider_model.config if provider_model else {}
return provider_class(provider=provider_model, config=config)
async def list_providers(
self, db: AsyncSession, tenant_id: str | None = None
) -> list[IdentityProvider]:
"""List all available identity providers.
Args:
db: Database session
tenant_id: Optional tenant ID to filter by
Returns:
List of IdentityProvider records
"""
query = select(IdentityProvider).where(IdentityProvider.is_active == True)
if tenant_id:
# Only include tenant-specific ones
query = query.where(IdentityProvider.tenant_id == tenant_id)
result = await db.execute(query)
return list(result.scalars().all())
async def create_provider(
self,
db: AsyncSession,
provider_type: str,
name: str,
config: dict[str, Any],
tenant_id: str | None = None,
) -> IdentityProvider:
"""Create a new identity provider.
Args:
db: Database session
provider_type: Type of provider
name: Display name
config: Provider configuration
tenant_id: Optional tenant ID for tenant-specific provider
Returns:
Created IdentityProvider record
"""
provider = IdentityProvider(
provider_type=provider_type,
name=name,
is_active=True,
config=config,
tenant_id=tenant_id,
)
db.add(provider)
await db.flush()
# Clear cache for this provider type
self._clear_cache(provider_type)
return provider
async def update_provider(
self,
db: AsyncSession,
provider_id: str,
name: str | None = None,
config: dict[str, Any] | None = None,
is_active: bool | None = None,
) -> IdentityProvider | None:
"""Update an existing identity provider.
Args:
db: Database session
provider_id: Provider ID
name: New display name
config: New configuration
is_active: New active status
Returns:
Updated IdentityProvider or None if not found
"""
result = await db.execute(
select(IdentityProvider).where(IdentityProvider.id == provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
return None
if name is not None:
provider.name = name
if config is not None:
provider.config = config
if is_active is not None:
provider.is_active = is_active
await db.flush()
# Clear cache
self._clear_cache(provider.provider_type)
return provider
async def delete_provider(self, db: AsyncSession, provider_id: str) -> bool:
"""Delete an identity provider.
Args:
db: Database session
provider_id: Provider ID
Returns:
True if deleted, False if not found
"""
result = await db.execute(
select(IdentityProvider).where(IdentityProvider.id == provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
return False
provider_type = provider.provider_type
await db.delete(provider)
await db.flush()
# Clear cache
self._clear_cache(provider_type)
return True
def _clear_cache(self, provider_type: str):
"""Clear cached provider instances for a type."""
keys_to_delete = [k for k in self._cache if k.startswith(f"{provider_type}:")]
for key in keys_to_delete:
del self._cache[key]
def clear_all_cache(self):
"""Clear all cached provider instances."""
self._cache.clear()
# Global registry instance
auth_provider_registry = AuthProviderRegistry()