55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
from typing import Annotated
|
|
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from app.db.engine import engine
|
|
from app.db.models import ApiKey
|
|
from app.security.api_key import get_prefix, verify_api_key
|
|
|
|
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
|
|
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def get_bearer_token(request: Request) -> str:
|
|
auth = request.headers.get("authorization")
|
|
if not auth:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization")
|
|
|
|
parts = auth.split(" ", 1)
|
|
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1].strip():
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization")
|
|
|
|
return parts[1].strip()
|
|
|
|
|
|
def require_permission(permission: str):
|
|
def _dep(
|
|
token: Annotated[str, Depends(get_bearer_token)],
|
|
db: Annotated[Session, Depends(get_db)],
|
|
) -> ApiKey:
|
|
prefix = get_prefix(token)
|
|
stmt = select(ApiKey).where(ApiKey.key_prefix == prefix, ApiKey.is_active.is_(True))
|
|
api_key = db.execute(stmt).scalar_one_or_none()
|
|
if not api_key:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
|
|
|
|
if not verify_api_key(token, api_key.key_hash):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
|
|
|
|
if permission not in (api_key.permissions or []):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
|
|
|
|
return api_key
|
|
|
|
return _dep
|