#!/usr/bin/python3
"""Privileged overlay mount helper for left4me.

Invoked from the systemd unit's ExecStartPre / ExecStopPost via
`+/usr/bin/nsenter --mount=/proc/1/ns/mnt -- …`. The unit-level
nsenter is what makes this work: it runs the helper Python interpreter
inside PID 1's mount namespace. Without it, the `+` Exec prefix
removes the sandbox/credentials but does NOT detach from the unit's
per-service mount namespace, and the helper process itself would pin
that namespace alive — turning every umount into a multi-second EBUSY
race with the kernel's deferred namespace cleanup. With the unit-level
nsenter the helper has no such reference and umount succeeds first try.

Validates inputs strictly, then performs `mount -t overlay` /
`umount` directly — no internal nsenter, since the helper is already
running where the syscalls need to take effect.

Verbs:
    mount <name>    Reads ${LEFT4ME_ROOT}/instances/<name>/instance.env
                    for L4D2_LOWERDIRS, validates every lowerdir is
                    under one of installation/overlays/workshop_cache/
                    global_overlay_cache, then mounts the kernel
                    overlay at runtime/<name>/merged.
    umount <name>   Unmounts runtime/<name>/merged and cleans up the
                    kernel-overlayfs `work/work` orphan.

Set LEFT4ME_OVERLAY_PRINT_ONLY=1 to print the would-be argv (one line,
shell-quoted) and exit 0 instead of execv. Used by tests.
"""

import os
import re
import shlex
import shutil
import subprocess
import sys
from pathlib import Path

NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
DEFAULT_ROOT = "/var/lib/left4me"
LOWERDIR_ALLOWLIST = (
    "installation",
    "overlays",
    "global_overlay_cache",
    "workshop_cache",
)
MAX_LOWERDIRS = 500
MOUNT_BIN = "/bin/mount"
UMOUNT_BIN = "/bin/umount"


def die(msg: str) -> None:
    sys.stderr.write(f"left4me-overlay: {msg}\n")
    sys.exit(1)


def root() -> Path:
    return Path(os.environ.get("LEFT4ME_ROOT") or DEFAULT_ROOT)


def validate_name(name: str) -> str:
    if not NAME_RE.fullmatch(name):
        die(f"invalid instance name: {name!r}")
    return name


def parse_lowerdirs(env_path: Path) -> list[str]:
    if not env_path.is_file():
        die(f"instance.env not found: {env_path}")
    raw = None
    for line in env_path.read_text().splitlines():
        if "=" not in line:
            continue
        key, value = line.split("=", 1)
        if key.strip() == "L4D2_LOWERDIRS":
            raw = value
            break
    if raw is None:
        die(f"L4D2_LOWERDIRS not set in {env_path}")
    if raw == "":
        die(f"L4D2_LOWERDIRS is empty in {env_path}")
    parts = raw.split(":")
    if any(p == "" for p in parts):
        die(f"L4D2_LOWERDIRS contains an empty entry: {raw!r}")
    if len(parts) > MAX_LOWERDIRS:
        die(f"L4D2_LOWERDIRS has {len(parts)} entries (cap {MAX_LOWERDIRS})")
    return parts


def canonical_under(allowed_roots: list[Path], path: Path) -> Path:
    try:
        canonical = path.resolve(strict=True)
    except (FileNotFoundError, RuntimeError):
        die(f"path does not exist or has a symlink loop: {path}")
    for r in allowed_roots:
        if canonical == r or r in canonical.parents:
            return canonical
    die(f"path is outside the permitted roots: {path} (resolved: {canonical})")


_LISTXATTR = getattr(os, "listxattr", None)


def _entry_has_fuse_xattr(path: str) -> str | None:
    if _LISTXATTR is None:
        return None
    try:
        attrs = _LISTXATTR(path, follow_symlinks=False)
    except OSError:
        return None
    for a in attrs:
        if a.startswith("user.fuseoverlayfs."):
            return a
    return None


def assert_no_fuse_xattrs(upper: Path) -> None:
    if not upper.exists() or _LISTXATTR is None:
        return
    for dirpath, dirnames, filenames in os.walk(upper):
        for entry in (dirpath, *(os.path.join(dirpath, n) for n in dirnames),
                      *(os.path.join(dirpath, n) for n in filenames)):
            tainted = _entry_has_fuse_xattr(entry)
            if tainted:
                die(
                    f"upperdir contains fuse-overlayfs xattr {tainted!r} on {entry}; "
                    "wipe upper/ and work/ before mounting"
                )


def exec_or_print(argv: list[str]) -> None:
    if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
        print(" ".join(shlex.quote(a) for a in argv))
        sys.exit(0)
    os.execv(argv[0], argv)


def cmd_mount(name: str) -> None:
    name = validate_name(name)
    r = root()
    runtime_name_dir = (r / "runtime" / name).resolve(strict=True)
    merged_for_check = (runtime_name_dir / "merged").resolve(strict=True)

    # Idempotency for unit restart cycles: if a previous start mounted
    # successfully but ExecStart failed afterwards (and Restart=on-failure
    # fires another cycle), the second ExecStartPre would otherwise refuse
    # to mount-on-top. Short-circuit here so the second cycle just gets
    # straight to ExecStart. PRINT_ONLY (test mode) bypasses this so the
    # tests can exercise the full nsenter argv regardless of mount state.
    if (
        os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") != "1"
        and os.path.ismount(merged_for_check)
    ):
        return

    instance_env = r / "instances" / name / "instance.env"
    raw_lowerdirs = parse_lowerdirs(instance_env)

    allowed_roots = [(r / sub).resolve() for sub in LOWERDIR_ALLOWLIST]
    canonical_lowerdirs = [str(canonical_under(allowed_roots, Path(p))) for p in raw_lowerdirs]

    upper = (runtime_name_dir / "upper").resolve(strict=True)
    work = (runtime_name_dir / "work").resolve(strict=True)
    merged = merged_for_check
    for label, path in (("upper", upper), ("work", work), ("merged", merged)):
        if path.parent != runtime_name_dir:
            die(f"{label} resolved outside runtime/{name}: {path}")

    assert_no_fuse_xattrs(upper)

    options = f"lowerdir={':'.join(canonical_lowerdirs)},upperdir={upper},workdir={work}"
    argv = [
        MOUNT_BIN,
        "-t", "overlay",
        "overlay",
        "-o", options,
        str(merged),
    ]
    exec_or_print(argv)


def cmd_umount(name: str) -> None:
    name = validate_name(name)
    r = root()
    runtime_name_dir = (r / "runtime" / name).resolve(strict=True)
    merged_path = runtime_name_dir / "merged"
    work_inner = runtime_name_dir / "work" / "work"

    argv = [
        UMOUNT_BIN,
        # Resolve only if it exists; PRINT_ONLY tests always pre-create it.
        str(merged_path.resolve(strict=True) if merged_path.exists() else merged_path),
    ]

    # PRINT_ONLY: emit the umount argv and exit. Tests assert exact shape
    # of this dry-run; the post-umount cleanup of work_inner is a runtime
    # behaviour exercised on the host, not in unit tests.
    if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
        print(" ".join(shlex.quote(a) for a in argv))
        sys.exit(0)

    if merged_path.exists():
        merged = merged_path.resolve(strict=True)
        if merged.parent != runtime_name_dir:
            die(f"merged resolved outside runtime/{name}: {merged}")
        # Idempotency: only umount if currently a mount point. Mirrors
        # cmd_mount's symmetric check; a redundant cleanup pass — or a
        # call after a partial _purge_instance — must be a no-op.
        #
        # No retry loop here: with the helper running in PID 1's mount
        # namespace (via the unit-level `nsenter --mount=/proc/1/ns/mnt`
        # in ExecStopPost), it holds no reference to the unit's
        # per-service mount namespace, so the cgroup-empty → namespace
        # reaped → umount-clears sequence happens without any race
        # window for us to ride out. EBUSY here is a real error.
        if os.path.ismount(merged):
            subprocess.run(argv, check=True)

    # Kernel-overlayfs creates work_inner during mount with root:root mode
    # 0/0. After unmount it's an orphan that the unit's User= (left4me)
    # cannot traverse via shutil.rmtree, so reset/delete in instances.py
    # blows up with EACCES on `runtime/<name>/work/work`. The helper is
    # the only code path with root that knows about this directory, so
    # the cleanup belongs here. Safe to nuke — the kernel re-creates it
    # on the next mount. Run unconditionally — covers both "we just
    # unmounted" and "previous teardown didn't finish" cases.
    if work_inner.exists():
        shutil.rmtree(work_inner)


def main(argv: list[str]) -> None:
    if len(argv) != 3 or argv[1] not in ("mount", "umount"):
        sys.stderr.write("usage: left4me-overlay mount|umount <name>\n")
        sys.exit(2)
    if argv[1] == "mount":
        cmd_mount(argv[2])
    else:
        cmd_umount(argv[2])


if __name__ == "__main__":
    main(sys.argv)
