From 7b5522ac94f9fc6999d39ea1ce657076b2328ac2 Mon Sep 17 00:00:00 2001 From: dw-0 Date: Fri, 3 Jul 2026 20:57:28 +0200 Subject: [PATCH] feat(tests): add unit tests for shared utilities --- kiauh/utils/common.py | 4 +- kiauh/utils/git_utils.py | 24 +- kiauh/utils/tests/conftest.py | 21 + kiauh/utils/tests/test_common.py | 206 +++++++ kiauh/utils/tests/test_config_utils.py | 103 ++++ kiauh/utils/tests/test_fs_utils.py | 245 ++++++++ kiauh/utils/tests/test_git_utils.py | 535 ++++++++++++++++++ kiauh/utils/tests/test_input_utils.py | 158 ++++++ kiauh/utils/tests/test_instance_type.py | 17 + kiauh/utils/tests/test_instance_utils.py | 98 ++++ kiauh/utils/tests/test_sys_utils.py | 686 +++++++++++++++++++++++ pyproject.toml | 2 +- 12 files changed, 2095 insertions(+), 4 deletions(-) create mode 100644 kiauh/utils/tests/conftest.py create mode 100644 kiauh/utils/tests/test_common.py create mode 100644 kiauh/utils/tests/test_config_utils.py create mode 100644 kiauh/utils/tests/test_fs_utils.py create mode 100644 kiauh/utils/tests/test_git_utils.py create mode 100644 kiauh/utils/tests/test_input_utils.py create mode 100644 kiauh/utils/tests/test_instance_type.py create mode 100644 kiauh/utils/tests/test_instance_utils.py create mode 100644 kiauh/utils/tests/test_sys_utils.py diff --git a/kiauh/utils/common.py b/kiauh/utils/common.py index 13b3ffa..da6ebb7 100644 --- a/kiauh/utils/common.py +++ b/kiauh/utils/common.py @@ -36,13 +36,15 @@ from utils.sys_utils import ( update_system_package_lists, ) +from kiauh import PROJECT_ROOT + def get_kiauh_version() -> str: """ Helper method to get the current KIAUH version by reading the latest tag :return: string of the latest tag or a default value if no tags exist """ - tags: List[str] = get_local_tags(Path(__file__).parent.parent) + tags: List[str] = get_local_tags(PROJECT_ROOT) if tags: return tags[-1] else: diff --git a/kiauh/utils/git_utils.py b/kiauh/utils/git_utils.py index ead448a..0ba6e56 100644 --- a/kiauh/utils/git_utils.py +++ b/kiauh/utils/git_utils.py @@ -67,7 +67,7 @@ def git_pull_wrapper(target_dir: Path) -> None: Logger.print_status("Updating repository ...") try: git_cmd_pull(target_dir) - except CalledProcessError: + except (CalledProcessError, GitException): log = "An unexpected error occured during updating the repository." Logger.print_error(log) return @@ -102,6 +102,9 @@ def get_current_branch(repo: Path) -> str | None: :param repo: Path to the local Git repository :return: Current branch or None if not determinable """ + if not repo.exists() or not repo.joinpath(".git").exists(): + return None + try: cmd = ["git", "branch", "--show-current"] result: str = check_output(cmd, stderr=DEVNULL, cwd=repo).decode( @@ -120,6 +123,8 @@ def get_local_tags(repo_path: Path, _filter: str | None = None) -> List[str]: :param _filter: Optional filter to filter the tags by :return: List of tags """ + if not repo_path.exists() or not repo_path.joinpath(".git").is_dir(): + return [] def parse_version(version: str) -> tuple: # Remove 'v' prefix if present @@ -337,6 +342,11 @@ def git_cmd_checkout(branch: str | None, target_dir: Path) -> None: if branch is None: return + if not target_dir.exists() or not target_dir.joinpath(".git").exists(): + log = f"'{target_dir}' is not a valid git repository." + Logger.print_error(log) + raise GitException(log) + try: command = ["git", "checkout", f"{branch}"] run(command, cwd=target_dir, check=True) @@ -349,6 +359,11 @@ def git_cmd_checkout(branch: str | None, target_dir: Path) -> None: def git_cmd_pull(target_dir: Path) -> None: + if not target_dir.exists() or not target_dir.joinpath(".git").exists(): + log = f"'{target_dir}' is not a valid git repository." + Logger.print_error(log) + raise GitException(log) + try: command = ["git", "pull"] run(command, cwd=target_dir, check=True) @@ -359,6 +374,11 @@ def git_cmd_pull(target_dir: Path) -> None: def rollback_repository(repo_dir: Path, instance: Type[InstanceType]) -> None: + if not repo_dir.exists() or not repo_dir.joinpath(".git").exists(): + log = f"'{repo_dir}' is not a valid git repository." + Logger.print_error(log) + raise GitException(log) + q1 = "How many commits do you want to roll back" amount = get_number_input(q1, 1, allow_go_back=True) @@ -394,7 +414,7 @@ def get_repo_url(repo_dir: Path) -> str | None: :param repo_dir: Path to the git repository :return: URL of the remote repository or None if not found """ - if not repo_dir.exists(): + if not repo_dir.exists() or not repo_dir.joinpath(".git").exists(): return None try: diff --git a/kiauh/utils/tests/conftest.py b/kiauh/utils/tests/conftest.py new file mode 100644 index 0000000..ed49280 --- /dev/null +++ b/kiauh/utils/tests/conftest.py @@ -0,0 +1,21 @@ +import sys +from pathlib import Path + +import pytest + +PROJECT_ROOT = Path(__file__).resolve().parents[3] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +@pytest.fixture(autouse=True) +def silence_logger(monkeypatch: pytest.MonkeyPatch) -> None: + for name in ( + "print_info", + "print_ok", + "print_warn", + "print_error", + "print_status", + "print_dialog", + ): + monkeypatch.setattr(f"core.logger.Logger.{name}", lambda *a, **k: None) diff --git a/kiauh/utils/tests/test_common.py b/kiauh/utils/tests/test_common.py new file mode 100644 index 0000000..ae7ada0 --- /dev/null +++ b/kiauh/utils/tests/test_common.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import List, Set + +import pytest +from core.constants import GLOBAL_DEPS +from utils.common import ( + check_install_dependencies, + convert_camelcase_to_kebabcase, + get_current_date, + get_install_status, + get_kiauh_version, + moonraker_exists, + trunc_string, +) + + +class TestGetKiauhVersion: + def test_uses_project_root(self, monkeypatch) -> None: + expected_root = Path(__file__).parent.parent.parent.parent + captured: List[Path] = [] + + def fake_get_local_tags(path: Path, _filter: str | None = None) -> List[str]: + captured.append(path) + return ["v6.3.0", "v6.3.1"] + + monkeypatch.setattr("utils.common.get_local_tags", fake_get_local_tags) + result = get_kiauh_version() + + assert captured == [expected_root] + assert result == "v6.3.1" + + def test_fallback_when_no_tags(self, monkeypatch) -> None: + monkeypatch.setattr("utils.common.get_local_tags", lambda *_a, **_k: []) + assert get_kiauh_version() == "v?.?.?" + + +class TestConvertCamelcaseToKebabcase: + @pytest.mark.parametrize( + "name,expected", + [ + ("Klipper", "klipper"), + ("Moonraker", "moonraker"), + ("MoonrakerObico", "moonraker-obico"), + ("HTTPResponse", "h-t-t-p-response"), + ("already", "already"), + ], + ) + def test_converts(self, name: str, expected: str) -> None: + assert convert_camelcase_to_kebabcase(name) == expected + + +class TestGetCurrentDate: + def test_returns_formatted_values(self) -> None: + result = get_current_date() + now = datetime.today() + + assert set(result.keys()) == {"date", "time"} + assert result["date"] == now.strftime("%Y%m%d") + assert result["time"] == now.strftime("%H%M%S") + + +class TestCheckInstallDependencies: + def test_with_global_and_custom(self, monkeypatch) -> None: + checked: Set[str] = set() + updated: List[bool] = [] + installed_pkgs: List[List[str]] = [] + + def fake_check_package_install(deps: Set[str]) -> List[str]: + checked.update(deps) + return ["extra-pkg"] + + monkeypatch.setattr( + "utils.common.check_package_install", fake_check_package_install + ) + monkeypatch.setattr( + "utils.common.update_system_package_lists", + lambda silent: updated.append(silent), + ) + monkeypatch.setattr( + "utils.common.install_system_packages", + lambda pkgs: installed_pkgs.append(pkgs), + ) + + check_install_dependencies({"custom-pkg"}, include_global=True) + + assert "custom-pkg" in checked + assert all(dep in checked for dep in GLOBAL_DEPS) + assert updated == [False] + assert installed_pkgs == [["extra-pkg"]] + + def test_no_requirements(self, monkeypatch) -> None: + monkeypatch.setattr("utils.common.check_package_install", lambda *_a, **_k: []) + monkeypatch.setattr( + "utils.common.update_system_package_lists", + lambda *a, **k: pytest.fail("should not update when nothing to install"), + ) + monkeypatch.setattr( + "utils.common.install_system_packages", + lambda *a, **k: pytest.fail("should not install when nothing to install"), + ) + + check_install_dependencies({"pkg"}) + + +class _FakeInstanceType: + def __init__(self, suffix: str): + self.suffix = suffix + + def __eq__(self, other): + return isinstance(other, _FakeInstanceType) and self.suffix == other.suffix + + +class TestGetInstallStatus: + def test_not_installed(self, tmp_path: Path, monkeypatch) -> None: + repo = tmp_path / "repo" + env = tmp_path / "env" + + monkeypatch.setattr("utils.common.get_current_branch", lambda *_a, **_k: None) + monkeypatch.setattr("utils.common.get_repo_name", lambda *_a, **_k: (None, None)) + monkeypatch.setattr("utils.common.get_repo_url", lambda *_a, **_k: None) + monkeypatch.setattr("utils.common.get_local_commit", lambda *_a, **_k: None) + monkeypatch.setattr("utils.common.get_remote_commit", lambda *_a, **_k: None) + monkeypatch.setattr("utils.instance_utils.get_instances", lambda *_a, **_k: []) + + status = get_install_status(repo, env, _FakeInstanceType) + + assert status.status == 0 + assert status.instances == 0 + + def test_fully_installed(self, tmp_path: Path, monkeypatch) -> None: + repo = tmp_path / "repo" + env = tmp_path / "env" + repo.mkdir() + env.mkdir() + (repo / ".git").mkdir() + extra_file = tmp_path / "extra" + extra_file.write_text("x") + + monkeypatch.setattr( + "utils.instance_utils.get_instances", lambda *_a, **_k: [_FakeInstanceType("")] + ) + monkeypatch.setattr("utils.common.get_current_branch", lambda *_a, **_k: "main") + monkeypatch.setattr( + "utils.common.get_repo_name", lambda *_a, **_k: ("dw-0", "kiauh") + ) + monkeypatch.setattr( + "utils.common.get_repo_url", lambda *_a, **_k: "https://github.com/dw-0/kiauh" + ) + monkeypatch.setattr("utils.common.get_local_commit", lambda *_a, **_k: "abc") + monkeypatch.setattr("utils.common.get_remote_commit", lambda *_a, **_k: "def") + + status = get_install_status(repo, env, _FakeInstanceType, files=[extra_file]) + + assert status.status == 2 + assert status.instances == 1 + assert status.owner == "dw-0" + assert status.repo == "kiauh" + assert status.branch == "main" + assert status.local == "abc" + assert status.remote == "def" + + def test_incomplete(self, tmp_path: Path, monkeypatch) -> None: + repo = tmp_path / "repo" + env = tmp_path / "env" + repo.mkdir() + + monkeypatch.setattr( + "utils.instance_utils.get_instances", lambda *_a, **_k: [_FakeInstanceType("")] + ) + monkeypatch.setattr("utils.common.get_current_branch", lambda *_a, **_k: "main") + monkeypatch.setattr("utils.common.get_repo_name", lambda *_a, **_k: (None, None)) + monkeypatch.setattr("utils.common.get_repo_url", lambda *_a, **_k: None) + monkeypatch.setattr("utils.common.get_local_commit", lambda *_a, **_k: None) + monkeypatch.setattr("utils.common.get_remote_commit", lambda *_a, **_k: None) + + status = get_install_status(repo, env, _FakeInstanceType) + + assert status.status == 1 + + +class TestMoonrakerExists: + def test_returns_instances(self, monkeypatch) -> None: + fake = object() + monkeypatch.setattr("utils.common.get_instances", lambda *_a, **_k: [fake]) + assert moonraker_exists() == [fake] + + def test_warns_when_none(self, monkeypatch) -> None: + monkeypatch.setattr("utils.common.get_instances", lambda *_a, **_k: []) + assert moonraker_exists("SomeInstaller") == [] + + +class TestTruncString: + @pytest.mark.parametrize( + "value,length,expected", + [ + ("short", 10, "short"), + ("exactly seven", 20, "exactly seven"), + ("much longer string", 10, "much lo..."), + ("abcdef", 5, "ab..."), + ], + ) + def test_truncates(self, value: str, length: int, expected: str) -> None: + assert trunc_string(value, length) == expected diff --git a/kiauh/utils/tests/test_config_utils.py b/kiauh/utils/tests/test_config_utils.py new file mode 100644 index 0000000..4eb6bc1 --- /dev/null +++ b/kiauh/utils/tests/test_config_utils.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from pathlib import Path + +from utils.config_utils import ( + add_config_section, + add_config_section_at_top, + remove_config_section, +) + + +class _FakeInstance: + def __init__(self, cfg_file: Path): + self.cfg_file = cfg_file + + +def _write_cfg(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + + +class TestAddConfigSection: + def test_creates_section_and_options(self, tmp_path: Path) -> None: + cfg = tmp_path / "printer.cfg" + _write_cfg(cfg, "[existing]\noption: value\n") + instance = _FakeInstance(cfg) + + add_config_section( + "new_section", + [instance], + options=[("opt1", "val1"), ("opt2", ["line1", "line2"])], + ) + + text = cfg.read_text(encoding="utf-8") + assert "[new_section]" in text + assert "opt1: val1" in text + assert " line1" in text + assert " line2" in text + + def test_skips_existing_section(self, tmp_path: Path) -> None: + cfg = tmp_path / "printer.cfg" + _write_cfg(cfg, "[section]\noption: value\n") + instance = _FakeInstance(cfg) + + add_config_section("section", [instance]) + + text = cfg.read_text(encoding="utf-8") + assert text.count("[section]") == 1 + + def test_warns_when_file_missing(self, tmp_path: Path) -> None: + cfg = tmp_path / "missing.cfg" + instance = _FakeInstance(cfg) + + add_config_section("section", [instance]) + + assert not cfg.exists() + + +class TestAddConfigSectionAtTop: + def test_prepends_section(self, tmp_path: Path) -> None: + cfg = tmp_path / "printer.cfg" + original = "[old]\noption: value\n" + _write_cfg(cfg, original) + instance = _FakeInstance(cfg) + + add_config_section_at_top("top_section", [instance]) + + text = cfg.read_text(encoding="utf-8") + lines = text.splitlines() + assert lines[0] == "[top_section]" + assert "[old]" in text + assert text.endswith("\n") + + +class TestRemoveConfigSection: + def test_removes_existing(self, tmp_path: Path) -> None: + cfg = tmp_path / "printer.cfg" + _write_cfg(cfg, "[keep]\noption: 1\n[drop]\noption: 2\n") + instance = _FakeInstance(cfg) + + removed = remove_config_section("drop", [instance]) + + assert removed == [instance] + text = cfg.read_text(encoding="utf-8") + assert "[drop]" not in text + assert "[keep]" in text + + def test_skips_missing_section(self, tmp_path: Path) -> None: + cfg = tmp_path / "printer.cfg" + _write_cfg(cfg, "[keep]\noption: 1\n") + instance = _FakeInstance(cfg) + + removed = remove_config_section("missing", [instance]) + + assert removed == [] + assert cfg.read_text(encoding="utf-8") == "[keep]\noption: 1\n" + + def test_warns_when_file_missing(self, tmp_path: Path) -> None: + cfg = tmp_path / "missing.cfg" + instance = _FakeInstance(cfg) + + removed = remove_config_section("section", [instance]) + + assert removed == [] diff --git a/kiauh/utils/tests/test_fs_utils.py b/kiauh/utils/tests/test_fs_utils.py new file mode 100644 index 0000000..d60919d --- /dev/null +++ b/kiauh/utils/tests/test_fs_utils.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from pathlib import Path +from subprocess import CalledProcessError +from typing import Any, List +from zipfile import ZipFile + +import pytest +from utils.fs_utils import ( + check_file_exist, + create_folders, + create_symlink, + get_data_dir, + remove_file, + remove_with_sudo, + run_remove_routines, + unzip, +) + + +class TestCheckFileExist: + def test_returns_true_for_existing_file(self, tmp_path: Path) -> None: + file = tmp_path / "file.txt" + file.write_text("x") + assert check_file_exist(file) is True + + def test_returns_false_for_missing_file(self, tmp_path: Path) -> None: + assert check_file_exist(tmp_path / "missing") is False + + def test_returns_false_for_broken_symlink(self, tmp_path: Path) -> None: + link = tmp_path / "link" + link.symlink_to(tmp_path / "target") + assert check_file_exist(link) is False + + def test_with_sudo_uses_subprocess(self, monkeypatch) -> None: + calls: List[List[str]] = [] + + def fake_check_output(cmd: List[str], **kwargs: Any) -> bytes: + calls.append(cmd) + return b"" + + monkeypatch.setattr("utils.fs_utils.check_output", fake_check_output) + path = Path("/some/path") + assert check_file_exist(path, sudo=True) is True + assert calls[0] == ["sudo", "find", "-L", "/some/path", "-maxdepth", "0"] + + def test_with_sudo_returns_false_on_error(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.fs_utils.check_output", + lambda *a, **k: (_ for _ in ()).throw(CalledProcessError(1, "find")), + ) + assert check_file_exist(Path("/some/path"), sudo=True) is False + + +class TestCreateSymlink: + def test_calls_ln_with_correct_args(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.fs_utils.run", fake_run) + create_symlink(Path("/src"), Path("/dst")) + assert runs == [["ln", "-sf", "/src", "/dst"]] + + def test_uses_sudo(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.fs_utils.run", fake_run) + create_symlink(Path("/src"), Path("/dst"), sudo=True) + assert runs == [["sudo", "ln", "-sf", "/src", "/dst"]] + + def test_raises_on_failure(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.fs_utils.run", + lambda *a, **k: (_ for _ in ()).throw(CalledProcessError(1, "ln")), + ) + with pytest.raises(CalledProcessError): + create_symlink(Path("/src"), Path("/dst")) + + +class TestRemoveWithSudo: + def test_removes_existing_files(self, monkeypatch) -> None: + calls: List[tuple] = [] + + def fake_call(cmd: List[str], **kwargs: Any) -> int: + calls.append(("call", cmd)) + return 0 + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + calls.append(("run", cmd)) + return None + + monkeypatch.setattr("utils.fs_utils.call", fake_call) + monkeypatch.setattr("utils.fs_utils.run", fake_run) + + result = remove_with_sudo(Path("/some/file")) + + assert result is True + assert ("call", ["sudo", "find", "/some/file"]) in calls + assert ("run", ["sudo", "rm", "-rf", "/some/file"]) in calls + + def test_skips_missing_files(self, monkeypatch) -> None: + def fake_call(cmd: List[str], **kwargs: Any) -> int: + return 1 + + monkeypatch.setattr("utils.fs_utils.call", fake_call) + monkeypatch.setattr( + "utils.fs_utils.run", + lambda *a, **k: pytest.fail("should not run rm for missing file"), + ) + + assert remove_with_sudo(Path("/some/file")) is False + + def test_accepts_list(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_call(cmd: List[str], **kwargs: Any) -> int: + return 0 + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.fs_utils.call", fake_call) + monkeypatch.setattr("utils.fs_utils.run", fake_run) + + remove_with_sudo([Path("/a"), Path("/b")]) + + assert runs == [ + ["sudo", "rm", "-rf", "/a"], + ["sudo", "rm", "-rf", "/b"], + ] + + +class TestRemoveFile: + def test_calls_shell_rm(self, monkeypatch) -> None: + runs: List[Any] = [] + + def fake_run(cmd: str, **kwargs: Any) -> Any: + runs.append((cmd, kwargs.get("shell"))) + return None + + monkeypatch.setattr("utils.fs_utils.run", fake_run) + + with pytest.warns(DeprecationWarning): + remove_file(Path("/some/file"), sudo=True) + + assert runs == [("sudo rm -f /some/file", True)] + + +class TestRunRemoveRoutines: + def test_returns_false_for_missing(self, tmp_path: Path) -> None: + assert run_remove_routines(tmp_path / "missing") is False + + def test_removes_file(self, tmp_path: Path) -> None: + file = tmp_path / "file.txt" + file.write_text("x") + assert run_remove_routines(file) is True + assert not file.exists() + + def test_removes_directory(self, tmp_path: Path) -> None: + directory = tmp_path / "dir" + directory.mkdir() + (directory / "child").write_text("x") + assert run_remove_routines(directory) is True + assert not directory.exists() + + def test_removes_symlink(self, tmp_path: Path) -> None: + target = tmp_path / "target" + target.write_text("x") + link = tmp_path / "link" + link.symlink_to(target) + assert run_remove_routines(link) is True + assert not link.exists() + assert target.exists() + + +class TestUnzip: + def test_extracts_contents(self, tmp_path: Path) -> None: + archive = tmp_path / "archive.zip" + target = tmp_path / "out" + target.mkdir() + + with ZipFile(archive, "w") as zf: + zf.writestr("hello.txt", "world") + + unzip(archive, target) + + assert (target / "hello.txt").read_text() == "world" + + +class TestCreateFolders: + def test_creates_missing_directories(self, tmp_path: Path) -> None: + dirs = [tmp_path / "a", tmp_path / "b"] + create_folders(dirs) + assert all(d.exists() for d in dirs) + + def test_skips_existing(self, tmp_path: Path) -> None: + existing = tmp_path / "exists" + existing.mkdir() + create_folders([existing]) + assert existing.exists() + + +class TestGetDataDir: + def test_reads_from_service_file(self, tmp_path: Path, monkeypatch) -> None: + service = tmp_path / "klipper.service" + service.write_text( + "EnvironmentFile=/home/user/printer_data/systemd/klipper.env\n" + ) + + def fake_service_path(instance_type: type, suffix: str) -> Path: + return service + + monkeypatch.setattr("utils.sys_utils.get_service_file_path", fake_service_path) + monkeypatch.setattr("utils.fs_utils.Path.home", lambda: tmp_path / "home") + + result = get_data_dir(object, "") + assert result == Path("/home/user/printer_data") + + def test_falls_back_to_suffixed_data_dir(self, tmp_path: Path, monkeypatch) -> None: + def fake_service_path(instance_type: type, suffix: str) -> Path: + return tmp_path / "no-such.service" + + monkeypatch.setattr("utils.sys_utils.get_service_file_path", fake_service_path) + home = tmp_path / "home" + monkeypatch.setattr("utils.fs_utils.Path.home", lambda: home) + + assert get_data_dir(object, "1") == home / "printer_1_data" + + def test_falls_back_to_default_data_dir(self, tmp_path: Path, monkeypatch) -> None: + def fake_service_path(instance_type: type, suffix: str) -> Path: + return tmp_path / "no-such.service" + + monkeypatch.setattr("utils.sys_utils.get_service_file_path", fake_service_path) + home = tmp_path / "home" + monkeypatch.setattr("utils.fs_utils.Path.home", lambda: home) + + assert get_data_dir(object, "") == home / "printer_data" diff --git a/kiauh/utils/tests/test_git_utils.py b/kiauh/utils/tests/test_git_utils.py new file mode 100644 index 0000000..f3bf23d --- /dev/null +++ b/kiauh/utils/tests/test_git_utils.py @@ -0,0 +1,535 @@ +from __future__ import annotations + +from pathlib import Path +from subprocess import CalledProcessError +from typing import Any, List + +import pytest +from utils.git_utils import ( + GitException, + compare_semver_tags, + get_current_branch, + get_latest_remote_tag, + get_latest_unstable_tag, + get_local_commit, + get_local_tags, + get_remote_commit, + get_remote_tags, + get_repo_name, + get_repo_url, + git_clone_wrapper, + git_cmd_checkout, + git_cmd_clone, + git_cmd_pull, + git_pull_wrapper, + rollback_repository, +) +from utils.instance_type import InstanceType + + +class TestGitCmdPull: + def test_missing_dir_raises(self, tmp_path: Path) -> None: + missing = tmp_path / "does-not-exist" + with pytest.raises(GitException): + git_cmd_pull(missing) + + def test_dir_without_git_raises(self, tmp_path: Path) -> None: + empty = tmp_path / "no-git" + empty.mkdir() + with pytest.raises(GitException): + git_cmd_pull(empty) + + def test_success_runs_git_pull(self, monkeypatch) -> None: + repo = Path("/fake/repo") + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.git_utils.Path.exists", lambda self: True) + monkeypatch.setattr( + "utils.git_utils.Path.joinpath", lambda self, name: repo / name + ) + monkeypatch.setattr("utils.git_utils.run", fake_run) + + git_cmd_pull(repo) + assert runs == [["git", "pull"]] + + +class TestGitCmdCheckout: + def test_missing_dir_raises(self, tmp_path: Path) -> None: + missing = tmp_path / "does-not-exist" + with pytest.raises(GitException): + git_cmd_checkout("main", missing) + + def test_dir_without_git_raises(self, tmp_path: Path) -> None: + empty = tmp_path / "no-git" + empty.mkdir() + with pytest.raises(GitException): + git_cmd_checkout("main", empty) + + def test_none_branch_returns(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.git_utils.run", + lambda *a, **k: pytest.fail("should not run checkout for None branch"), + ) + git_cmd_checkout(None, Path("/repo")) + + +class TestGitPullWrapper: + def test_missing_dir_does_not_raise(self, tmp_path: Path) -> None: + missing = tmp_path / "does-not-exist" + git_pull_wrapper(missing) + + def test_dir_without_git_does_not_raise(self, tmp_path: Path) -> None: + empty = tmp_path / "no-git" + empty.mkdir() + git_pull_wrapper(empty) + + def test_success_calls_git_pull(self, monkeypatch) -> None: + repo = Path("/fake/repo") + called: List[Path] = [] + + def fake_git_cmd_pull(path: Path) -> None: + called.append(path) + + monkeypatch.setattr("utils.git_utils.git_cmd_pull", fake_git_cmd_pull) + git_pull_wrapper(repo) + assert called == [repo] + + +class TestRollbackRepository: + def test_missing_dir_raises( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + missing = tmp_path / "does-not-exist" + called: list[bool] = [] + monkeypatch.setattr( + "utils.git_utils.get_number_input", + lambda *_a, **_k: called.append(True) or 1, + ) + with pytest.raises(GitException): + rollback_repository(missing, InstanceType) + assert not called + + def test_dir_without_git_raises( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + empty = tmp_path / "no-git" + empty.mkdir() + called: list[bool] = [] + monkeypatch.setattr( + "utils.git_utils.get_number_input", + lambda *_a, **_k: called.append(True) or 1, + ) + with pytest.raises(GitException): + rollback_repository(empty, InstanceType) + assert not called + + def test_aborts_when_not_confirmed(self, tmp_path: Path, monkeypatch) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr("utils.git_utils.get_number_input", lambda *a, **k: 2) + monkeypatch.setattr("utils.git_utils.get_confirm", lambda *a, **k: False) + monkeypatch.setattr( + "utils.git_utils.get_instances", lambda *a, **k: ["instance"] + ) + monkeypatch.setattr( + "utils.git_utils.InstanceManager.stop_all", + lambda *a, **k: pytest.fail("should not stop when aborted"), + ) + monkeypatch.setattr( + "utils.git_utils.run", + lambda *a, **k: pytest.fail("should not reset when aborted"), + ) + monkeypatch.setattr( + "utils.git_utils.InstanceManager.start_all", + lambda *a, **k: pytest.fail("should not start when aborted"), + ) + + rollback_repository(repo, InstanceType) + + def test_resets_and_restarts_when_confirmed(self, tmp_path: Path, monkeypatch) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + stops: List[List[Any]] = [] + starts: List[List[Any]] = [] + resets: List[List[str]] = [] + + monkeypatch.setattr("utils.git_utils.get_number_input", lambda *a, **k: 3) + monkeypatch.setattr("utils.git_utils.get_confirm", lambda *a, **k: True) + monkeypatch.setattr( + "utils.git_utils.get_instances", lambda *a, **k: ["instance"] + ) + monkeypatch.setattr( + "utils.git_utils.InstanceManager.stop_all", + lambda instances: stops.append(instances), + ) + monkeypatch.setattr( + "utils.git_utils.InstanceManager.start_all", + lambda instances: starts.append(instances), + ) + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + resets.append(cmd) + return None + + monkeypatch.setattr("utils.git_utils.run", fake_run) + + rollback_repository(repo, InstanceType) + + assert stops == [["instance"]] + assert resets == [["git", "reset", "--hard", "HEAD~3"]] + assert starts == [["instance"]] + + +class TestGetRepoName: + def test_extracts_org_and_repo(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr( + "utils.git_utils.check_output", + lambda *a, **k: b"https://github.com/dw-0/kiauh.git\n", + ) + assert get_repo_name(repo) == ("dw-0", "kiauh") + + def test_returns_none_for_missing_repo(self, tmp_path: Path) -> None: + assert get_repo_name(tmp_path / "missing") == (None, None) + + def test_returns_none_on_git_error(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr( + "utils.git_utils.check_output", + lambda *a, **k: (_ for _ in ()).throw(CalledProcessError(1, "git")), + ) + assert get_repo_name(repo) == (None, None) + + +class TestGetCurrentBranch: + def test_returns_branch(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr( + "utils.git_utils.check_output", lambda *a, **k: b"feature-x\n" + ) + assert get_current_branch(repo) == "feature-x" + + def test_returns_none_for_missing_repo(self, tmp_path: Path) -> None: + assert get_current_branch(tmp_path / "missing") is None + + +class TestGetLocalTags: + def test_sorts_semver(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr( + "utils.git_utils.check_output", + lambda *a, **k: b"v1.0.0\nv1.0.1\nv1.0.10\nv1.0.2\nv2.0.0-beta.1\n", + ) + assert get_local_tags(repo) == [ + "v1.0.0", + "v1.0.1", + "v1.0.2", + "v1.0.10", + "v2.0.0-beta.1", + ] + + def test_returns_empty_for_missing_repo(self, tmp_path: Path) -> None: + assert get_local_tags(tmp_path / "missing") == [] + + +class _FakeResponse: + def __init__(self, code: int, body: bytes = b""): + self._code = code + self._body = body + + def __enter__(self): + return self + + def __exit__(self, *args): + return None + + def getcode(self) -> int: + return self._code + + def read(self) -> bytes: + return self._body + + +class TestGetRemoteTags: + def test_parses_github_api(self, monkeypatch) -> None: + body = b'[{"name":"v1.0.0"},{"name":"v1.1.0"}]' + + class FakeUrlLib: + @staticmethod + def urlopen(url: str): + return _FakeResponse(200, body) + + monkeypatch.setattr("utils.git_utils.urllib.request", FakeUrlLib()) + assert get_remote_tags("dw-0/kiauh") == ["v1.0.0", "v1.1.0"] + + def test_returns_empty_on_http_error(self, monkeypatch) -> None: + class FakeUrlLib: + @staticmethod + def urlopen(url: str): + return _FakeResponse(404) + + monkeypatch.setattr("utils.git_utils.urllib.request", FakeUrlLib()) + assert get_remote_tags("dw-0/kiauh") == [] + + +class TestGetLatestRemoteTag: + def test_returns_first_tag(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.git_utils.get_remote_tags", lambda *_a, **_k: ["v2.0.0", "v1.0.0"] + ) + assert get_latest_remote_tag("dw-0/kiauh") == "v2.0.0" + + def test_returns_empty_when_no_tags(self, monkeypatch) -> None: + monkeypatch.setattr("utils.git_utils.get_remote_tags", lambda *_a, **_k: []) + assert get_latest_remote_tag("dw-0/kiauh") == "" + + +class TestGetLatestUnstableTag: + def test_filters_prereleases(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.git_utils.get_remote_tags", + lambda *_a, **_k: ["v2.0.0", "v2.0.0-rc.1", "v1.0.0-beta.2"], + ) + assert get_latest_unstable_tag("dw-0/kiauh") == "v2.0.0-rc.1" + + def test_returns_empty_when_stable_only(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.git_utils.get_remote_tags", lambda *_a, **_k: ["v2.0.0", "v1.0.0"] + ) + assert get_latest_unstable_tag("dw-0/kiauh") == "" + + +class TestCompareSemverTags: + @pytest.mark.parametrize( + "tag1,tag2,expected", + [ + ("v1.0.0", "v1.0.1", False), + ("v1.1.0", "v1.0.1", True), + ("v1.0.0", "v1.0.0", False), + ("v2.0.0", "v1.9.9", True), + ], + ) + def test_comparison(self, tag1: str, tag2: str, expected: bool) -> None: + assert compare_semver_tags(tag1, tag2) is expected + + +class TestGetLocalCommit: + def test_describes_head(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr( + "utils.git_utils.check_output", + lambda *a, **k: "v1.0.0-0-gabc1234", + ) + assert get_local_commit(repo) == "v1.0.0-0-gabc1234" + + def test_returns_none_for_missing_repo(self, tmp_path: Path) -> None: + assert get_local_commit(tmp_path / "missing") is None + + +class TestGetRemoteCommit: + def test_describes_origin(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + def fake_check_output(cmd: str, **kwargs: Any) -> str: + if "HEAD" in cmd: + return "v1.0.0" + return "origin/main" + + monkeypatch.setattr( + "utils.git_utils.get_current_branch", lambda *_a, **_k: "main" + ) + monkeypatch.setattr("utils.git_utils.check_output", fake_check_output) + assert get_remote_commit(repo) == "origin/main" + + def test_returns_none_for_missing_repo(self, tmp_path: Path) -> None: + assert get_remote_commit(tmp_path / "missing") is None + + +class TestGitCmdClone: + def test_without_blobless(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.git_utils.run", fake_run) + git_cmd_clone("https://github.com/dw-0/kiauh", Path("/target")) + assert runs == [["git", "clone", "https://github.com/dw-0/kiauh", "/target"]] + + def test_with_blobless(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.git_utils.run", fake_run) + git_cmd_clone( + "https://github.com/dw-0/kiauh", Path("/target"), blobless=True + ) + assert runs == [ + [ + "git", + "clone", + "--filter=blob:none", + "https://github.com/dw-0/kiauh", + "/target", + ] + ] + + +class TestGitCmdCheckoutSingle: + def test_runs_git_checkout(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.git_utils.run", fake_run) + git_cmd_checkout("dev", repo) + assert runs == [["git", "checkout", "dev"]] + + +class TestGetRepoUrl: + def test_extracts_remote_url(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + class FakeResult: + stdout = "https://github.com/dw-0/kiauh.git\n" + + monkeypatch.setattr("utils.git_utils.run", lambda *a, **k: FakeResult()) + assert get_repo_url(repo) == "https://github.com/dw-0/kiauh.git" + + def test_returns_none_for_missing_repo(self, tmp_path: Path) -> None: + assert get_repo_url(tmp_path / "missing") is None + + def test_returns_none_on_git_error(self, monkeypatch, tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / ".git").mkdir() + + monkeypatch.setattr( + "utils.git_utils.run", + lambda *a, **k: (_ for _ in ()).throw(CalledProcessError(1, "git")), + ) + assert get_repo_url(repo) is None + + +class _CloneRecorder: + def __init__(self): + self.calls: List[tuple] = [] + self.checkouts: List[tuple] = [] + self.removed: List[Path] = [] + + def fake_clone(self, repo: str, target: Path, blobless: bool = False) -> None: + self.calls.append((repo, target, blobless)) + + def fake_checkout(self, branch: str | None, target: Path) -> None: + self.checkouts.append((branch, target)) + + +class TestGitCloneWrapper: + def test_clones_when_target_missing(self, monkeypatch, tmp_path: Path) -> None: + target = tmp_path / "kiauh" + recorder = _CloneRecorder() + monkeypatch.setattr("utils.git_utils.git_cmd_clone", recorder.fake_clone) + monkeypatch.setattr("utils.git_utils.git_cmd_checkout", recorder.fake_checkout) + + git_clone_wrapper("https://github.com/dw-0/kiauh", target, branch="dev") + + assert recorder.calls == [("https://github.com/dw-0/kiauh", target, True)] + assert recorder.checkouts == [("dev", target)] + + def test_skips_checkout_for_main(self, monkeypatch, tmp_path: Path) -> None: + target = tmp_path / "kiauh" + recorder = _CloneRecorder() + monkeypatch.setattr("utils.git_utils.git_cmd_clone", recorder.fake_clone) + monkeypatch.setattr("utils.git_utils.git_cmd_checkout", recorder.fake_checkout) + + git_clone_wrapper("https://github.com/dw-0/kiauh", target, branch="main") + + assert recorder.checkouts == [] + + def test_prompts_before_overwrite(self, monkeypatch, tmp_path: Path) -> None: + target = tmp_path / "kiauh" + target.mkdir() + recorder = _CloneRecorder() + removed: List[Path] = [] + + monkeypatch.setattr("utils.git_utils.git_cmd_clone", recorder.fake_clone) + monkeypatch.setattr("utils.git_utils.git_cmd_checkout", recorder.fake_checkout) + monkeypatch.setattr("utils.git_utils.shutil.rmtree", lambda p: removed.append(p)) + monkeypatch.setattr("utils.git_utils.get_confirm", lambda *a, **k: True) + + git_clone_wrapper("https://github.com/dw-0/kiauh", target, branch="dev") + + assert removed == [target] + assert recorder.calls == [("https://github.com/dw-0/kiauh", target, True)] + + def test_respects_decline_to_overwrite(self, monkeypatch, tmp_path: Path) -> None: + target = tmp_path / "kiauh" + target.mkdir() + recorder = _CloneRecorder() + + monkeypatch.setattr("utils.git_utils.git_cmd_clone", recorder.fake_clone) + monkeypatch.setattr("utils.git_utils.git_cmd_checkout", recorder.fake_checkout) + monkeypatch.setattr( + "utils.git_utils.shutil.rmtree", + lambda *a, **k: pytest.fail("should not remove"), + ) + monkeypatch.setattr("utils.git_utils.get_confirm", lambda *a, **k: False) + + git_clone_wrapper("https://github.com/dw-0/kiauh", target) + + assert recorder.calls == [] + + def test_force_overwrites_without_prompt(self, monkeypatch, tmp_path: Path) -> None: + target = tmp_path / "kiauh" + target.mkdir() + recorder = _CloneRecorder() + removed: List[Path] = [] + + monkeypatch.setattr("utils.git_utils.git_cmd_clone", recorder.fake_clone) + monkeypatch.setattr("utils.git_utils.git_cmd_checkout", recorder.fake_checkout) + monkeypatch.setattr("utils.git_utils.shutil.rmtree", lambda p: removed.append(p)) + monkeypatch.setattr( + "utils.git_utils.get_confirm", + lambda *a, **k: pytest.fail("should not prompt when forced"), + ) + + git_clone_wrapper("https://github.com/dw-0/kiauh", target, force=True) + + assert removed == [target] + assert recorder.calls == [("https://github.com/dw-0/kiauh", target, True)] diff --git a/kiauh/utils/tests/test_input_utils.py b/kiauh/utils/tests/test_input_utils.py new file mode 100644 index 0000000..4005fc2 --- /dev/null +++ b/kiauh/utils/tests/test_input_utils.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import Any, List + +import pytest +from utils.input_utils import ( + format_question, + get_confirm, + get_number_input, + get_selection_input, + get_string_input, + validate_number_input, +) + + +def _input_sequence(answers: List[str]): + it = iter(answers) + + def _input(_prompt: str = "") -> str: + return next(it) + + return _input + + +class TestGetConfirm: + def test_accepts_yes(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["y"])) + assert get_confirm("go?") is True + + def test_accepts_no(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["n"])) + assert get_confirm("go?") is False + + def test_default_yes_on_empty(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence([""])) + assert get_confirm("go?", default_choice=True) is True + + def test_default_no_on_empty(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence([""])) + assert get_confirm("go?", default_choice=False) is False + + def test_handles_invalid_then_valid(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["maybe", "yes"])) + assert get_confirm("go?") is True + + def test_go_back_returns_none(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["b"])) + assert get_confirm("go?", allow_go_back=True) is None + + +class TestGetNumberInput: + def test_returns_valid(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["5"])) + assert get_number_input("count?", 1, 10) == 5 + + def test_uses_default(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence([""])) + assert get_number_input("count?", 1, default=3) == 3 + + def test_enforces_minimum(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["0", "2"])) + assert get_number_input("count?", 1) == 2 + + def test_enforces_maximum(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["11", "9"])) + assert get_number_input("count?", 1, 10) == 9 + + def test_go_back_returns_none(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["b"])) + assert get_number_input("count?", 1, allow_go_back=True) is None + + +class TestGetStringInput: + def test_accepts_alphanumeric(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["abc123"])) + assert get_string_input("name?") == "abc123" + + def test_rejects_empty(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["", "value"])) + assert get_string_input("name?") == "value" + + def test_uses_default(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence([""])) + assert get_string_input("name?", default="fallback") == "fallback" + + def test_validates_regex(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["@", "#"])) + assert get_string_input("name?", regex=r"^#+$") == "#" + + def test_rejects_excluded(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["taken", "free"])) + assert get_string_input("name?", exclude=["taken"]) == "free" + + def test_allows_special_chars(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["a-b_c"])) + assert get_string_input("name?", allow_special_chars=True) == "a-b_c" + + def test_allows_empty_with_special_chars(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence([""])) + assert ( + get_string_input("name?", allow_empty=True, allow_special_chars=True) == "" + ) + + +class TestGetSelectionInput: + def test_from_list(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["b"])) + assert get_selection_input("pick?", ["a", "b", "c"]) == "b" + + def test_from_dict(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["two"])) + assert get_selection_input("pick?", {"one": 1, "two": 2}) == "two" + + def test_invalid_then_valid(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["z", "a"])) + assert get_selection_input("pick?", ["a", "b"]) == "a" + + def test_invalid_type_raises(self, monkeypatch) -> None: + monkeypatch.setattr("builtins.input", _input_sequence(["x"])) + with pytest.raises(ValueError): + get_selection_input("pick?", 123) # type: ignore[arg-type] + + +class TestFormatQuestion: + def test_includes_default(self) -> None: + assert "default=5" in format_question("count", 5) + + def test_no_default(self) -> None: + assert "count" in format_question("count") + assert "default" not in format_question("count") + + +class TestValidateNumberInput: + @pytest.mark.parametrize( + "value,min_count,max_count,expected", + [ + ("5", 1, 10, 5), + ("1", 1, 10, 1), + ("10", 1, 10, 10), + ("3", 1, None, 3), + ], + ) + def test_valid( + self, value: str, min_count: int, max_count: Any, expected: int + ) -> None: + assert validate_number_input(value, min_count, max_count) == expected + + @pytest.mark.parametrize( + "value,min_count,max_count", + [ + ("0", 1, 10), + ("11", 1, 10), + ("-1", 0, None), + ], + ) + def test_raises(self, value: str, min_count: int, max_count: Any) -> None: + with pytest.raises(ValueError): + validate_number_input(value, min_count, max_count) diff --git a/kiauh/utils/tests/test_instance_type.py b/kiauh/utils/tests/test_instance_type.py new file mode 100644 index 0000000..b1541c3 --- /dev/null +++ b/kiauh/utils/tests/test_instance_type.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TypeVar + +from components.klipper.klipper import Klipper +from components.moonraker.moonraker import Moonraker +from utils.instance_type import InstanceType + + +class TestInstanceType: + def test_is_typevar(self) -> None: + assert isinstance(InstanceType, TypeVar) + + def test_bound_classes_include_components(self) -> None: + bound = InstanceType.__constraints__ + assert Klipper in bound + assert Moonraker in bound diff --git a/kiauh/utils/tests/test_instance_utils.py b/kiauh/utils/tests/test_instance_utils.py new file mode 100644 index 0000000..b7b92f0 --- /dev/null +++ b/kiauh/utils/tests/test_instance_utils.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from pathlib import Path +from typing import List + +import pytest +from utils.instance_utils import ( + get_instance_suffix, + get_instances, + stop_klipper_instances_interactively, +) + + +class Klipper: + def __init__(self, suffix: str): + self.suffix = suffix + + def __eq__(self, other): + return isinstance(other, Klipper) and self.suffix == other.suffix + + def __repr__(self): + return f"Klipper({self.suffix!r})" + + +class TestGetInstances: + def test_returns_empty_when_no_services(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.instance_utils.SYSTEMD", tmp_path) + assert get_instances(Klipper) == [] + + def test_raises_when_not_a_class(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.instance_utils.SYSTEMD", tmp_path) + with pytest.raises(ValueError): + get_instances("not-a-class") # type: ignore[arg-type] + + def test_finds_and_sorts_instances(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.instance_utils.SYSTEMD", tmp_path) + + (tmp_path / "klipper.service").write_text("") + (tmp_path / "klipper-1.service").write_text("") + (tmp_path / "klipper-10.service").write_text("") + (tmp_path / "klipper-a.service").write_text("") + + instances = get_instances(Klipper) + assert [i.suffix for i in instances] == ["", "1", "10", "a"] + + def test_excludes_blacklisted_suffixes(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.instance_utils.SYSTEMD", tmp_path) + + (tmp_path / "klipper.service").write_text("") + (tmp_path / "klipper-mcu.service").write_text("") + + instances = get_instances(Klipper) + assert [i.suffix for i in instances] == [""] + + +class TestGetInstanceSuffix: + @pytest.mark.parametrize( + "name,service,expected", + [ + ("klipper", "klipper.service", ""), + ("klipper", "klipper-1.service", "1"), + ("klipper", "klipper-10.service", "10"), + ("moonraker", "moonraker-foo.service", "foo"), + ], + ) + def test_suffix(self, name: str, service: str, expected: str) -> None: + assert get_instance_suffix(name, Path(service)) == expected + + +class TestStopKlipperInstancesInteractively: + def test_empty_returns_true(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.instance_utils.get_confirm", + lambda *a, **k: pytest.fail("no prompt when no instances"), + ) + assert stop_klipper_instances_interactively([]) is True + + def test_stops_on_confirm(self, monkeypatch) -> None: + stopped: List[Klipper] = [] + instance = Klipper("") + + monkeypatch.setattr("utils.instance_utils.get_confirm", lambda *a, **k: True) + monkeypatch.setattr( + "utils.instance_utils.InstanceManager.stop_all", + lambda instances: stopped.extend(instances), + ) + + assert stop_klipper_instances_interactively([instance], "update") is True + assert [i.suffix for i in stopped] == [""] + + def test_aborts_on_decline(self, monkeypatch) -> None: + monkeypatch.setattr("utils.instance_utils.get_confirm", lambda *a, **k: False) + monkeypatch.setattr( + "utils.instance_utils.InstanceManager.stop_all", + lambda *a, **k: pytest.fail("should not stop when declined"), + ) + + assert stop_klipper_instances_interactively([Klipper("")]) is False diff --git a/kiauh/utils/tests/test_sys_utils.py b/kiauh/utils/tests/test_sys_utils.py new file mode 100644 index 0000000..76f305f --- /dev/null +++ b/kiauh/utils/tests/test_sys_utils.py @@ -0,0 +1,686 @@ +from __future__ import annotations + +import builtins +from io import StringIO +from pathlib import Path +from subprocess import CalledProcessError +from typing import Any, List + +import pytest +from utils.sys_utils import ( + VenvCreationFailedException, + check_package_install, + check_python_version, + cmd_sysctl_manage, + cmd_sysctl_service, + create_env_file, + create_python_venv, + create_service_file, + download_file, + download_progress, + get_distro_info, + get_ipv4_addr, + get_service_file_path, + get_system_timezone, + get_upgradable_packages, + install_python_packages, + install_python_requirements, + install_system_packages, + kill, + log_process, + parse_packages_from_file, + remove_system_service, + set_nginx_permissions, + unit_file_exists, + update_python_pip, + update_system_package_lists, + upgrade_system_packages, +) + + +class TestKill: + def test_exits_with_error(self, monkeypatch) -> None: + exited: List[int] = [] + + def fake_exit(code: int) -> None: + exited.append(code) + raise SystemExit(code) + + monkeypatch.setattr("utils.sys_utils.sys.exit", fake_exit) + with pytest.raises(SystemExit): + kill("boom") + assert exited == [1] + + +class TestCheckPythonVersion: + def test_old(self, monkeypatch) -> None: + info = type("VI", (), {"major": 3, "minor": 7})() + monkeypatch.setattr("utils.sys_utils.sys.version_info", info) + assert check_python_version(3, 8) is False + + def test_current(self, monkeypatch) -> None: + info = type("VI", (), {"major": 3, "minor": 9})() + monkeypatch.setattr("utils.sys_utils.sys.version_info", info) + assert check_python_version(3, 8) is True + + +class TestParsePackagesFromFile: + def test_reads_pkglist(self, tmp_path: Path) -> None: + script = tmp_path / "install.sh" + script.write_text('PKGLIST="git curl wget"\nOTHER="x"\n') + assert parse_packages_from_file(script) == ["git", "curl", "wget"] + + +class TestCreatePythonVenv: + def test_creates_when_missing(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + target = Path("/tmp/venv") + + assert create_python_venv(target) is True + assert runs == [ + ["virtualenv", "-p", "/usr/bin/python3", "/tmp/venv"], + ] + + def test_declines_recreate(self, monkeypatch) -> None: + target = Path("/tmp/venv") + monkeypatch.setattr("utils.sys_utils.Path.exists", lambda self: self == target) + monkeypatch.setattr("utils.sys_utils.get_confirm", lambda *a, **k: False) + monkeypatch.setattr( + "utils.sys_utils.run", + lambda *a, **k: pytest.fail("should not recreate when declined"), + ) + + assert create_python_venv(target) is False + + def test_confirms_recreate(self, monkeypatch) -> None: + target = Path("/tmp/venv") + state = {"exists": True} + removed: List[Path] = [] + runs: List[List[str]] = [] + + def fake_exists(self: Path) -> bool: + return state["exists"] and str(self) == str(target) + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + def fake_rmtree(p: Path) -> None: + removed.append(p) + state["exists"] = False + + monkeypatch.setattr("utils.sys_utils.Path.exists", fake_exists) + monkeypatch.setattr("utils.sys_utils.get_confirm", lambda *a, **k: True) + monkeypatch.setattr("utils.sys_utils.shutil.rmtree", fake_rmtree) + monkeypatch.setattr("utils.sys_utils.run", fake_run) + + assert ( + create_python_venv(target, allow_access_to_system_site_packages=True) + is True + ) + assert removed == [target] + assert runs == [ + [ + "virtualenv", + "-p", + "/usr/bin/python3", + "/tmp/venv", + "--system-site-packages", + ], + ] + + def test_force_recreate(self, monkeypatch) -> None: + target = Path("/tmp/venv") + state = {"exists": True} + removed: List[Path] = [] + + def fake_exists(self: Path) -> bool: + return state["exists"] and str(self) == str(target) + + def fake_rmtree(p: Path) -> None: + removed.append(p) + state["exists"] = False + + monkeypatch.setattr("utils.sys_utils.Path.exists", fake_exists) + monkeypatch.setattr( + "utils.sys_utils.get_confirm", + lambda *a, **k: pytest.fail("should not prompt when forced"), + ) + monkeypatch.setattr("utils.sys_utils.shutil.rmtree", fake_rmtree) + monkeypatch.setattr("utils.sys_utils.run", lambda *a, **k: None) + + assert create_python_venv(target, force=True) is True + assert removed == [target] + + def test_creation_failure(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.sys_utils.run", + lambda *a, **k: (_ for _ in ()).throw(CalledProcessError(1, "virtualenv")), + ) + assert create_python_venv(Path("/tmp/venv")) is False + + def test_remove_failure(self, monkeypatch) -> None: + target = Path("/tmp/venv") + + def fake_exists(self: Path) -> bool: + return str(self) == str(target) + + monkeypatch.setattr("utils.sys_utils.Path.exists", fake_exists) + monkeypatch.setattr("utils.sys_utils.get_confirm", lambda *a, **k: True) + monkeypatch.setattr( + "utils.sys_utils.shutil.rmtree", + lambda *a, **k: (_ for _ in ()).throw(OSError("locked")), + ) + + assert create_python_venv(target) is False + + +class TestUpdatePythonPip: + def test_raises_when_pip_missing(self, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.check_file_exist", lambda *a, **k: False) + with pytest.raises(FileNotFoundError): + update_python_pip(Path("/tmp/venv")) + + def test_runs_upgrade(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return type("R", (), {"returncode": 0, "stderr": ""})() + + monkeypatch.setattr("utils.sys_utils.check_file_exist", lambda *a, **k: True) + monkeypatch.setattr("utils.sys_utils.run", fake_run) + + update_python_pip(Path("/tmp/venv")) + assert runs == [["/tmp/venv/bin/pip", "install", "-U", "pip"]] + + def test_logs_stderr(self, monkeypatch, capsys) -> None: + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + return type("R", (), {"returncode": 0, "stderr": "some warning"})() + + monkeypatch.setattr("utils.sys_utils.check_file_exist", lambda *a, **k: True) + monkeypatch.setattr("utils.sys_utils.run", fake_run) + + update_python_pip(Path("/tmp/venv")) + + +class TestInstallPythonRequirements: + def test_success(self, monkeypatch) -> None: + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + return type("R", (), {"returncode": 0, "stderr": ""})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + install_python_requirements(Path("/tmp/venv"), Path("/tmp/req.txt")) + + def test_failure(self, monkeypatch) -> None: + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + return type("R", (), {"returncode": 1, "stderr": "nope"})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + with pytest.raises(VenvCreationFailedException): + install_python_requirements(Path("/tmp/venv"), Path("/tmp/req.txt")) + + +class TestInstallPythonPackages: + def test_success(self, monkeypatch) -> None: + captured: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + captured.append(cmd) + return type("R", (), {"returncode": 0, "stderr": ""})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + install_python_packages(Path("/tmp/venv"), ["a", "b"]) + assert captured == [["/tmp/venv/bin/pip", "install", "a", "b"]] + + +class TestUpdateSystemPackageLists: + def test_skips_when_recent(self, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.time.time", lambda: 1000) + monkeypatch.setattr("utils.sys_utils.os.path.getmtime", lambda p: 900) + monkeypatch.setattr( + "utils.sys_utils.run", + lambda *a, **k: pytest.fail("should not update when recent"), + ) + + update_system_package_lists(silent=True) + + def test_runs_when_old(self, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.time.time", lambda: 100_000) + monkeypatch.setattr("utils.sys_utils.os.path.getmtime", lambda p: 0) + + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return type("R", (), {"returncode": 0, "stderr": ""})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + update_system_package_lists(silent=True) + + assert runs == [["sudo", "apt-get", "update"]] + + def test_allows_releaseinfo_change(self, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.time.time", lambda: 100_000) + monkeypatch.setattr("utils.sys_utils.os.path.getmtime", lambda p: 0) + + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return type("R", (), {"returncode": 0, "stderr": ""})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + update_system_package_lists(silent=True, rls_info_change=True) + + assert runs == [["sudo", "apt-get", "update", "--allow-releaseinfo-change"]] + + +class TestGetUpgradablePackages: + def test_parses_apt_list(self, monkeypatch) -> None: + output = ( + "package1/stable 1.0 [upgradable from: 0.9]\n" + "package2/testing 2.0 [upgradable from: 1.0]\n" + ) + monkeypatch.setattr("utils.sys_utils.check_output", lambda *a, **k: output) + assert get_upgradable_packages() == ["package1", "package2"] + + +class TestCheckPackageInstall: + def test_detects_installed(self, monkeypatch) -> None: + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + return type("R", (), {"stdout": "install ok installed"})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + assert check_package_install({"git"}) == [] + + def test_detects_missing(self, monkeypatch) -> None: + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + return type("R", (), {"stdout": "not-installed"})() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + assert check_package_install({"missing"}) == ["missing"] + + +class TestInstallSystemPackages: + def test_runs_apt(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + install_system_packages(["git", "curl"]) + assert runs == [["sudo", "apt-get", "install", "-y", "git", "curl"]] + + +class TestUpgradeSystemPackages: + def test_runs_apt(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + upgrade_system_packages(["git"]) + assert runs == [["sudo", "apt-get", "upgrade", "-y", "git"]] + + +class TestGetIpv4Addr: + def test_returns_socket_address(self, monkeypatch) -> None: + class FakeSocket: + def __init__(self, *args): + pass + + def settimeout(self, value: float) -> None: + pass + + def connect(self, addr: tuple) -> None: + pass + + def getsockname(self) -> tuple: + return ("192.168.1.50", 54321) + + def close(self) -> None: + pass + + monkeypatch.setattr("utils.sys_utils.socket.socket", FakeSocket) + assert get_ipv4_addr() == "192.168.1.50" + + def test_falls_back_to_loopback(self, monkeypatch) -> None: + class FakeSocket: + def __init__(self, *args): + pass + + def settimeout(self, value: float) -> None: + pass + + def connect(self, addr: tuple) -> None: + raise OSError("no route") + + def close(self) -> None: + pass + + monkeypatch.setattr("utils.sys_utils.socket.socket", FakeSocket) + assert get_ipv4_addr() == "127.0.0.1" + + +class TestDownloadFile: + def test_without_progress(self, monkeypatch) -> None: + calls: List[tuple] = [] + + def fake_urlretrieve(url: str, target: Path, reporthook=None) -> None: + calls.append((url, str(target), reporthook)) + + monkeypatch.setattr( + "utils.sys_utils.urllib.request.urlretrieve", fake_urlretrieve + ) + download_file("http://x/file", Path("/target"), show_progress=False) + assert calls == [("http://x/file", "/target", None)] + + def test_with_progress(self, monkeypatch) -> None: + calls: List[tuple] = [] + + def fake_urlretrieve(url: str, target: Path, reporthook=None) -> None: + calls.append((url, str(target), reporthook)) + if reporthook: + reporthook(1, 1024, 2048) + + monkeypatch.setattr( + "utils.sys_utils.urllib.request.urlretrieve", fake_urlretrieve + ) + download_file("http://x/file", Path("/target"), show_progress=True) + assert calls[0][2] is not None + + +class TestDownloadProgress: + def test_writes_to_stdout(self, capsys) -> None: + download_progress(1, 1024, 2048) + captured = capsys.readouterr() + assert "Downloading:" in captured.out + assert "50.00%" in captured.out + + +class TestSetNginxPermissions: + def test_no_change_when_executable(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.sys_utils.run", + lambda cmd, **kwargs: type("R", (), {"stdout": "drwxr-xr-x"})() + if "ls" in cmd + else pytest.fail("should not chmod"), + ) + set_nginx_permissions() + + def test_adds_execute(self, monkeypatch) -> None: + commands: List[Any] = [] + + def fake_run(cmd, **kwargs): + commands.append(cmd) + if isinstance(cmd, str): + return type("R", (), {"stdout": "drwxr------"})() + return None + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + monkeypatch.setattr("utils.sys_utils.Path.home", lambda: Path("/home/user")) + set_nginx_permissions() + assert ["chmod", "og+x", Path("/home/user")] in commands + + +class TestCmdSysctlService: + def test_runs_systemctl(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + cmd_sysctl_service("klipper", "restart") + assert runs == [["sudo", "systemctl", "restart", "klipper"]] + + +class TestCmdSysctlManage: + def test_runs_systemctl(self, monkeypatch) -> None: + runs: List[List[str]] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append(cmd) + return None + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + cmd_sysctl_manage("daemon-reload") + assert runs == [["sudo", "systemctl", "daemon-reload"]] + + +class TestUnitFileExists: + def test_finds_matching_service(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.SYSTEMD", tmp_path) + (tmp_path / "klipper.service").write_text("") + (tmp_path / "klipper-1.service").write_text("") + (tmp_path / "moonraker.service").write_text("") + + assert unit_file_exists("klipper", "service") is True + assert unit_file_exists("moonraker", "service") is True + assert unit_file_exists("klipper", "timer") is False + + def test_respects_exclude(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.SYSTEMD", tmp_path) + (tmp_path / "klipper-mcu.service").write_text("") + + assert unit_file_exists("klipper", "service", exclude=["mcu"]) is False + + +class TestLogProcess: + def test_prints_stdout(self, monkeypatch, capsys) -> None: + lines = iter(["line1\n", "line2\n", ""]) + poll_results = iter([None, 0]) + + class FakeStdout: + def fileno(self) -> int: + return 7 + + def readline(self) -> str: + return next(lines) + + class FakeProcess: + stdout = FakeStdout() + + def poll(self): + return next(poll_results) + + monkeypatch.setattr( + "utils.sys_utils.select.select", lambda r, w, x: ([7], [], []) + ) + log_process(FakeProcess()) # type: ignore[arg-type] + + captured = capsys.readouterr() + assert "line1" in captured.out + assert "line2" in captured.out + + +class TestCreateServiceFile: + def test_writes_via_tee(self, monkeypatch) -> None: + runs: List[tuple] = [] + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + runs.append((cmd, kwargs.get("input"))) + return None + + monkeypatch.setattr("utils.sys_utils.SYSTEMD", Path("/etc/systemd/system")) + monkeypatch.setattr("utils.sys_utils.run", fake_run) + create_service_file("klipper.service", "[Unit]\n") + + assert runs[0][0] == [ + "sudo", + "tee", + Path("/etc/systemd/system/klipper.service"), + ] + assert runs[0][1] == b"[Unit]\n" + + +class TestCreateEnvFile: + def test_writes_file(self, tmp_path: Path) -> None: + path = tmp_path / "env" + create_env_file(path, "KEY=value\n") + assert path.read_text() == "KEY=value\n" + + +class TestRemoveSystemService: + def test_rejects_bad_name(self) -> None: + with pytest.raises(ValueError): + remove_system_service("klipper") + + def test_skips_missing_file(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.SYSTEMD", tmp_path) + remove_system_service("klipper.service") + + def test_full_removal(self, monkeypatch) -> None: + sysd = Path("/fake/systemd") + service_file = sysd / "klipper.service" + monkeypatch.setattr("utils.sys_utils.SYSTEMD", sysd) + + monkeypatch.setattr( + "utils.sys_utils.Path.exists", + lambda self: str(self) == str(service_file), + ) + monkeypatch.setattr( + "utils.sys_utils.Path.is_file", + lambda self: str(self) == str(service_file), + ) + + service_calls: List[tuple] = [] + manage_calls: List[str] = [] + removed: List[Path] = [] + + monkeypatch.setattr( + "utils.sys_utils.cmd_sysctl_service", + lambda name, action: service_calls.append((name, action)), + ) + monkeypatch.setattr( + "utils.sys_utils.cmd_sysctl_manage", + lambda action: manage_calls.append(action), + ) + monkeypatch.setattr( + "utils.sys_utils.remove_with_sudo", lambda p: removed.append(p) + ) + + remove_system_service("klipper.service") + + assert service_calls == [ + ("klipper.service", "stop"), + ("klipper.service", "disable"), + ] + assert removed == [service_file] + assert manage_calls == ["daemon-reload", "reset-failed"] + + +class _FakeInstanceType: + pass + + +_FakeInstanceType.__name__ = "Klipper" + + +class TestGetServiceFilePath: + def test_builds_path(self, monkeypatch) -> None: + monkeypatch.setattr("utils.sys_utils.SYSTEMD", Path("/etc/systemd/system")) + assert get_service_file_path(_FakeInstanceType, "") == Path( + "/etc/systemd/system/klipper.service" + ) + assert get_service_file_path(_FakeInstanceType, "1") == Path( + "/etc/systemd/system/klipper-1.service" + ) + + +class TestGetDistroInfo: + def test_parses_os_release(self, monkeypatch) -> None: + content = """ +ID="ubuntu" +ID_LIKE="debian" +VERSION_ID="22.04" +""" + monkeypatch.setattr( + "utils.sys_utils.check_output", lambda *a, **k: content.encode() + ) + assert get_distro_info() == ("ubuntu", "22.04") + + def test_remaps_raspbian(self, monkeypatch) -> None: + content = """ +ID="raspbian" +ID_LIKE="debian" +VERSION_ID="11" +""" + monkeypatch.setattr( + "utils.sys_utils.check_output", lambda *a, **k: content.encode() + ) + assert get_distro_info() == ("debian", "11") + + def test_raises_on_missing_id(self, monkeypatch) -> None: + monkeypatch.setattr( + "utils.sys_utils.check_output", lambda *a, **k: b'VERSION_ID="1"\n' + ) + with pytest.raises(ValueError): + get_distro_info() + + +class TestGetSystemTimezone: + def test_from_etc_timezone(self, monkeypatch) -> None: + def fake_open(path: str, mode: str = "r", *args, **kwargs): + if path == "/etc/timezone": + return StringIO("Europe/Berlin\n") + return builtins.open(path, mode, *args, **kwargs) + + monkeypatch.setattr("builtins.open", fake_open) + assert get_system_timezone() == "Europe/Berlin" + + def test_fallback_to_timedatectl(self, monkeypatch) -> None: + def fake_open(path: str, mode: str = "r", *args, **kwargs): + raise FileNotFoundError(path) + + monkeypatch.setattr("builtins.open", fake_open) + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + class Result: + stdout = "Timezone=America/New_York\n" + + return Result() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + assert get_system_timezone() == "America/New_York" + + def test_fallback_to_readlink(self, monkeypatch) -> None: + def fake_open(path: str, mode: str = "r", *args, **kwargs): + raise FileNotFoundError(path) + + monkeypatch.setattr("builtins.open", fake_open) + + def fake_run(cmd: List[str], **kwargs: Any) -> Any: + if cmd[:2] == ["timedatectl", "show"]: + raise CalledProcessError(1, "timedatectl") + + class Result: + stdout = "/usr/share/zoneinfo/Asia/Tokyo\n" + + return Result() + + monkeypatch.setattr("utils.sys_utils.run", fake_run) + assert get_system_timezone() == "Asia/Tokyo" + + def test_defaults_to_utc(self, monkeypatch) -> None: + def fake_open(path: str, mode: str = "r", *args, **kwargs): + raise FileNotFoundError(path) + + monkeypatch.setattr("builtins.open", fake_open) + monkeypatch.setattr( + "utils.sys_utils.run", + lambda *a, **k: (_ for _ in ()).throw(CalledProcessError(1, "timedatectl")), + ) + assert get_system_timezone() == "UTC" diff --git a/pyproject.toml b/pyproject.toml index 9c5b287..bd63166 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,5 +33,5 @@ warn_unreachable = true [tool.pytest.ini_options] minversion = "8.2.1" -testpaths = ["kiauh/core/simple_config_parser/tests"] +testpaths = ["kiauh/core/simple_config_parser/tests", "kiauh/utils/tests"] pythonpath = ["kiauh"]