From 069d122b55a4ba535d02e816f000387b29c505db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bence=20H=C3=A9zs=C5=91?= Date: Thu, 7 May 2026 12:55:20 +0200 Subject: [PATCH] Add unit tests and update PR checks workflow --- .github/workflows/pr-checks.yml | 32 ++++++++-- requirements-dev.txt | 3 + tests/test_utils_and_main.py | 72 +++++++++++++++++++++ tests/test_validate.py | 108 ++++++++++++++++++++++++++++++++ utils/validate.py | 2 +- 5 files changed, 210 insertions(+), 7 deletions(-) create mode 100644 requirements-dev.txt create mode 100644 tests/test_utils_and_main.py create mode 100644 tests/test_validate.py diff --git a/.github/workflows/pr-checks.yml b/.github/workflows/pr-checks.yml index 62e594b..79e7ed0 100644 --- a/.github/workflows/pr-checks.yml +++ b/.github/workflows/pr-checks.yml @@ -23,13 +23,13 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} cache: pip - - name: Install Ruff + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff + pip install -r requirements-dev.txt - name: Ruff - run: ruff check main.py core utils + run: ruff check main.py core utils tests check-types: if: github.event.pull_request.draft == false @@ -43,10 +43,30 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} cache: pip - - name: Install Black + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install black + pip install -r requirements-dev.txt - name: Black - run: black --check main.py core utils + run: black --check main.py core utils tests + + check-tests: + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + + - name: Unit tests + run: python -m unittest discover -s tests diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..39ee357 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +-r requirements.txt +black +ruff diff --git a/tests/test_utils_and_main.py b/tests/test_utils_and_main.py new file mode 100644 index 0000000..d07bb59 --- /dev/null +++ b/tests/test_utils_and_main.py @@ -0,0 +1,72 @@ +import json +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import main +from utils.utils import load_config + + +class LoadConfigTests(unittest.TestCase): + def test_load_config_returns_parsed_json(self): + with tempfile.TemporaryDirectory() as tmp_dir: + config_path = Path(tmp_dir) / "config.json" + expected = { + "cloudServiceProvider": 2, + "assessmentType": 1, + "providerDetails": {"region": "eu-central-1"}, + } + config_path.write_text(json.dumps(expected), encoding="utf-8") + + self.assertEqual(load_config(str(config_path)), expected) + + def test_load_config_returns_none_for_missing_file(self): + with patch("utils.utils.console.print") as mock_print: + result = load_config("/tmp/does-not-exist-config.json") + + self.assertIsNone(result) + mock_print.assert_called_once() + + def test_load_config_returns_none_for_invalid_json(self): + with tempfile.TemporaryDirectory() as tmp_dir: + config_path = Path(tmp_dir) / "config.json" + config_path.write_text("{invalid json", encoding="utf-8") + + with patch("utils.utils.console.print") as mock_print: + result = load_config(str(config_path)) + + self.assertIsNone(result) + mock_print.assert_called_once() + + +class RunAssessmentPreValidationTests(unittest.TestCase): + def test_invalid_config_stops_before_pipeline_side_effects(self): + config = { + "assessmentType": 99, + "cloudServiceProvider": 2, + "providerDetails": {}, + } + + with ( + patch("main.validate_config", side_effect=ValueError("bad config")) as mock_validate, + patch("main.resolve_mode") as mock_resolve_mode, + patch("main.create_directory") as mock_create_directory, + patch("main.verify_credentials") as mock_verify_credentials, + patch("main.print_step") as mock_print_step, + patch("main.console.print"), + ): + result = main.run_assessment(config, "aws") + + self.assertIsNone(result) + mock_validate.assert_called_once_with(config) + mock_print_step.assert_called_once_with( + "Configuration validation failed.", status="error", logs="bad config" + ) + mock_resolve_mode.assert_not_called() + mock_create_directory.assert_not_called() + mock_verify_credentials.assert_not_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..a88a962 --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,108 @@ +import unittest + +from utils.validate import validate_config, validate_region + + +def build_aws_config(): + return { + "name": "Example Assessment", + "assessmentType": 1, + "cloudServiceProvider": 2, + "exitStrategy": 1, + "providerDetails": { + "accessKey": "AKIA_TEST", + "secretKey": "SECRET_TEST", + "region": "eu-central-1", + }, + } + + +def build_azure_config(): + return { + "name": "Example Assessment", + "assessmentType": 2, + "cloudServiceProvider": 1, + "exitStrategy": 3, + "providerDetails": { + "tenantId": "tenant-id", + "clientId": "client-id", + "clientSecret": "client-secret", + "subscriptionId": "subscription-id", + "resourceGroupName": "resource-group", + }, + } + + +class ValidateRegionTests(unittest.TestCase): + def test_accepts_known_region(self): + self.assertIsNone(validate_region("eu-central-1")) + + def test_rejects_unknown_region(self): + with self.assertRaisesRegex(ValueError, "Invalid AWS region"): + validate_region("moon-central-1") + + +class ValidateConfigTests(unittest.TestCase): + def test_accepts_valid_aws_config(self): + self.assertTrue(validate_config(build_aws_config())) + + def test_accepts_valid_azure_service_principal_config(self): + self.assertTrue(validate_config(build_azure_config())) + + def test_accepts_valid_azure_cli_config(self): + config = build_azure_config() + config["providerDetails"] = { + "credential": object(), + "tenantId": "tenant-id", + "subscriptionId": "subscription-id", + "resourceGroupName": "resource-group", + } + + self.assertTrue(validate_config(config)) + + def test_rejects_azure_config_without_client_credentials(self): + config = build_azure_config() + del config["providerDetails"]["clientId"] + del config["providerDetails"]["clientSecret"] + + with self.assertRaisesRegex(ValueError, "Missing required fields in providerDetails"): + validate_config(config) + + def test_rejects_invalid_assessment_type(self): + config = build_aws_config() + config["assessmentType"] = 9 + + with self.assertRaisesRegex(ValueError, "Invalid assessmentType"): + validate_config(config) + + def test_rejects_non_integer_top_level_fields(self): + config = build_aws_config() + config["assessmentType"] = "basic" + + with self.assertRaisesRegex(ValueError, "must be integers"): + validate_config(config) + + def test_rejects_invalid_name_characters(self): + config = build_aws_config() + config["name"] = "Bad/Name" + + with self.assertRaisesRegex(ValueError, "Assessment name contains invalid characters"): + validate_config(config) + + def test_rejects_too_long_name(self): + config = build_aws_config() + config["name"] = "a" * 51 + + with self.assertRaisesRegex(ValueError, "cannot exceed 50 characters"): + validate_config(config) + + def test_rejects_aws_config_with_invalid_region(self): + config = build_aws_config() + config["providerDetails"]["region"] = "invalid-region" + + with self.assertRaisesRegex(ValueError, "Invalid AWS region"): + validate_config(config) + + +if __name__ == "__main__": + unittest.main() diff --git a/utils/validate.py b/utils/validate.py index 1590bf5..fc72afa 100644 --- a/utils/validate.py +++ b/utils/validate.py @@ -39,7 +39,7 @@ def validate_config(config: Dict[str, Any]) -> bool: provider_details = config.get("providerDetails", {}) if cloud_service_provider == 1: # Azure # Skip validation of clientId and clientSecret if using CLI credentials - if isinstance(provider_details.get("credential"), object): # Assuming it's DefaultAzureCredential + if provider_details.get("credential") is not None: required_fields = ["tenantId", "subscriptionId", "resourceGroupName"] else: required_fields = REQUIRED_FIELDS_AZURE