223 lines
6.4 KiB
Python
223 lines
6.4 KiB
Python
"""Authentication provider registry and factory.
|
|
|
|
This module provides a centralized way to manage and instantiate auth providers.
|
|
"""
|
|
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.identity import IdentityProvider
|
|
from app.services.auth_provider import (
|
|
PROVIDER_CLASSES,
|
|
BaseAuthProvider,
|
|
DingTalkAuthProvider,
|
|
FeishuAuthProvider,
|
|
MicrosoftTeamsAuthProvider,
|
|
WeComAuthProvider,
|
|
)
|
|
|
|
|
|
class AuthProviderRegistry:
|
|
"""Registry for managing authentication provider instances.
|
|
|
|
This class provides a factory method to create provider instances
|
|
and caches them for reuse.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._cache: dict[str, BaseAuthProvider] = {}
|
|
|
|
async def get_provider(
|
|
self, db: AsyncSession, provider_type: str, tenant_id: str | None = None
|
|
) -> BaseAuthProvider | None:
|
|
"""Get or create an authentication provider instance.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider_type: The type of provider (feishu, dingtalk, etc.)
|
|
tenant_id: Optional tenant ID for tenant-specific providers
|
|
|
|
Returns:
|
|
Provider instance or None if provider type is not supported
|
|
"""
|
|
# Check cache first
|
|
cache_key = f"{provider_type}:{tenant_id or 'global'}"
|
|
if cache_key in self._cache:
|
|
return self._cache[cache_key]
|
|
|
|
# Try to get provider config from database
|
|
query = select(IdentityProvider).where(
|
|
IdentityProvider.provider_type == provider_type,
|
|
IdentityProvider.is_active == True,
|
|
IdentityProvider.tenant_id == tenant_id
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
provider_model = result.scalar_one_or_none()
|
|
|
|
# Create provider instance
|
|
provider = self._create_provider(provider_type, provider_model)
|
|
if provider:
|
|
self._cache[cache_key] = provider
|
|
|
|
return provider
|
|
|
|
def _create_provider(
|
|
self, provider_type: str, provider_model: IdentityProvider | None
|
|
) -> BaseAuthProvider | None:
|
|
"""Create a provider instance based on type.
|
|
|
|
Args:
|
|
provider_type: The type of provider
|
|
provider_model: Optional IdentityProvider model from database
|
|
|
|
Returns:
|
|
Provider instance or None
|
|
"""
|
|
provider_class = PROVIDER_CLASSES.get(provider_type)
|
|
if not provider_class:
|
|
return None
|
|
|
|
config = provider_model.config if provider_model else {}
|
|
return provider_class(provider=provider_model, config=config)
|
|
|
|
async def list_providers(
|
|
self, db: AsyncSession, tenant_id: str | None = None
|
|
) -> list[IdentityProvider]:
|
|
"""List all available identity providers.
|
|
|
|
Args:
|
|
db: Database session
|
|
tenant_id: Optional tenant ID to filter by
|
|
|
|
Returns:
|
|
List of IdentityProvider records
|
|
"""
|
|
query = select(IdentityProvider).where(IdentityProvider.is_active == True)
|
|
|
|
if tenant_id:
|
|
# Only include tenant-specific ones
|
|
query = query.where(IdentityProvider.tenant_id == tenant_id)
|
|
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def create_provider(
|
|
self,
|
|
db: AsyncSession,
|
|
provider_type: str,
|
|
name: str,
|
|
config: dict[str, Any],
|
|
tenant_id: str | None = None,
|
|
) -> IdentityProvider:
|
|
"""Create a new identity provider.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider_type: Type of provider
|
|
name: Display name
|
|
config: Provider configuration
|
|
tenant_id: Optional tenant ID for tenant-specific provider
|
|
|
|
Returns:
|
|
Created IdentityProvider record
|
|
"""
|
|
provider = IdentityProvider(
|
|
provider_type=provider_type,
|
|
name=name,
|
|
is_active=True,
|
|
config=config,
|
|
tenant_id=tenant_id,
|
|
)
|
|
db.add(provider)
|
|
await db.flush()
|
|
|
|
# Clear cache for this provider type
|
|
self._clear_cache(provider_type)
|
|
|
|
return provider
|
|
|
|
async def update_provider(
|
|
self,
|
|
db: AsyncSession,
|
|
provider_id: str,
|
|
name: str | None = None,
|
|
config: dict[str, Any] | None = None,
|
|
is_active: bool | None = None,
|
|
) -> IdentityProvider | None:
|
|
"""Update an existing identity provider.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider_id: Provider ID
|
|
name: New display name
|
|
config: New configuration
|
|
is_active: New active status
|
|
|
|
Returns:
|
|
Updated IdentityProvider or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(IdentityProvider).where(IdentityProvider.id == provider_id)
|
|
)
|
|
provider = result.scalar_one_or_none()
|
|
|
|
if not provider:
|
|
return None
|
|
|
|
if name is not None:
|
|
provider.name = name
|
|
if config is not None:
|
|
provider.config = config
|
|
if is_active is not None:
|
|
provider.is_active = is_active
|
|
|
|
await db.flush()
|
|
|
|
# Clear cache
|
|
self._clear_cache(provider.provider_type)
|
|
|
|
return provider
|
|
|
|
async def delete_provider(self, db: AsyncSession, provider_id: str) -> bool:
|
|
"""Delete an identity provider.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider_id: Provider ID
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(IdentityProvider).where(IdentityProvider.id == provider_id)
|
|
)
|
|
provider = result.scalar_one_or_none()
|
|
|
|
if not provider:
|
|
return False
|
|
|
|
provider_type = provider.provider_type
|
|
await db.delete(provider)
|
|
await db.flush()
|
|
|
|
# Clear cache
|
|
self._clear_cache(provider_type)
|
|
|
|
return True
|
|
|
|
def _clear_cache(self, provider_type: str):
|
|
"""Clear cached provider instances for a type."""
|
|
keys_to_delete = [k for k in self._cache if k.startswith(f"{provider_type}:")]
|
|
for key in keys_to_delete:
|
|
del self._cache[key]
|
|
|
|
def clear_all_cache(self):
|
|
"""Clear all cached provider instances."""
|
|
self._cache.clear()
|
|
|
|
|
|
# Global registry instance
|
|
auth_provider_registry = AuthProviderRegistry() |