From c9e8c4807e295c720aa79e54add06d54e8ea6ae0 Mon Sep 17 00:00:00 2001 From: dw-0 Date: Tue, 31 Oct 2023 20:54:44 +0100 Subject: [PATCH] feat(klipper): convert single to multi instance Signed-off-by: Dominik Willner --- kiauh/instance_manager/base_instance.py | 8 - kiauh/modules/klipper/klipper.py | 84 ++++----- kiauh/modules/klipper/klipper_dialogs.py | 98 +++++++++++ kiauh/modules/klipper/klipper_setup.py | 182 +++++-------------- kiauh/modules/klipper/klipper_utils.py | 212 ++++++++++++++++++----- kiauh/utils/input_utils.py | 31 ++-- 6 files changed, 374 insertions(+), 241 deletions(-) create mode 100644 kiauh/modules/klipper/klipper_dialogs.py diff --git a/kiauh/instance_manager/base_instance.py b/kiauh/instance_manager/base_instance.py index 169ff74..3dcda97 100644 --- a/kiauh/instance_manager/base_instance.py +++ b/kiauh/instance_manager/base_instance.py @@ -72,14 +72,6 @@ class BaseInstance(ABC): def create(self) -> None: raise NotImplementedError("Subclasses must implement the create method") - @abstractmethod - def read(self) -> None: - raise NotImplementedError("Subclasses must implement the read method") - - @abstractmethod - def update(self) -> None: - raise NotImplementedError("Subclasses must implement the update method") - @abstractmethod def delete(self, del_remnants: bool) -> None: raise NotImplementedError("Subclasses must implement the delete method") diff --git a/kiauh/modules/klipper/klipper.py b/kiauh/modules/klipper/klipper.py index e21d6a0..5c187e2 100644 --- a/kiauh/modules/klipper/klipper.py +++ b/kiauh/modules/klipper/klipper.py @@ -50,36 +50,12 @@ class Klipper(BaseInstance): service_file_target = f"{SYSTEMD}/{service_file_name}" env_file_target = os.path.abspath(f"{self.sysd_dir}/klipper.env") - # create folder structure - dirs = [ - self.data_dir, - self.cfg_dir, - self.log_dir, - self.comms_dir, - self.sysd_dir, - ] - for _dir in dirs: - create_directory(Path(_dir)) - try: - # writing the klipper service file (requires sudo!) - service_content = self._prep_service_file( - service_template_path, env_file_target + self.create_folder_structure() + self.write_service_file( + service_template_path, service_file_target, env_file_target ) - command = ["sudo", "tee", service_file_target] - subprocess.run( - command, - input=service_content.encode(), - stdout=subprocess.DEVNULL, - check=True, - ) - Logger.print_ok(f"Service file created: {service_file_target}") - - # writing the klipper.env file - env_file_content = self._prep_env_file(env_template_file_path) - with open(env_file_target, "w") as env_file: - env_file.write(env_file_content) - Logger.print_ok(f"Env file created: {env_file_target}") + self.write_env_file(env_template_file_path, env_file_target) except subprocess.CalledProcessError as e: Logger.print_error( @@ -90,12 +66,6 @@ class Klipper(BaseInstance): Logger.print_error(f"Error creating env file {env_file_target}: {e}") raise - def read(self) -> None: - print("Reading Klipper Instance") - - def update(self) -> None: - print("Updating Klipper Instance") - def delete(self, del_remnants: bool) -> None: service_file = self.get_service_file_name(extension=True) service_file_path = self._get_service_file_path() @@ -113,6 +83,45 @@ class Klipper(BaseInstance): if del_remnants: self._delete_klipper_remnants() + def create_folder_structure(self) -> None: + dirs = [ + self.data_dir, + self.cfg_dir, + self.log_dir, + self.comms_dir, + self.sysd_dir, + ] + for _dir in dirs: + create_directory(Path(_dir)) + + def write_service_file( + self, service_template_path: str, service_file_target: str, env_file_target: str + ): + service_content = self._prep_service_file( + service_template_path, env_file_target + ) + command = ["sudo", "tee", service_file_target] + subprocess.run( + command, + input=service_content.encode(), + stdout=subprocess.DEVNULL, + check=True, + ) + Logger.print_ok(f"Service file created: {service_file_target}") + + def write_env_file(self, env_template_file_path: str, env_file_target: str): + env_file_content = self._prep_env_file(env_template_file_path) + with open(env_file_target, "w") as env_file: + env_file.write(env_file_content) + Logger.print_ok(f"Env file created: {env_file_target}") + + def get_service_file_name(self, extension=False) -> str: + name = self.prefix if self.name is None else self.prefix + "-" + self.name + return name if not extension else f"{name}.service" + + def _get_service_file_path(self): + return f"{SYSTEMD}/{self.get_service_file_name(extension=True)}" + def _delete_klipper_remnants(self) -> None: try: Logger.print_info(f"Delete {self.klipper_dir} ...") @@ -127,13 +136,6 @@ class Klipper(BaseInstance): Logger.print_ok("Directories successfully deleted.") - def get_service_file_name(self, extension=False) -> str: - name = self.prefix if self.name is None else self.prefix + "-" + self.name - return name if not extension else f"{name}.service" - - def _get_service_file_path(self): - return f"{SYSTEMD}/{self.get_service_file_name(extension=True)}" - def _get_data_dir_from_name(self, name: str) -> str: if name is None: return "printer" diff --git a/kiauh/modules/klipper/klipper_dialogs.py b/kiauh/modules/klipper/klipper_dialogs.py new file mode 100644 index 0000000..204ea43 --- /dev/null +++ b/kiauh/modules/klipper/klipper_dialogs.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +# ======================================================================= # +# Copyright (C) 2020 - 2023 Dominik Willner # +# # +# This file is part of KIAUH - Klipper Installation And Update Helper # +# https://github.com/dw-0/kiauh # +# # +# This file may be distributed under the terms of the GNU GPLv3 license # +# ======================================================================= # + +from typing import List + +from kiauh.instance_manager.base_instance import BaseInstance +from kiauh.menus.base_menu import print_back_footer +from kiauh.utils.constants import COLOR_GREEN, RESET_FORMAT, COLOR_YELLOW, COLOR_CYAN + + +def print_instance_overview( + instances: List[BaseInstance], show_index=False, show_select_all=False +): + headline = f"{COLOR_GREEN}The following Klipper instances were found:{RESET_FORMAT}" + + print("/=======================================================\\") + print(f"|{'{:^64}'.format(headline)}|") + print("|-------------------------------------------------------|") + + if show_select_all: + select_all = f" {COLOR_YELLOW}a) Select all{RESET_FORMAT}" + print(f"|{'{:64}'.format(select_all)}|") + print("| |") + + for i, s in enumerate(instances): + index = f"{i})" if show_index else "●" + instance = s.get_service_file_name() + line = f"{'{:53}'.format(f'{index} {instance}')}" + print(f"| {COLOR_CYAN}{line}{RESET_FORMAT}|") + + print_back_footer() + + +def print_select_instance_count_dialog(): + print("/=======================================================\\") + print("| Please select the number of Klipper instances to set |") + print("| up. The number of Klipper instances will determine |") + print("| the amount of printers you can run from this host. |") + print("| |") + print( + f"| {COLOR_YELLOW}WARNING:{RESET_FORMAT} |" + ) + print( + f"| {COLOR_YELLOW}Setting up too many instances may crash your system.{RESET_FORMAT} |" + ) + print_back_footer() + + +def print_select_custom_name_dialog(): + print("/=======================================================\\") + print("| You can now assign a custom name to each instance. |") + print("| If skipped, each instance will get an index assigned |") + print("| in ascending order, starting at index '1'. |") + print("| |") + print( + f"| {COLOR_YELLOW}INFO:{RESET_FORMAT} |" + ) + print( + f"| {COLOR_YELLOW}Only alphanumeric characters are allowed!{RESET_FORMAT} |" + ) + print_back_footer() + + +def print_missing_usergroup_dialog(missing_groups) -> None: + print("/=======================================================\\") + print( + f"| {COLOR_YELLOW}WARNING: Your current user is not in group:{RESET_FORMAT} |" + ) + if "tty" in missing_groups: + print( + f"| {COLOR_CYAN}● tty{RESET_FORMAT} |" + ) + if "dialout" in missing_groups: + print( + f"| {COLOR_CYAN}● dialout{RESET_FORMAT} |" + ) + print("| |") + print("| It is possible that you won't be able to successfully |") + print("| connect and/or flash the controller board without |") + print("| your user being a member of that group. |") + print("| If you want to add the current user to the group(s) |") + print("| listed above, answer with 'Y'. Else skip with 'n'. |") + print("| |") + print( + f"| {COLOR_YELLOW}INFO:{RESET_FORMAT} |" + ) + print( + f"| {COLOR_YELLOW}Relog required for group assignments to take effect!{RESET_FORMAT} |" + ) + print("\\=======================================================/") diff --git a/kiauh/modules/klipper/klipper_setup.py b/kiauh/modules/klipper/klipper_setup.py index 865d2bf..5c41c40 100644 --- a/kiauh/modules/klipper/klipper_setup.py +++ b/kiauh/modules/klipper/klipper_setup.py @@ -9,27 +9,31 @@ # This file may be distributed under the terms of the GNU GPLv3 license # # ======================================================================= # -import grp import os -import re import subprocess -import textwrap from pathlib import Path -from typing import Optional, List, Union +from typing import List, Union from kiauh.config_manager.config_manager import ConfigManager from kiauh.instance_manager.instance_manager import InstanceManager from kiauh.modules.klipper.klipper import Klipper -from kiauh.modules.klipper.klipper_utils import ( +from kiauh.modules.klipper.klipper_dialogs import ( print_instance_overview, - print_missing_usergroup_dialog, + print_select_instance_count_dialog, +) +from kiauh.modules.klipper.klipper_utils import ( + handle_convert_single_to_multi_instance_names, + handle_new_multi_instance_names, + handle_existing_multi_instance_names, + handle_disruptive_system_packages, + check_user_groups, + handle_single_to_multi_conversion, ) from kiauh.repo_manager.repo_manager import RepoManager -from kiauh.utils.constants import CURRENT_USER, KLIPPER_DIR, KLIPPER_ENV_DIR +from kiauh.utils.constants import KLIPPER_DIR, KLIPPER_ENV_DIR from kiauh.utils.input_utils import ( get_confirm, get_number_input, - get_string_input, get_selection_input, ) from kiauh.utils.logger import Logger @@ -39,7 +43,6 @@ from kiauh.utils.system_utils import ( install_python_requirements, update_system_package_lists, install_system_packages, - mask_system_service, ) @@ -92,19 +95,37 @@ def handle_existing_instances(instance_manager: InstanceManager) -> bool: def install_klipper(instance_manager: InstanceManager) -> None: instance_list = instance_manager.get_instances() - if_adding = " additional" if len(instance_list) > 0 else "" - install_count = get_number_input( - f"Number of{if_adding} Klipper instances to set up", 1, default=1 - ) + + print_select_instance_count_dialog() + question = f"Number of{' additional' if len(instance_list) > 0 else ''} Klipper instances to set up" + install_count = get_number_input(question, 1, default=1, allow_go_back=True) + if install_count is None: + Logger.print_info("Exiting Klipper setup ...") + return instance_names = set_instance_names(instance_list, install_count) + if instance_names is None: + Logger.print_info("Exiting Klipper setup ...") + return if len(instance_list) < 1: setup_klipper_prerequesites() + convert_single_to_multi = ( + True + if len(instance_list) == 1 + and instance_list[0].name is None + and install_count >= 1 + else False + ) + for name in instance_names: - current_instance = Klipper(name=name) - instance_manager.set_current_instance(current_instance) + if convert_single_to_multi: + handle_single_to_multi_conversion(instance_manager, name) + convert_single_to_multi = False + else: + instance_manager.set_current_instance(Klipper(name=name)) + instance_manager.create_instance() instance_manager.enable_instance() instance_manager.start_instance() @@ -122,11 +143,11 @@ def setup_klipper_prerequesites() -> None: cm = ConfigManager() cm.read_config() - repo = ( + repo = str( cm.get_value("klipper", "repository_url") or "https://github.com/Klipper3D/klipper" ) - branch = cm.get_value("klipper", "branch") or "master" + branch = str(cm.get_value("klipper", "branch") or "master") repo_manager = RepoManager( repo=repo, @@ -159,64 +180,23 @@ def install_klipper_packages(klipper_dir: Path) -> None: def set_instance_names(instance_list, install_count: int) -> List[Union[str, None]]: instance_count = len(instance_list) - # default single instance install + # new single instance install if instance_count == 0 and install_count == 1: return [None] + # convert single instance install to multi install + elif instance_count == 1 and instance_list[0].name is None and install_count >= 1: + return handle_convert_single_to_multi_instance_names(install_count) + # new multi instance install - elif ( - (instance_count == 0 and install_count > 1) - # or convert single instance install to multi instance install - or (instance_count == 1 and install_count >= 1) - ): - if get_confirm("Assign custom names?", False): - return assign_custom_names(instance_count, install_count, None) - else: - _range = range(1, install_count + 1) - return [str(i) for i in _range] + elif instance_count == 0 and install_count > 1: + return handle_new_multi_instance_names(instance_count, install_count) # existing multi instance install elif instance_count > 1: - if has_custom_names(instance_list): - return assign_custom_names(instance_count, install_count, instance_list) - else: - start = get_highest_index(instance_list) + 1 - _range = range(start, start + install_count) - return [str(i) for i in _range] - - -def has_custom_names(instance_list: List[Klipper]) -> bool: - pattern = re.compile("^\d+$") - for instance in instance_list: - if not pattern.match(instance.name): - return True - - return False - - -def assign_custom_names( - instance_count: int, install_count: int, instance_list: Optional[List[Klipper]] -) -> List[str]: - instance_names = [] - exclude = Klipper.blacklist() - - # if an instance_list is provided, exclude all existing instance names - if instance_list is not None: - for instance in instance_list: - exclude.append(instance.name) - - for i in range(instance_count + install_count): - question = f"Enter name for instance {i + 1}" - name = get_string_input(question, exclude=exclude) - instance_names.append(name) - exclude.append(name) - - return instance_names - - -def get_highest_index(instance_list: List[Klipper]) -> int: - indices = [int(instance.name.split("-")[-1]) for instance in instance_list] - return max(indices) + return handle_existing_multi_instance_names( + instance_count, install_count, instance_list + ) def remove_single_instance(instance_manager: InstanceManager) -> None: @@ -262,69 +242,3 @@ def remove_multi_instance(instance_manager: InstanceManager) -> None: instance_manager.delete_instance(del_remnants=False) instance_manager.reload_daemon() - - -def check_user_groups(): - current_groups = [grp.getgrgid(gid).gr_name for gid in os.getgroups()] - - missing_groups = [] - if "tty" not in current_groups: - missing_groups.append("tty") - if "dialout" not in current_groups: - missing_groups.append("dialout") - - if not missing_groups: - return - - print_missing_usergroup_dialog(missing_groups) - if not get_confirm(f"Add user '{CURRENT_USER}' to group(s) now?"): - Logger.warn( - "Skipped adding user to required groups. You might encounter issues." - ) - return - - try: - for group in missing_groups: - Logger.print_info(f"Adding user '{CURRENT_USER}' to group {group} ...") - command = ["sudo", "usermod", "-a", "-G", group, CURRENT_USER] - subprocess.run(command, check=True) - Logger.print_ok(f"Group {group} assigned to user '{CURRENT_USER}'.") - except subprocess.CalledProcessError as e: - Logger.print_error(f"Unable to add user to usergroups: {e}") - raise - - Logger.print_warn( - "Remember to relog/restart this machine for the group(s) to be applied!" - ) - - -def handle_disruptive_system_packages() -> None: - services = [] - brltty_status = subprocess.run( - ["systemctl", "is-enabled", "brltty"], capture_output=True, text=True - ) - modem_manager_status = subprocess.run( - ["systemctl", "is-enabled", "ModemManager"], capture_output=True, text=True - ) - - if "enabled" in brltty_status.stdout: - services.append("brltty") - if "enabled" in modem_manager_status.stdout: - services.append("ModemManager") - - for service in services if services else []: - try: - Logger.print_info( - f"{service} service detected! Masking {service} service ..." - ) - mask_system_service(service) - Logger.print_ok(f"{service} service masked!") - except subprocess.CalledProcessError: - warn_msg = textwrap.dedent( - f""" - KIAUH was unable to mask the {service} system service. - Please fix the problem manually. Otherwise, this may have - undesirable effects on the operation of Klipper. - """ - )[1:] - Logger.print_warn(warn_msg) diff --git a/kiauh/modules/klipper/klipper_utils.py b/kiauh/modules/klipper/klipper_utils.py index 0d42b86..e018250 100644 --- a/kiauh/modules/klipper/klipper_utils.py +++ b/kiauh/modules/klipper/klipper_utils.py @@ -9,60 +9,178 @@ # This file may be distributed under the terms of the GNU GPLv3 license # # ======================================================================= # -from typing import List +import os +import re +import grp +import subprocess +import textwrap -from kiauh.instance_manager.base_instance import BaseInstance -from kiauh.menus.base_menu import print_back_footer -from kiauh.utils.constants import COLOR_GREEN, COLOR_CYAN, COLOR_YELLOW, RESET_FORMAT +from typing import List, Union + +from kiauh.instance_manager.instance_manager import InstanceManager +from kiauh.modules.klipper.klipper import Klipper +from kiauh.modules.klipper.klipper_dialogs import ( + print_missing_usergroup_dialog, + print_select_custom_name_dialog, +) +from kiauh.utils.constants import CURRENT_USER +from kiauh.utils.input_utils import get_confirm, get_string_input +from kiauh.utils.logger import Logger +from kiauh.utils.system_utils import mask_system_service -def print_instance_overview( - instances: List[BaseInstance], show_index=False, show_select_all=False -): - headline = f"{COLOR_GREEN}The following Klipper instances were found:{RESET_FORMAT}" +def assign_custom_names( + instance_count: int, install_count: int, instance_list: List[Klipper] = None +) -> List[str]: + instance_names = [] + exclude = Klipper.blacklist() - print("/=======================================================\\") - print(f"|{'{:^64}'.format(headline)}|") - print("|-------------------------------------------------------|") + # if an instance_list is provided, exclude all existing instance names + if instance_list is not None: + for instance in instance_list: + exclude.append(instance.name) - if show_select_all: - select_all = f" {COLOR_YELLOW}a) Select all{RESET_FORMAT}" - print(f"|{'{:64}'.format(select_all)}|") - print("| |") + for i in range(instance_count + install_count): + question = f"Enter name for instance {i + 1}" + name = get_string_input(question, exclude=exclude) + instance_names.append(name) + exclude.append(name) - for i, s in enumerate(instances): - index = f"{i})" if show_index else "●" - instance = s.get_service_file_name() - line = f"{'{:53}'.format(f'{index} {instance}')}" - print(f"| {COLOR_CYAN}{line}{RESET_FORMAT}|") - - print_back_footer() + return instance_names -def print_missing_usergroup_dialog(missing_groups) -> None: - print("/=======================================================\\") - print( - f"| {COLOR_YELLOW}WARNING: Your current user is not in group:{RESET_FORMAT} |" +def handle_convert_single_to_multi_instance_names( + install_count: int, +) -> Union[List[str], None]: + print_select_custom_name_dialog() + choice = get_confirm("Assign custom names?", False, allow_go_back=True) + if choice is True: + # instance_count = 0 and install_count + 1 as we want to assign a new name to the existing single install + return assign_custom_names(0, install_count + 1) + elif choice is False: + # "install_count + 2" as we need to account for the existing single install + _range = range(1, install_count + 2) + return [str(i) for i in _range] + + return None + + +def handle_new_multi_instance_names( + instance_count: int, install_count: int +) -> Union[List[str], None]: + print_select_custom_name_dialog() + choice = get_confirm("Assign custom names?", False, allow_go_back=True) + if choice is True: + return assign_custom_names(instance_count, install_count) + elif choice is False: + _range = range(1, install_count + 1) + return [str(i) for i in _range] + + return None + + +def handle_existing_multi_instance_names( + instance_count: int, install_count: int, instance_list: List[Klipper] +) -> List[str]: + if has_custom_names(instance_list): + return assign_custom_names(instance_count, install_count, instance_list) + else: + start = get_highest_index(instance_list) + 1 + _range = range(start, start + install_count) + return [str(i) for i in _range] + + +def handle_single_to_multi_conversion( + instance_manager: InstanceManager, name: str +) -> None: + instance_list = instance_manager.get_instances() + instance_manager.set_current_instance(instance_list[0]) + old_data_dir_name = instance_manager.get_instances()[0].data_dir + instance_manager.stop_instance() + instance_manager.disable_instance() + instance_manager.delete_instance(del_remnants=False) + instance_manager.set_current_instance(Klipper(name=name)) + new_data_dir_name = instance_manager.get_current_instance().data_dir + try: + os.rename(old_data_dir_name, new_data_dir_name) + except OSError as e: + log = f"Cannot rename {old_data_dir_name} to {new_data_dir_name}:\n{e}" + Logger.print_error(log) + + +def check_user_groups(): + current_groups = [grp.getgrgid(gid).gr_name for gid in os.getgroups()] + + missing_groups = [] + if "tty" not in current_groups: + missing_groups.append("tty") + if "dialout" not in current_groups: + missing_groups.append("dialout") + + if not missing_groups: + return + + print_missing_usergroup_dialog(missing_groups) + if not get_confirm(f"Add user '{CURRENT_USER}' to group(s) now?"): + log = "Skipped adding user to required groups. You might encounter issues." + Logger.warn(log) + return + + try: + for group in missing_groups: + Logger.print_info(f"Adding user '{CURRENT_USER}' to group {group} ...") + command = ["sudo", "usermod", "-a", "-G", group, CURRENT_USER] + subprocess.run(command, check=True) + Logger.print_ok(f"Group {group} assigned to user '{CURRENT_USER}'.") + except subprocess.CalledProcessError as e: + Logger.print_error(f"Unable to add user to usergroups: {e}") + raise + + log = "Remember to relog/restart this machine for the group(s) to be applied!" + Logger.print_warn(log) + + +def handle_disruptive_system_packages() -> None: + services = [] + brltty_status = subprocess.run( + ["systemctl", "is-enabled", "brltty"], capture_output=True, text=True ) - if "tty" in missing_groups: - print( - f"| {COLOR_CYAN}● tty{RESET_FORMAT} |" - ) - if "dialout" in missing_groups: - print( - f"| {COLOR_CYAN}● dialout{RESET_FORMAT} |" - ) - print("| |") - print("| It is possible that you won't be able to successfully |") - print("| connect and/or flash the controller board without |") - print("| your user being a member of that group. |") - print("| If you want to add the current user to the group(s) |") - print("| listed above, answer with 'Y'. Else skip with 'n'. |") - print("| |") - print( - f"| {COLOR_YELLOW}INFO:{RESET_FORMAT} |" + modem_manager_status = subprocess.run( + ["systemctl", "is-enabled", "ModemManager"], capture_output=True, text=True ) - print( - f"| {COLOR_YELLOW}Relog required for group assignments to take effect!{RESET_FORMAT} |" - ) - print("\\=======================================================/") + + if "enabled" in brltty_status.stdout: + services.append("brltty") + if "enabled" in modem_manager_status.stdout: + services.append("ModemManager") + + for service in services if services else []: + try: + Logger.print_info( + f"{service} service detected! Masking {service} service ..." + ) + mask_system_service(service) + Logger.print_ok(f"{service} service masked!") + except subprocess.CalledProcessError: + warn_msg = textwrap.dedent( + f""" + KIAUH was unable to mask the {service} system service. + Please fix the problem manually. Otherwise, this may have + undesirable effects on the operation of Klipper. + """ + )[1:] + Logger.print_warn(warn_msg) + + +def has_custom_names(instance_list: List[Klipper]) -> bool: + pattern = re.compile("^\d+$") + for instance in instance_list: + if not pattern.match(instance.name): + return True + + return False + + +def get_highest_index(instance_list: List[Klipper]) -> int: + indices = [int(instance.name.split("-")[-1]) for instance in instance_list] + return max(indices) diff --git a/kiauh/utils/input_utils.py b/kiauh/utils/input_utils.py index eaf97f6..5ab04c5 100644 --- a/kiauh/utils/input_utils.py +++ b/kiauh/utils/input_utils.py @@ -9,15 +9,18 @@ # This file may be distributed under the terms of the GNU GPLv3 license # # ======================================================================= # -from typing import Optional, List +from typing import Optional, List, Union from kiauh.utils.constants import COLOR_CYAN, RESET_FORMAT from kiauh.utils.logger import Logger -def get_confirm(question: str, default_choice=True) -> bool: +def get_confirm( + question: str, default_choice=True, allow_go_back=False +) -> Union[bool, None]: options_confirm = ["y", "yes"] options_decline = ["n", "no"] + options_go_back = ["b", "B"] if default_choice: def_choice = "(Y/n)" @@ -28,7 +31,7 @@ def get_confirm(question: str, default_choice=True) -> bool: while True: choice = ( - input(f"{COLOR_CYAN}###### {question} {def_choice} {RESET_FORMAT}") + input(f"{COLOR_CYAN}###### {question} {def_choice}: {RESET_FORMAT}") .strip() .lower() ) @@ -37,28 +40,34 @@ def get_confirm(question: str, default_choice=True) -> bool: return True elif choice in options_decline: return False + elif allow_go_back and choice in options_go_back: + return None else: Logger.print_error("Invalid choice. Please select 'y' or 'n'.") def get_number_input( - question: str, min_count: int, max_count=None, default=None -) -> int: + question: str, min_count: int, max_count=None, default=None, allow_go_back=False +) -> Union[int, None]: + options_go_back = ["b", "B"] _question = question + f" (default={default})" if default else question _question = f"{COLOR_CYAN}###### {_question}: {RESET_FORMAT}" while True: try: - num = input(_question) - if num == "": + _input = input(_question) + if allow_go_back and _input in options_go_back: + return None + + if _input == "": return default if max_count is not None: - if min_count <= int(num) <= max_count: - return int(num) + if min_count <= int(_input) <= max_count: + return int(_input) else: raise ValueError - elif int(num) >= min_count: - return int(num) + elif int(_input) >= min_count: + return int(_input) else: raise ValueError except ValueError: