add previous fix bug forgotting commit-push
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqladmin import Admin, ModelView
|
||||
from sqladmin.authentication import AuthenticationBackend
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.responses import HTMLResponse, RedirectResponse
|
||||
from starlette.datastructures import URL
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from wtforms import StringField
|
||||
from wtforms import BooleanField, SelectField, StringField
|
||||
from wtforms.validators import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -36,40 +38,214 @@ class AdminAuth(AuthenticationBackend):
|
||||
class ApiClientAdmin(ModelView, model=ApiClient):
|
||||
column_list = [ApiClient.id, ApiClient.name, ApiClient.is_active]
|
||||
|
||||
async def insert_model(self, request: Request, data: dict) -> ApiClient:
|
||||
obj: ApiClient = await super().insert_model(request, data)
|
||||
|
||||
plain_key = generate_api_key()
|
||||
|
||||
db = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
|
||||
try:
|
||||
api_key = ApiKey(
|
||||
client_id=obj.id,
|
||||
name="auto",
|
||||
key_prefix=get_prefix(plain_key),
|
||||
key_hash=hash_api_key(plain_key),
|
||||
permissions=[],
|
||||
is_active=True,
|
||||
)
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
request.session["generated_api_key"] = {
|
||||
"client_id": obj.id,
|
||||
"client_name": obj.name,
|
||||
"key_id": api_key.id,
|
||||
"api_key": plain_key,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class ApiKeyAdmin(ModelView, model=ApiKey):
|
||||
column_list = [ApiKey.id, ApiKey.client_id, ApiKey.name, ApiKey.is_active, ApiKey.permissions]
|
||||
form_excluded_columns = [ApiKey.key_hash, ApiKey.key_prefix, ApiKey.created_at]
|
||||
|
||||
form_extra_fields = {
|
||||
"plain_key": StringField("Plain Key", validators=[Optional()]),
|
||||
"permissions_csv": StringField("Permissions (comma)", validators=[Optional()]),
|
||||
"endpoint_path": SelectField("Endpoint", choices=[], validators=[Optional()]),
|
||||
"perm_read": BooleanField("Read (GET)"),
|
||||
"perm_write": BooleanField("Write (POST/PATCH)"),
|
||||
"perm_delete": BooleanField("Delete (DELETE)"),
|
||||
}
|
||||
|
||||
async def on_model_change(self, data: dict, model: ApiKey, is_created: bool, request: Request) -> None:
|
||||
plain_key = data.get("plain_key")
|
||||
if not plain_key and is_created:
|
||||
plain_key = generate_api_key()
|
||||
|
||||
if plain_key:
|
||||
model.key_prefix = get_prefix(plain_key)
|
||||
model.key_hash = hash_api_key(plain_key)
|
||||
|
||||
if is_created:
|
||||
request.state.generated_api_key_plain = plain_key
|
||||
|
||||
permissions: list[str] = []
|
||||
endpoint_path = data.get("endpoint_path")
|
||||
if endpoint_path:
|
||||
if data.get("perm_read"):
|
||||
permissions.append(f"{endpoint_path}:read")
|
||||
if data.get("perm_write"):
|
||||
permissions.append(f"{endpoint_path}:write")
|
||||
if data.get("perm_delete"):
|
||||
permissions.append(f"{endpoint_path}:delete")
|
||||
|
||||
permissions_csv = data.get("permissions_csv")
|
||||
if permissions_csv is not None:
|
||||
perms = [p.strip() for p in permissions_csv.split(",") if p.strip()]
|
||||
model.permissions = perms
|
||||
permissions.extend(perms)
|
||||
|
||||
if permissions:
|
||||
seen: set[str] = set()
|
||||
deduped: list[str] = []
|
||||
for p in permissions:
|
||||
if p not in seen:
|
||||
seen.add(p)
|
||||
deduped.append(p)
|
||||
model.permissions = deduped
|
||||
|
||||
async def after_model_change(self, data: dict, model: ApiKey, is_created: bool, request: Request) -> None:
|
||||
if not is_created:
|
||||
return
|
||||
|
||||
plain_key = getattr(request.state, "generated_api_key_plain", None)
|
||||
if not plain_key:
|
||||
return
|
||||
|
||||
request.session["generated_api_key"] = {
|
||||
"client_id": model.client_id,
|
||||
"client_name": str(getattr(model, "client", "")) if getattr(model, "client", None) else "",
|
||||
"key_id": model.id,
|
||||
"api_key": plain_key,
|
||||
}
|
||||
|
||||
|
||||
def mount_admin(app):
|
||||
auth_backend = AdminAuth(secret_key=settings.ADMIN_SECRET_KEY)
|
||||
admin = Admin(app=app, engine=engine, authentication_backend=auth_backend)
|
||||
|
||||
class CustomAdmin(Admin):
|
||||
def get_save_redirect_url(
|
||||
self, request: Request, form, model_view: ModelView, obj
|
||||
):
|
||||
if (
|
||||
getattr(model_view, "model", None) in (ApiClient, ApiKey)
|
||||
and request.session.get("generated_api_key")
|
||||
):
|
||||
root_path = request.scope.get("root_path") or ""
|
||||
return URL(f"{root_path}/admin/generated-api-key")
|
||||
|
||||
return super().get_save_redirect_url(
|
||||
request=request,
|
||||
form=form,
|
||||
model_view=model_view,
|
||||
obj=obj,
|
||||
)
|
||||
|
||||
admin = CustomAdmin(
|
||||
app=app,
|
||||
engine=engine,
|
||||
authentication_backend=auth_backend,
|
||||
title="My Service Management",
|
||||
base_url="/admin",
|
||||
)
|
||||
|
||||
openapi = app.openapi()
|
||||
paths = openapi.get("paths") or {}
|
||||
endpoint_choices: list[tuple[str, str]] = []
|
||||
for path in sorted(paths.keys()):
|
||||
if not path.startswith("/api/"):
|
||||
continue
|
||||
methods = paths.get(path) or {}
|
||||
available = sorted([m.upper() for m in methods.keys()])
|
||||
label = f"{path} [{' '.join(available)}]" if available else path
|
||||
endpoint_choices.append((path, label))
|
||||
ApiKeyAdmin.form_extra_fields["endpoint_path"].kwargs["choices"] = endpoint_choices
|
||||
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
|
||||
admin.add_view(ApiClientAdmin)
|
||||
admin.add_view(ApiKeyAdmin)
|
||||
|
||||
@app.get("/admin")
|
||||
async def _admin_redirect(request: Request):
|
||||
@app.get("/admin/generated-api-key")
|
||||
async def _admin_generated_api_key(request: Request):
|
||||
if not request.session.get("admin"):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
|
||||
key_info = request.session.pop("generated_api_key", None)
|
||||
root_path = request.scope.get("root_path") or ""
|
||||
return RedirectResponse(url=f"{root_path}/admin/")
|
||||
clients_url = f"{root_path}/admin/{ApiClientAdmin.identity}/list"
|
||||
|
||||
if not key_info:
|
||||
return HTMLResponse(
|
||||
f"<h2>No API key to display</h2><p>The API key was already shown or expired.</p><p><a href=\"{clients_url}\">Back to clients</a></p>",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
client_name = key_info.get("client_name", "")
|
||||
client_id = key_info.get("client_id", "")
|
||||
key_id = key_info.get("key_id", "")
|
||||
api_key = key_info.get("api_key", "")
|
||||
|
||||
return HTMLResponse(
|
||||
(
|
||||
"<h2>API key generated</h2>"
|
||||
"<p>Copy this API key now. You won't be able to view it again.</p>"
|
||||
f"<p><b>Client</b>: {client_name} (ID: {client_id})</p>"
|
||||
f"<p><b>Key ID</b>: {key_id}</p>"
|
||||
f"<pre style=\"padding:12px;border:1px solid #ddd;background:#f7f7f7;\">{api_key}</pre>"
|
||||
f"<p><a href=\"{clients_url}\">Back to clients</a></p>"
|
||||
),
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
@app.get("/admin/clients/{client_id}/generate-api-key")
|
||||
async def _admin_generate_api_key_get(
|
||||
request: Request,
|
||||
client_id: int,
|
||||
permissions: str = "",
|
||||
name: str | None = None,
|
||||
):
|
||||
if not request.session.get("admin"):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
|
||||
perms = [p.strip() for p in permissions.split(",") if p.strip()]
|
||||
plain_key = generate_api_key()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
client = db.get(ApiClient, client_id)
|
||||
if not client:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Client not found")
|
||||
|
||||
api_key = ApiKey(
|
||||
client_id=client_id,
|
||||
name=name,
|
||||
key_prefix=get_prefix(plain_key),
|
||||
key_hash=hash_api_key(plain_key),
|
||||
permissions=perms,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
return {"key_id": api_key.id, "api_key": plain_key, "permissions": perms}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@app.post("/admin/api-keys/generate")
|
||||
async def _admin_generate_api_key(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
@@ -10,12 +11,15 @@ from sqlalchemy.orm import Session
|
||||
from app.api.v1.schemas import FeedCheckpointIn
|
||||
from app.core.config import settings
|
||||
from app.db.models import RawOpdCheckpoint
|
||||
from app.security.dependencies import get_db, require_permission
|
||||
from app.security.dependencies import get_db, get_supabase_db, require_permission
|
||||
from app.utils.supabase_client import SupabaseAPIError, upsert_to_supabase_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
|
||||
PERM_FEED_CHECKPOINT_WRITE = "feed.checkpoint:write"
|
||||
PERM_FEED_CHECKPOINT_WRITE = "/api/v1/feed/checkpoint:write"
|
||||
PERM_FEED_CHECKPOINT_WRITE_LEGACY = "feed.checkpoint:write"
|
||||
|
||||
|
||||
def _to_tz(dt):
|
||||
@@ -26,33 +30,59 @@ def _to_tz(dt):
|
||||
return dt.astimezone(ZoneInfo(settings.TIMEZONE))
|
||||
|
||||
|
||||
def _to_iso(dt):
|
||||
"""Convert datetime to ISO 8601 string for Supabase API."""
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
@router.post("/feed/checkpoint")
|
||||
def upsert_feed_checkpoint(
|
||||
payload: list[FeedCheckpointIn],
|
||||
_: Annotated[object, Depends(require_permission(PERM_FEED_CHECKPOINT_WRITE))],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
):
|
||||
rows = []
|
||||
supabase_rows = []
|
||||
|
||||
#clean_data = payload.model_dump(exclude_none=True)
|
||||
for item in payload:
|
||||
rows.append(
|
||||
{
|
||||
"id": item.id,
|
||||
"hn": item.hn,
|
||||
"vn": item.vn,
|
||||
"location": item.location,
|
||||
"type": item.type,
|
||||
"timestamp_in": _to_tz(item.timestamp_in),
|
||||
"timestamp_out": _to_tz(item.timestamp_out),
|
||||
"waiting_time": item.waiting_time,
|
||||
"bu": item.bu,
|
||||
}
|
||||
)
|
||||
# Prepare data for local database 'default' if item.id is None else
|
||||
row = {
|
||||
"id": item.id,
|
||||
"hn": item.hn,
|
||||
"vn": item.vn,
|
||||
"location": item.location,
|
||||
"type": item.type,
|
||||
"timestamp_in": _to_tz(item.timestamp_in),
|
||||
"timestamp_out": _to_tz(item.timestamp_out),
|
||||
"waiting_time": item.waiting_time,
|
||||
"bu": item.bu,
|
||||
}
|
||||
if item.id is None:
|
||||
del(row["id"])
|
||||
rows.append(row)
|
||||
|
||||
# Prepare data for Supabase API (convert datetime to ISO string) 'default' if item.id is None else
|
||||
supabase_row = {
|
||||
"id": item.id,
|
||||
"hn": item.hn,
|
||||
"vn": item.vn,
|
||||
"location": item.location,
|
||||
"type": item.type,
|
||||
"timestamp_in": _to_iso(_to_tz(item.timestamp_in)),
|
||||
"timestamp_out": _to_iso(_to_tz(item.timestamp_out)),
|
||||
"waiting_time": item.waiting_time,
|
||||
"bu": item.bu,
|
||||
}
|
||||
if item.id is None:
|
||||
del(supabase_row["id"])
|
||||
supabase_rows.append(supabase_row)
|
||||
|
||||
# Insert/update to local database
|
||||
stmt = insert(RawOpdCheckpoint).values(rows)
|
||||
update_cols = {
|
||||
"hn": stmt.excluded.hn,
|
||||
"vn": stmt.excluded.vn,
|
||||
"location": stmt.excluded.location,
|
||||
"id": stmt.excluded.id,
|
||||
"type": stmt.excluded.type,
|
||||
"timestamp_in": stmt.excluded.timestamp_in,
|
||||
"timestamp_out": stmt.excluded.timestamp_out,
|
||||
@@ -60,8 +90,38 @@ def upsert_feed_checkpoint(
|
||||
"bu": stmt.excluded.bu,
|
||||
}
|
||||
|
||||
stmt = stmt.on_conflict_do_update(index_elements=[RawOpdCheckpoint.id], set_=update_cols)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[RawOpdCheckpoint.hn, RawOpdCheckpoint.vn, RawOpdCheckpoint.location, RawOpdCheckpoint.timestamp_in],
|
||||
set_=update_cols,
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
db.commit()
|
||||
|
||||
return {"upserted": len(rows), "rowcount": result.rowcount}
|
||||
# Send data to Supabase via API call
|
||||
supabase_result = None
|
||||
supabase_error = None
|
||||
|
||||
try:
|
||||
logger.info(f"Sending {len(supabase_rows)} records to Supabase API")
|
||||
supabase_result = upsert_to_supabase_sync(
|
||||
table="raw_opd_checkpoint",
|
||||
data=supabase_rows,
|
||||
on_conflict="hn,vn,location,timestamp_in",
|
||||
)
|
||||
logger.info(f"Successfully sent data to Supabase: {supabase_result.get('status_code')}")
|
||||
except SupabaseAPIError as e:
|
||||
logger.error(f"Failed to send data to Supabase: {str(e)}")
|
||||
supabase_error = str(e)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error sending data to Supabase: {str(e)}")
|
||||
supabase_error = f"Unexpected error: {str(e)}"
|
||||
|
||||
return {
|
||||
"upserted": len(rows),
|
||||
"rowcount": result.rowcount,
|
||||
"supabase": {
|
||||
"success": supabase_result is not None,
|
||||
"result": supabase_result,
|
||||
"error": supabase_error,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class FeedCheckpointIn(BaseModel):
|
||||
id: int
|
||||
id: int | None = None
|
||||
hn: int
|
||||
vn: int
|
||||
location: str
|
||||
|
||||
@@ -11,7 +11,17 @@ class Settings(BaseSettings):
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
DB_NAME: str
|
||||
DB_SSLMODE: str = "prefer"
|
||||
DB_SSLMODE: str = "disable"
|
||||
|
||||
SUPABASE_DB_HOST: str
|
||||
SUPABASE_DB_PORT: int = 5432
|
||||
SUPABASE_DB_USER: str
|
||||
SUPABASE_DB_PASSWORD: str
|
||||
SUPABASE_DB_NAME: str
|
||||
SUPABASE_DB_SSLMODE: str = "disable"
|
||||
|
||||
SUPABASE_API_URL: str
|
||||
SUPABASE_API_KEY: str
|
||||
|
||||
ROOT_PATH: str = ""
|
||||
|
||||
|
||||
@@ -18,4 +18,18 @@ def build_db_url() -> str:
|
||||
)
|
||||
|
||||
|
||||
def build_supabase_db_url() -> str:
|
||||
user = quote_plus(settings.SUPABASE_DB_USER)
|
||||
password = quote_plus(settings.SUPABASE_DB_PASSWORD)
|
||||
host = settings.SUPABASE_DB_HOST
|
||||
port = settings.SUPABASE_DB_PORT
|
||||
db = quote_plus(settings.SUPABASE_DB_NAME)
|
||||
|
||||
return (
|
||||
f"postgresql+psycopg://{user}:{password}@{host}:{port}/{db}"
|
||||
f"?sslmode={quote_plus(settings.SUPABASE_DB_SSLMODE)}"
|
||||
)
|
||||
|
||||
|
||||
engine = create_engine(build_db_url(), pool_pre_ping=True)
|
||||
supabase_engine = create_engine(build_supabase_db_url(), pool_pre_ping=True)
|
||||
|
||||
@@ -5,8 +5,9 @@ from app.db.engine import engine
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("CREATE SCHEMA IF NOT EXISTS fastapi"))
|
||||
conn.execute(text("CREATE SCHEMA IF NOT EXISTS operationbi"))
|
||||
# with engine.begin() as conn:
|
||||
# conn.execute(text("CREATE SCHEMA IF NOT EXISTS fastapi"))
|
||||
# conn.execute(text("CREATE SCHEMA IF NOT EXISTS operationbi"))
|
||||
|
||||
Base.metadata.create_all(bind=conn)
|
||||
# Base.metadata.create_all(bind=conn)
|
||||
pass
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -11,7 +11,10 @@ from app.db.base import Base
|
||||
|
||||
class RawOpdCheckpoint(Base):
|
||||
__tablename__ = "raw_opd_checkpoint"
|
||||
__table_args__ = {"schema": "operationbi"}
|
||||
__table_args__ = (
|
||||
UniqueConstraint("hn", "vn", "location", name="uq_raw_opd_checkpoint_hn_vn_location"),
|
||||
{"schema": "rawdata"},
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
hn: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
@@ -38,6 +41,15 @@ class ApiClient(Base):
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
client_id = getattr(self, "id", None)
|
||||
if client_id is None:
|
||||
return self.name
|
||||
return f"{self.name} ({client_id})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_key"
|
||||
|
||||
@@ -1,21 +1,99 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
class ForceHTTPSMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request, call_next):
|
||||
# บังคับให้ FastAPI มองว่า Request ที่เข้ามาเป็น HTTPS เสมอ
|
||||
# เพื่อให้ url_for() เจนลิงก์ CSS/JS เป็น https://
|
||||
request.scope["scheme"] = "https"
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
class ForwardedProtoMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] in {"http", "websocket"}:
|
||||
headers = Headers(scope=scope)
|
||||
forwarded_proto = headers.get("x-forwarded-proto")
|
||||
if forwarded_proto:
|
||||
proto = forwarded_proto.split(",", 1)[0].strip()
|
||||
if proto:
|
||||
new_scope = dict(scope)
|
||||
new_scope["scheme"] = proto
|
||||
return await self.app(new_scope, receive, send)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
# class RootPathStripMiddleware:
|
||||
# def __init__(self, app, prefix: str):
|
||||
# self.app = app
|
||||
# self.prefix = (prefix or "").rstrip("/")
|
||||
|
||||
# async def __call__(self, scope, receive, send):
|
||||
# if scope["type"] in {"http", "websocket"} and self.prefix:
|
||||
# path = scope.get("path") or ""
|
||||
# new_scope = dict(scope)
|
||||
# new_scope["root_path"] = self.prefix
|
||||
|
||||
# if path == self.prefix or path.startswith(self.prefix + "/"):
|
||||
# new_path = path[len(self.prefix) :]
|
||||
# new_scope["path"] = new_path if new_path else "/"
|
||||
|
||||
# return await self.app(new_scope, receive, send)
|
||||
|
||||
# return await self.app(scope, receive, send)
|
||||
|
||||
from app.admin import mount_admin
|
||||
from app.api.v1.routes import router as v1_router
|
||||
from app.core.config import settings
|
||||
from app.db.init_db import init_db
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqladmin import Admin
|
||||
import os
|
||||
import sqladmin
|
||||
|
||||
# รายชื่อ Origins ที่อนุญาตให้ยิง API มาหาเราได้
|
||||
origins = [
|
||||
"http://localhost:80400", # สำหรับตอนพัฒนา Frontend
|
||||
"https://ai.sriphat.com", # Domain หลักของคุณ
|
||||
"http://ai.sriphat.com",
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
init_db()
|
||||
yield
|
||||
|
||||
print(settings.ROOT_PATH, flush=True)
|
||||
|
||||
sqladmin_dir = os.path.dirname(sqladmin.__file__)
|
||||
statics_path = os.path.join(sqladmin_dir, "statics")
|
||||
|
||||
app = FastAPI(title=settings.APP_NAME, root_path=settings.ROOT_PATH, lifespan=lifespan)
|
||||
#if settings.ROOT_PATH:
|
||||
# app.add_middleware(RootPathStripMiddleware, prefix=settings.ROOT_PATH)
|
||||
app.add_middleware(ForceHTTPSMiddleware)
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.ADMIN_SECRET_KEY)
|
||||
app.add_middleware(ForwardedProtoMiddleware)
|
||||
app.include_router(v1_router)
|
||||
app.mount("/admin/statics", StaticFiles(directory=statics_path), name="admin_statics")
|
||||
app.mount("/apiservice/admin/statics", StaticFiles(directory=statics_path), name="proxy_admin_statics")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins, # หรือ ["*"] ถ้าต้องการอนุญาตทั้งหมด (ไม่แนะนำใน production)
|
||||
allow_credentials=True, # สำคัญมาก! ต้องเป็น True ถ้าหน้า Admin/API มีการใช้ Cookies/Sessions
|
||||
allow_methods=["*"], # อนุญาตทุก HTTP Method (GET, POST, PUT, DELETE, etc.)
|
||||
allow_headers=["*"], # อนุญาตทุก Headers
|
||||
)
|
||||
|
||||
mount_admin(app)
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from typing import Annotated
|
||||
from collections.abc import Sequence
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.db.engine import engine
|
||||
from app.db.engine import engine, supabase_engine
|
||||
from app.db.models import ApiKey
|
||||
from app.security.api_key import get_prefix, verify_api_key
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
SupabaseSessionLocal = sessionmaker(bind=supabase_engine, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
def get_db():
|
||||
@@ -20,6 +22,14 @@ def get_db():
|
||||
db.close()
|
||||
|
||||
|
||||
def get_supabase_db():
|
||||
db = SupabaseSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_bearer_token(request: Request) -> str:
|
||||
auth = request.headers.get("authorization")
|
||||
if not auth:
|
||||
@@ -32,7 +42,7 @@ def get_bearer_token(request: Request) -> str:
|
||||
return parts[1].strip()
|
||||
|
||||
|
||||
def require_permission(permission: str):
|
||||
def require_permission(permission: str | Sequence[str]):
|
||||
def _dep(
|
||||
token: Annotated[str, Depends(get_bearer_token)],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
@@ -46,7 +56,9 @@ def require_permission(permission: str):
|
||||
if not verify_api_key(token, api_key.key_hash):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
|
||||
|
||||
if permission not in (api_key.permissions or []):
|
||||
allowed = set(api_key.permissions or [])
|
||||
required = [permission] if isinstance(permission, str) else list(permission)
|
||||
if not any(p in allowed for p in required):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
|
||||
|
||||
return api_key
|
||||
|
||||
Reference in New Issue
Block a user