#!/usr/bin/env python3

from argparse import ArgumentParser
from time import sleep

from bundlewrap.exceptions import RemoteException
from bundlewrap.utils.cmdline import get_target_nodes
from bundlewrap.utils.ui import io
from bundlewrap.repo import Repository
from os.path import realpath, dirname


# parse args
parser = ArgumentParser()
parser.add_argument("targets", nargs="*", default=['bundle:routeros'], help="bw nodes selector")
parser.add_argument("--yes", action="store_true", default=False, help="skip confirmation prompts")
args = parser.parse_args()


def wait_up(node):
    sleep(5)
    while True:
        try:
            node.run_routeros('/system/resource/print')
        except RemoteException:
            sleep(2)
            continue
        else:
            io.debug(f"{node.name}: is up")
            sleep(10)
            return


def upgrade_switch_os(node):
    # get versions for comparison
    with io.job(f"{node.name}: checking OS version"):
        response = node.run_routeros('/system/package/update/check-for-updates').raw[-1]
        installed_os = bw.libs.version.Version(response['installed-version'])
        latest_os = bw.libs.version.Version(response['latest-version'])
        io.debug(f"{node.name}: installed: {installed_os} >= latest: {latest_os}")

    # compare versions
    if installed_os >= latest_os:
        # os is up to date
        io.stdout(f"{node.name}: os up to date ({installed_os})")
    else:
        # confirm os upgrade
        if not args.yes and not io.ask(
            f"{node.name}: upgrade os from {installed_os} to {latest_os}?", default=True
        ):
            io.stdout(f"{node.name}: skipped by user")
            return

        # download os
        with io.job(f"{node.name}: downloading OS"):
            response = node.run_routeros('/system/package/update/download').raw[-1]
            io.debug(f"{node.name}: OS upgrade download response: {response['status']}")

        # install and wait for reboot
        with io.job(f"{node.name}: upgrading OS"):
            try:
                response = node.run_routeros('/system/package/update/install').raw[-1]
            except RemoteException:
                pass
            wait_up(node)

        # verify new os version
        with io.job(f"{node.name}: checking new OS version"):
            new_os = bw.libs.version.Version(node.run_routeros('/system/package/update/check-for-updates').raw[-1]['installed-version'])
            if new_os == latest_os:
                io.stdout(f"{node.name}: OS successfully upgraded from {installed_os} to {new_os}")
            else:
                raise Exception(f"{node.name}: OS upgrade failed, expected {latest_os}, got {new_os}")


def upgrade_switch_firmware(node):
    # get versions for comparison
    with io.job(f"{node.name}: checking Firmware version"):
        response = node.run_routeros('/system/routerboard/print').raw[-1]
        current_firmware = bw.libs.version.Version(response['current-firmware'])
        upgrade_firmware = bw.libs.version.Version(response['upgrade-firmware'])
        io.debug(f"{node.name}: firmware installed: {current_firmware}, upgrade: {upgrade_firmware}")

    # compare versions
    if current_firmware >= upgrade_firmware:
        # firmware is up to date
        io.stdout(f"{node.name}: firmware is up to date ({current_firmware})")
    else:
        # confirm firmware upgrade
        if not args.yes and not io.ask(
            f"{node.name}: upgrade firmware from {current_firmware} to {upgrade_firmware}?", default=True
        ):
            io.stdout(f"{node.name}: skipped by user")
            return

        # upgrade firmware
        with io.job(f"{node.name}: upgrading Firmware"):
            node.run_routeros('/system/routerboard/upgrade')

        # reboot and wait
        with io.job(f"{node.name}: rebooting"):
            try:
                node.run_routeros('/system/reboot')
            except RemoteException:
                pass
            wait_up(node)

        # verify firmware version
        new_firmware = bw.libs.version.Version(node.run_routeros('/system/routerboard/print').raw[-1]['current-firmware'])
        if new_firmware == upgrade_firmware:
            io.stdout(f"{node.name}: firmware successfully upgraded from {current_firmware} to {new_firmware}")
        else:
            raise Exception(f"firmware upgrade failed, expected {upgrade_firmware}, got {new_firmware}")


def upgrade_switch(node):
    with io.job(f"{node.name}: checking"):
        # check if routeros
        if node.os != 'routeros':
            io.progress_advance(2)
            io.stdout(f"{node.name}: skipped, unsupported os {node.os}")
            return

        # check switch reachability
        try:
            node.run_routeros('/system/resource/print')
        except RemoteException as error:
            io.progress_advance(2)
            io.stdout(f"{node.name}: skipped, error {error}")
            return

    upgrade_switch_os(node)
    io.progress_advance(1)

    upgrade_switch_firmware(node)
    io.progress_advance(1)


with io:
    bw = Repository(dirname(dirname(realpath(__file__))))

    nodes = get_target_nodes(bw, args.targets)

    io.progress_set_total(len(nodes) * 2)
    io.stdout(f"upgrading {len(nodes)} switches: {', '.join([node.name for node in sorted(nodes)])}")

    for node in sorted(nodes):
        upgrade_switch(node)