from functools import wraps from typing import Callable, TypeVar from urllib.parse import quote, unquote from flask import abort, g, redirect, request, session from sqlalchemy import select from werkzeug.security import check_password_hash, generate_password_hash from l4d2web.db import session_scope from l4d2web.models import User F = TypeVar("F", bound=Callable) def hash_password(raw: str) -> str: return generate_password_hash(raw) def verify_password(raw: str, digest: str) -> bool: return check_password_hash(digest, raw) def load_current_user() -> None: user_id = session.get("user_id") if user_id is None: g.user = None return with session_scope() as db: g.user = db.scalar(select(User).where(User.id == int(user_id))) def current_user() -> User | None: return getattr(g, "user", None) def login_user(user_id: int) -> None: session["user_id"] = user_id def logout_user() -> None: session.pop("user_id", None) def is_safe_next(target: str | None) -> bool: if not target: return False if not target.startswith("/"): return False if target.startswith("//"): return False if "://" in target: return False if "\\" in target: return False decoded_target = unquote(target) if decoded_target.startswith("//"): return False if "://" in decoded_target: return False if "\\" in decoded_target: return False return True def login_redirect_for_current_request(): target = request.full_path.rstrip("?") if is_safe_next(target): return redirect(f"/login?next={quote(target, safe='/')}") return redirect("/login") def require_login(func: F) -> F: @wraps(func) def wrapper(*args, **kwargs): if current_user() is None: return login_redirect_for_current_request() return func(*args, **kwargs) return wrapper # type: ignore[return-value] def require_admin(func: F) -> F: @wraps(func) def wrapper(*args, **kwargs): user = current_user() if user is None: return login_redirect_for_current_request() if not user.admin: abort(403) return func(*args, **kwargs) return wrapper # type: ignore[return-value]