Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions .github/workflows/pr-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
cache: pip

- name: Install Ruff
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff
pip install -r requirements-dev.txt

- name: Ruff
run: ruff check main.py core utils
run: ruff check main.py core utils tests

check-types:
if: github.event.pull_request.draft == false
Expand All @@ -43,10 +43,30 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
cache: pip

- name: Install Black
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install black
pip install -r requirements-dev.txt

- name: Black
run: black --check main.py core utils
run: black --check main.py core utils tests

check-tests:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: pip

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt

- name: Unit tests
run: python -m unittest discover -s tests
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
black
ruff
72 changes: 72 additions & 0 deletions tests/test_utils_and_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

import main
from utils.utils import load_config


class LoadConfigTests(unittest.TestCase):
def test_load_config_returns_parsed_json(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / "config.json"
expected = {
"cloudServiceProvider": 2,
"assessmentType": 1,
"providerDetails": {"region": "eu-central-1"},
}
config_path.write_text(json.dumps(expected), encoding="utf-8")

self.assertEqual(load_config(str(config_path)), expected)

def test_load_config_returns_none_for_missing_file(self):
with patch("utils.utils.console.print") as mock_print:
result = load_config("/tmp/does-not-exist-config.json")

self.assertIsNone(result)
mock_print.assert_called_once()

def test_load_config_returns_none_for_invalid_json(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / "config.json"
config_path.write_text("{invalid json", encoding="utf-8")

with patch("utils.utils.console.print") as mock_print:
result = load_config(str(config_path))

self.assertIsNone(result)
mock_print.assert_called_once()


class RunAssessmentPreValidationTests(unittest.TestCase):
def test_invalid_config_stops_before_pipeline_side_effects(self):
config = {
"assessmentType": 99,
"cloudServiceProvider": 2,
"providerDetails": {},
}

with (
patch("main.validate_config", side_effect=ValueError("bad config")) as mock_validate,
patch("main.resolve_mode") as mock_resolve_mode,
patch("main.create_directory") as mock_create_directory,
patch("main.verify_credentials") as mock_verify_credentials,
patch("main.print_step") as mock_print_step,
patch("main.console.print"),
):
result = main.run_assessment(config, "aws")

self.assertIsNone(result)
mock_validate.assert_called_once_with(config)
mock_print_step.assert_called_once_with(
"Configuration validation failed.", status="error", logs="bad config"
)
mock_resolve_mode.assert_not_called()
mock_create_directory.assert_not_called()
mock_verify_credentials.assert_not_called()


if __name__ == "__main__":
unittest.main()
108 changes: 108 additions & 0 deletions tests/test_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import unittest

from utils.validate import validate_config, validate_region


def build_aws_config():
return {
"name": "Example Assessment",
"assessmentType": 1,
"cloudServiceProvider": 2,
"exitStrategy": 1,
"providerDetails": {
"accessKey": "AKIA_TEST",
"secretKey": "SECRET_TEST",
"region": "eu-central-1",
},
}


def build_azure_config():
return {
"name": "Example Assessment",
"assessmentType": 2,
"cloudServiceProvider": 1,
"exitStrategy": 3,
"providerDetails": {
"tenantId": "tenant-id",
"clientId": "client-id",
"clientSecret": "client-secret",
"subscriptionId": "subscription-id",
"resourceGroupName": "resource-group",
},
}


class ValidateRegionTests(unittest.TestCase):
def test_accepts_known_region(self):
self.assertIsNone(validate_region("eu-central-1"))

def test_rejects_unknown_region(self):
with self.assertRaisesRegex(ValueError, "Invalid AWS region"):
validate_region("moon-central-1")


class ValidateConfigTests(unittest.TestCase):
def test_accepts_valid_aws_config(self):
self.assertTrue(validate_config(build_aws_config()))

def test_accepts_valid_azure_service_principal_config(self):
self.assertTrue(validate_config(build_azure_config()))

def test_accepts_valid_azure_cli_config(self):
config = build_azure_config()
config["providerDetails"] = {
"credential": object(),
"tenantId": "tenant-id",
"subscriptionId": "subscription-id",
"resourceGroupName": "resource-group",
}

self.assertTrue(validate_config(config))

def test_rejects_azure_config_without_client_credentials(self):
config = build_azure_config()
del config["providerDetails"]["clientId"]
del config["providerDetails"]["clientSecret"]

with self.assertRaisesRegex(ValueError, "Missing required fields in providerDetails"):
validate_config(config)

def test_rejects_invalid_assessment_type(self):
config = build_aws_config()
config["assessmentType"] = 9

with self.assertRaisesRegex(ValueError, "Invalid assessmentType"):
validate_config(config)

def test_rejects_non_integer_top_level_fields(self):
config = build_aws_config()
config["assessmentType"] = "basic"

with self.assertRaisesRegex(ValueError, "must be integers"):
validate_config(config)

def test_rejects_invalid_name_characters(self):
config = build_aws_config()
config["name"] = "Bad/Name"

with self.assertRaisesRegex(ValueError, "Assessment name contains invalid characters"):
validate_config(config)

def test_rejects_too_long_name(self):
config = build_aws_config()
config["name"] = "a" * 51

with self.assertRaisesRegex(ValueError, "cannot exceed 50 characters"):
validate_config(config)

def test_rejects_aws_config_with_invalid_region(self):
config = build_aws_config()
config["providerDetails"]["region"] = "invalid-region"

with self.assertRaisesRegex(ValueError, "Invalid AWS region"):
validate_config(config)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion utils/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def validate_config(config: Dict[str, Any]) -> bool:
provider_details = config.get("providerDetails", {})
if cloud_service_provider == 1: # Azure
# Skip validation of clientId and clientSecret if using CLI credentials
if isinstance(provider_details.get("credential"), object): # Assuming it's DefaultAzureCredential
if provider_details.get("credential") is not None:
required_fields = ["tenantId", "subscriptionId", "resourceGroupName"]
else:
required_fields = REQUIRED_FIELDS_AZURE
Expand Down
Loading