auth: add session lifecycle (create/lookup/refresh/revoke/purge)

This commit is contained in:
Ivo Oskamp 2026-05-28 15:55:21 +02:00
parent f4dd7a507f
commit 0b7b58efe9
3 changed files with 184 additions and 7 deletions

View File

@ -9,7 +9,7 @@ from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text, TypeDecorator
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
@ -17,6 +17,32 @@ def _utcnow() -> datetime:
return datetime.now(timezone.utc)
class UTCDateTime(TypeDecorator):
"""DateTime that always returns tz-aware UTC values.
SQLite (used in tests) does not preserve tzinfo on roundtrip even with
``DateTime(timezone=True)``. This decorator normalises stored and loaded
values to UTC-aware datetimes so app code can rely on tz arithmetic.
"""
impl = DateTime(timezone=True)
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def process_result_value(self, value, dialect):
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
class Base(DeclarativeBase):
pass
@ -29,8 +55,8 @@ class User(Base):
password_hash: Mapped[str] = mapped_column(Text, nullable=False)
role: Mapped[str] = mapped_column(String(16), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
created_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
updated_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
class UserSession(Base):
@ -40,9 +66,9 @@ class UserSession(Base):
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
last_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
created_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
expires_at: Mapped[datetime] = mapped_column(UTCDateTime(), nullable=False, index=True)
last_seen_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
ip: Mapped[str | None] = mapped_column(String(64), nullable=True)
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
remember: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
@ -52,7 +78,7 @@ class AuthAudit(Base):
__tablename__ = "auth_audit"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
ts: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False, index=True)
ts: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False, index=True)
user_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)

View File

@ -0,0 +1,72 @@
"""Session lifecycle: create, look up + refresh, revoke, purge expired."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from sqlalchemy import delete
from sqlalchemy.orm import Session
from .models import UserSession
from .security import new_session_id
SLIDING_TTL = timedelta(hours=8)
REMEMBER_TTL = timedelta(days=30)
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
def create_session(
db: Session,
*,
user_id: int,
remember: bool,
ip: str | None,
user_agent: str | None,
) -> tuple[str, datetime]:
ttl = REMEMBER_TTL if remember else SLIDING_TTL
expires = _utcnow() + ttl
sid = new_session_id()
db.add(
UserSession(
id=sid,
user_id=user_id,
expires_at=expires,
ip=ip,
user_agent=user_agent,
remember=remember,
)
)
db.flush()
return sid, expires
def lookup_and_refresh(db: Session, sid: str | None) -> UserSession | None:
if not sid:
return None
row = db.get(UserSession, sid)
if row is None:
return None
now = _utcnow()
expires = row.expires_at if row.expires_at.tzinfo else row.expires_at.replace(tzinfo=timezone.utc)
if expires <= now:
return None
row.last_seen_at = now
if not row.remember:
row.expires_at = now + SLIDING_TTL
return row
def revoke(db: Session, sid: str) -> None:
db.execute(delete(UserSession).where(UserSession.id == sid))
def revoke_all_for_user(db: Session, user_id: int) -> int:
res = db.execute(delete(UserSession).where(UserSession.user_id == user_id))
return res.rowcount or 0
def purge_expired(db: Session) -> int:
res = db.execute(delete(UserSession).where(UserSession.expires_at <= _utcnow()))
return res.rowcount or 0

View File

@ -0,0 +1,79 @@
from datetime import datetime, timedelta, timezone
import pytest
from clearview_app.auth import sessions as S
from clearview_app.auth.models import User, UserSession
@pytest.fixture()
def user(db_session):
u = User(username="alice", password_hash="x", role="admin")
db_session.add(u); db_session.commit(); db_session.refresh(u)
return u
def test_create_session_sliding(db_session, user):
sid, expires = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
db_session.commit()
assert len(sid) == 32
row = db_session.get(UserSession, sid)
assert row.remember is False
delta = row.expires_at - datetime.now(timezone.utc)
assert timedelta(hours=7, minutes=55) < delta < timedelta(hours=8, minutes=5)
def test_create_session_remember(db_session, user):
sid, _ = S.create_session(db_session, user_id=user.id, remember=True, ip=None, user_agent=None)
db_session.commit()
row = db_session.get(UserSession, sid)
delta = row.expires_at - datetime.now(timezone.utc)
assert delta > timedelta(days=29)
def test_lookup_refresh_sliding_extends(db_session, user):
sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
db_session.commit()
row = db_session.get(UserSession, sid)
row.expires_at = datetime.now(timezone.utc) + timedelta(minutes=5)
db_session.commit()
looked = S.lookup_and_refresh(db_session, sid)
db_session.commit()
assert looked is not None
assert looked.expires_at - datetime.now(timezone.utc) > timedelta(hours=7)
def test_lookup_refresh_remember_does_not_slide(db_session, user):
sid, _ = S.create_session(db_session, user_id=user.id, remember=True, ip=None, user_agent=None)
db_session.commit()
before = db_session.get(UserSession, sid).expires_at
S.lookup_and_refresh(db_session, sid)
db_session.commit()
after = db_session.get(UserSession, sid).expires_at
assert before == after
def test_expired_session_returns_none(db_session, user):
sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
row = db_session.get(UserSession, sid)
row.expires_at = datetime.now(timezone.utc) - timedelta(minutes=1)
db_session.commit()
assert S.lookup_and_refresh(db_session, sid) is None
def test_revoke(db_session, user):
sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
db_session.commit()
S.revoke(db_session, sid); db_session.commit()
assert db_session.get(UserSession, sid) is None
def test_purge_expired(db_session, user):
fresh, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
stale, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
db_session.get(UserSession, stale).expires_at = datetime.now(timezone.utc) - timedelta(hours=1)
db_session.commit()
removed = S.purge_expired(db_session); db_session.commit()
assert removed == 1
assert db_session.get(UserSession, fresh) is not None
assert db_session.get(UserSession, stale) is None