Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 96 additions & 21 deletions sentry_streams_k8s/sentry_streams_k8s/pipeline_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,45 @@ def make_k8s_name(name: str) -> str:
return name


def get_multiprocess_config(pipeline_config: dict[str, Any]) -> tuple[int | None, list[int]]:
"""
Extract multiprocessing configuration from pipeline config.

Iterates through all segments in the pipeline configuration and looks for
parallelism.multi_process.processes configuration in any step.

Examples:
>>> config = {"pipeline": {"segments": [{"steps_config": {"step1": {"parallelism": {"multi_process": {"processes": 4}}}}}]}}
>>> get_multiprocess_config(config)
(4, [0])
"""
segments_with_parallelism: list[int] = []
process_count: int | None = None

segments = pipeline_config["pipeline"]["segments"]

for segment_idx, segment in enumerate(segments):
steps_config = segment.get("steps_config", {})

for step_config in steps_config.values():
parallelism = step_config.get("parallelism")
if not parallelism or not isinstance(parallelism, dict):
continue

multi_process = parallelism.get("multi_process")
if not multi_process:
continue

processes = multi_process.get("processes")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type check causes crash on non-dict multi_process

Medium Severity

The code checks isinstance(parallelism, dict) before calling .get() on it, but there's no corresponding check for multi_process. If multi_process is a truthy non-dict value (e.g., multi_process: true or multi_process: 1 in YAML), the condition if not multi_process passes, and then multi_process.get("processes") crashes with AttributeError: 'bool' object has no attribute 'get'.

Fix in Cursor Fix in Web

if processes is not None:
segments_with_parallelism.append(segment_idx)
if process_count is None:
process_count = processes
break # Found parallelism in this segment, move to next segment

return process_count, segments_with_parallelism


def build_container(
container_template: dict[str, Any],
pipeline_name: str,
Expand All @@ -49,6 +88,7 @@ def build_container(
cpu_per_process: int,
memory_per_process: int,
segment_id: int,
process_count: int | None = None,
) -> dict[str, Any]:
"""
Build a complete container specification for the pipeline step.
Expand All @@ -59,10 +99,34 @@ def build_container(
some standard parameters like securityContext
3. building the streaming pipeline specific parameters and merging them
onto the result of step 2.

"""
base_container = load_base_template("container")
container = deepmerge(base_container, container_template)

# CPU and memory are provided per process, so we need to multiply them
# by the number of processes to get the total resources.
cpu_total = cpu_per_process * (process_count or 1)
memory_total = memory_per_process * (process_count or 1)

volume_mounts: list[dict[str, Any]] = [
{
"name": "pipeline-config",
"mountPath": "/etc/pipeline-config",
"readOnly": True,
}
]

# Shared memory volume is needed to allow the communication between processes.
# via shared memory. Only needed when in multiprocess mode.
if process_count is not None and process_count > 1:
volume_mounts.append(
{
"name": "dshm",
"mountPath": "/dev/shm",
}
)

pipeline_additions = {
"name": "pipeline-consumer",
"image": image_name,
Expand All @@ -80,20 +144,14 @@ def build_container(
],
"resources": {
"requests": {
"cpu": f"{cpu_per_process}m",
"memory": f"{memory_per_process}Mi",
"cpu": f"{cpu_total}m",
"memory": f"{memory_total}Mi",
},
"limits": {
"memory": f"{memory_per_process}Mi",
"memory": f"{memory_total}Mi",
},
},
"volumeMounts": [
{
"name": "pipeline-config",
"mountPath": "/etc/pipeline-config",
"readOnly": True,
}
],
"volumeMounts": volume_mounts,
}

return deepmerge(container, pipeline_additions)
Expand Down Expand Up @@ -259,7 +317,13 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]:
replicas = ctx["replicas"]
emergency_patch = ctx.get("emergency_patch", {})

# Create deployment
process_count, segments_with_parallelism = get_multiprocess_config(pipeline_config)
if len(segments_with_parallelism) > 1:
raise ValueError(
f"Multi-processing configuration can only be specified in one segment. "
f"Found parallelism configuration in {len(segments_with_parallelism)} segments "
f"(segment indices: {segments_with_parallelism})."
)

container = build_container(
container_template,
Expand All @@ -269,6 +333,7 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]:
cpu_per_process,
memory_per_process,
segment_id,
process_count,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiprocessing resources applied to wrong segments

High Severity

The process_count from get_multiprocess_config() is applied unconditionally to the current segment_id being deployed, without checking if that segment actually has parallelism configured. If segment 0 has no multiprocessing but segment 1 has processes: 4, deploying segment 0 would incorrectly receive 4x resources and the /dev/shm volume. The code retrieves segments_with_parallelism but never checks whether segment_id is in that list before applying resource scaling.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test to cover this case? I don't think this is a valid bug but what is expected to happen where multiple segments are passed in and the first one is not the one that requires parallelism?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to add the test but then I figured out that the comment is wrong.

This is an example of the pipeline with parallelism

pipeline:
  segments:
    - steps_config:
        myinput:
          starts_segment: True
          bootstrap_servers: ["127.0.0.1:9092"]
          parallelism: 1
        parser:
          # Parser is the beginning of the segment.
          # All Map steps in the same segment are chained
          # together in the same process.
          #
          # When adding a step to the segment that is not
          # a map we need to create a new segment as these
          # cannot be ran in a multi process step.
          starts_segment: True
          parallelism:
            multi_process:
              processes: 4
              batch_size: 1000
              batch_time: 0.2
        mysink:
          starts_segment: True
          bootstrap_servers: ["127.0.0.1:9092"]

WE pass --segment-id = 0 to run it but technically speaking, the parallel segment is the second.
This is because segments do not have a sound semantics. We have segments in the segments list but we can also start segments implicitly inside a single step.

We need to figure this out, though we cannot assert that the segment with the parallelism config would be the one passed to the consumer as this is never the case.

)

base_deployment = load_base_template("deployment")
Expand All @@ -279,6 +344,25 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]:
}
configmap_name = make_k8s_name(f"{service_name}-pipeline-{pipeline_name}")

volumes: list[dict[str, Any]] = [
{
"name": "pipeline-config",
"configMap": {
"name": configmap_name,
},
}
]

# Shared memory volume is needed to allow the communication between processes.
# via shared memory. Only needed when in multiprocess mode.
if process_count is not None and process_count > 1:
volumes.append(
{
"name": "dshm",
"emptyDir": {"medium": "Memory"},
}
)

pipeline_additions = {
"metadata": {
"name": make_k8s_name(f"{service_name}-pipeline-{pipeline_name}-{segment_id}"),
Expand All @@ -295,14 +379,7 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]:
},
"spec": {
"containers": [container],
"volumes": [
{
"name": "pipeline-config",
"configMap": {
"name": configmap_name,
},
}
],
"volumes": volumes,
},
},
},
Expand All @@ -329,7 +406,6 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]:
if emergency_patch:
deployment = deepmerge(deployment, emergency_patch)

# Create configmap
configmap = {
"apiVersion": "v1",
"kind": "ConfigMap",
Expand All @@ -342,7 +418,6 @@ def run(self, context: dict[str, Any]) -> dict[str, Any]:
},
}

# Add namespace if present in deployment template
if "namespace" in deployment.get("metadata", {}):
metadata = cast(dict[str, Any], configmap["metadata"])
metadata["namespace"] = deployment["metadata"]["namespace"]
Expand Down
Loading
Loading