from datetime import datetime from functools import wraps from typing import Callable, TypeVar from urllib.parse import quote, unquote from flask import abort, g, redirect, request, session from sqlalchemy import select from werkzeug.security import check_password_hash, generate_password_hash from l4d2web.db import session_scope from l4d2web.models import User F = TypeVar("F", bound=Callable) def hash_password(raw: str) -> str: return generate_password_hash(raw) def verify_password(raw: str, digest: str) -> bool: return check_password_hash(digest, raw) MIN_PASSWORD_LENGTH = 8 def validate_new_password(raw: str) -> str | None: if raw == "": return "password must not be empty" if len(raw) < MIN_PASSWORD_LENGTH: return f"password must be at least {MIN_PASSWORD_LENGTH} characters" return None def load_current_user() -> None: user_id = session.get("user_id") if user_id is None: g.user = None return with session_scope() as db: user = db.scalar(select(User).where(User.id == int(user_id))) if user is None or not user.active: g.user = None return marker = session.get("pw_changed_at") if marker is None: g.user = None return try: marker_dt = datetime.fromisoformat(marker) except ValueError: g.user = None return # user.password_changed_at comes back naive from SQLite; strip tz from the # marker so an aware-marker session (just stamped from an in-memory user) # compares cleanly with a freshly-loaded user row. if marker_dt.tzinfo is not None: marker_dt = marker_dt.replace(tzinfo=None) if marker_dt < user.password_changed_at: g.user = None return g.user = user def current_user() -> User | None: return getattr(g, "user", None) def login_user(user_id: int, password_changed_at) -> None: session["user_id"] = user_id session["pw_changed_at"] = password_changed_at.isoformat() def logout_user() -> None: session.pop("user_id", None) def is_safe_next(target: str | None) -> bool: if not target: return False if not target.startswith("/"): return False if target.startswith("//"): return False if "://" in target: return False if "\\" in target: return False decoded_target = unquote(target) if decoded_target.startswith("//"): return False if "://" in decoded_target: return False if "\\" in decoded_target: return False return True def login_redirect_for_current_request(): target = request.full_path.rstrip("?") if is_safe_next(target): return redirect(f"/login?next={quote(target, safe='/')}") return redirect("/login") def require_login(func: F) -> F: @wraps(func) def wrapper(*args, **kwargs): if current_user() is None: return login_redirect_for_current_request() return func(*args, **kwargs) return wrapper # type: ignore[return-value] def require_admin(func: F) -> F: @wraps(func) def wrapper(*args, **kwargs): user = current_user() if user is None: return login_redirect_for_current_request() if not user.admin: abort(403) return func(*args, **kwargs) return wrapper # type: ignore[return-value]