diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index e3821bba..dfe71fc0 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -11,6 +11,7 @@ # limitations under the License. import threading import uuid +import warnings from unittest.mock import patch import httpretty @@ -27,6 +28,7 @@ from tests.unit.oauth_test_utils import SERVER_ADDRESS from tests.unit.oauth_test_utils import TOKEN_RESOURCE from trino import constants +from trino.auth import BasicAuthentication from trino.auth import OAuth2Authentication from trino.dbapi import connect from trino.dbapi import Connection @@ -362,3 +364,14 @@ def test_default_encoding_zstd(): def test_default_encoding_all(): connection = Connection("host", 8080, user="test") assert connection._client_session.encoding == ["json+zstd", "json+lz4", "json"] + + +def test_warning_when_auth_over_http(): + with pytest.warns(UserWarning, match="Authentication credentials are being sent over HTTP"): + Connection("mytrinoserver.domain", auth=BasicAuthentication("u", "p")) + + +def test_no_warning_when_auth_over_https(): + with warnings.catch_warnings(): + warnings.simplefilter("error") + Connection("mytrinoserver.domain", http_scheme=constants.HTTPS, auth=BasicAuthentication("u", "p")) diff --git a/trino/dbapi.py b/trino/dbapi.py index 42eeb547..60d87808 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -20,6 +20,7 @@ import datetime import math import uuid +import warnings from collections import OrderedDict from decimal import Decimal from itertools import islice @@ -215,6 +216,14 @@ def __init__( else: self.http_scheme = constants.HTTP + if auth is not None and self.http_scheme == constants.HTTP: + warnings.warn( + "Authentication credentials are being sent over HTTP. " + "To use HTTPS, specify 'https://' in the host URL (which takes precedence " + "over http_scheme), or, if the host URL has no scheme, pass http_scheme='https'.", + stacklevel=2, + ) + # Infer connection port: `hostname` takes precedence over explicit `port` argument # If none is given, use default based on HTTP protocol default_port = constants.DEFAULT_TLS_PORT if self.http_scheme == constants.HTTPS else constants.DEFAULT_PORT