#!/usr/bin/env python3

from bundlewrap.repo import Repository
from os.path import realpath, dirname
from ipaddress import ip_interface

repo = Repository(dirname(dirname(realpath(__file__))))
nodes = [
    node
        for node in sorted(repo.nodes_in_group('debian'))
        if not node.dummy
]

print('updating nodes:', sorted(node.name for node in nodes))

# UPDATE

for node in nodes:
    print('--------------------------------------')
    print('updating', node.name)
    print('--------------------------------------')
    repo.libs.wol.wake(node)
    print(node.run('DEBIAN_FRONTEND=noninteractive apt update').stdout.decode())
    print(node.run('DEBIAN_FRONTEND=noninteractive apt list --upgradable').stdout.decode())
    if int(node.run('DEBIAN_FRONTEND=noninteractive apt list --upgradable 2> /dev/null | grep upgradable | wc -l').stdout.decode()):
        print(node.run('DEBIAN_FRONTEND=noninteractive apt -y dist-upgrade').stdout.decode())

# REBOOT IN ORDER

wireguard_servers = [
    node
        for node in nodes
        if node.has_bundle('wireguard')
        and (
            ip_interface(node.metadata.get('wireguard/my_ip')).network.prefixlen <
            ip_interface(node.metadata.get('wireguard/my_ip')).network.max_prefixlen
        )
]

wireguard_s2s = [
    node
        for node in nodes
        if node.has_bundle('wireguard')
        and (
            ip_interface(node.metadata.get('wireguard/my_ip')).network.prefixlen ==
            ip_interface(node.metadata.get('wireguard/my_ip')).network.max_prefixlen
        )
]

everything_else = [
    node
        for node in nodes
        if not node.has_bundle('wireguard')
]

print('======================================')

for node in [
    *everything_else,
    *wireguard_s2s,
    *wireguard_servers,
]:
    try:
        if node.run('test -e /var/run/reboot-required', may_fail=True).return_code == 0:
            print('rebooting', node.name)
            print(node.run('systemctl reboot').stdout.decode())
        else:
            print('not rebooting', node.name)
    except Exception as e:
        print(e)