diff --git a/tests/conftest.py b/tests/conftest.py index cce6a03f..f211f70d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import requests import yaml +from metrics import MetricsClient from server import FileActivityService @@ -132,6 +133,13 @@ def fact_config(request, monitored_dir, logs_dir): config_file.close() +@pytest.fixture +def metrics(fact_config): + """Client for taking metrics snapshots from the FACT endpoint.""" + config, _ = fact_config + return MetricsClient(config['endpoint']['address']) + + @pytest.fixture def test_container(request, docker_client, monitored_dir, ignored_dir): """ diff --git a/tests/metrics.py b/tests/metrics.py new file mode 100644 index 00000000..1e9807b2 --- /dev/null +++ b/tests/metrics.py @@ -0,0 +1,112 @@ +import re + +import requests + + +class MetricsSnapshot: + """ + A parsed snapshot of Prometheus/OpenMetrics metrics. + + Supports querying by metric name and labels: + + ss = metrics.snapshot() + assert ss.get("rate_limiter_events", label="Dropped") == 5 + assert ss.get("bpf_events", label="Added") > 0 + + Metric names are matched without the "stackrox_fact_" prefix and + "_total" counter suffix, so "rate_limiter_events" matches + "stackrox_fact_rate_limiter_events_total". + """ + + _PREFIX = "stackrox_fact_" + _TOTAL_SUFFIX = "_total" + _LINE_RE = re.compile( + r'^(?P\S+?)(?:\{(?P[^}]*)\})?\s+(?P\S+)$' + ) + _LABEL_RE = re.compile(r'(\w+)="([^"]*)"') + + def __init__(self, text): + self._entries = [] + for line in text.splitlines(): + if line.startswith('#') or not line.strip(): + continue + + m = self._LINE_RE.match(line) + if not m: + continue + + name, raw, labels = m.group('name', 'value', 'labels') + + value = float(raw) if '.' in raw else int(raw) + labels = dict(self._LABEL_RE.findall(labels or '')) + + self._entries.append((name, labels, value)) + + @classmethod + def _normalize(cls, name): + if name.startswith(cls._PREFIX): + name = name[len(cls._PREFIX):] + if name.endswith(cls._TOTAL_SUFFIX): + name = name[:-len(cls._TOTAL_SUFFIX)] + return name + + def get(self, metric, **labels): + """ + Get the value of a metric, optionally filtered by labels. + + Args: + metric: Metric name, with or without the "stackrox_fact_" + prefix and "_total" suffix. + **labels: Label key=value pairs to match. + + Returns: + The metric value as int or float. + + Raises: + KeyError: If no matching metric is found. + ValueError: If multiple metrics match. + """ + target = self._normalize(metric) + matches = [] + for name, entry_labels, value in self._entries: + if self._normalize(name) != target: + continue + if all(entry_labels.get(k) == v for k, v in labels.items()): + matches.append(value) + + if not matches: + label_desc = ', '.join(f'{k}="{v}"' for k, v in labels.items()) + key = f'{metric}{{{label_desc}}}' if label_desc else metric + available = '\n '.join( + f'{n} {ls} = {v}' for n, ls, v in self._entries + ) + raise KeyError( + f'metric {key!r} not found. Available:\n {available}' + ) + if len(matches) > 1: + raise ValueError( + f'{metric} matched {len(matches)} entries; use labels to ' + f'narrow the result' + ) + return matches[0] + + def get_all(self, metric, **labels): + """Like get(), but returns a list of all matching values.""" + target = self._normalize(metric) + return [ + value for name, entry_labels, value in self._entries + if self._normalize(name) == target + and all(entry_labels.get(k) == v for k, v in labels.items()) + ] + + +class MetricsClient: + """Fetches metrics snapshots from a FACT endpoint.""" + + def __init__(self, address): + self._url = f'http://{address}/metrics' + + def snapshot(self, timeout=30): + resp = requests.get(self._url, timeout=timeout) + resp.raise_for_status() + return MetricsSnapshot(resp.text) diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py index 080606f7..f566fb31 100644 --- a/tests/test_rate_limit.py +++ b/tests/test_rate_limit.py @@ -3,11 +3,11 @@ from time import sleep import pytest -import requests import yaml from event import Event, EventType, Process + @pytest.fixture def rate_limited_config(fact, fact_config, monitored_dir): """ @@ -23,11 +23,10 @@ def rate_limited_config(fact, fact_config, monitored_dir): sleep(0.1) return config, config_file -def test_rate_limit_drops_events(rate_limited_config, monitored_dir, server): +def test_rate_limit_drops_events(rate_limited_config, monitored_dir, server, metrics): """ Test that the rate limiter drops events when the rate limit is exceeded. """ - config, _ = rate_limited_config num_files = 100 start_time = time.time() @@ -51,19 +50,8 @@ def test_rate_limit_drops_events(rate_limited_config, monitored_dir, server): assert received_count < num_files, \ f'Expected rate limiting to drop some events, but received all {received_count}' - metrics_response = requests.get(f'http://{config["endpoint"]["address"]}/metrics') - assert metrics_response.status_code == 200 - - metrics_text = metrics_response.text - assert 'rate_limiter_events' in metrics_text, 'rate_limiter_events metric not found' - - dropped_count = 0 - for line in metrics_text.split('\n'): - if 'rate_limiter_events' in line and 'label="Dropped"' in line: - parts = line.split() - if len(parts) >= 2: - dropped_count = int(parts[1]) - break + ss = metrics.snapshot() + dropped_count = ss.get("rate_limiter_events", label="Dropped") assert dropped_count > 0, 'Expected rate limiter to report dropped events in metrics' @@ -71,11 +59,10 @@ def test_rate_limit_drops_events(rate_limited_config, monitored_dir, server): assert total_accounted == num_files, 'Expected rate limiter to see all events' -def test_rate_limit_unlimited(monitored_dir, server, fact_config): +def test_rate_limit_unlimited(monitored_dir, server, metrics): """ Test that the default config (rate_limit=0) allows all events through. """ - config, _ = fact_config num_files = 20 events = [] process = Process.from_proc() @@ -90,18 +77,8 @@ def test_rate_limit_unlimited(monitored_dir, server, fact_config): server.wait_events(events) - metrics_response = requests.get(f'http://{config["endpoint"]["address"]}/metrics') - assert metrics_response.status_code == 200 - - metrics_text = metrics_response.text - - dropped_count = 0 - for line in metrics_text.split('\n'): - if 'rate_limiter_events' in line and 'label="Dropped"' in line: - parts = line.split() - if len(parts) >= 2: - dropped_count = int(parts[1]) - break + ss = metrics.snapshot() + dropped_count = ss.get("rate_limiter_events", label="Dropped") assert dropped_count == 0, \ f'Expected no dropped events with unlimited rate limiting, but got {dropped_count}'