#!/usr/bin/env python3

from bundlewrap.repo import Repository
from os.path import realpath, dirname
from ipaddress import ip_interface

repo = Repository(dirname(dirname(realpath(__file__))))
nodes = [
    node
        for node in repo.nodes_in_group('debian')
        if not node.dummy
]

print('updating nodes:', sorted(node.name for node in nodes))

# UPDATE

for node in nodes:
    print('--------------------------------------')
    print('updating', node.name)
    print('--------------------------------------')
    repo.libs.wol.wake(node)
    print(node.run('DEBIAN_FRONTEND=noninteractive apt update').stdout.decode())
    print(node.run('DEBIAN_FRONTEND=noninteractive apt -y dist-upgrade').stdout.decode())

# REBOOT IN ORDER

wireguard_servers = [
    node
        for node in nodes
        if node.has_bundle('wireguard')
        and (
            ip_interface(node.metadata.get('wireguard/my_ip')).network.prefixlen <
            ip_interface(node.metadata.get('wireguard/my_ip')).network.max_prefixlen
        )
]

wireguard_s2s = [
    node
        for node in nodes
        if node.has_bundle('wireguard')
        and (
            ip_interface(node.metadata.get('wireguard/my_ip')).network.prefixlen ==
            ip_interface(node.metadata.get('wireguard/my_ip')).network.max_prefixlen
        )
]

everything_else = [
    node
        for node in nodes
        if not node.has_bundle('wireguard')
]

print('======================================')
print(len(everything_else), len(wireguard_s2s), len(wireguard_servers))

for node in [
    *everything_else,
    *wireguard_s2s,
    *wireguard_servers,
]:
    print('rebooting', node.name)
    print(node.run('systemctl reboot').stdout.decode())