auth: add session lifecycle (create/lookup/refresh/revoke/purge)
This commit is contained in:
parent
f4dd7a507f
commit
0b7b58efe9
@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text, TypeDecorator
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
@ -17,6 +17,32 @@ def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class UTCDateTime(TypeDecorator):
|
||||
"""DateTime that always returns tz-aware UTC values.
|
||||
|
||||
SQLite (used in tests) does not preserve tzinfo on roundtrip even with
|
||||
``DateTime(timezone=True)``. This decorator normalises stored and loaded
|
||||
values to UTC-aware datetimes so app code can rely on tz arithmetic.
|
||||
"""
|
||||
|
||||
impl = DateTime(timezone=True)
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
@ -29,8 +55,8 @@ class User(Base):
|
||||
password_hash: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
@ -40,9 +66,9 @@ class UserSession(Base):
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
last_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(UTCDateTime(), nullable=False, index=True)
|
||||
last_seen_at: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False)
|
||||
ip: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
remember: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
@ -52,7 +78,7 @@ class AuthAudit(Base):
|
||||
__tablename__ = "auth_audit"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
ts: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow, nullable=False, index=True)
|
||||
ts: Mapped[datetime] = mapped_column(UTCDateTime(), default=_utcnow, nullable=False, index=True)
|
||||
user_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
72
containers/clearview/src/clearview_app/auth/sessions.py
Normal file
72
containers/clearview/src/clearview_app/auth/sessions.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Session lifecycle: create, look up + refresh, revoke, purge expired."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import UserSession
|
||||
from .security import new_session_id
|
||||
|
||||
SLIDING_TTL = timedelta(hours=8)
|
||||
REMEMBER_TTL = timedelta(days=30)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def create_session(
|
||||
db: Session,
|
||||
*,
|
||||
user_id: int,
|
||||
remember: bool,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> tuple[str, datetime]:
|
||||
ttl = REMEMBER_TTL if remember else SLIDING_TTL
|
||||
expires = _utcnow() + ttl
|
||||
sid = new_session_id()
|
||||
db.add(
|
||||
UserSession(
|
||||
id=sid,
|
||||
user_id=user_id,
|
||||
expires_at=expires,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
remember=remember,
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
return sid, expires
|
||||
|
||||
|
||||
def lookup_and_refresh(db: Session, sid: str | None) -> UserSession | None:
|
||||
if not sid:
|
||||
return None
|
||||
row = db.get(UserSession, sid)
|
||||
if row is None:
|
||||
return None
|
||||
now = _utcnow()
|
||||
expires = row.expires_at if row.expires_at.tzinfo else row.expires_at.replace(tzinfo=timezone.utc)
|
||||
if expires <= now:
|
||||
return None
|
||||
row.last_seen_at = now
|
||||
if not row.remember:
|
||||
row.expires_at = now + SLIDING_TTL
|
||||
return row
|
||||
|
||||
|
||||
def revoke(db: Session, sid: str) -> None:
|
||||
db.execute(delete(UserSession).where(UserSession.id == sid))
|
||||
|
||||
|
||||
def revoke_all_for_user(db: Session, user_id: int) -> int:
|
||||
res = db.execute(delete(UserSession).where(UserSession.user_id == user_id))
|
||||
return res.rowcount or 0
|
||||
|
||||
|
||||
def purge_expired(db: Session) -> int:
|
||||
res = db.execute(delete(UserSession).where(UserSession.expires_at <= _utcnow()))
|
||||
return res.rowcount or 0
|
||||
79
containers/clearview/tests/test_sessions.py
Normal file
79
containers/clearview/tests/test_sessions.py
Normal file
@ -0,0 +1,79 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from clearview_app.auth import sessions as S
|
||||
from clearview_app.auth.models import User, UserSession
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def user(db_session):
|
||||
u = User(username="alice", password_hash="x", role="admin")
|
||||
db_session.add(u); db_session.commit(); db_session.refresh(u)
|
||||
return u
|
||||
|
||||
|
||||
def test_create_session_sliding(db_session, user):
|
||||
sid, expires = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
|
||||
db_session.commit()
|
||||
assert len(sid) == 32
|
||||
row = db_session.get(UserSession, sid)
|
||||
assert row.remember is False
|
||||
delta = row.expires_at - datetime.now(timezone.utc)
|
||||
assert timedelta(hours=7, minutes=55) < delta < timedelta(hours=8, minutes=5)
|
||||
|
||||
|
||||
def test_create_session_remember(db_session, user):
|
||||
sid, _ = S.create_session(db_session, user_id=user.id, remember=True, ip=None, user_agent=None)
|
||||
db_session.commit()
|
||||
row = db_session.get(UserSession, sid)
|
||||
delta = row.expires_at - datetime.now(timezone.utc)
|
||||
assert delta > timedelta(days=29)
|
||||
|
||||
|
||||
def test_lookup_refresh_sliding_extends(db_session, user):
|
||||
sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
|
||||
db_session.commit()
|
||||
row = db_session.get(UserSession, sid)
|
||||
row.expires_at = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
db_session.commit()
|
||||
looked = S.lookup_and_refresh(db_session, sid)
|
||||
db_session.commit()
|
||||
assert looked is not None
|
||||
assert looked.expires_at - datetime.now(timezone.utc) > timedelta(hours=7)
|
||||
|
||||
|
||||
def test_lookup_refresh_remember_does_not_slide(db_session, user):
|
||||
sid, _ = S.create_session(db_session, user_id=user.id, remember=True, ip=None, user_agent=None)
|
||||
db_session.commit()
|
||||
before = db_session.get(UserSession, sid).expires_at
|
||||
S.lookup_and_refresh(db_session, sid)
|
||||
db_session.commit()
|
||||
after = db_session.get(UserSession, sid).expires_at
|
||||
assert before == after
|
||||
|
||||
|
||||
def test_expired_session_returns_none(db_session, user):
|
||||
sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
|
||||
row = db_session.get(UserSession, sid)
|
||||
row.expires_at = datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
db_session.commit()
|
||||
assert S.lookup_and_refresh(db_session, sid) is None
|
||||
|
||||
|
||||
def test_revoke(db_session, user):
|
||||
sid, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
|
||||
db_session.commit()
|
||||
S.revoke(db_session, sid); db_session.commit()
|
||||
assert db_session.get(UserSession, sid) is None
|
||||
|
||||
|
||||
def test_purge_expired(db_session, user):
|
||||
fresh, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
|
||||
stale, _ = S.create_session(db_session, user_id=user.id, remember=False, ip=None, user_agent=None)
|
||||
db_session.get(UserSession, stale).expires_at = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
db_session.commit()
|
||||
removed = S.purge_expired(db_session); db_session.commit()
|
||||
assert removed == 1
|
||||
assert db_session.get(UserSession, fresh) is not None
|
||||
assert db_session.get(UserSession, stale) is None
|
||||
Loading…
Reference in New Issue
Block a user