From 00ad6afd5f5f11b26d12cbd05eac442e240d0cd9 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 27 May 2026 19:35:27 +0530 Subject: [PATCH 1/4] Add credential= parameter for custom Azure Identity credential support Add a new 'credential' parameter to connect() that accepts any object following the Azure TokenCredential protocol (.get_token() method). This allows users to authenticate with any azure-identity credential class without being limited to the driver's hardcoded credential map. Changes: - auth.py: Add _get_token_from_credential() shared helper, acquire_token_from_credential(), acquire_raw_token_from_credential() - db_connection.py: Add credential=None parameter to connect() - connection.py: Validate credential, acquire token, store for bulk copy token refresh. Mutually exclusive with Authentication= - cursor.py: Check _custom_credential before _auth_type in bulk copy - constants.py: Unify _KEY_* constants with _ALLOWED_CONNECTION_STRING_PARAMS to use single source of truth (_CONNECTION_STRING_*_KEY pattern) - test_008_auth.py: Add 12 new tests for custom credential flow --- mssql_python/auth.py | 65 +++++++++++++ mssql_python/connection.py | 26 +++++- mssql_python/constants.py | 30 +++--- mssql_python/cursor.py | 19 +++- mssql_python/db_connection.py | 12 +++ tests/test_008_auth.py | 169 ++++++++++++++++++++++++++++++++++ 6 files changed, 305 insertions(+), 16 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index dd716c2c0..847574f12 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -238,3 +238,68 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: """ auth_value = parsed_params.get(_KEY_AUTHENTICATION, "").strip().lower() return _AUTH_TYPE_MAP.get(auth_value) + + +def _get_token_from_credential(credential: object) -> str: + """Internal: call credential.get_token() and return the raw JWT string. + + Centralises the token-acquisition + error-wrapping logic that both + :func:`acquire_token_from_credential` and + :func:`acquire_raw_token_from_credential` need. + + Raises: + RuntimeError: If token acquisition fails. + """ + try: + raw_token = credential.get_token("https://database.windows.net/.default").token + logger.info( + "_get_token_from_credential: Token acquired from %s - length=%d chars", + type(credential).__name__, + len(raw_token), + ) + return raw_token + except Exception as e: + logger.error( + "_get_token_from_credential: Failed - credential=%s, error=%s", + type(credential).__name__, + str(e), + ) + raise RuntimeError( + f"Failed to acquire token from credential " f"({type(credential).__name__}): {e}" + ) from e + + +def acquire_token_from_credential(credential: object) -> bytes: + """Acquire an ODBC token struct from a user-supplied credential object. + + The credential must follow the Azure ``TokenCredential`` protocol — i.e. + have a ``.get_token(scope)`` method returning an object with a ``.token`` + attribute (a raw JWT string). + + Args: + credential: Any object with a ``.get_token(scope)`` method. + + Returns: + bytes: ODBC-compatible token struct for ``SQL_COPT_SS_ACCESS_TOKEN``. + + Raises: + RuntimeError: If token acquisition fails. + """ + return AADAuth.get_token_struct(_get_token_from_credential(credential)) + + +def acquire_raw_token_from_credential(credential: object) -> str: + """Acquire a raw JWT string from a user-supplied credential object. + + Used by bulk copy, which needs the raw JWT rather than the ODBC struct. + + Args: + credential: Any object with a ``.get_token(scope)`` method. + + Returns: + str: Raw JWT token string. + + Raises: + RuntimeError: If token acquisition fails. + """ + return _get_token_from_credential(credential) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692e..3e81f3bac 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -246,6 +246,7 @@ def __init__( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> None: """ @@ -334,10 +335,33 @@ def __init__( # fresh token; re-parsing self.connection_str at that point would miss # them because UID is already gone. self._credential_kwargs: Optional[Dict[str, str]] = None + # User-supplied credential object for custom Entra ID authentication. + # Stored so bulk copy can call .get_token() for a fresh JWT later. + self._custom_credential = None + + # Custom credential= parameter — takes priority, mutually exclusive + # with Authentication= in the connection string. + if credential is not None: + if _KEY_AUTHENTICATION in parsed_params: + raise ValueError( + "Cannot specify both 'credential' parameter and " + "'Authentication' in the connection string. " + "Use one or the other." + ) + if not callable(getattr(credential, "get_token", None)): + raise TypeError( + f"credential must have a .get_token() method. " + f"Got {type(credential).__name__}." + ) + from mssql_python.auth import acquire_token_from_credential + + token = acquire_token_from_credential(credential) + self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token + self._custom_credential = credential # Handle Entra ID authentication if specified. # The parsed dict is used directly — no re-parsing of the connection string. - if _KEY_AUTHENTICATION in parsed_params: + elif _KEY_AUTHENTICATION in parsed_params: auth_type = process_auth_parameters(parsed_params) if auth_type: diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 3bfd39483..6283d6b90 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -467,6 +467,16 @@ def get_attribute_set_timing(attribute): _CONNECTION_STRING_DRIVER_KEY = "Driver" _CONNECTION_STRING_APP_KEY = "APP" +_CONNECTION_STRING_AUTH_KEY = "Authentication" +_CONNECTION_STRING_UID_KEY = "UID" +_CONNECTION_STRING_PWD_KEY = "PWD" +_CONNECTION_STRING_TRUSTED_CONNECTION_KEY = "Trusted_Connection" + +# Aliases used by auth.py / connection.py — kept for readability. +_KEY_AUTHENTICATION = _CONNECTION_STRING_AUTH_KEY +_KEY_UID = _CONNECTION_STRING_UID_KEY +_KEY_PWD = _CONNECTION_STRING_PWD_KEY +_KEY_TRUSTED_CONNECTION = _CONNECTION_STRING_TRUSTED_CONNECTION_KEY # Reserved connection string parameters that are controlled by the driver # and cannot be set by users @@ -486,16 +496,16 @@ def get_attribute_set_timing(attribute): "address": "Server", "addr": "Server", # Authentication - "uid": "UID", - "pwd": "PWD", - "authentication": "Authentication", - "trusted_connection": "Trusted_Connection", + "uid": _CONNECTION_STRING_UID_KEY, + "pwd": _CONNECTION_STRING_PWD_KEY, + "authentication": _CONNECTION_STRING_AUTH_KEY, + "trusted_connection": _CONNECTION_STRING_TRUSTED_CONNECTION_KEY, # Database "database": "Database", # Driver (always controlled by mssql-python) - "driver": "Driver", + "driver": _CONNECTION_STRING_DRIVER_KEY, # Application name (always controlled by mssql-python) - "app": "APP", + "app": _CONNECTION_STRING_APP_KEY, # Encryption and Security "encrypt": "Encrypt", "trustservercertificate": "TrustServerCertificate", @@ -519,14 +529,6 @@ def get_attribute_set_timing(attribute): "packetsize": "PacketSize", } -# Canonical normalized key names produced by _ConnectionStringParser._normalize_params. -# Consumer code should reference these instead of hard-coding raw strings so that -# a rename in _ALLOWED_CONNECTION_STRING_PARAMS is caught at import time. -_KEY_AUTHENTICATION = "Authentication" -_KEY_UID = "UID" -_KEY_PWD = "PWD" -_KEY_TRUSTED_CONNECTION = "Trusted_Connection" - def get_info_constants() -> Dict[str, int]: """ diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 49eb1b92d..966941d8d 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2956,7 +2956,24 @@ def bulkcopy( pycore_context = connstr_to_pycore_params(params) # Token acquisition — only thing cursor must handle (needs azure-identity SDK) - if self.connection._auth_type: + if self.connection._custom_credential is not None: + # User-supplied credential — use it directly for a fresh token. + from mssql_python.auth import acquire_raw_token_from_credential + + try: + raw_token = acquire_raw_token_from_credential(self.connection._custom_credential) + except RuntimeError as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}" + ) from e + pycore_context["access_token"] = raw_token + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh token from custom credential (%s)", + type(self.connection._custom_credential).__name__, + ) + elif self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection. credential # kwargs (e.g. user-assigned MSI client_id) were captured by # Connection.__init__ before remove_sensitive_params stripped UID diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index fe10b819b..1688a56ed 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -15,6 +15,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> Connection: """ @@ -35,6 +36,16 @@ def connect( This per-connection override is useful for migration from pyodbc: connections that need string UUIDs can pass native_uuid=False, while the default (True) returns native uuid.UUID objects. + credential (object, optional): An Azure Identity credential object (or any object with a + ``.get_token(scope)`` method) used for Entra ID authentication. When provided, the + driver calls ``credential.get_token()`` to acquire a token instead of using the + built-in credential map. Cannot be combined with ``Authentication=`` in the + connection string. + + For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` in + the connection string — ``DefaultAzureCredential`` automatically picks the right + credential per environment. Use ``credential=`` only when you need explicit control + (e.g., excluding specific providers or using a credential not in the built-in map). Keyword Args: **kwargs: Additional key/value pairs for the connection string. Below attributes are not implemented in the internal driver: @@ -58,6 +69,7 @@ def connect( attrs_before=attrs_before, timeout=timeout, native_uuid=native_uuid, + credential=credential, **kwargs, ) return conn diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index b127133a5..fe6937723 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -17,6 +17,8 @@ extract_auth_type, _credential_cache, _credential_cache_lock, + acquire_token_from_credential, + acquire_raw_token_from_credential, ) from mssql_python.constants import AuthType, ConstantsDDBC import secrets @@ -522,6 +524,7 @@ def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" mock_conn._auth_type = "msi" mock_conn._credential_kwargs = {"client_id": client_id} + mock_conn._custom_credential = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) @@ -970,3 +973,169 @@ def test_token_output_correct_on_cache_miss_and_hit(self): # Same credential instance for both assert "default" in _credential_cache + + +# ── Custom credential= parameter tests ── + + +class TestAcquireTokenFromCredential: + """Tests for the acquire_token_from_credential helper.""" + + def test_happy_path(self): + """acquire_token_from_credential returns a valid token struct.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + result = acquire_token_from_credential(mock_cred) + assert isinstance(result, bytes) + assert len(result) > 4 + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_credential_raises_exception(self): + """acquire_token_from_credential wraps credential errors in RuntimeError.""" + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + acquire_token_from_credential(mock_cred) + + +class TestAcquireRawTokenFromCredential: + """Tests for the acquire_raw_token_from_credential helper.""" + + def test_happy_path(self): + """acquire_raw_token_from_credential returns the raw JWT string.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + result = acquire_raw_token_from_credential(mock_cred) + assert result == SAMPLE_TOKEN + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_credential_raises_exception(self): + """acquire_raw_token_from_credential wraps credential errors in RuntimeError.""" + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + acquire_raw_token_from_credential(mock_cred) + + +class TestCustomCredentialConnect: + """Tests for the credential= parameter on connect().""" + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_happy_path(self, mock_ddbc_conn): + """credential= acquires token and sets attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("Server=test;Database=testdb", credential=mock_cred) + assert conn._custom_credential is mock_cred + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + # Existing auth_type should be None (no Authentication= in conn str) + assert conn._auth_type is None + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): + """credential= + Authentication= raises ValueError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(ValueError, match="Cannot specify both"): + connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault", + credential=mock_cred, + ) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): + """credential= + Authentication via kwargs raises ValueError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(ValueError, match="Cannot specify both"): + connect( + "Server=test;Database=testdb", + credential=mock_cred, + Authentication="ActiveDirectoryDefault", + ) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_without_get_token_raises_typeerror(self, mock_ddbc_conn): + """Passing an object without .get_token() raises TypeError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + with pytest.raises(TypeError, match="credential must have a .get_token"): + connect("Server=test;Database=testdb", credential="not_a_credential") + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_none_uses_existing_flow(self, mock_ddbc_conn): + """credential=None (default) uses existing auth flow, no change.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") + assert conn._custom_credential is None + assert conn._auth_type == "default" + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): + """credential= works alongside non-auth attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT + conn = connect( + "Server=test;Database=testdb", + credential=mock_cred, + attrs_before={login_timeout_attr: 30}, + ) + assert conn._attrs_before[login_timeout_attr] == 30 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_get_token_failure_raises_runtime_error(self, mock_ddbc_conn): + """If credential.get_token() fails, connect() raises RuntimeError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("token acquisition failed") + from mssql_python import connect + + with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + connect("Server=test;Database=testdb", credential=mock_cred) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): + """Object with .get_token as a non-callable attribute raises TypeError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class BadCredential: + get_token = "not_a_method" + + with pytest.raises(TypeError, match="credential must have a .get_token"): + connect("Server=test;Database=testdb", credential=BadCredential()) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): + """Two connections can share the same credential object safely.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn1 = connect("Server=test1;Database=db1", credential=mock_cred) + conn2 = connect("Server=test2;Database=db2", credential=mock_cred) + assert conn1._custom_credential is conn2._custom_credential + assert mock_cred.get_token.call_count == 2 + conn1.close() + conn2.close() From 36ccf52001c083e0675edb96814c5fb2e7af6b50 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 27 May 2026 19:51:39 +0530 Subject: [PATCH 2/4] Sanitize conn string in credential= path, update .pyi stubs - Strip UID/PWD/Trusted_Connection from connection_str when credential= is used (same as Authentication= path) to avoid leaking unused secrets - Add credential= parameter to Connection.__init__ and connect() in mssql_python.pyi type stubs --- mssql_python/connection.py | 4 ++++ mssql_python/mssql_python.pyi | 2 ++ 2 files changed, 6 insertions(+) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 3e81f3bac..8ecba7df6 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -358,6 +358,10 @@ def __init__( token = acquire_token_from_credential(credential) self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token self._custom_credential = credential + # Strip sensitive params (UID/PWD/Trusted_Connection) since + # access-token auth is used — same as the Authentication= path. + sanitized = remove_sensitive_params(parsed_params) + self.connection_str = _ConnectionStringBuilder(sanitized).build() # Handle Entra ID authentication if specified. # The parsed dict is used directly — no re-parsing of the connection string. diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index 9b08913d6..ef2655bab 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -246,6 +246,7 @@ class Connection: attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> None: ... @@ -289,6 +290,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> Connection: ... From 4a5687a950ed7c37d17cbd90aa57691f9768765f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 27 May 2026 20:02:35 +0530 Subject: [PATCH 3/4] Fix bulkcopy auth test mock to set _custom_credential = None The _make_cursor helper uses MagicMock for the connection, which auto-creates truthy attributes. Without explicitly setting _custom_credential = None, the bulk copy code takes the custom credential path instead of the expected _auth_type path. --- tests/test_020_bulkcopy_auth_cleanup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 164438344..404faca91 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -22,6 +22,7 @@ def _make_cursor(connection_str, auth_type): mock_conn = MagicMock() mock_conn.connection_str = connection_str mock_conn._auth_type = auth_type + mock_conn._custom_credential = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) From 2304871b74477c9d1c295f08d1863b073f38293f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 29 May 2026 12:40:16 +0530 Subject: [PATCH 4/4] Rename credential= to token_provider= per team feedback Rename the public API parameter from 'credential' to 'token_provider' to reduce ambiguity in our multi-auth-path context. 'credential' could be confused with SQL auth username/password; 'token_provider' clearly signals token-based Entra ID auth. - Rename parameter: credential -> token_provider (connect, Connection) - Rename internal attr: _custom_credential -> _token_provider - Update error messages, docstrings, comments, .pyi stubs - Improve docstring with usage example and explicit guidance - All 97 tests pass --- mssql_python/connection.py | 22 ++++++------- mssql_python/cursor.py | 6 ++-- mssql_python/db_connection.py | 35 +++++++++++++------- mssql_python/mssql_python.pyi | 4 +-- tests/test_008_auth.py | 44 ++++++++++++------------- tests/test_020_bulkcopy_auth_cleanup.py | 2 +- 6 files changed, 63 insertions(+), 50 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 8ecba7df6..be6fc9a57 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -246,7 +246,7 @@ def __init__( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> None: """ @@ -335,29 +335,29 @@ def __init__( # fresh token; re-parsing self.connection_str at that point would miss # them because UID is already gone. self._credential_kwargs: Optional[Dict[str, str]] = None - # User-supplied credential object for custom Entra ID authentication. + # User-supplied token provider for custom Entra ID authentication. # Stored so bulk copy can call .get_token() for a fresh JWT later. - self._custom_credential = None + self._token_provider = None - # Custom credential= parameter — takes priority, mutually exclusive + # Custom token_provider= parameter — takes priority, mutually exclusive # with Authentication= in the connection string. - if credential is not None: + if token_provider is not None: if _KEY_AUTHENTICATION in parsed_params: raise ValueError( - "Cannot specify both 'credential' parameter and " + "Cannot specify both 'token_provider' parameter and " "'Authentication' in the connection string. " "Use one or the other." ) - if not callable(getattr(credential, "get_token", None)): + if not callable(getattr(token_provider, "get_token", None)): raise TypeError( - f"credential must have a .get_token() method. " - f"Got {type(credential).__name__}." + f"token_provider must have a .get_token() method. " + f"Got {type(token_provider).__name__}." ) from mssql_python.auth import acquire_token_from_credential - token = acquire_token_from_credential(credential) + token = acquire_token_from_credential(token_provider) self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token - self._custom_credential = credential + self._token_provider = token_provider # Strip sensitive params (UID/PWD/Trusted_Connection) since # access-token auth is used — same as the Authentication= path. sanitized = remove_sensitive_params(parsed_params) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index c0b3e02e6..ea47d851e 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2961,12 +2961,12 @@ def bulkcopy( pycore_context = connstr_to_pycore_params(params) # Token acquisition — only thing cursor must handle (needs azure-identity SDK) - if self.connection._custom_credential is not None: + if self.connection._token_provider is not None: # User-supplied credential — use it directly for a fresh token. from mssql_python.auth import acquire_raw_token_from_credential try: - raw_token = acquire_raw_token_from_credential(self.connection._custom_credential) + raw_token = acquire_raw_token_from_credential(self.connection._token_provider) except RuntimeError as e: raise RuntimeError( f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}" @@ -2976,7 +2976,7 @@ def bulkcopy( pycore_context.pop(key, None) logger.debug( "Bulk copy: acquired fresh token from custom credential (%s)", - type(self.connection._custom_credential).__name__, + type(self.connection._token_provider).__name__, ) elif self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection. credential diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index 1688a56ed..894440009 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -15,7 +15,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> Connection: """ @@ -36,16 +36,29 @@ def connect( This per-connection override is useful for migration from pyodbc: connections that need string UUIDs can pass native_uuid=False, while the default (True) returns native uuid.UUID objects. - credential (object, optional): An Azure Identity credential object (or any object with a - ``.get_token(scope)`` method) used for Entra ID authentication. When provided, the - driver calls ``credential.get_token()`` to acquire a token instead of using the - built-in credential map. Cannot be combined with ``Authentication=`` in the - connection string. + token_provider (object, optional): A token provider for Microsoft Entra ID + authentication. This must be any object with a ``.get_token(scope)`` method that + returns an object with a ``.token`` attribute containing a raw JWT string — for + example, any ``azure-identity`` credential class such as + ``DefaultAzureCredential``, ``AzureCliCredential``, ``ManagedIdentityCredential``, + ``CertificateCredential``, etc. - For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` in - the connection string — ``DefaultAzureCredential`` automatically picks the right - credential per environment. Use ``credential=`` only when you need explicit control - (e.g., excluding specific providers or using a credential not in the built-in map). + When provided, the driver calls ``token_provider.get_token()`` to acquire an + access token for SQL Server, bypassing the built-in credential map. + Cannot be combined with ``Authentication=`` in the connection string. + + For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` + in the connection string — ``DefaultAzureCredential`` automatically picks the + right credential per environment (CLI on dev, Managed Identity in prod). + Use ``token_provider=`` only when you need explicit control over token + acquisition (e.g., excluding specific providers, using a credential not in + the built-in map, or passing custom options to the credential constructor). + + Example:: + + from azure.identity import AzureCliCredential + conn = mssql_python.connect("Server=s;Database=d", + token_provider=AzureCliCredential()) Keyword Args: **kwargs: Additional key/value pairs for the connection string. Below attributes are not implemented in the internal driver: @@ -69,7 +82,7 @@ def connect( attrs_before=attrs_before, timeout=timeout, native_uuid=native_uuid, - credential=credential, + token_provider=token_provider, **kwargs, ) return conn diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index e419b19bb..05aeec499 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -248,7 +248,7 @@ class Connection: attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> None: ... @@ -292,7 +292,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> Connection: ... diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index fe6937723..4dd13569b 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -524,7 +524,7 @@ def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" mock_conn._auth_type = "msi" mock_conn._credential_kwargs = {"client_id": client_id} - mock_conn._custom_credential = None + mock_conn._token_provider = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) @@ -975,7 +975,7 @@ def test_token_output_correct_on_cache_miss_and_hit(self): assert "default" in _credential_cache -# ── Custom credential= parameter tests ── +# ── Custom token_provider= parameter tests ── class TestAcquireTokenFromCredential: @@ -1018,18 +1018,18 @@ def test_credential_raises_exception(self): class TestCustomCredentialConnect: - """Tests for the credential= parameter on connect().""" + """Tests for the token_provider= parameter on connect().""" @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_happy_path(self, mock_ddbc_conn): - """credential= acquires token and sets attrs_before.""" + """token_provider= acquires token and sets attrs_before.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - conn = connect("Server=test;Database=testdb", credential=mock_cred) - assert conn._custom_credential is mock_cred + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_provider is mock_cred assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before # Existing auth_type should be None (no Authentication= in conn str) assert conn._auth_type is None @@ -1037,7 +1037,7 @@ def test_credential_happy_path(self, mock_ddbc_conn): @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): - """credential= + Authentication= raises ValueError.""" + """token_provider= + Authentication= raises ValueError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1046,12 +1046,12 @@ def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): with pytest.raises(ValueError, match="Cannot specify both"): connect( "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault", - credential=mock_cred, + token_provider=mock_cred, ) @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): - """credential= + Authentication via kwargs raises ValueError.""" + """token_provider= + Authentication via kwargs raises ValueError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1060,7 +1060,7 @@ def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ with pytest.raises(ValueError, match="Cannot specify both"): connect( "Server=test;Database=testdb", - credential=mock_cred, + token_provider=mock_cred, Authentication="ActiveDirectoryDefault", ) @@ -1070,23 +1070,23 @@ def test_credential_without_get_token_raises_typeerror(self, mock_ddbc_conn): mock_ddbc_conn.return_value = MagicMock() from mssql_python import connect - with pytest.raises(TypeError, match="credential must have a .get_token"): - connect("Server=test;Database=testdb", credential="not_a_credential") + with pytest.raises(TypeError, match="token_provider must have a .get_token"): + connect("Server=test;Database=testdb", token_provider="not_a_credential") @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_none_uses_existing_flow(self, mock_ddbc_conn): - """credential=None (default) uses existing auth flow, no change.""" + """token_provider=None (default) uses existing auth flow, no change.""" mock_ddbc_conn.return_value = MagicMock() from mssql_python import connect conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") - assert conn._custom_credential is None + assert conn._token_provider is None assert conn._auth_type == "default" conn.close() @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): - """credential= works alongside non-auth attrs_before.""" + """token_provider= works alongside non-auth attrs_before.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1095,7 +1095,7 @@ def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT conn = connect( "Server=test;Database=testdb", - credential=mock_cred, + token_provider=mock_cred, attrs_before={login_timeout_attr: 30}, ) assert conn._attrs_before[login_timeout_attr] == 30 @@ -1111,7 +1111,7 @@ def test_credential_get_token_failure_raises_runtime_error(self, mock_ddbc_conn) from mssql_python import connect with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): - connect("Server=test;Database=testdb", credential=mock_cred) + connect("Server=test;Database=testdb", token_provider=mock_cred) @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): @@ -1122,8 +1122,8 @@ def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc class BadCredential: get_token = "not_a_method" - with pytest.raises(TypeError, match="credential must have a .get_token"): - connect("Server=test;Database=testdb", credential=BadCredential()) + with pytest.raises(TypeError, match="token_provider must have a .get_token"): + connect("Server=test;Database=testdb", token_provider=BadCredential()) @patch("mssql_python.connection.ddbc_bindings.Connection") def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): @@ -1133,9 +1133,9 @@ def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - conn1 = connect("Server=test1;Database=db1", credential=mock_cred) - conn2 = connect("Server=test2;Database=db2", credential=mock_cred) - assert conn1._custom_credential is conn2._custom_credential + conn1 = connect("Server=test1;Database=db1", token_provider=mock_cred) + conn2 = connect("Server=test2;Database=db2", token_provider=mock_cred) + assert conn1._token_provider is conn2._token_provider assert mock_cred.get_token.call_count == 2 conn1.close() conn2.close() diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 404faca91..4863c5e0e 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -22,7 +22,7 @@ def _make_cursor(connection_str, auth_type): mock_conn = MagicMock() mock_conn.connection_str = connection_str mock_conn._auth_type = auth_type - mock_conn._custom_credential = None + mock_conn._token_provider = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor)