Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
import threading
import uuid
import warnings
from unittest.mock import patch

import httpretty
Expand All @@ -27,6 +28,7 @@
from tests.unit.oauth_test_utils import SERVER_ADDRESS
from tests.unit.oauth_test_utils import TOKEN_RESOURCE
from trino import constants
from trino.auth import BasicAuthentication
from trino.auth import OAuth2Authentication
from trino.dbapi import connect
from trino.dbapi import Connection
Expand Down Expand Up @@ -362,3 +364,14 @@ def test_default_encoding_zstd():
def test_default_encoding_all():
connection = Connection("host", 8080, user="test")
assert connection._client_session.encoding == ["json+zstd", "json+lz4", "json"]


def test_warning_when_auth_over_http():
with pytest.warns(UserWarning, match="Authentication credentials are being sent over HTTP"):
Connection("mytrinoserver.domain", auth=BasicAuthentication("u", "p"))


def test_no_warning_when_auth_over_https():
with warnings.catch_warnings():
warnings.simplefilter("error")
Connection("mytrinoserver.domain", http_scheme=constants.HTTPS, auth=BasicAuthentication("u", "p"))
9 changes: 9 additions & 0 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import datetime
import math
import uuid
import warnings
from collections import OrderedDict
from decimal import Decimal
from itertools import islice
Expand Down Expand Up @@ -215,6 +216,14 @@ def __init__(
else:
self.http_scheme = constants.HTTP

if auth is not None and self.http_scheme == constants.HTTP:
warnings.warn(
"Authentication credentials are being sent over HTTP. "
"To use HTTPS, specify 'https://' in the host URL (which takes precedence "
"over http_scheme), or, if the host URL has no scheme, pass http_scheme='https'.",
stacklevel=2,
)

# Infer connection port: `hostname` takes precedence over explicit `port` argument
# If none is given, use default based on HTTP protocol
default_port = constants.DEFAULT_TLS_PORT if self.http_scheme == constants.HTTPS else constants.DEFAULT_PORT
Expand Down