diff --git a/containers/clearview/src/clearview_app/auth/models.py b/containers/clearview/src/clearview_app/auth/models.py index 7317812..3736ca1 100644 --- a/containers/clearview/src/clearview_app/auth/models.py +++ b/containers/clearview/src/clearview_app/auth/models.py @@ -9,7 +9,7 @@ from __future__ import annotations from datetime import datetime, timezone from typing import Any -from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text, TypeDecorator from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -17,6 +17,32 @@ def _utcnow() -> datetime: return datetime.now(timezone.utc) +class UTCDateTime(TypeDecorator): + """DateTime that always returns tz-aware UTC values. + + SQLite (used in tests) does not preserve tzinfo on roundtrip even with + ``DateTime(timezone=True)``. This decorator normalises stored and loaded + values to UTC-aware datetimes so app code can rely on tz arithmetic. + """ + + impl = DateTime(timezone=True) + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + def process_result_value(self, value, dialect): + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + class Base(DeclarativeBase): pass @@ -29,8 +55,8 @@ class User(Base): password_hash: Mapped[str] = mapped_column(Text, nullable=False) role: Mapped[str] = mapped_column(String(16), nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False) + created_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False) class UserSession(Base): @@ -40,9 +66,9 @@ class UserSession(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False) - expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) - last_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False) + created_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False) + expires_at: Mapped[datetime] = mapped_column(UTCDateTime(), nullable=False, index=True) + last_seen_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False) ip: Mapped[str | None] = mapped_column(String(64), nullable=True) user_agent: Mapped[str | None] = mapped_column(Text, nullable=True) remember: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) @@ -52,7 +78,7 @@ class AuthAudit(Base): __tablename__ = "auth_audit" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - ts: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False, index=True) + ts: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False, index=True) user_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) diff --git a/containers/clearview/src/clearview_app/auth/sessions.py b/containers/clearview/src/clearview_app/auth/sessions.py new file mode 100644 index 0000000..2263f1f --- /dev/null +++ b/containers/clearview/src/clearview_app/auth/sessions.py @@ -0,0 +1,72 @@ +"""Session lifecycle: create, look up + refresh, revoke, purge expired.""" +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from sqlalchemy import delete +from sqlalchemy.orm import Session + +from .models import UserSession +from .security import new_session_id + +SLIDING_TTL = timedelta(hours=8) +REMEMBER_TTL = timedelta(days=30) + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +def create_session( + db: Session, + *, + user_id: int, + remember: bool, + ip: str | None, + user_agent: str | None, +) -> tuple[str, datetime]: + ttl = REMEMBER_TTL if remember else SLIDING_TTL + expires = _utcnow() + ttl + sid = new_session_id() + db.add( + UserSession( + id=sid, + user_id=user_id, + expires_at=expires, + ip=ip, + user_agent=user_agent, + remember=remember, + ) + ) + db.flush() + return sid, expires + + +def lookup_and_refresh(db: Session, sid: str | None) -> UserSession | None: + if not sid: + return None + row = db.get(UserSession, sid) + if row is None: + return None + now = _utcnow() + expires = row.expires_at if row.expires_at.tzinfo else row.expires_at.replace(tzinfo=timezone.utc) + if expires <= now: + return None + row.last_seen_at = now + if not row.remember: + row.expires_at = now + SLIDING_TTL + return row + + +def revoke(db: Session, sid: str) -> None: + db.execute(delete(UserSession).where(UserSession.id == sid)) + + +def revoke_all_for_user(db: Session, user_id: int) -> int: + res = db.execute(delete(UserSession).where(UserSession.user_id == user_id)) + return res.rowcount or 0 + + +def purge_expired(db: Session) -> int: + res = db.execute(delete(UserSession).where(UserSession.expires_at <= _utcnow())) + return res.rowcount or 0 diff --git a/containers/clearview/tests/test_sessions.py b/containers/clearview/tests/test_sessions.py new file mode 100644 index 0000000..3093ace --- /dev/null +++ b/containers/clearview/tests/test_sessions.py @@ -0,0 +1,79 @@ +from datetime import datetime, timedelta, timezone + +import pytest + +from clearview_app.auth import sessions as S +from clearview_app.auth.models import User, UserSession + + +@pytest.fixture() +def user(db_session): + u = User(username="alice", password_hash="x", role="admin") + db_session.add(u); db_session.commit(); db_session.refresh(u) + return u + + +def test_create_session_sliding(db_session, user): + sid, expires = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None) + db_session.commit() + assert len(sid) == 32 + row = db_session.get(UserSession, sid) + assert row.remember is False + delta = row.expires_at - datetime.now(timezone.utc) + assert timedelta(hours=7, minutes=55) < delta < timedelta(hours=8, minutes=5) + + +def test_create_session_remember(db_session, user): + sid, _ = S.create_session(db_session, user_id=user.id, remember=True, ip=None, user_agent=None) + db_session.commit() + row = db_session.get(UserSession, sid) + delta = row.expires_at - datetime.now(timezone.utc) + assert delta > timedelta(days=29) + + +def test_lookup_refresh_sliding_extends(db_session, user): + sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None) + db_session.commit() + row = db_session.get(UserSession, sid) + row.expires_at = datetime.now(timezone.utc) + timedelta(minutes=5) + db_session.commit() + looked = S.lookup_and_refresh(db_session, sid) + db_session.commit() + assert looked is not None + assert looked.expires_at - datetime.now(timezone.utc) > timedelta(hours=7) + + +def test_lookup_refresh_remember_does_not_slide(db_session, user): + sid, _ = S.create_session(db_session, user_id=user.id, remember=True, ip=None, user_agent=None) + db_session.commit() + before = db_session.get(UserSession, sid).expires_at + S.lookup_and_refresh(db_session, sid) + db_session.commit() + after = db_session.get(UserSession, sid).expires_at + assert before == after + + +def test_expired_session_returns_none(db_session, user): + sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None) + row = db_session.get(UserSession, sid) + row.expires_at = datetime.now(timezone.utc) - timedelta(minutes=1) + db_session.commit() + assert S.lookup_and_refresh(db_session, sid) is None + + +def test_revoke(db_session, user): + sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None) + db_session.commit() + S.revoke(db_session, sid); db_session.commit() + assert db_session.get(UserSession, sid) is None + + +def test_purge_expired(db_session, user): + fresh, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None) + stale, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None) + db_session.get(UserSession, stale).expires_at = datetime.now(timezone.utc) - timedelta(hours=1) + db_session.commit() + removed = S.purge_expired(db_session); db_session.commit() + assert removed == 1 + assert db_session.get(UserSession, fresh) is not None + assert db_session.get(UserSession, stale) is None