diff --git a/sentry_streams_k8s/sentry_streams_k8s/__init__.py b/sentry_streams_k8s/sentry_streams_k8s/__init__.py index af0ba071..66cf2748 100644 --- a/sentry_streams_k8s/sentry_streams_k8s/__init__.py +++ b/sentry_streams_k8s/sentry_streams_k8s/__init__.py @@ -1,4 +1,9 @@ -from sentry_streams_k8s.merge import TypeMismatchError +from sentry_streams_k8s.merge import ScalarOverwriteError, TypeMismatchError from sentry_streams_k8s.pipeline_step import PipelineStep, PipelineStepContext -__all__ = ["PipelineStep", "PipelineStepContext", "TypeMismatchError"] +__all__ = [ + "PipelineStep", + "PipelineStepContext", + "ScalarOverwriteError", + "TypeMismatchError", +] diff --git a/sentry_streams_k8s/sentry_streams_k8s/merge.py b/sentry_streams_k8s/sentry_streams_k8s/merge.py index aa45519d..32aa5fbc 100644 --- a/sentry_streams_k8s/sentry_streams_k8s/merge.py +++ b/sentry_streams_k8s/sentry_streams_k8s/merge.py @@ -10,7 +10,18 @@ class TypeMismatchError(TypeError): pass -def deepmerge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: +class ScalarOverwriteError(ValueError): + """Raised when attempting to overwrite a scalar value during deepmerge.""" + + pass + + +def deepmerge( + base: dict[str, Any], + override: dict[str, Any], + fail_on_scalar_overwrite: bool = False, + _path: list[str] | None = None, +) -> dict[str, Any]: """ Deep merge two dictionaries with specific semantics for Kubernetes manifests. @@ -20,12 +31,12 @@ def deepmerge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: - Lists: concatenate (append override elements to base) - Type mismatches (e.g., dict + list, dict + str): raises TypeMismatchError - Returns: - A new dictionary with merged values (base and override are not mutated) Raises: TypeMismatchError: When attempting to merge incompatible types (e.g., trying to merge a dict with a list, or a list with a string) + ScalarOverwriteError: When fail_on_scalar_overwrite is True and attempting + to overwrite a scalar value with a different scalar value Examples: >>> base = {"a": 1, "b": {"c": 2}} @@ -44,10 +55,17 @@ def deepmerge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: Traceback (most recent call last): ... TypeMismatchError: Cannot merge key 'key': base type is dict but override type is str + """ + if _path is None: + _path = [] + result = copy.deepcopy(base) for key, override_value in override.items(): + current_path = _path + [key] + path_str = ".".join(current_path) + if key not in result: result[key] = copy.deepcopy(override_value) else: @@ -55,7 +73,12 @@ def deepmerge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: # Both base and override have this key if isinstance(base_value, dict) and isinstance(override_value, dict): - result[key] = deepmerge(base_value, override_value) + result[key] = deepmerge( + base_value, + override_value, + fail_on_scalar_overwrite=fail_on_scalar_overwrite, + _path=current_path, + ) elif isinstance(base_value, list) and isinstance(override_value, list): result[key] = base_value + copy.deepcopy(override_value) elif type(base_value) is not type(override_value): @@ -63,6 +86,11 @@ def deepmerge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: f"Cannot merge key '{key}': base type is {type(base_value)} but override type is {type(override_value)}" ) else: + # Scalar to scalar replacement + if fail_on_scalar_overwrite and base_value != override_value: + raise ScalarOverwriteError( + f"Cannot overwrite scalar at '{path_str}': would change {base_value!r} to {override_value!r}" + ) result[key] = copy.deepcopy(override_value) return result diff --git a/sentry_streams_k8s/sentry_streams_k8s/pipeline_step.py b/sentry_streams_k8s/sentry_streams_k8s/pipeline_step.py index e795425a..fc1facdd 100644 --- a/sentry_streams_k8s/sentry_streams_k8s/pipeline_step.py +++ b/sentry_streams_k8s/sentry_streams_k8s/pipeline_step.py @@ -8,7 +8,7 @@ import yaml from libsentrykube.ext import ExternalMacro -from sentry_streams_k8s.merge import deepmerge +from sentry_streams_k8s.merge import ScalarOverwriteError, deepmerge from sentry_streams_k8s.validation import validate_pipeline_config @@ -141,6 +141,7 @@ def parse_context(context: dict[str, Any]) -> PipelineStepContext: "cpu_per_process": context["cpu_per_process"], "memory_per_process": context["memory_per_process"], "segment_id": context["segment_id"], + "replicas": context.get("replicas", 1), "emergency_patch": emergency_patch_parsed, } @@ -158,6 +159,7 @@ class PipelineStepContext(TypedDict): cpu_per_process: int memory_per_process: int segment_id: int + replicas: int emergency_patch: NotRequired[dict[str, Any]] @@ -201,6 +203,7 @@ class PipelineStep(ExternalMacro): "segment_id": 0, "cpu_per_process": 1000, "memory_per_process": 512, + "replicas": 3, } ) }} @@ -253,6 +256,7 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]: pipeline_name = ctx["pipeline_name"] segment_id = ctx["segment_id"] service_name = ctx["service_name"] + replicas = ctx["replicas"] emergency_patch = ctx.get("emergency_patch", {}) # Create deployment @@ -268,7 +272,6 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]: ) base_deployment = load_base_template("deployment") - deployment = deepmerge(base_deployment, deployment_template) labels = { "pipeline-app": make_k8s_name(pipeline_module), @@ -282,6 +285,7 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]: "labels": labels, }, "spec": { + "replicas": replicas, "selector": { "matchLabels": labels, }, @@ -304,6 +308,22 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]: }, } + # Check for scalar conflicts between user template and pipeline additions + # This ensures pipeline additions don't override user-provided values + # while still allowing both to override base template defaults + try: + # Perform a test merge to detect conflicts + deepmerge(deployment_template, pipeline_additions, fail_on_scalar_overwrite=True) + except ScalarOverwriteError as e: + raise ScalarOverwriteError( + f"{e}\n\n" + f"This field is automatically set by PipelineStep and conflicts with your deployment_template. " + f"Note: Lists and dicts can be provided (they get merged), but scalar values cannot be overridden." + ) from e + + # No conflicts found, proceed with merging + # Both user template and pipeline additions can override base template + deployment = deepmerge(base_deployment, deployment_template) deployment = deepmerge(deployment, pipeline_additions) if emergency_patch: diff --git a/sentry_streams_k8s/tests/test_merge.py b/sentry_streams_k8s/tests/test_merge.py index 01fe855a..f8e1db45 100644 --- a/sentry_streams_k8s/tests/test_merge.py +++ b/sentry_streams_k8s/tests/test_merge.py @@ -384,3 +384,117 @@ def test_deepmerge_kubernetes_deployment_example() -> None: }, }, } + + +def test_fail_on_scalar_overwrite_catches_conflicts() -> None: + """Test that fail_on_scalar_overwrite raises error when overwriting different scalars.""" + from sentry_streams_k8s.merge import ScalarOverwriteError + + base = {"replicas": 1, "name": "old-name"} + override = {"replicas": 5, "extra": "value"} + + # Should raise when trying to overwrite replicas with different value + with pytest.raises(ScalarOverwriteError, match="replicas.*1.*5"): + deepmerge(base, override, fail_on_scalar_overwrite=True) + + +def test_fail_on_scalar_overwrite_allows_same_values() -> None: + """Test that fail_on_scalar_overwrite allows overwriting with same value.""" + base = {"replicas": 1, "name": "my-name"} + override = {"replicas": 1, "extra": "value"} + + # Should not raise when overwriting with same value + result = deepmerge(base, override, fail_on_scalar_overwrite=True) + assert result == {"replicas": 1, "name": "my-name", "extra": "value"} + + +def test_fail_on_scalar_overwrite_allows_dicts_and_lists() -> None: + """Test that fail_on_scalar_overwrite still allows dict and list merging.""" + base = { + "labels": {"app": "my-app", "version": "1.0"}, + "volumes": [{"name": "vol1"}], + "replicas": 1, + } + override = { + "labels": {"env": "prod"}, # Dict merge - should work + "volumes": [{"name": "vol2"}], # List append - should work + "replicas": 1, # Same value - should work + } + + result = deepmerge(base, override, fail_on_scalar_overwrite=True) + assert result == { + "labels": {"app": "my-app", "version": "1.0", "env": "prod"}, + "volumes": [{"name": "vol1"}, {"name": "vol2"}], + "replicas": 1, + } + + +def test_fail_on_scalar_overwrite_nested_path() -> None: + """Test that fail_on_scalar_overwrite provides correct path in error message.""" + from sentry_streams_k8s.merge import ScalarOverwriteError + + base = { + "metadata": { + "labels": { + "pipeline": "old-value", + } + } + } + override = { + "metadata": { + "labels": { + "pipeline": "new-value", + } + } + } + + with pytest.raises(ScalarOverwriteError, match="metadata.labels.pipeline"): + deepmerge(base, override, fail_on_scalar_overwrite=True) + + +def test_fail_on_scalar_overwrite_multiple_levels() -> None: + """Test that fail_on_scalar_overwrite works correctly with deeply nested structures.""" + from sentry_streams_k8s.merge import ScalarOverwriteError + + base = { + "spec": { + "template": { + "spec": { + "replicas": 1, + "containers": [{"name": "base-container"}], + } + } + } + } + override = { + "spec": { + "template": { + "spec": { + "replicas": 3, # Conflict here + "containers": [{"name": "override-container"}], # This is fine (list) + } + } + } + } + + with pytest.raises(ScalarOverwriteError, match="spec.template.spec.replicas"): + deepmerge(base, override, fail_on_scalar_overwrite=True) + + +def test_fail_on_scalar_overwrite_disabled_by_default() -> None: + """Test that scalar overwriting works normally when flag is not set.""" + base = {"replicas": 1, "name": "old-name"} + override = {"replicas": 5, "name": "new-name"} + + # Should work fine without the flag + result = deepmerge(base, override) + assert result == {"replicas": 5, "name": "new-name"} + + +def test_fail_on_scalar_overwrite_with_new_keys() -> None: + """Test that fail_on_scalar_overwrite allows adding new keys.""" + base = {"replicas": 1} + override = {"replicas": 1, "new_key": "new_value", "another": 42} + + result = deepmerge(base, override, fail_on_scalar_overwrite=True) + assert result == {"replicas": 1, "new_key": "new_value", "another": 42} diff --git a/sentry_streams_k8s/tests/test_pipeline_step.py b/sentry_streams_k8s/tests/test_pipeline_step.py index 6cc33399..efa30795 100644 --- a/sentry_streams_k8s/tests/test_pipeline_step.py +++ b/sentry_streams_k8s/tests/test_pipeline_step.py @@ -4,6 +4,7 @@ import yaml from jsonschema import ValidationError +from sentry_streams_k8s.merge import ScalarOverwriteError from sentry_streams_k8s.pipeline_step import ( PipelineStep, build_container, @@ -53,6 +54,7 @@ def test_parse_context() -> None: "cpu_per_process": 1000, "memory_per_process": 512, "segment_id": 0, + "replicas": 2, } parsed_context = parse_context(context) @@ -81,6 +83,7 @@ def test_parse_context() -> None: assert parsed_context["cpu_per_process"] == 1000 assert parsed_context["memory_per_process"] == 512 assert parsed_context["segment_id"] == 0 + assert parsed_context["replicas"] == 2 context["deployment_template"] = yaml.dump(context["deployment_template"]) context["container_template"] = yaml.dump(context["container_template"]) @@ -104,6 +107,7 @@ def test_parse_context() -> None: ] }, } + assert parsed_context["replicas"] == 2 def test_build_container() -> None: @@ -197,6 +201,7 @@ def test_validate_context_valid() -> None: "cpu_per_process": 1000, "memory_per_process": 512, "segment_id": 0, + "replicas": 1, } # Should not raise any exception @@ -310,6 +315,7 @@ def test_run_generates_complete_manifests() -> None: "cpu_per_process": 1000, "memory_per_process": 512, "segment_id": 0, + "replicas": 1, } pipeline_step = PipelineStep() @@ -418,6 +424,7 @@ def test_run_with_base_templates() -> None: "cpu_per_process": 1000, "memory_per_process": 512, "segment_id": 0, + "replicas": 1, } pipeline_step = PipelineStep() @@ -445,14 +452,12 @@ def test_run_with_base_templates() -> None: def test_user_template_overrides_base() -> None: - """Test that user template values take precedence over base template values.""" + """Test that user template values can override base template for non-controlled fields.""" context: dict[str, Any] = { "service_name": "my-service", "pipeline_name": "profiles", "deployment_template": { - # Override base replicas "spec": { - "replicas": 5, "template": { "spec": { # Override base terminationGracePeriodSeconds @@ -485,6 +490,7 @@ def test_user_template_overrides_base() -> None: "cpu_per_process": 1000, "memory_per_process": 512, "segment_id": 0, + "replicas": 1, # Use base template value to avoid conflict } pipeline_step = PipelineStep() @@ -492,8 +498,9 @@ def test_user_template_overrides_base() -> None: deployment = result["deployment"] - # Check that user overrides took effect - assert deployment["spec"]["replicas"] == 5 # User override, not base 1 + # Check that replicas parameter took effect + assert deployment["spec"]["replicas"] == 1 + # Check that user template overrides for non-controlled fields worked assert ( deployment["spec"]["template"]["spec"]["terminationGracePeriodSeconds"] == 60 ) # User override, not base 30 @@ -575,6 +582,46 @@ def test_user_volumes_and_containers_preserved() -> None: assert "pipeline-config" in volume_mount_names +def test_template_conflict_scalar_overwrite() -> None: + """Test that PipelineStep detects and prevents scalar field conflicts in templates.""" + # Test conflict with replicas field + context: dict[str, Any] = { + "service_name": "my-service", + "pipeline_name": "profiles", + "deployment_template": { + "spec": { + "replicas": 5, # User tries to set replicas - conflicts with parameter + } + }, + "container_template": {}, + "pipeline_config": { + "env": {}, + "pipeline": { + "segments": [ + { + "steps_config": { + "myinput": { + "starts_segment": True, + "bootstrap_servers": ["127.0.0.1:9092"], + } + } + } + ] + }, + }, + "pipeline_module": "sbc.profiles", + "image_name": "my-image:latest", + "cpu_per_process": 1000, + "memory_per_process": 512, + "segment_id": 0, + "replicas": 3, # Different from template value + } + + pipeline_step = PipelineStep() + with pytest.raises(ScalarOverwriteError, match="spec.replicas"): + pipeline_step.run(context) + + def test_emergency_patch_overrides_final_deployment() -> None: """Test that emergency_patch overrides all other layers including pipeline additions.""" context: dict[str, Any] = {