diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692..4578cc73 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1114,6 +1114,172 @@ def clear_output_converters(self) -> None: self._conn.clear_output_converters() logger.info("Cleared all output converters") + # ---- Session Metadata / Auditing API ---- + + # Maximum length for session context keys and values to prevent abuse. + _AUDIT_KEY_MAX_LEN: int = 128 + _AUDIT_VALUE_MAX_LEN: int = 4000 + + def set_audit_context( + self, + *, + application: Optional[str] = None, + module: Optional[str] = None, + action: Optional[str] = None, + user_id: Optional[str] = None, + read_only: bool = False, + **extra: str, + ) -> None: + """ + Set session-level auditing / tracing metadata on the current connection. + + This stores name-value pairs in the SQL Server session context via + ``sp_set_session_context``, making them visible to: + + * ``SESSION_CONTEXT()`` in T-SQL queries, triggers, and stored procedures + * Extended Events sessions that capture session context + * ``sys.dm_exec_sessions`` (for *application*) + * Audit specifications that reference session context + + All parameters are optional; only the ones provided will be set. + Calling this method again merges new values with previously-set ones; + to clear a key pass an empty string ``""``. + + Args: + application: Logical application name (sets ``application_name``). + module: Module or component name (sets ``module_name``). + action: Current action or operation (sets ``action_name``). + user_id: End-user identifier (sets ``user_id``). + read_only: If ``True``, the keys become read-only for the + remainder of the session — subsequent calls cannot change them. + **extra: Arbitrary additional key-value pairs to store in the + session context. + + Raises: + InterfaceError: If the connection is closed. + ProgrammingError: If a key or value exceeds length limits + DatabaseError: If ``sp_set_session_context`` execution fails. + + Example:: + + conn.set_audit_context( + application="BillingAPI", + module="InvoiceProcessor", + action="GenerateInvoice", + user_id="123", + ) + # Values are now readable in T-SQL: + # SELECT SESSION_CONTEXT(N'application_name') + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot set audit context on a closed connection", + ) + + # Build the mapping of keys to set + pairs: Dict[str, str] = {} + if application is not None: + pairs["application_name"] = application + if module is not None: + pairs["module_name"] = module + if action is not None: + pairs["action_name"] = action + if user_id is not None: + pairs["user_id"] = user_id + for key, value in extra.items(): + pairs[key] = value + + if not pairs: + return # nothing to do + + # Validate lengths + for key, value in pairs.items(): + if not isinstance(key, str) or not key: + raise ProgrammingError( + driver_error="Invalid audit context key", + ddbc_error="Session context key must be a non-empty string", + ) + if len(key) > self._AUDIT_KEY_MAX_LEN: + raise ProgrammingError( + driver_error="Audit context key too long", + ddbc_error=( + f"Session context key exceeds {self._AUDIT_KEY_MAX_LEN} characters" + ), + ) + if not isinstance(value, str): + raise ProgrammingError( + driver_error="Invalid audit context value", + ddbc_error="Session context values must be strings", + ) + if len(value) > self._AUDIT_VALUE_MAX_LEN: + raise ProgrammingError( + driver_error="Audit context value too long", + ddbc_error=( + f"Session context value exceeds {self._AUDIT_VALUE_MAX_LEN} characters" + ), + ) + + # Initialize local cache if first call + if not hasattr(self, "_audit_context"): + self._audit_context: Dict[str, str] = {} + + # Execute sp_set_session_context for each pair using parameterized queries + cursor = self.cursor() + try: + for key, value in pairs.items(): + # Empty string means "clear"; sp_set_session_context requires NULL + sql_value = None if value == "" else value + if read_only: + cursor.execute( + "EXEC sp_set_session_context @key=?, @value=?, @read_only=1", + key, + sql_value, + ) + else: + cursor.execute( + "EXEC sp_set_session_context @key=?, @value=?", + key, + sql_value, + ) + if value == "": + self._audit_context.pop(key, None) + else: + self._audit_context[key] = value + logger.debug("Set session context: %s", sanitize_user_input(key)) + finally: + cursor.close() + + logger.info( + "Audit context set with %d key(s): %s", + len(pairs), + ", ".join(sanitize_user_input(k) for k in pairs), + ) + + def get_audit_context(self) -> Dict[str, str]: + """ + Return a copy of the session audit context previously set via + :meth:`set_audit_context`. + + This returns the *locally cached* values — it does not round-trip to + the server. To verify server-side values, query + ``SESSION_CONTEXT(N'')`` directly. + + Returns: + dict: A ``{key: value}`` mapping of the current session context. + + Raises: + InterfaceError: If the connection is closed. + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot get audit context on a closed connection", + ) + if not hasattr(self, "_audit_context"): + return {} + return dict(self._audit_context) + def execute(self, sql: str, *args: Any) -> Cursor: """ Creates a new Cursor object, calls its execute method, and returns the new cursor. diff --git a/tests/test_023_audit_context.py b/tests/test_023_audit_context.py new file mode 100644 index 00000000..8c1c723d --- /dev/null +++ b/tests/test_023_audit_context.py @@ -0,0 +1,144 @@ +""" +Tests for the session metadata / auditing API (set_audit_context / get_audit_context). + +Functions: +- test_set_and_get_audit_context: Set named fields and verify local cache. +- test_audit_context_server_roundtrip: Verify values are readable via SESSION_CONTEXT(). +- test_audit_context_extra_keys: Test arbitrary extra key-value pairs. +- test_audit_context_merge: Successive calls merge, not replace. +- test_audit_context_empty_call: Calling with no arguments is a no-op. +- test_audit_context_clear_value: Setting a key to "" clears it server-side. +- test_audit_context_read_only: read_only=True prevents subsequent changes. +- test_audit_context_closed_connection: Raises InterfaceError when connection is closed. +- test_audit_context_key_too_long: Raises ProgrammingError for oversized keys. +- test_audit_context_value_too_long: Raises ProgrammingError for oversized values. +- test_audit_context_non_string_value: Raises ProgrammingError for non-string values. +""" + +import pytest +from mssql_python import connect +from mssql_python.exceptions import InterfaceError, ProgrammingError, DatabaseError + + +@pytest.fixture() +def audit_conn(conn_str): + """Dedicated connection for audit context tests (module-scoped fixtures + would share session state, so we create a fresh connection per test).""" + conn = connect(conn_str) + yield conn + conn.close() + + +class TestAuditContext: + """Tests for Connection.set_audit_context / get_audit_context.""" + + def test_set_and_get_audit_context(self, audit_conn): + """Named fields are reflected in the local cache.""" + audit_conn.set_audit_context( + application="BillingAPI", + module="InvoiceProcessor", + action="GenerateInvoice", + user_id="123", + ) + ctx = audit_conn.get_audit_context() + assert ctx["application_name"] == "BillingAPI" + assert ctx["module_name"] == "InvoiceProcessor" + assert ctx["action_name"] == "GenerateInvoice" + assert ctx["user_id"] == "123" + + def test_audit_context_server_roundtrip(self, audit_conn): + """Values set via set_audit_context are readable with SESSION_CONTEXT().""" + audit_conn.set_audit_context(application="RoundTrip", user_id="42") + cursor = audit_conn.cursor() + try: + cursor.execute("SELECT SESSION_CONTEXT(N'application_name')") + row = cursor.fetchone() + assert row[0] == "RoundTrip" + + cursor.execute("SELECT SESSION_CONTEXT(N'user_id')") + row = cursor.fetchone() + assert row[0] == "42" + finally: + cursor.close() + + def test_audit_context_extra_keys(self, audit_conn): + """Arbitrary extra keys are stored via sp_set_session_context.""" + audit_conn.set_audit_context(tenant_id="ACME", correlation_id="abc-def") + ctx = audit_conn.get_audit_context() + assert ctx["tenant_id"] == "ACME" + assert ctx["correlation_id"] == "abc-def" + + # Verify server-side + cursor = audit_conn.cursor() + try: + cursor.execute("SELECT SESSION_CONTEXT(N'tenant_id')") + assert cursor.fetchone()[0] == "ACME" + finally: + cursor.close() + + def test_audit_context_merge(self, audit_conn): + """Successive calls merge values, not replace.""" + audit_conn.set_audit_context(application="App1") + audit_conn.set_audit_context(module="Mod1") + ctx = audit_conn.get_audit_context() + assert ctx["application_name"] == "App1" + assert ctx["module_name"] == "Mod1" + + def test_audit_context_overwrite(self, audit_conn): + """A second call with the same key overwrites the previous value.""" + audit_conn.set_audit_context(action="First") + audit_conn.set_audit_context(action="Second") + assert audit_conn.get_audit_context()["action_name"] == "Second" + + def test_audit_context_empty_call(self, audit_conn): + """Calling with no arguments is a silent no-op.""" + audit_conn.set_audit_context() + assert audit_conn.get_audit_context() == {} + + def test_audit_context_clear_value(self, audit_conn): + """Setting a key to '' clears it (sends NULL to the server).""" + audit_conn.set_audit_context(user_id="99") + audit_conn.set_audit_context(user_id="") + assert "user_id" not in audit_conn.get_audit_context() + + def test_audit_context_read_only(self, audit_conn): + """read_only=True makes the key immutable for the session.""" + audit_conn.set_audit_context(action="Locked", read_only=True) + # Attempting to change a read-only key should raise a DatabaseError + # from SQL Server (error 15664). + with pytest.raises(DatabaseError): + audit_conn.set_audit_context(action="Changed") + + def test_audit_context_closed_connection_set(self, audit_conn): + """set_audit_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.set_audit_context(application="X") + + def test_audit_context_closed_connection_get(self, audit_conn): + """get_audit_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.get_audit_context() + + def test_audit_context_key_too_long(self, audit_conn): + """Keys longer than 128 characters are rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_audit_context(**{"x" * 200: "v"}) + + def test_audit_context_value_too_long(self, audit_conn): + """Values longer than 4000 characters are rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_audit_context(user_id="v" * 4001) + + def test_audit_context_non_string_value(self, audit_conn): + """Non-string values are rejected with ProgrammingError.""" + with pytest.raises(ProgrammingError): + audit_conn.set_audit_context(user_id=123) # type: ignore[arg-type] + + def test_get_audit_context_returns_copy(self, audit_conn): + """get_audit_context returns a copy, not the internal dict.""" + audit_conn.set_audit_context(application="Copy") + ctx = audit_conn.get_audit_context() + ctx["application_name"] = "Mutated" + assert audit_conn.get_audit_context()["application_name"] == "Copy"