diff --git a/awsiot/_iot_metrics.py b/awsiot/_iot_metrics.py new file mode 100644 index 00000000..3e426abb --- /dev/null +++ b/awsiot/_iot_metrics.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0. + +""" +Private IoT SDK metrics module. + +Provides SDK-level metadata (version info) to pass to the CRT layer. +The CRT handles all feature detection (certificate source, TLS settings, etc.) +and embeds the combined metrics in the MQTT CONNECT packet username field. + +""" + +from awscrt.aws_iot_metrics import AWSIoTMetrics, IoTMetricsMetadata + +_SDK_LIBRARY_NAME = "IoTDeviceSDK/Python" + +# The current version of the IoT SDK metrics format. +# This must match the version expected by CRT layer. +_IOT_SDK_METRICS_VERSION = 1 + + +def _get_sdk_version(): + """ + Return the installed ``awsiotsdk`` package version string. + + Falls back to ``"dev"`` if the package metadata is unavailable (e.g. when + running from a source checkout without installing). + + Returns: + str: A version string like ``1.32.0`` or ``"dev"``. + """ + try: + import importlib.metadata + return importlib.metadata.version("awsiotsdk") + except Exception: + return "dev" + + +def _build_sdk_metrics(): + """ + Build the SDK-level :class:`~awscrt.aws_iot_metrics.AWSIoTMetrics` payload + that is passed down to the CRT layer. + + The returned object carries SDK identity and the metrics format version + via two metadata entries: + + - ``IoTSDKVersion``: the installed ``awsiotsdk`` package version, used + to identify the SDK release on the server side. + - ``IoTSDKMetricsVersion``: the metrics format version this SDK supports. + The CRT only merges SDK-supplied features when this value matches the + version it expects, so bumping :data:`_IOT_SDK_METRICS_VERSION` should + be done in lockstep with CRT changes. + + The CRT layer is responsible for detecting connection-level features + (protocol version, certificate source, TLS settings, proxy type, etc.) + and appending them to the metadata before embedding the result in the + MQTT CONNECT packet username field. + + Returns: + AWSIoTMetrics: A populated metrics object ready to attach to an + MQTT5 client or MQTT3 connection configuration. + """ + return AWSIoTMetrics( + library_name=_SDK_LIBRARY_NAME, + metadata_entries=[ + IoTMetricsMetadata(key="IoTSDKVersion", value=_get_sdk_version()), + IoTMetricsMetadata(key="IoTSDKMetricsVersion", value=str(_IOT_SDK_METRICS_VERSION)), + ] + ) diff --git a/awsiot/mqtt5_client_builder.py b/awsiot/mqtt5_client_builder.py index ae9b9750..e31c857a 100644 --- a/awsiot/mqtt5_client_builder.py +++ b/awsiot/mqtt5_client_builder.py @@ -170,8 +170,8 @@ **cipher_pref** (:class:`awscrt.io.TlsCipherPref`): Cipher preference to use for TLS connection. Default is `TlsCipherPref.DEFAULT`. - **enable_metrics_collection** (`bool`): Whether to send the SDK version number in the CONNECT packet. - Default is True. + **disable_metrics** (`bool`): Disable IoT SDK metrics in the CONNECT packet username field. + Defaults to False (metrics enabled). """ @@ -184,6 +184,8 @@ import awscrt.mqtt5 import urllib.parse +from awsiot._iot_metrics import _build_sdk_metrics + DEFAULT_WEBSOCKET_MQTT_PORT = 443 DEFAULT_DIRECT_MQTT_PORT = 8883 @@ -210,35 +212,6 @@ def _get(kwargs, name, default=None): return val -_metrics_str = None - - -def _get_metrics_str(current_username=""): - global _metrics_str - - username_has_query = False - if current_username.find("?") != -1: - username_has_query = True - - if _metrics_str is None: - try: - import importlib.metadata - try: - version = importlib.metadata.version("awsiotsdk") - _metrics_str = "SDK=PythonV2&Version={}".format(version) - except importlib.metadata.PackageNotFoundError: - _metrics_str = "SDK=PythonV2&Version=dev" - except BaseException: - _metrics_str = "" - - if not _metrics_str == "": - if username_has_query: - return "&" + _metrics_str - else: - return "?" + _metrics_str - else: - return "" - def _builder( tls_ctx_options, @@ -251,8 +224,6 @@ def _builder( assert isinstance(cipher_pref, awscrt.io.TlsCipherPref) username = _get(kwargs, 'username', '') - if _get(kwargs, 'enable_metrics_collection', True): - username += _get_metrics_str(username) client_options = _get(kwargs, 'client_options') if client_options is None: @@ -364,6 +335,11 @@ def _builder( tls_ctx = awscrt.io.ClientTlsContext(tls_ctx_options) client_options.tls_ctx = tls_ctx + + # Set SDK metrics for the CRT layer to embed in the CONNECT packet username + disable_metrics = _get(kwargs, 'disable_metrics', False) + client_options.disable_metrics = disable_metrics + client_options.metrics = None if disable_metrics else _build_sdk_metrics() client = awscrt.mqtt5.Client(client_options=client_options) return client diff --git a/awsiot/mqtt_connection_builder.py b/awsiot/mqtt_connection_builder.py index 75144563..d09143a6 100644 --- a/awsiot/mqtt_connection_builder.py +++ b/awsiot/mqtt_connection_builder.py @@ -113,8 +113,8 @@ **cipher_pref** (:class:`awscrt.io.TlsCipherPref`): Cipher preference to use for TLS connection. Default is `TlsCipherPref.DEFAULT`. - **enable_metrics_collection** (`bool`): Whether to send the SDK version number in the CONNECT packet. - Default is True. + **disable_metrics** (`bool`): Disable IoT SDK metrics in the CONNECT packet username field. + Default is False (metrics enabled). **http_proxy_options** (:class: 'awscrt.http.HttpProxyOptions'): HTTP proxy options to use """ @@ -127,6 +127,8 @@ import awscrt.mqtt import urllib.parse +from awsiot._iot_metrics import _build_sdk_metrics + def _check_required_kwargs(**kwargs): for required in ['endpoint', 'client_id']: @@ -148,35 +150,6 @@ def _get(kwargs, name, default=None): return val -_metrics_str = None - - -def _get_metrics_str(current_username=""): - global _metrics_str - - username_has_query = False - if current_username.find("?") != -1: - username_has_query = True - - if _metrics_str is None: - try: - import importlib.metadata - try: - version = importlib.metadata.version("awsiotsdk") - _metrics_str = "SDK=PythonV2&Version={}".format(version) - except importlib.metadata.PackageNotFoundError: - _metrics_str = "SDK=PythonV2&Version=dev" - except BaseException: - _metrics_str = "" - - if not _metrics_str == "": - if username_has_query: - return "&" + _metrics_str - else: - return "?" + _metrics_str - else: - return "" - def _builder( tls_ctx_options, @@ -225,12 +198,14 @@ def _builder( _get(kwargs, 'tcp_keep_alive_max_probes', _get(kwargs, 'tcp_keepalive_max_probes', 0)) username = _get(kwargs, 'username', '') - if _get(kwargs, 'enable_metrics_collection', True): - username += _get_metrics_str(username) if username == "": username = None + # Set SDK metrics for the CRT layer to embed in the CONNECT packet username + disable_metrics = _get(kwargs, 'disable_metrics', False) + metrics = None if disable_metrics else _build_sdk_metrics() + client_bootstrap = _get(kwargs, 'client_bootstrap') if client_bootstrap is None: client_bootstrap = awscrt.io.ClientBootstrap.get_or_create_static_default() @@ -262,6 +237,8 @@ def _builder( on_connection_success=_get(kwargs, 'on_connection_success'), on_connection_failure=_get(kwargs, 'on_connection_failure'), on_connection_closed=_get(kwargs, 'on_connection_closed'), + disable_metrics=disable_metrics, + metrics=metrics ) diff --git a/test/test_get_metrics.py b/test/test_get_metrics.py index 0c0d7d8b..1ff55ab8 100644 --- a/test/test_get_metrics.py +++ b/test/test_get_metrics.py @@ -1,127 +1,196 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0. +import os import unittest +import uuid +import warnings from unittest.mock import patch +import boto3 +import botocore.exceptions -class TestImportlibMetadata(unittest.TestCase): - """Test that importlib.metadata is used instead of pkg_resources""" +from awsiot._iot_metrics import ( + _SDK_LIBRARY_NAME, + _IOT_SDK_METRICS_VERSION, + _get_sdk_version, + _build_sdk_metrics, +) - def setUp(self): - """Reset the metrics string cache before each test""" - # Reset the cached metrics string in both modules - import awsiot.mqtt5_client_builder - import awsiot.mqtt_connection_builder +AWS_DEFAULT_REGION = os.environ.get("AWS_DEFAULT_REGION") - # Reset the global _metrics_str variable - awsiot.mqtt_connection_builder._metrics_str = None - awsiot.mqtt5_client_builder._metrics_str = None - def test_metrics_string_generation_mqtt_connection_builder(self): - """Test that mqtt_connection_builder uses importlib.metadata for version detection""" - from awsiot import mqtt_connection_builder +class Config: + cache = None - # Mock importlib.metadata.version to return a known version - with patch("importlib.metadata.version") as mock_version: - mock_version.return_value = "1.2.3" + def __init__(self, endpoint): + self.endpoint = endpoint - # Call the function that uses version detection - # We need to access the private function for testing - result = mqtt_connection_builder._get_metrics_str("test_username") + @staticmethod + def get(): + """Raises SkipTest if credentials aren't set up correctly""" + if Config.cache: + return Config.cache - # Verify that importlib.metadata.version was called - mock_version.assert_called_once_with("awsiotsdk") + warnings.simplefilter('ignore', ResourceWarning) - # Verify the result contains the expected format - self.assertIn("SDK=PythonV2&Version=1.2.3", result) + try: + secrets = boto3.client('secretsmanager', region_name=AWS_DEFAULT_REGION) + response = secrets.get_secret_value(SecretId='unit-test/endpoint') + endpoint = response['SecretString'] + Config.cache = Config(endpoint) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as ex: + print(ex) + raise unittest.SkipTest("No credentials") - def test_metrics_string_generation_mqtt5_client_builder(self): - """Test that mqtt5_client_builder uses importlib.metadata for version detection""" - from awsiot import mqtt5_client_builder + return Config.cache - # Mock importlib.metadata.version to return a known version - with patch("importlib.metadata.version") as mock_version: - mock_version.return_value = "1.2.3" - # Call the function that uses version detection - # We need to access the private function for testing - result = mqtt5_client_builder._get_metrics_str("test_username") +def create_client_id(): + return 'test-aws-iot-device-sdk-python-v2-unit-test-{0}'.format(uuid.uuid4()) - # Verify that importlib.metadata.version was called - mock_version.assert_called_once_with("awsiotsdk") +class TestGetSdkVersion(unittest.TestCase): - # Verify the result contains the expected format - self.assertIn("SDK=PythonV2&Version=1.2.3", result) + def test_calls_importlib_metadata(self): + with patch("importlib.metadata.version") as mock_version: + mock_version.return_value = "1.2.3" + result = _get_sdk_version() + mock_version.assert_called_once_with("awsiotsdk") + self.assertEqual(result, "1.2.3") - def test_package_not_found_handling_mqtt_connection_builder(self): - """Test that PackageNotFoundError is handled correctly in mqtt_connection_builder""" + def test_fallback_on_package_not_found(self): import importlib.metadata + with patch("importlib.metadata.version") as mock_version: + mock_version.side_effect = importlib.metadata.PackageNotFoundError("not found") + result = _get_sdk_version() + self.assertEqual(result, "dev") - from awsiot import mqtt_connection_builder - - # Mock importlib.metadata.version to raise PackageNotFoundError + def test_fallback_on_general_exception(self): with patch("importlib.metadata.version") as mock_version: - mock_version.side_effect = importlib.metadata.PackageNotFoundError("Package not found") + mock_version.side_effect = Exception("unexpected") + result = _get_sdk_version() + self.assertEqual(result, "dev") - # Call the function that uses version detection - result = mqtt_connection_builder._get_metrics_str("test_username") - # Verify that the fallback version is used - self.assertIn("SDK=PythonV2&Version=dev", result) +class TestBuildSdkMetrics(unittest.TestCase): - def test_package_not_found_handling_mqtt5_client_builder(self): - """Test that PackageNotFoundError is handled correctly in mqtt5_client_builder""" - import importlib.metadata + def test_library_name(self): + metrics = _build_sdk_metrics() + self.assertEqual(metrics.library_name, _SDK_LIBRARY_NAME) - from awsiot import mqtt5_client_builder + def test_contains_sdk_version(self): + with patch("awsiot._iot_metrics._get_sdk_version", return_value="1.2.3"): + metrics = _build_sdk_metrics() + entries = {e.key: e.value for e in metrics.metadata_entries} + self.assertIn("IoTSDKVersion", entries) + self.assertEqual(entries["IoTSDKVersion"], "1.2.3") + + def test_contains_metrics_version(self): + metrics = _build_sdk_metrics() + entries = {e.key: e.value for e in metrics.metadata_entries} + self.assertIn("IoTSDKMetricsVersion", entries) + self.assertEqual(entries["IoTSDKMetricsVersion"], str(_IOT_SDK_METRICS_VERSION)) + + def test_only_two_metadata_entries(self): + metrics = _build_sdk_metrics() + self.assertEqual(len(metrics.metadata_entries), 2) - # Mock importlib.metadata.version to raise PackageNotFoundError + def test_with_dev_fallback_version(self): with patch("importlib.metadata.version") as mock_version: - mock_version.side_effect = importlib.metadata.PackageNotFoundError("Package not found") + mock_version.side_effect = Exception("no package") + metrics = _build_sdk_metrics() + entries = {e.key: e.value for e in metrics.metadata_entries} + self.assertEqual(entries["IoTSDKVersion"], "dev") - # Call the function that uses version detection - result = mqtt5_client_builder._get_metrics_str("test_username") - # Verify that the fallback version is used - self.assertIn("SDK=PythonV2&Version=dev", result) +class TestMqtt3BuilderMetrics(unittest.TestCase): + """Test that mqtt_connection_builder passes disable_metrics correctly.""" - def test_general_exception_handling_mqtt_connection_builder(self): - """Test that general exceptions are handled correctly in mqtt_connection_builder""" + def test_metrics_enabled_by_default(self): + """When disable_metrics is not set, builder should pass metrics to Connection.""" + config = Config.get() + import awscrt.io + import awscrt.mqtt from awsiot import mqtt_connection_builder - # Mock importlib.metadata.version to raise a general exception - with patch("importlib.metadata.version") as mock_version: - mock_version.side_effect = Exception("Some other error") + with patch("awsiot._iot_metrics._get_sdk_version", return_value="2.0.0"), \ + patch.object(awscrt.mqtt, "Connection") as mock_conn, \ + patch.object(awscrt.mqtt, "Client"): + mqtt_connection_builder._builder( + awscrt.io.TlsContextOptions(), + endpoint=config.endpoint, + client_id=create_client_id(), + ) + + kwargs = mock_conn.call_args.kwargs + self.assertFalse(kwargs["disable_metrics"]) + self.assertIsNotNone(kwargs["metrics"]) + entries = {e.key: e.value for e in kwargs["metrics"].metadata_entries} + self.assertEqual(entries["IoTSDKVersion"], "2.0.0") + self.assertEqual(entries["IoTSDKMetricsVersion"], str(_IOT_SDK_METRICS_VERSION)) + + def test_metrics_disabled(self): + """When disable_metrics=True, builder should pass None metrics.""" + config = Config.get() + import awscrt.io + import awscrt.mqtt + from awsiot import mqtt_connection_builder - # Call the function that uses version detection - result = mqtt_connection_builder._get_metrics_str("test_username") + with patch.object(awscrt.mqtt, "Connection") as mock_conn, \ + patch.object(awscrt.mqtt, "Client"): + mqtt_connection_builder._builder( + awscrt.io.TlsContextOptions(), + endpoint=config.endpoint, + client_id=create_client_id(), + disable_metrics=True, + ) - # Verify that empty string is returned on general exception - self.assertEqual(result, "") + kwargs = mock_conn.call_args.kwargs + self.assertTrue(kwargs["disable_metrics"]) + self.assertIsNone(kwargs["metrics"]) - def test_general_exception_handling_mqtt5_client_builder(self): - """Test that general exceptions are handled correctly in mqtt5_client_builder""" - from awsiot import mqtt5_client_builder - # Mock importlib.metadata.version to raise a general exception - with patch("importlib.metadata.version") as mock_version: - mock_version.side_effect = Exception("Some other error") +class TestMqtt5BuilderMetrics(unittest.TestCase): + """Test that mqtt5_client_builder passes disable_metrics correctly.""" - # Call the function that uses version detection - result = mqtt5_client_builder._get_metrics_str("test_username") + def test_metrics_enabled_by_default(self): + """When disable_metrics is not set, builder should set metrics on client_options.""" + config = Config.get() + import awscrt.io + import awscrt.mqtt5 + from awsiot import mqtt5_client_builder - # Verify that empty string is returned on general exception - self.assertEqual(result, "") + with patch("awsiot._iot_metrics._get_sdk_version", return_value="2.0.0"), \ + patch.object(awscrt.mqtt5, "Client") as mock_client: + mqtt5_client_builder._builder( + awscrt.io.TlsContextOptions(), + endpoint=config.endpoint, + ) + + client_options = mock_client.call_args.kwargs["client_options"] + self.assertFalse(client_options.disable_metrics) + self.assertIsNotNone(client_options.metrics) + entries = {e.key: e.value for e in client_options.metrics.metadata_entries} + self.assertEqual(entries["IoTSDKVersion"], "2.0.0") + self.assertEqual(entries["IoTSDKMetricsVersion"], str(_IOT_SDK_METRICS_VERSION)) + + def test_metrics_disabled(self): + """When disable_metrics=True, builder should set None metrics on client_options.""" + config = Config.get() + import awscrt.io + import awscrt.mqtt5 + from awsiot import mqtt5_client_builder - def test_no_pkg_resources_import(self): - """Test that pkg_resources is not imported in the modified files""" - import awsiot.mqtt5_client_builder - import awsiot.mqtt_connection_builder + with patch.object(awscrt.mqtt5, "Client") as mock_client: + mqtt5_client_builder._builder( + awscrt.io.TlsContextOptions(), + endpoint=config.endpoint, + disable_metrics=True, + ) - # Check that pkg_resources is not in the module's globals - self.assertNotIn("pkg_resources", awsiot.mqtt_connection_builder.__dict__) - self.assertNotIn("pkg_resources", awsiot.mqtt5_client_builder.__dict__) + client_options = mock_client.call_args.kwargs["client_options"] + self.assertTrue(client_options.disable_metrics) + self.assertIsNone(client_options.metrics) if __name__ == "__main__":