from typing import Annotated from collections.abc import Sequence from fastapi import Depends, HTTPException, Request, status from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from app.db.engine import engine, supabase_engine from app.db.models import ApiKey from app.security.api_key import get_prefix, verify_api_key SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) SupabaseSessionLocal = sessionmaker(bind=supabase_engine, autoflush=False, autocommit=False) def get_db(): db = SessionLocal() try: yield db finally: db.close() def get_supabase_db(): db = SupabaseSessionLocal() try: yield db finally: db.close() def get_bearer_token(request: Request) -> str: auth = request.headers.get("authorization") if not auth: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization") parts = auth.split(" ", 1) if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1].strip(): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization") return parts[1].strip() def require_permission(permission: str | Sequence[str]): def _dep( token: Annotated[str, Depends(get_bearer_token)], db: Annotated[Session, Depends(get_db)], ) -> ApiKey: prefix = get_prefix(token) stmt = select(ApiKey).where(ApiKey.key_prefix == prefix, ApiKey.is_active.is_(True)) api_key = db.execute(stmt).scalar_one_or_none() if not api_key: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") if not verify_api_key(token, api_key.key_hash): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") allowed = set(api_key.permissions or []) required = [permission] if isinstance(permission, str) else list(permission) if not any(p in allowed for p in required): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied") return api_key return _dep