feat(klipper): convert single to multi instance

Signed-off-by: Dominik Willner <th33xitus@gmail.com>
This commit is contained in:
dw-0
2023-10-31 20:54:44 +01:00
parent 09e874214b
commit c9e8c4807e
6 changed files with 374 additions and 241 deletions

View File

@@ -72,14 +72,6 @@ class BaseInstance(ABC):
def create(self) -> None: def create(self) -> None:
raise NotImplementedError("Subclasses must implement the create method") raise NotImplementedError("Subclasses must implement the create method")
@abstractmethod
def read(self) -> None:
raise NotImplementedError("Subclasses must implement the read method")
@abstractmethod
def update(self) -> None:
raise NotImplementedError("Subclasses must implement the update method")
@abstractmethod @abstractmethod
def delete(self, del_remnants: bool) -> None: def delete(self, del_remnants: bool) -> None:
raise NotImplementedError("Subclasses must implement the delete method") raise NotImplementedError("Subclasses must implement the delete method")

View File

@@ -50,36 +50,12 @@ class Klipper(BaseInstance):
service_file_target = f"{SYSTEMD}/{service_file_name}" service_file_target = f"{SYSTEMD}/{service_file_name}"
env_file_target = os.path.abspath(f"{self.sysd_dir}/klipper.env") env_file_target = os.path.abspath(f"{self.sysd_dir}/klipper.env")
# create folder structure
dirs = [
self.data_dir,
self.cfg_dir,
self.log_dir,
self.comms_dir,
self.sysd_dir,
]
for _dir in dirs:
create_directory(Path(_dir))
try: try:
# writing the klipper service file (requires sudo!) self.create_folder_structure()
service_content = self._prep_service_file( self.write_service_file(
service_template_path, env_file_target service_template_path, service_file_target, env_file_target
) )
command = ["sudo", "tee", service_file_target] self.write_env_file(env_template_file_path, env_file_target)
subprocess.run(
command,
input=service_content.encode(),
stdout=subprocess.DEVNULL,
check=True,
)
Logger.print_ok(f"Service file created: {service_file_target}")
# writing the klipper.env file
env_file_content = self._prep_env_file(env_template_file_path)
with open(env_file_target, "w") as env_file:
env_file.write(env_file_content)
Logger.print_ok(f"Env file created: {env_file_target}")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
Logger.print_error( Logger.print_error(
@@ -90,12 +66,6 @@ class Klipper(BaseInstance):
Logger.print_error(f"Error creating env file {env_file_target}: {e}") Logger.print_error(f"Error creating env file {env_file_target}: {e}")
raise raise
def read(self) -> None:
print("Reading Klipper Instance")
def update(self) -> None:
print("Updating Klipper Instance")
def delete(self, del_remnants: bool) -> None: def delete(self, del_remnants: bool) -> None:
service_file = self.get_service_file_name(extension=True) service_file = self.get_service_file_name(extension=True)
service_file_path = self._get_service_file_path() service_file_path = self._get_service_file_path()
@@ -113,6 +83,45 @@ class Klipper(BaseInstance):
if del_remnants: if del_remnants:
self._delete_klipper_remnants() self._delete_klipper_remnants()
def create_folder_structure(self) -> None:
dirs = [
self.data_dir,
self.cfg_dir,
self.log_dir,
self.comms_dir,
self.sysd_dir,
]
for _dir in dirs:
create_directory(Path(_dir))
def write_service_file(
self, service_template_path: str, service_file_target: str, env_file_target: str
):
service_content = self._prep_service_file(
service_template_path, env_file_target
)
command = ["sudo", "tee", service_file_target]
subprocess.run(
command,
input=service_content.encode(),
stdout=subprocess.DEVNULL,
check=True,
)
Logger.print_ok(f"Service file created: {service_file_target}")
def write_env_file(self, env_template_file_path: str, env_file_target: str):
env_file_content = self._prep_env_file(env_template_file_path)
with open(env_file_target, "w") as env_file:
env_file.write(env_file_content)
Logger.print_ok(f"Env file created: {env_file_target}")
def get_service_file_name(self, extension=False) -> str:
name = self.prefix if self.name is None else self.prefix + "-" + self.name
return name if not extension else f"{name}.service"
def _get_service_file_path(self):
return f"{SYSTEMD}/{self.get_service_file_name(extension=True)}"
def _delete_klipper_remnants(self) -> None: def _delete_klipper_remnants(self) -> None:
try: try:
Logger.print_info(f"Delete {self.klipper_dir} ...") Logger.print_info(f"Delete {self.klipper_dir} ...")
@@ -127,13 +136,6 @@ class Klipper(BaseInstance):
Logger.print_ok("Directories successfully deleted.") Logger.print_ok("Directories successfully deleted.")
def get_service_file_name(self, extension=False) -> str:
name = self.prefix if self.name is None else self.prefix + "-" + self.name
return name if not extension else f"{name}.service"
def _get_service_file_path(self):
return f"{SYSTEMD}/{self.get_service_file_name(extension=True)}"
def _get_data_dir_from_name(self, name: str) -> str: def _get_data_dir_from_name(self, name: str) -> str:
if name is None: if name is None:
return "printer" return "printer"

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
# ======================================================================= #
# Copyright (C) 2020 - 2023 Dominik Willner <th33xitus@gmail.com> #
# #
# This file is part of KIAUH - Klipper Installation And Update Helper #
# https://github.com/dw-0/kiauh #
# #
# This file may be distributed under the terms of the GNU GPLv3 license #
# ======================================================================= #
from typing import List
from kiauh.instance_manager.base_instance import BaseInstance
from kiauh.menus.base_menu import print_back_footer
from kiauh.utils.constants import COLOR_GREEN, RESET_FORMAT, COLOR_YELLOW, COLOR_CYAN
def print_instance_overview(
instances: List[BaseInstance], show_index=False, show_select_all=False
):
headline = f"{COLOR_GREEN}The following Klipper instances were found:{RESET_FORMAT}"
print("/=======================================================\\")
print(f"|{'{:^64}'.format(headline)}|")
print("|-------------------------------------------------------|")
if show_select_all:
select_all = f" {COLOR_YELLOW}a) Select all{RESET_FORMAT}"
print(f"|{'{:64}'.format(select_all)}|")
print("| |")
for i, s in enumerate(instances):
index = f"{i})" if show_index else ""
instance = s.get_service_file_name()
line = f"{'{:53}'.format(f'{index} {instance}')}"
print(f"| {COLOR_CYAN}{line}{RESET_FORMAT}|")
print_back_footer()
def print_select_instance_count_dialog():
print("/=======================================================\\")
print("| Please select the number of Klipper instances to set |")
print("| up. The number of Klipper instances will determine |")
print("| the amount of printers you can run from this host. |")
print("| |")
print(
f"| {COLOR_YELLOW}WARNING:{RESET_FORMAT} |"
)
print(
f"| {COLOR_YELLOW}Setting up too many instances may crash your system.{RESET_FORMAT} |"
)
print_back_footer()
def print_select_custom_name_dialog():
print("/=======================================================\\")
print("| You can now assign a custom name to each instance. |")
print("| If skipped, each instance will get an index assigned |")
print("| in ascending order, starting at index '1'. |")
print("| |")
print(
f"| {COLOR_YELLOW}INFO:{RESET_FORMAT} |"
)
print(
f"| {COLOR_YELLOW}Only alphanumeric characters are allowed!{RESET_FORMAT} |"
)
print_back_footer()
def print_missing_usergroup_dialog(missing_groups) -> None:
print("/=======================================================\\")
print(
f"| {COLOR_YELLOW}WARNING: Your current user is not in group:{RESET_FORMAT} |"
)
if "tty" in missing_groups:
print(
f"| {COLOR_CYAN}● tty{RESET_FORMAT} |"
)
if "dialout" in missing_groups:
print(
f"| {COLOR_CYAN}● dialout{RESET_FORMAT} |"
)
print("| |")
print("| It is possible that you won't be able to successfully |")
print("| connect and/or flash the controller board without |")
print("| your user being a member of that group. |")
print("| If you want to add the current user to the group(s) |")
print("| listed above, answer with 'Y'. Else skip with 'n'. |")
print("| |")
print(
f"| {COLOR_YELLOW}INFO:{RESET_FORMAT} |"
)
print(
f"| {COLOR_YELLOW}Relog required for group assignments to take effect!{RESET_FORMAT} |"
)
print("\\=======================================================/")

View File

@@ -9,27 +9,31 @@
# This file may be distributed under the terms of the GNU GPLv3 license # # This file may be distributed under the terms of the GNU GPLv3 license #
# ======================================================================= # # ======================================================================= #
import grp
import os import os
import re
import subprocess import subprocess
import textwrap
from pathlib import Path from pathlib import Path
from typing import Optional, List, Union from typing import List, Union
from kiauh.config_manager.config_manager import ConfigManager from kiauh.config_manager.config_manager import ConfigManager
from kiauh.instance_manager.instance_manager import InstanceManager from kiauh.instance_manager.instance_manager import InstanceManager
from kiauh.modules.klipper.klipper import Klipper from kiauh.modules.klipper.klipper import Klipper
from kiauh.modules.klipper.klipper_utils import ( from kiauh.modules.klipper.klipper_dialogs import (
print_instance_overview, print_instance_overview,
print_missing_usergroup_dialog, print_select_instance_count_dialog,
)
from kiauh.modules.klipper.klipper_utils import (
handle_convert_single_to_multi_instance_names,
handle_new_multi_instance_names,
handle_existing_multi_instance_names,
handle_disruptive_system_packages,
check_user_groups,
handle_single_to_multi_conversion,
) )
from kiauh.repo_manager.repo_manager import RepoManager from kiauh.repo_manager.repo_manager import RepoManager
from kiauh.utils.constants import CURRENT_USER, KLIPPER_DIR, KLIPPER_ENV_DIR from kiauh.utils.constants import KLIPPER_DIR, KLIPPER_ENV_DIR
from kiauh.utils.input_utils import ( from kiauh.utils.input_utils import (
get_confirm, get_confirm,
get_number_input, get_number_input,
get_string_input,
get_selection_input, get_selection_input,
) )
from kiauh.utils.logger import Logger from kiauh.utils.logger import Logger
@@ -39,7 +43,6 @@ from kiauh.utils.system_utils import (
install_python_requirements, install_python_requirements,
update_system_package_lists, update_system_package_lists,
install_system_packages, install_system_packages,
mask_system_service,
) )
@@ -92,19 +95,37 @@ def handle_existing_instances(instance_manager: InstanceManager) -> bool:
def install_klipper(instance_manager: InstanceManager) -> None: def install_klipper(instance_manager: InstanceManager) -> None:
instance_list = instance_manager.get_instances() instance_list = instance_manager.get_instances()
if_adding = " additional" if len(instance_list) > 0 else ""
install_count = get_number_input( print_select_instance_count_dialog()
f"Number of{if_adding} Klipper instances to set up", 1, default=1 question = f"Number of{' additional' if len(instance_list) > 0 else ''} Klipper instances to set up"
) install_count = get_number_input(question, 1, default=1, allow_go_back=True)
if install_count is None:
Logger.print_info("Exiting Klipper setup ...")
return
instance_names = set_instance_names(instance_list, install_count) instance_names = set_instance_names(instance_list, install_count)
if instance_names is None:
Logger.print_info("Exiting Klipper setup ...")
return
if len(instance_list) < 1: if len(instance_list) < 1:
setup_klipper_prerequesites() setup_klipper_prerequesites()
convert_single_to_multi = (
True
if len(instance_list) == 1
and instance_list[0].name is None
and install_count >= 1
else False
)
for name in instance_names: for name in instance_names:
current_instance = Klipper(name=name) if convert_single_to_multi:
instance_manager.set_current_instance(current_instance) handle_single_to_multi_conversion(instance_manager, name)
convert_single_to_multi = False
else:
instance_manager.set_current_instance(Klipper(name=name))
instance_manager.create_instance() instance_manager.create_instance()
instance_manager.enable_instance() instance_manager.enable_instance()
instance_manager.start_instance() instance_manager.start_instance()
@@ -122,11 +143,11 @@ def setup_klipper_prerequesites() -> None:
cm = ConfigManager() cm = ConfigManager()
cm.read_config() cm.read_config()
repo = ( repo = str(
cm.get_value("klipper", "repository_url") cm.get_value("klipper", "repository_url")
or "https://github.com/Klipper3D/klipper" or "https://github.com/Klipper3D/klipper"
) )
branch = cm.get_value("klipper", "branch") or "master" branch = str(cm.get_value("klipper", "branch") or "master")
repo_manager = RepoManager( repo_manager = RepoManager(
repo=repo, repo=repo,
@@ -159,64 +180,23 @@ def install_klipper_packages(klipper_dir: Path) -> None:
def set_instance_names(instance_list, install_count: int) -> List[Union[str, None]]: def set_instance_names(instance_list, install_count: int) -> List[Union[str, None]]:
instance_count = len(instance_list) instance_count = len(instance_list)
# default single instance install # new single instance install
if instance_count == 0 and install_count == 1: if instance_count == 0 and install_count == 1:
return [None] return [None]
# convert single instance install to multi install
elif instance_count == 1 and instance_list[0].name is None and install_count >= 1:
return handle_convert_single_to_multi_instance_names(install_count)
# new multi instance install # new multi instance install
elif ( elif instance_count == 0 and install_count > 1:
(instance_count == 0 and install_count > 1) return handle_new_multi_instance_names(instance_count, install_count)
# or convert single instance install to multi instance install
or (instance_count == 1 and install_count >= 1)
):
if get_confirm("Assign custom names?", False):
return assign_custom_names(instance_count, install_count, None)
else:
_range = range(1, install_count + 1)
return [str(i) for i in _range]
# existing multi instance install # existing multi instance install
elif instance_count > 1: elif instance_count > 1:
if has_custom_names(instance_list): return handle_existing_multi_instance_names(
return assign_custom_names(instance_count, install_count, instance_list) instance_count, install_count, instance_list
else: )
start = get_highest_index(instance_list) + 1
_range = range(start, start + install_count)
return [str(i) for i in _range]
def has_custom_names(instance_list: List[Klipper]) -> bool:
pattern = re.compile("^\d+$")
for instance in instance_list:
if not pattern.match(instance.name):
return True
return False
def assign_custom_names(
instance_count: int, install_count: int, instance_list: Optional[List[Klipper]]
) -> List[str]:
instance_names = []
exclude = Klipper.blacklist()
# if an instance_list is provided, exclude all existing instance names
if instance_list is not None:
for instance in instance_list:
exclude.append(instance.name)
for i in range(instance_count + install_count):
question = f"Enter name for instance {i + 1}"
name = get_string_input(question, exclude=exclude)
instance_names.append(name)
exclude.append(name)
return instance_names
def get_highest_index(instance_list: List[Klipper]) -> int:
indices = [int(instance.name.split("-")[-1]) for instance in instance_list]
return max(indices)
def remove_single_instance(instance_manager: InstanceManager) -> None: def remove_single_instance(instance_manager: InstanceManager) -> None:
@@ -262,69 +242,3 @@ def remove_multi_instance(instance_manager: InstanceManager) -> None:
instance_manager.delete_instance(del_remnants=False) instance_manager.delete_instance(del_remnants=False)
instance_manager.reload_daemon() instance_manager.reload_daemon()
def check_user_groups():
current_groups = [grp.getgrgid(gid).gr_name for gid in os.getgroups()]
missing_groups = []
if "tty" not in current_groups:
missing_groups.append("tty")
if "dialout" not in current_groups:
missing_groups.append("dialout")
if not missing_groups:
return
print_missing_usergroup_dialog(missing_groups)
if not get_confirm(f"Add user '{CURRENT_USER}' to group(s) now?"):
Logger.warn(
"Skipped adding user to required groups. You might encounter issues."
)
return
try:
for group in missing_groups:
Logger.print_info(f"Adding user '{CURRENT_USER}' to group {group} ...")
command = ["sudo", "usermod", "-a", "-G", group, CURRENT_USER]
subprocess.run(command, check=True)
Logger.print_ok(f"Group {group} assigned to user '{CURRENT_USER}'.")
except subprocess.CalledProcessError as e:
Logger.print_error(f"Unable to add user to usergroups: {e}")
raise
Logger.print_warn(
"Remember to relog/restart this machine for the group(s) to be applied!"
)
def handle_disruptive_system_packages() -> None:
services = []
brltty_status = subprocess.run(
["systemctl", "is-enabled", "brltty"], capture_output=True, text=True
)
modem_manager_status = subprocess.run(
["systemctl", "is-enabled", "ModemManager"], capture_output=True, text=True
)
if "enabled" in brltty_status.stdout:
services.append("brltty")
if "enabled" in modem_manager_status.stdout:
services.append("ModemManager")
for service in services if services else []:
try:
Logger.print_info(
f"{service} service detected! Masking {service} service ..."
)
mask_system_service(service)
Logger.print_ok(f"{service} service masked!")
except subprocess.CalledProcessError:
warn_msg = textwrap.dedent(
f"""
KIAUH was unable to mask the {service} system service.
Please fix the problem manually. Otherwise, this may have
undesirable effects on the operation of Klipper.
"""
)[1:]
Logger.print_warn(warn_msg)

View File

@@ -9,60 +9,178 @@
# This file may be distributed under the terms of the GNU GPLv3 license # # This file may be distributed under the terms of the GNU GPLv3 license #
# ======================================================================= # # ======================================================================= #
from typing import List import os
import re
import grp
import subprocess
import textwrap
from kiauh.instance_manager.base_instance import BaseInstance from typing import List, Union
from kiauh.menus.base_menu import print_back_footer
from kiauh.utils.constants import COLOR_GREEN, COLOR_CYAN, COLOR_YELLOW, RESET_FORMAT
from kiauh.instance_manager.instance_manager import InstanceManager
def print_instance_overview( from kiauh.modules.klipper.klipper import Klipper
instances: List[BaseInstance], show_index=False, show_select_all=False from kiauh.modules.klipper.klipper_dialogs import (
): print_missing_usergroup_dialog,
headline = f"{COLOR_GREEN}The following Klipper instances were found:{RESET_FORMAT}" print_select_custom_name_dialog,
print("/=======================================================\\")
print(f"|{'{:^64}'.format(headline)}|")
print("|-------------------------------------------------------|")
if show_select_all:
select_all = f" {COLOR_YELLOW}a) Select all{RESET_FORMAT}"
print(f"|{'{:64}'.format(select_all)}|")
print("| |")
for i, s in enumerate(instances):
index = f"{i})" if show_index else ""
instance = s.get_service_file_name()
line = f"{'{:53}'.format(f'{index} {instance}')}"
print(f"| {COLOR_CYAN}{line}{RESET_FORMAT}|")
print_back_footer()
def print_missing_usergroup_dialog(missing_groups) -> None:
print("/=======================================================\\")
print(
f"| {COLOR_YELLOW}WARNING: Your current user is not in group:{RESET_FORMAT} |"
) )
if "tty" in missing_groups: from kiauh.utils.constants import CURRENT_USER
print( from kiauh.utils.input_utils import get_confirm, get_string_input
f"| {COLOR_CYAN}● tty{RESET_FORMAT} |" from kiauh.utils.logger import Logger
from kiauh.utils.system_utils import mask_system_service
def assign_custom_names(
instance_count: int, install_count: int, instance_list: List[Klipper] = None
) -> List[str]:
instance_names = []
exclude = Klipper.blacklist()
# if an instance_list is provided, exclude all existing instance names
if instance_list is not None:
for instance in instance_list:
exclude.append(instance.name)
for i in range(instance_count + install_count):
question = f"Enter name for instance {i + 1}"
name = get_string_input(question, exclude=exclude)
instance_names.append(name)
exclude.append(name)
return instance_names
def handle_convert_single_to_multi_instance_names(
install_count: int,
) -> Union[List[str], None]:
print_select_custom_name_dialog()
choice = get_confirm("Assign custom names?", False, allow_go_back=True)
if choice is True:
# instance_count = 0 and install_count + 1 as we want to assign a new name to the existing single install
return assign_custom_names(0, install_count + 1)
elif choice is False:
# "install_count + 2" as we need to account for the existing single install
_range = range(1, install_count + 2)
return [str(i) for i in _range]
return None
def handle_new_multi_instance_names(
instance_count: int, install_count: int
) -> Union[List[str], None]:
print_select_custom_name_dialog()
choice = get_confirm("Assign custom names?", False, allow_go_back=True)
if choice is True:
return assign_custom_names(instance_count, install_count)
elif choice is False:
_range = range(1, install_count + 1)
return [str(i) for i in _range]
return None
def handle_existing_multi_instance_names(
instance_count: int, install_count: int, instance_list: List[Klipper]
) -> List[str]:
if has_custom_names(instance_list):
return assign_custom_names(instance_count, install_count, instance_list)
else:
start = get_highest_index(instance_list) + 1
_range = range(start, start + install_count)
return [str(i) for i in _range]
def handle_single_to_multi_conversion(
instance_manager: InstanceManager, name: str
) -> None:
instance_list = instance_manager.get_instances()
instance_manager.set_current_instance(instance_list[0])
old_data_dir_name = instance_manager.get_instances()[0].data_dir
instance_manager.stop_instance()
instance_manager.disable_instance()
instance_manager.delete_instance(del_remnants=False)
instance_manager.set_current_instance(Klipper(name=name))
new_data_dir_name = instance_manager.get_current_instance().data_dir
try:
os.rename(old_data_dir_name, new_data_dir_name)
except OSError as e:
log = f"Cannot rename {old_data_dir_name} to {new_data_dir_name}:\n{e}"
Logger.print_error(log)
def check_user_groups():
current_groups = [grp.getgrgid(gid).gr_name for gid in os.getgroups()]
missing_groups = []
if "tty" not in current_groups:
missing_groups.append("tty")
if "dialout" not in current_groups:
missing_groups.append("dialout")
if not missing_groups:
return
print_missing_usergroup_dialog(missing_groups)
if not get_confirm(f"Add user '{CURRENT_USER}' to group(s) now?"):
log = "Skipped adding user to required groups. You might encounter issues."
Logger.warn(log)
return
try:
for group in missing_groups:
Logger.print_info(f"Adding user '{CURRENT_USER}' to group {group} ...")
command = ["sudo", "usermod", "-a", "-G", group, CURRENT_USER]
subprocess.run(command, check=True)
Logger.print_ok(f"Group {group} assigned to user '{CURRENT_USER}'.")
except subprocess.CalledProcessError as e:
Logger.print_error(f"Unable to add user to usergroups: {e}")
raise
log = "Remember to relog/restart this machine for the group(s) to be applied!"
Logger.print_warn(log)
def handle_disruptive_system_packages() -> None:
services = []
brltty_status = subprocess.run(
["systemctl", "is-enabled", "brltty"], capture_output=True, text=True
) )
if "dialout" in missing_groups: modem_manager_status = subprocess.run(
print( ["systemctl", "is-enabled", "ModemManager"], capture_output=True, text=True
f"| {COLOR_CYAN}● dialout{RESET_FORMAT} |"
) )
print("| |")
print("| It is possible that you won't be able to successfully |") if "enabled" in brltty_status.stdout:
print("| connect and/or flash the controller board without |") services.append("brltty")
print("| your user being a member of that group. |") if "enabled" in modem_manager_status.stdout:
print("| If you want to add the current user to the group(s) |") services.append("ModemManager")
print("| listed above, answer with 'Y'. Else skip with 'n'. |")
print("| |") for service in services if services else []:
print( try:
f"| {COLOR_YELLOW}INFO:{RESET_FORMAT} |" Logger.print_info(
f"{service} service detected! Masking {service} service ..."
) )
print( mask_system_service(service)
f"| {COLOR_YELLOW}Relog required for group assignments to take effect!{RESET_FORMAT} |" Logger.print_ok(f"{service} service masked!")
) except subprocess.CalledProcessError:
print("\\=======================================================/") warn_msg = textwrap.dedent(
f"""
KIAUH was unable to mask the {service} system service.
Please fix the problem manually. Otherwise, this may have
undesirable effects on the operation of Klipper.
"""
)[1:]
Logger.print_warn(warn_msg)
def has_custom_names(instance_list: List[Klipper]) -> bool:
pattern = re.compile("^\d+$")
for instance in instance_list:
if not pattern.match(instance.name):
return True
return False
def get_highest_index(instance_list: List[Klipper]) -> int:
indices = [int(instance.name.split("-")[-1]) for instance in instance_list]
return max(indices)

View File

@@ -9,15 +9,18 @@
# This file may be distributed under the terms of the GNU GPLv3 license # # This file may be distributed under the terms of the GNU GPLv3 license #
# ======================================================================= # # ======================================================================= #
from typing import Optional, List from typing import Optional, List, Union
from kiauh.utils.constants import COLOR_CYAN, RESET_FORMAT from kiauh.utils.constants import COLOR_CYAN, RESET_FORMAT
from kiauh.utils.logger import Logger from kiauh.utils.logger import Logger
def get_confirm(question: str, default_choice=True) -> bool: def get_confirm(
question: str, default_choice=True, allow_go_back=False
) -> Union[bool, None]:
options_confirm = ["y", "yes"] options_confirm = ["y", "yes"]
options_decline = ["n", "no"] options_decline = ["n", "no"]
options_go_back = ["b", "B"]
if default_choice: if default_choice:
def_choice = "(Y/n)" def_choice = "(Y/n)"
@@ -28,7 +31,7 @@ def get_confirm(question: str, default_choice=True) -> bool:
while True: while True:
choice = ( choice = (
input(f"{COLOR_CYAN}###### {question} {def_choice} {RESET_FORMAT}") input(f"{COLOR_CYAN}###### {question} {def_choice}: {RESET_FORMAT}")
.strip() .strip()
.lower() .lower()
) )
@@ -37,28 +40,34 @@ def get_confirm(question: str, default_choice=True) -> bool:
return True return True
elif choice in options_decline: elif choice in options_decline:
return False return False
elif allow_go_back and choice in options_go_back:
return None
else: else:
Logger.print_error("Invalid choice. Please select 'y' or 'n'.") Logger.print_error("Invalid choice. Please select 'y' or 'n'.")
def get_number_input( def get_number_input(
question: str, min_count: int, max_count=None, default=None question: str, min_count: int, max_count=None, default=None, allow_go_back=False
) -> int: ) -> Union[int, None]:
options_go_back = ["b", "B"]
_question = question + f" (default={default})" if default else question _question = question + f" (default={default})" if default else question
_question = f"{COLOR_CYAN}###### {_question}: {RESET_FORMAT}" _question = f"{COLOR_CYAN}###### {_question}: {RESET_FORMAT}"
while True: while True:
try: try:
num = input(_question) _input = input(_question)
if num == "": if allow_go_back and _input in options_go_back:
return None
if _input == "":
return default return default
if max_count is not None: if max_count is not None:
if min_count <= int(num) <= max_count: if min_count <= int(_input) <= max_count:
return int(num) return int(_input)
else: else:
raise ValueError raise ValueError
elif int(num) >= min_count: elif int(_input) >= min_count:
return int(num) return int(_input)
else: else:
raise ValueError raise ValueError
except ValueError: except ValueError: