Clawith/backend/app/core/security.py

207 lines
6.4 KiB
Python

"""Security utilities: JWT, password hashing, and authentication dependencies."""
import base64
import os
import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import get_settings
from app.database import get_db
settings = get_settings()
# Bearer token scheme
security = HTTPBearer()
def encrypt_data(plaintext: str, key: str) -> str:
"""Encrypt a string using AES-256-CBC with the given key.
Args:
plaintext: The string to encrypt
key: The encryption key (will be hashed to 32 bytes)
Returns:
Base64-encoded encrypted string with IV prefix
"""
if not plaintext:
return ""
# Derive 32-byte key from the secret key
key_bytes = key.encode("utf-8")
# Use SHA-256 hash to get exactly 32 bytes for AES-256
import hashlib
aes_key = hashlib.sha256(key_bytes).digest()
# Generate random 16-byte IV
iv = os.urandom(16)
# Create cipher and encrypt
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
padded_data = pad(plaintext.encode("utf-8"), AES.block_size)
encrypted = cipher.encrypt(padded_data)
# Prepend IV to ciphertext and encode as base64
result = base64.b64encode(iv + encrypted).decode("utf-8")
return result
def decrypt_data(ciphertext: str, key: str) -> str:
"""Decrypt a string encrypted with encrypt_data.
Args:
ciphertext: Base64-encoded encrypted string with IV prefix
key: The encryption key (must match the key used for encryption)
Returns:
Decrypted plaintext string
Raises:
ValueError: If decryption fails (wrong key, corrupted data, etc.)
"""
if not ciphertext:
return ""
try:
# Decode base64
raw = base64.b64decode(ciphertext)
# Extract IV (first 16 bytes) and ciphertext
iv = raw[:16]
encrypted = raw[16:]
# Derive key
import hashlib
aes_key = hashlib.sha256(key.encode("utf-8")).digest()
# Decrypt
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
padded_data = cipher.decrypt(encrypted)
plaintext = unpad(padded_data, AES.block_size).decode("utf-8")
return plaintext
except Exception as e:
raise ValueError(f"Decryption failed: {e}") from e
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
def create_access_token(user_id: str, role: str, expires_delta: timedelta | None = None) -> str:
"""Create a JWT access token."""
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
)
to_encode = {
"sub": user_id,
"role": role,
"exp": expire,
}
return jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
def decode_access_token(token: str) -> dict:
"""Decode and validate a JWT access token."""
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
):
"""Dependency to get the current authenticated and active user."""
from app.models.user import User
payload = decode_access_token(credentials.credentials)
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
result = await db.execute(
select(User)
.where(User.id == uuid.UUID(user_id))
.options(selectinload(User.identity))
)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
return user
async def get_authenticated_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
):
"""Dependency to get the current authenticated user (even if not active yet)."""
from app.models.user import User
payload = decode_access_token(credentials.credentials)
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
result = await db.execute(
select(User)
.where(User.id == uuid.UUID(user_id))
.options(selectinload(User.identity))
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user
async def get_current_admin(current_user=Depends(get_current_user)):
"""Dependency to require admin role (platform_admin or org_admin)."""
if current_user.role not in ("platform_admin", "org_admin"):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
return current_user
# Role hierarchy: higher index = more privileges
ROLE_HIERARCHY = ["member", "agent_admin", "org_admin", "platform_admin"]
def require_role(*allowed_roles: str):
"""Factory to create a dependency that checks if the user has one of the allowed roles.
Usage:
@router.post("/", dependencies=[Depends(require_role("org_admin", "platform_admin"))])
async def my_endpoint(...):
"""
async def _check(current_user=Depends(get_current_user)):
if current_user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"需要以下角色之一: {', '.join(allowed_roles)}",
)
return current_user
return _check