67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
from typing import Annotated
|
|
from collections.abc import Sequence
|
|
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from app.db.engine import engine, supabase_engine
|
|
from app.db.models import ApiKey
|
|
from app.security.api_key import get_prefix, verify_api_key
|
|
|
|
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
SupabaseSessionLocal = sessionmaker(bind=supabase_engine, autoflush=False, autocommit=False)
|
|
|
|
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def get_supabase_db():
|
|
db = SupabaseSessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def get_bearer_token(request: Request) -> str:
|
|
auth = request.headers.get("authorization")
|
|
if not auth:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization")
|
|
|
|
parts = auth.split(" ", 1)
|
|
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1].strip():
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization")
|
|
|
|
return parts[1].strip()
|
|
|
|
|
|
def require_permission(permission: str | Sequence[str]):
|
|
def _dep(
|
|
token: Annotated[str, Depends(get_bearer_token)],
|
|
db: Annotated[Session, Depends(get_db)],
|
|
) -> ApiKey:
|
|
prefix = get_prefix(token)
|
|
stmt = select(ApiKey).where(ApiKey.key_prefix == prefix, ApiKey.is_active.is_(True))
|
|
api_key = db.execute(stmt).scalar_one_or_none()
|
|
if not api_key:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
|
|
|
|
if not verify_api_key(token, api_key.key_hash):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
|
|
|
|
allowed = set(api_key.permissions or [])
|
|
required = [permission] if isinstance(permission, str) else list(permission)
|
|
if not any(p in allowed for p in required):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
|
|
|
|
return api_key
|
|
|
|
return _dep
|