#!/usr/bin/python3
"""Privileged overlay mount helper for left4me.

Invoked from the systemd unit's ExecStartPre / ExecStopPost via
`+/usr/bin/nsenter --mount=/proc/1/ns/mnt -- …`. The unit-level
nsenter is what makes this work: it runs the helper Python interpreter
inside PID 1's mount namespace. Without it, the `+` Exec prefix
removes the sandbox/credentials but does NOT detach from the unit's
per-service mount namespace, and the helper process itself would pin
that namespace alive — turning every umount into a multi-second EBUSY
race with the kernel's deferred namespace cleanup. With the unit-level
nsenter the helper has no such reference and umount succeeds first try.

Validates inputs strictly, then performs `mount -t overlay` /
`umount` directly — no internal nsenter, since the helper is already
running where the syscalls need to take effect.

Verbs:
    mount <name>    Reads ${LEFT4ME_ROOT}/instances/<name>/instance.env
                    for L4D2_LOWERDIRS, validates every lowerdir is
                    under one of installation/overlays/workshop_cache/
                    global_overlay_cache, then mounts the kernel
                    overlay at runtime/<name>/merged.
    umount <name>   Unmounts runtime/<name>/merged and cleans up the
                    kernel-overlayfs `work/work` orphan.

Set LEFT4ME_OVERLAY_PRINT_ONLY=1 to print the would-be argv (one line,
shell-quoted) and exit 0 instead of execv. Used by tests.
"""

import os
import pwd
import re
import shlex
import shutil
import subprocess
import sys
from pathlib import Path

NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
DEFAULT_ROOT = "/var/lib/left4me"
LOWERDIR_ALLOWLIST = (
    "installation",
    "overlays",
    "global_overlay_cache",
    "workshop_cache",
)
MAX_LOWERDIRS = 500
MOUNT_BIN = "/bin/mount"
UMOUNT_BIN = "/bin/umount"


def die(msg: str) -> None:
    sys.stderr.write(f"left4me-overlay: {msg}\n")
    sys.exit(1)


def _lookup_uid(username: str) -> tuple[int, int]:
    """Return (uid, gid) for *username*, dying with a clear message if missing."""
    try:
        entry = pwd.getpwnam(username)
    except KeyError:
        die(
            f"required system user {username!r} does not exist; "
            "this is a deploy misconfiguration"
        )
    return entry.pw_uid, entry.pw_gid


def _get_user_ids() -> tuple[int, int, int, int]:
    """Return (sandbox_uid, sandbox_gid, left4me_uid, left4me_gid).

    In normal operation, looks up the real system users.  In PRINT_ONLY
    (test) mode the env vars LEFT4ME_TEST_SANDBOX_UID/LEFT4ME_TEST_SANDBOX_GID/
    LEFT4ME_TEST_LEFT4ME_UID/LEFT4ME_TEST_LEFT4ME_GID may be used to inject
    synthetic uids so tests can run without root and without real system
    users present.  The stubs are intentionally ignored outside PRINT_ONLY
    mode so that a misconfigured systemd unit override cannot influence the
    real uid mapping.
    """
    if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
        sandbox_uid_env = os.environ.get("LEFT4ME_TEST_SANDBOX_UID")
        sandbox_gid_env = os.environ.get("LEFT4ME_TEST_SANDBOX_GID")
        left4me_uid_env = os.environ.get("LEFT4ME_TEST_LEFT4ME_UID")
        left4me_gid_env = os.environ.get("LEFT4ME_TEST_LEFT4ME_GID")

        if all(v is not None for v in (sandbox_uid_env, sandbox_gid_env,
                                       left4me_uid_env, left4me_gid_env)):
            return (
                int(sandbox_uid_env),  # type: ignore[arg-type]
                int(sandbox_gid_env),  # type: ignore[arg-type]
                int(left4me_uid_env),  # type: ignore[arg-type]
                int(left4me_gid_env),  # type: ignore[arg-type]
            )

    sandbox_uid, sandbox_gid = _lookup_uid("l4d2-sandbox")
    left4me_uid, left4me_gid = _lookup_uid("left4me")
    return sandbox_uid, sandbox_gid, left4me_uid, left4me_gid


def root() -> Path:
    return Path(os.environ.get("LEFT4ME_ROOT") or DEFAULT_ROOT)


def validate_name(name: str) -> str:
    if not NAME_RE.fullmatch(name):
        die(f"invalid instance name: {name!r}")
    return name


def parse_lowerdirs(env_path: Path) -> list[str]:
    if not env_path.is_file():
        die(f"instance.env not found: {env_path}")
    raw = None
    for line in env_path.read_text().splitlines():
        if "=" not in line:
            continue
        key, value = line.split("=", 1)
        if key.strip() == "L4D2_LOWERDIRS":
            raw = value
            break
    if raw is None:
        die(f"L4D2_LOWERDIRS not set in {env_path}")
    if raw == "":
        die(f"L4D2_LOWERDIRS is empty in {env_path}")
    parts = raw.split(":")
    if any(p == "" for p in parts):
        die(f"L4D2_LOWERDIRS contains an empty entry: {raw!r}")
    if len(parts) > MAX_LOWERDIRS:
        die(f"L4D2_LOWERDIRS has {len(parts)} entries (cap {MAX_LOWERDIRS})")
    return parts


def canonical_under(allowed_roots: list[Path], path: Path) -> Path:
    try:
        canonical = path.resolve(strict=True)
    except (FileNotFoundError, RuntimeError):
        die(f"path does not exist or has a symlink loop: {path}")
    for r in allowed_roots:
        if canonical == r or r in canonical.parents:
            return canonical
    die(f"path is outside the permitted roots: {path} (resolved: {canonical})")


_LISTXATTR = getattr(os, "listxattr", None)


def _entry_has_fuse_xattr(path: str) -> str | None:
    if _LISTXATTR is None:
        return None
    try:
        attrs = _LISTXATTR(path, follow_symlinks=False)
    except OSError:
        return None
    for a in attrs:
        if a.startswith("user.fuseoverlayfs."):
            return a
    return None


def assert_no_fuse_xattrs(upper: Path) -> None:
    if not upper.exists() or _LISTXATTR is None:
        return
    for dirpath, dirnames, filenames in os.walk(upper):
        for entry in (dirpath, *(os.path.join(dirpath, n) for n in dirnames),
                      *(os.path.join(dirpath, n) for n in filenames)):
            tainted = _entry_has_fuse_xattr(entry)
            if tainted:
                die(
                    f"upperdir contains fuse-overlayfs xattr {tainted!r} on {entry}; "
                    "wipe upper/ and work/ before mounting"
                )


def _print_argv(argv: list[str]) -> None:
    """Emit one shell-quoted argv line to stdout (PRINT_ONLY helper, no exit)."""
    print(" ".join(shlex.quote(a) for a in argv))


def exec_or_print(argv: list[str]) -> None:
    if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
        _print_argv(argv)
        sys.exit(0)
    os.execv(argv[0], argv)


def cmd_mount(name: str) -> None:
    name = validate_name(name)
    r = root()
    runtime_name_dir = (r / "runtime" / name).resolve(strict=True)
    merged_for_check = (runtime_name_dir / "merged").resolve(strict=True)

    # Idempotency for unit restart cycles: if a previous start mounted
    # successfully but ExecStart failed afterwards (and Restart=on-failure
    # fires another cycle), the second ExecStartPre would otherwise refuse
    # to mount-on-top. Short-circuit here so the second cycle just gets
    # straight to ExecStart. PRINT_ONLY (test mode) bypasses this so the
    # tests can exercise the full nsenter argv regardless of mount state.
    if (
        os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") != "1"
        and os.path.ismount(merged_for_check)
    ):
        return

    instance_env = r / "instances" / name / "instance.env"
    raw_lowerdirs = parse_lowerdirs(instance_env)

    allowed_roots = [(r / sub).resolve() for sub in LOWERDIR_ALLOWLIST]
    canonical_lowerdirs = [str(canonical_under(allowed_roots, Path(p))) for p in raw_lowerdirs]

    upper = (runtime_name_dir / "upper").resolve(strict=True)
    work = (runtime_name_dir / "work").resolve(strict=True)
    merged = merged_for_check
    for label, path in (("upper", upper), ("work", work), ("merged", merged)):
        if path.parent != runtime_name_dir:
            die(f"{label} resolved outside runtime/{name}: {path}")

    assert_no_fuse_xattrs(upper)

    # Resolve user ids now (fails fast on deploy misconfiguration).
    sandbox_uid, sandbox_gid, left4me_uid, left4me_gid = _get_user_ids()

    # Build the final lowerdir list, substituting idmap bind-mount paths for
    # any lowerdir owned by l4d2-sandbox.  An idmap bind mount makes the kernel
    # see the l4d2-sandbox-owned tree as if it were owned by left4me, so that
    # overlayfs copy-up produces left4me-owned upperdir entries.
    idmap_dir = runtime_name_dir / "idmap"
    final_lowerdirs: list[str] = []
    bind_argvs: list[list[str]] = []
    seen_idmap_targets: dict[Path, str] = {}

    for lowerdir in canonical_lowerdirs:
        try:
            st = os.stat(lowerdir)
        except OSError as exc:
            die(f"failed to stat lowerdir {shlex.quote(lowerdir)}: {exc}")
        if st.st_uid == sandbox_uid:
            # This lowerdir needs idmap remapping.
            # Include the parent dirname to avoid basename collisions between
            # lowerdirs from different allowlist roots (e.g. overlays/foo and
            # workshop_cache/foo would otherwise map to the same idmap target).
            p = Path(lowerdir)
            lowerdir_basename = f"{p.parent.name}_{p.name}"
            idmap_target = idmap_dir / lowerdir_basename

            # Belt-and-braces: detect if two different lowerdirs would collide
            # on the same idmap target after the parent+name derivation.
            if idmap_target in seen_idmap_targets:
                die(
                    f"idmap target collision: lowerdirs {shlex.quote(seen_idmap_targets[idmap_target])}"
                    f" and {shlex.quote(lowerdir)} both map to {idmap_target}"
                )
            seen_idmap_targets[idmap_target] = lowerdir

            if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") != "1":
                idmap_dir.mkdir(mode=0o700, exist_ok=True)
                idmap_target.mkdir(mode=0o700, exist_ok=True)

            if not os.path.ismount(idmap_target) or \
                    os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
                # --map-users / --map-groups argument format:
                #   <on-disk-uid>:<in-mount-uid>:<count>
                # The util-linux man page calls these <inner>:<outer>, which is
                # misleading.  Empirically (verified on left4.me, kernel 6.12,
                # ext4) the FIRST number is the on-disk uid and the SECOND is
                # the uid exposed inside the mount.  Don't swap them.
                bind_argv = [
                    MOUNT_BIN,
                    "--bind",
                    f"--map-users={sandbox_uid}:{left4me_uid}:1",
                    f"--map-groups={sandbox_gid}:{left4me_gid}:1",
                    lowerdir,
                    str(idmap_target),
                ]
                bind_argvs.append(bind_argv)

            final_lowerdirs.append(str(idmap_target))
        else:
            final_lowerdirs.append(lowerdir)

    print_only = os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1"

    if print_only:
        # Emit each bind-mount argv first, then fall through to the overlay argv.
        for bind_argv in bind_argvs:
            _print_argv(bind_argv)
    else:
        # Actually exec each bind mount before the overlay mount.
        for bind_argv in bind_argvs:
            subprocess.run(bind_argv, check=True)

    options = f"lowerdir={':'.join(final_lowerdirs)},upperdir={upper},workdir={work}"
    argv = [
        MOUNT_BIN,
        "-t", "overlay",
        "overlay",
        "-o", options,
        str(merged),
    ]
    exec_or_print(argv)


def cmd_umount(name: str) -> None:
    name = validate_name(name)
    r = root()
    runtime_name_dir = (r / "runtime" / name).resolve(strict=True)
    merged_path = runtime_name_dir / "merged"
    work_inner = runtime_name_dir / "work" / "work"

    overlay_umount_argv = [
        UMOUNT_BIN,
        # Resolve only if it exists; PRINT_ONLY tests always pre-create it.
        str(merged_path.resolve(strict=True) if merged_path.exists() else merged_path),
    ]

    # Collect idmap bind-umount argvs: one per direct subdir of runtime/<name>/idmap/.
    idmap_dir = runtime_name_dir / "idmap"
    bind_umount_argvs: list[list[str]] = []
    if idmap_dir.is_dir():
        for entry in sorted(idmap_dir.iterdir()):
            if entry.is_dir():
                bind_umount_argvs.append([UMOUNT_BIN, str(entry)])

    # PRINT_ONLY: emit the overlay umount argv, then each bind-umount argv, then exit.
    # Order matches real execution (overlay first, then idmap binds underneath).
    if os.environ.get("LEFT4ME_OVERLAY_PRINT_ONLY") == "1":
        _print_argv(overlay_umount_argv)
        for bind_umount_argv in bind_umount_argvs:
            _print_argv(bind_umount_argv)
        sys.exit(0)

    if merged_path.exists():
        merged = merged_path.resolve(strict=True)
        if merged.parent != runtime_name_dir:
            die(f"merged resolved outside runtime/{name}: {merged}")
        # Idempotency: only umount if currently a mount point. Mirrors
        # cmd_mount's symmetric check; a redundant cleanup pass — or a
        # call after a partial _purge_instance — must be a no-op.
        #
        # No retry loop here: with the helper running in PID 1's mount
        # namespace (via the unit-level `nsenter --mount=/proc/1/ns/mnt`
        # in ExecStopPost), it holds no reference to the unit's
        # per-service mount namespace, so the cgroup-empty → namespace
        # reaped → umount-clears sequence happens without any race
        # window for us to ride out. EBUSY here is a real error.
        if os.path.ismount(merged):
            subprocess.run(overlay_umount_argv, check=True)

    # Kernel-overlayfs creates work_inner during mount with root:root mode
    # 0/0. After unmount it's an orphan that the unit's User= (left4me)
    # cannot traverse via shutil.rmtree, so reset/delete in instances.py
    # blows up with EACCES on `runtime/<name>/work/work`. The helper is
    # the only code path with root that knows about this directory, so
    # the cleanup belongs here. Safe to nuke — the kernel re-creates it
    # on the next mount. Run unconditionally — covers both "we just
    # unmounted" and "previous teardown didn't finish" cases.
    if work_inner.exists():
        shutil.rmtree(work_inner)

    # Unwind idmap bind mounts, then remove the idmap directory.  Each bind
    # is only umounted if it is still a mountpoint (idempotent across partial
    # teardowns).
    for bind_umount_argv in bind_umount_argvs:
        target = Path(bind_umount_argv[-1])
        if os.path.ismount(target):
            subprocess.run(bind_umount_argv, check=True)
    shutil.rmtree(idmap_dir, ignore_errors=True)


def main(argv: list[str]) -> None:
    if len(argv) != 3 or argv[1] not in ("mount", "umount"):
        sys.stderr.write("usage: left4me-overlay mount|umount <name>\n")
        sys.exit(2)
    if argv[1] == "mount":
        cmd_mount(argv[2])
    else:
        cmd_umount(argv[2])


if __name__ == "__main__":
    main(sys.argv)
