auth: add require_user / require_admin FastAPI dependencies
This commit is contained in:
parent
96879e75f0
commit
a8cb96aa61
52
containers/clearview/src/clearview_app/auth/dependencies.py
Normal file
52
containers/clearview/src/clearview_app/auth/dependencies.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""FastAPI dependencies that gate API endpoints behind a session."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..config import COOKIE_NAME # noqa: F401
|
||||
from ..db import SessionLocal
|
||||
from . import sessions as S
|
||||
from .models import User, UserSession
|
||||
|
||||
|
||||
AuthedUser = User
|
||||
|
||||
|
||||
def get_db():
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _load_session(db: Session, sid: str | None) -> tuple[User, UserSession]:
|
||||
session = S.lookup_and_refresh(db, sid)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
user = db.get(User, session.user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
db.commit()
|
||||
return user, session
|
||||
|
||||
|
||||
def require_user(
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
clearview_session: Annotated[str | None, Cookie()] = None,
|
||||
) -> User:
|
||||
user, _ = _load_session(db, clearview_session)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
clearview_session: Annotated[str | None, Cookie()] = None,
|
||||
) -> User:
|
||||
user, _ = _load_session(db, clearview_session)
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin required")
|
||||
return user
|
||||
@ -12,6 +12,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
SRC = Path(__file__).resolve().parents[1] / "src"
|
||||
sys.path.insert(0, str(SRC))
|
||||
@ -25,6 +26,7 @@ def db_engine():
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
future=True,
|
||||
)
|
||||
|
||||
|
||||
91
containers/clearview/tests/test_dependencies.py
Normal file
91
containers/clearview/tests/test_dependencies.py
Normal file
@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from clearview_app.auth import sessions as S
|
||||
from clearview_app.auth.dependencies import (
|
||||
AuthedUser,
|
||||
get_db,
|
||||
require_admin,
|
||||
require_user,
|
||||
)
|
||||
from clearview_app.auth.models import User
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_and_client(db_engine):
|
||||
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False, future=True)
|
||||
|
||||
def override_get_db():
|
||||
s = Session()
|
||||
try:
|
||||
yield s
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/who")
|
||||
def who(u: AuthedUser = Depends(require_user)):
|
||||
return {"id": u.id, "role": u.role}
|
||||
|
||||
@app.get("/admin-only")
|
||||
def admin_only(u: AuthedUser = Depends(require_admin)):
|
||||
return {"ok": True}
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
return app, Session
|
||||
|
||||
|
||||
def _make_user(Session, role: str, username: str = "x"):
|
||||
s = Session()
|
||||
u = User(username=username, password_hash="h", role=role)
|
||||
s.add(u); s.commit(); s.refresh(u); s.close()
|
||||
return u
|
||||
|
||||
|
||||
def _login(Session, user_id: int) -> str:
|
||||
s = Session()
|
||||
sid, _ = S.create_session(s, user_id=user_id, remember=False, ip=None, user_agent=None)
|
||||
s.commit(); s.close()
|
||||
return sid
|
||||
|
||||
|
||||
def test_anon_gets_401(app_and_client):
|
||||
app, _ = app_and_client
|
||||
assert TestClient(app).get("/who").status_code == 401
|
||||
|
||||
|
||||
def test_user_can_access_require_user(app_and_client):
|
||||
app, Session = app_and_client
|
||||
u = _make_user(Session, "user")
|
||||
sid = _login(Session, u.id)
|
||||
c = TestClient(app); c.cookies.set("clearview_session", sid)
|
||||
r = c.get("/who")
|
||||
assert r.status_code == 200 and r.json()["role"] == "user"
|
||||
|
||||
|
||||
def test_user_blocked_from_admin(app_and_client):
|
||||
app, Session = app_and_client
|
||||
u = _make_user(Session, "user")
|
||||
sid = _login(Session, u.id)
|
||||
c = TestClient(app); c.cookies.set("clearview_session", sid)
|
||||
assert c.get("/admin-only").status_code == 403
|
||||
|
||||
|
||||
def test_admin_allowed(app_and_client):
|
||||
app, Session = app_and_client
|
||||
u = _make_user(Session, "admin")
|
||||
sid = _login(Session, u.id)
|
||||
c = TestClient(app); c.cookies.set("clearview_session", sid)
|
||||
assert c.get("/admin-only").status_code == 200
|
||||
|
||||
|
||||
def test_inactive_user_rejected(app_and_client):
|
||||
app, Session = app_and_client
|
||||
u = _make_user(Session, "admin")
|
||||
s = Session(); s.get(User, u.id).is_active = False; s.commit(); s.close()
|
||||
sid = _login(Session, u.id)
|
||||
c = TestClient(app); c.cookies.set("clearview_session", sid)
|
||||
assert c.get("/who").status_code == 401
|
||||
Loading…
Reference in New Issue
Block a user