#!/usr/bin/python3
"""Privileged overlay mount helper for left4me.

Invoked via sudo by the left4me runtime user. Validates inputs strictly,
then enters PID 1's mount namespace via nsenter to perform the actual
mount/umount syscall, so the resulting mount lives in the host namespace
and is visible to the systemd-managed gameserver units.

Verbs:
    mount <name>    Reads ${LEFT4ME_ROOT}/instances/<name>/instance.env
                    for L4D2_LOWERDIRS, validates every lowerdir is
                    under one of installation/overlays/workshop_cache/
                    global_overlay_cache, then mounts the kernel
                    overlay at runtime/<name>/merged.
    umount <name>   Unmounts runtime/<name>/merged.

Set LEFT4ME_OVERLAY_PRINT_ONLY=1 to print the would-be argv (one line,
shell-quoted) and exit 0 instead of execv. Used by tests.
"""

import os
import re
import shlex
import sys
from pathlib import Path

NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
DEFAULT_ROOT = "/var/lib/left4me"
LOWERDIR_ALLOWLIST = (
    "installation",
    "overlays",
    "global_overlay_cache",
    "workshop_cache",
)
MAX_LOWERDIRS = 500
NSENTER = "/usr/bin/nsenter"
MOUNT_BIN = "/bin/mount"
UMOUNT_BIN = "/bin/umount"


def die(msg: str) -> None:
    sys.stderr.write(f"left4me-overlay: {msg}\n")
    sys.exit(1)


def root() -> Path:
    return Path(os.environ.get("LEFT4ME_ROOT") or DEFAULT_ROOT)


def validate_name(name: str) -> str:
    if not NAME_RE.fullmatch(name):
        die(f"invalid instance name: {name!r}")
    return name


def parse_lowerdirs(env_path: Path) -> list[str]:
    if not env_path.is_file():
        die(f"instance.env not found: {env_path}")
    raw = None
    for line in env_path.read_text().splitlines():
        if "=" not in line:
            continue
        key, value = line.split("=", 1)
        if key.strip() == "L4D2_LOWERDIRS":
            raw = value
            break
    if raw is None:
        die(f"L4D2_LOWERDIRS not set in {env_path}")
    if raw == "":
        die(f"L4D2_LOWERDIRS is empty in {env_path}")
    parts = raw.split(":")
    if any(p == "" for p in parts):
        die(f"L4D2_LOWERDIRS contains an empty entry: {raw!r}")
    if len(parts) > MAX_LOWERDIRS:
        die(f"L4D2_LOWERDIRS has {len(parts)} entries (cap {MAX_LOWERDIRS})")
    return parts


def canonical_under(allowed_roots: list[Path], path: Path) -> Path:
    try:
        canonical = path.resolve(strict=True)
    except (FileNotFoundError, RuntimeError):
        die(f"path does not exist or has a symlink loop: {path}")
    for r in allowed_roots:
        if canonical == r or r in canonical.parents:
            return canonical
    die(f"path is outside the permitted roots: {path} (resolved: {canonical})")


_LISTXATTR = getattr(os, "listxattr", None)


def _entry_has_fuse_xattr(path: str) -> str | None:
    if _LISTXATTR is None:
        return None
    try:
        attrs = _LISTXATTR(path, follow_symlinks=False)
    except OSError:
        return None
    for a in attrs:
        if a.startswith("user.fuseoverlayfs."):
            return a
    return None


def assert_no_fuse_xattrs(upper: Path) -> None:
    if not upper.exists() or _LISTXATTR is None:
        return
    for dirpath, dirnames, filenames in os.walk(upper):
        for entry in (dirpath, *(os.path.join(dirpath, n) for n in dirnames),
                      *(os.path.join(dirpath, n) for n in filenames)):
            tainted = _entry_has_fuse_xattr(entry)
            if tainted:
                die(
                    f"upperdir contains fuse-overlayfs xattr {tainted!r} on {entry}; "
                    "wipe upper/ and work/ before mounting"
                )


def exec_or_print(argv: list[str]) -> None:
    if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
        print(" ".join(shlex.quote(a) for a in argv))
        sys.exit(0)
    os.execv(argv[0], argv)


def cmd_mount(name: str) -> None:
    name = validate_name(name)
    r = root()
    instance_env = r / "instances" / name / "instance.env"
    raw_lowerdirs = parse_lowerdirs(instance_env)

    allowed_roots = [(r / sub).resolve() for sub in LOWERDIR_ALLOWLIST]
    canonical_lowerdirs = [str(canonical_under(allowed_roots, Path(p))) for p in raw_lowerdirs]

    runtime_name_dir = (r / "runtime" / name).resolve(strict=True)
    upper = (runtime_name_dir / "upper").resolve(strict=True)
    work = (runtime_name_dir / "work").resolve(strict=True)
    merged = (runtime_name_dir / "merged").resolve(strict=True)
    for label, path in (("upper", upper), ("work", work), ("merged", merged)):
        if path.parent != runtime_name_dir:
            die(f"{label} resolved outside runtime/{name}: {path}")

    assert_no_fuse_xattrs(upper)

    options = f"lowerdir={':'.join(canonical_lowerdirs)},upperdir={upper},workdir={work}"
    argv = [
        NSENTER,
        "--mount=/proc/1/ns/mnt",
        "--",
        MOUNT_BIN,
        "-t", "overlay",
        "overlay",
        "-o", options,
        str(merged),
    ]
    exec_or_print(argv)


def cmd_umount(name: str) -> None:
    name = validate_name(name)
    r = root()
    runtime_name_dir = (r / "runtime" / name).resolve(strict=True)
    merged = (runtime_name_dir / "merged").resolve(strict=True)
    if merged.parent != runtime_name_dir:
        die(f"merged resolved outside runtime/{name}: {merged}")
    argv = [
        NSENTER,
        "--mount=/proc/1/ns/mnt",
        "--",
        UMOUNT_BIN,
        str(merged),
    ]
    exec_or_print(argv)


def main(argv: list[str]) -> None:
    if len(argv) != 3 or argv[1] not in ("mount", "umount"):
        sys.stderr.write("usage: left4me-overlay mount|umount <name>\n")
        sys.exit(2)
    if argv[1] == "mount":
        cmd_mount(argv[2])
    else:
        cmd_umount(argv[2])


if __name__ == "__main__":
    main(sys.argv)
