Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,68 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]:
"""
auth_value = parsed_params.get(_KEY_AUTHENTICATION, "").strip().lower()
return _AUTH_TYPE_MAP.get(auth_value)


def _get_token_from_credential(credential: object) -> str:
"""Internal: call credential.get_token() and return the raw JWT string.

Centralises the token-acquisition + error-wrapping logic that both
:func:`acquire_token_from_credential` and
:func:`acquire_raw_token_from_credential` need.

Raises:
RuntimeError: If token acquisition fails.
"""
try:
raw_token = credential.get_token("https://database.windows.net/.default").token
logger.info(
"_get_token_from_credential: Token acquired from %s - length=%d chars",
type(credential).__name__,
len(raw_token),
)
return raw_token
except Exception as e:
logger.error(
"_get_token_from_credential: Failed - credential=%s, error=%s",
type(credential).__name__,
str(e),
)
raise RuntimeError(
f"Failed to acquire token from credential " f"({type(credential).__name__}): {e}"
) from e


def acquire_token_from_credential(credential: object) -> bytes:
"""Acquire an ODBC token struct from a user-supplied credential object.

The credential must follow the Azure ``TokenCredential`` protocol — i.e.
have a ``.get_token(scope)`` method returning an object with a ``.token``
attribute (a raw JWT string).

Args:
credential: Any object with a ``.get_token(scope)`` method.

Returns:
bytes: ODBC-compatible token struct for ``SQL_COPT_SS_ACCESS_TOKEN``.

Raises:
RuntimeError: If token acquisition fails.
"""
return AADAuth.get_token_struct(_get_token_from_credential(credential))


def acquire_raw_token_from_credential(credential: object) -> str:
"""Acquire a raw JWT string from a user-supplied credential object.

Used by bulk copy, which needs the raw JWT rather than the ODBC struct.

Args:
credential: Any object with a ``.get_token(scope)`` method.

Returns:
str: Raw JWT token string.

Raises:
RuntimeError: If token acquisition fails.
"""
return _get_token_from_credential(credential)
30 changes: 29 additions & 1 deletion mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -334,10 +335,37 @@ def __init__(
# fresh token; re-parsing self.connection_str at that point would miss
# them because UID is already gone.
self._credential_kwargs: Optional[Dict[str, str]] = None
# User-supplied token provider for custom Entra ID authentication.
# Stored so bulk copy can call .get_token() for a fresh JWT later.
self._token_provider = None

# Custom token_provider= parameter — takes priority, mutually exclusive
# with Authentication= in the connection string.
if token_provider is not None:
if _KEY_AUTHENTICATION in parsed_params:
raise ValueError(
"Cannot specify both 'token_provider' parameter and "
"'Authentication' in the connection string. "
"Use one or the other."
)
if not callable(getattr(token_provider, "get_token", None)):
raise TypeError(
f"token_provider must have a .get_token() method. "
f"Got {type(token_provider).__name__}."
)
from mssql_python.auth import acquire_token_from_credential

token = acquire_token_from_credential(token_provider)
self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token
self._token_provider = token_provider
# Strip sensitive params (UID/PWD/Trusted_Connection) since
# access-token auth is used — same as the Authentication= path.
sanitized = remove_sensitive_params(parsed_params)
self.connection_str = _ConnectionStringBuilder(sanitized).build()

# Handle Entra ID authentication if specified.
# The parsed dict is used directly — no re-parsing of the connection string.
if _KEY_AUTHENTICATION in parsed_params:
elif _KEY_AUTHENTICATION in parsed_params:
auth_type = process_auth_parameters(parsed_params)

if auth_type:
Expand Down
30 changes: 16 additions & 14 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,16 @@ def get_attribute_set_timing(attribute):

_CONNECTION_STRING_DRIVER_KEY = "Driver"
_CONNECTION_STRING_APP_KEY = "APP"
_CONNECTION_STRING_AUTH_KEY = "Authentication"
_CONNECTION_STRING_UID_KEY = "UID"
_CONNECTION_STRING_PWD_KEY = "PWD"
_CONNECTION_STRING_TRUSTED_CONNECTION_KEY = "Trusted_Connection"

# Aliases used by auth.py / connection.py — kept for readability.
_KEY_AUTHENTICATION = _CONNECTION_STRING_AUTH_KEY
_KEY_UID = _CONNECTION_STRING_UID_KEY
_KEY_PWD = _CONNECTION_STRING_PWD_KEY
_KEY_TRUSTED_CONNECTION = _CONNECTION_STRING_TRUSTED_CONNECTION_KEY

# Reserved connection string parameters that are controlled by the driver
# and cannot be set by users
Expand All @@ -486,16 +496,16 @@ def get_attribute_set_timing(attribute):
"address": "Server",
"addr": "Server",
# Authentication
"uid": "UID",
"pwd": "PWD",
"authentication": "Authentication",
"trusted_connection": "Trusted_Connection",
"uid": _CONNECTION_STRING_UID_KEY,
"pwd": _CONNECTION_STRING_PWD_KEY,
"authentication": _CONNECTION_STRING_AUTH_KEY,
"trusted_connection": _CONNECTION_STRING_TRUSTED_CONNECTION_KEY,
# Database
"database": "Database",
# Driver (always controlled by mssql-python)
"driver": "Driver",
"driver": _CONNECTION_STRING_DRIVER_KEY,
# Application name (always controlled by mssql-python)
"app": "APP",
"app": _CONNECTION_STRING_APP_KEY,
# Encryption and Security
"encrypt": "Encrypt",
"trustservercertificate": "TrustServerCertificate",
Expand All @@ -519,14 +529,6 @@ def get_attribute_set_timing(attribute):
"packetsize": "PacketSize",
}

# Canonical normalized key names produced by _ConnectionStringParser._normalize_params.
# Consumer code should reference these instead of hard-coding raw strings so that
# a rename in _ALLOWED_CONNECTION_STRING_PARAMS is caught at import time.
_KEY_AUTHENTICATION = "Authentication"
_KEY_UID = "UID"
_KEY_PWD = "PWD"
_KEY_TRUSTED_CONNECTION = "Trusted_Connection"


def get_info_constants() -> Dict[str, int]:
"""
Expand Down
19 changes: 18 additions & 1 deletion mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2961,7 +2961,24 @@ def bulkcopy(
pycore_context = connstr_to_pycore_params(params)

# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
if self.connection._auth_type:
if self.connection._token_provider is not None:
# User-supplied credential — use it directly for a fresh token.
from mssql_python.auth import acquire_raw_token_from_credential

try:
raw_token = acquire_raw_token_from_credential(self.connection._token_provider)
except RuntimeError as e:
raise RuntimeError(
f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}"
) from e
pycore_context["access_token"] = raw_token
for key in ("authentication", "user_name", "password"):
pycore_context.pop(key, None)
logger.debug(
"Bulk copy: acquired fresh token from custom credential (%s)",
type(self.connection._token_provider).__name__,
)
elif self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection. credential
# kwargs (e.g. user-assigned MSI client_id) were captured by
# Connection.__init__ before remove_sensitive_params stripped UID
Expand Down
25 changes: 25 additions & 0 deletions mssql_python/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def connect(
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> Connection:
"""
Expand All @@ -35,6 +36,29 @@ def connect(
This per-connection override is useful for migration from pyodbc:
connections that need string UUIDs can pass native_uuid=False, while the default (True)
returns native uuid.UUID objects.
token_provider (object, optional): A token provider for Microsoft Entra ID
authentication. This must be any object with a ``.get_token(scope)`` method that
returns an object with a ``.token`` attribute containing a raw JWT string — for
example, any ``azure-identity`` credential class such as
``DefaultAzureCredential``, ``AzureCliCredential``, ``ManagedIdentityCredential``,
``CertificateCredential``, etc.

When provided, the driver calls ``token_provider.get_token()`` to acquire an
access token for SQL Server, bypassing the built-in credential map.
Cannot be combined with ``Authentication=`` in the connection string.

For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault``
in the connection string — ``DefaultAzureCredential`` automatically picks the
right credential per environment (CLI on dev, Managed Identity in prod).
Use ``token_provider=`` only when you need explicit control over token
acquisition (e.g., excluding specific providers, using a credential not in
the built-in map, or passing custom options to the credential constructor).

Example::

from azure.identity import AzureCliCredential
conn = mssql_python.connect("Server=s;Database=d",
token_provider=AzureCliCredential())
Keyword Args:
**kwargs: Additional key/value pairs for the connection string.
Below attributes are not implemented in the internal driver:
Expand All @@ -58,6 +82,7 @@ def connect(
attrs_before=attrs_before,
timeout=timeout,
native_uuid=native_uuid,
token_provider=token_provider,
**kwargs,
)
return conn
2 changes: 2 additions & 0 deletions mssql_python/mssql_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class Connection:
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> None: ...

Expand Down Expand Up @@ -291,6 +292,7 @@ def connect(
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> Connection: ...

Expand Down
Loading
Loading