diff --git a/CHANGELOG.md b/CHANGELOG.md index d473436..7e7ab1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +## v1.12.0 6/15/26 +- Add Azure client + ## v1.11.1 5/20/26 - Reset Snowflake connection to None when it's closed diff --git a/README.md b/README.md index d3a1bab..87f5139 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ This package contains common Python utility classes and functions. * Encoding and decoding records using a given Avro schema * Retrieving secrets from AWS Secrets Manager * Downloading files from a remote SSH SFTP server +* Connecting to and querying an Azure SQL database * Connecting to and querying a MySQL database * Connecting to and querying a PostgreSQL database * Connecting to and querying Redshift diff --git a/pyproject.toml b/pyproject.toml index 3ae1a7a..f745dfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "nypl_py_utils" -version = "1.11.1" +version = "1.12.0" authors = [ { name="Aaron Friedman", email="aaronfriedman@nypl.org" }, ] @@ -28,6 +28,11 @@ avro-client = [ "fastavro==1.12.2", "requests==2.34.0" ] +azure-client = [ + "nypl_py_utils[log-helper]", + "mssql-python==1.9.0", + "pandas==3.0.3" +] cloudlibrary-client = [ "nypl_py_utils[log-helper]", "requests==2.34.0" @@ -91,13 +96,13 @@ obfuscation-helper = [ ] patron-data-helper = [ "nypl_py_utils[postgresql-client,redshift-client,log-helper]", - "pandas==3.0.2" + "pandas==3.0.3" ] research-catalog-identifier-helper = [ "requests==2.34.0" ] development = [ - "nypl_py_utils[avro-client,cloudlibrary-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,redshift-client,s3-client,secrets-manager-client,sftp-client,snowflake-client,config-helper,log-helper,obfuscation-helper,patron-data-helper,research-catalog-identifier-helper]", + "nypl_py_utils[avro-client,azure-client,cloudlibrary-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,redshift-client,s3-client,secrets-manager-client,sftp-client,snowflake-client,config-helper,log-helper,obfuscation-helper,patron-data-helper,research-catalog-identifier-helper]", "flake8==7.3.0", "freezegun==1.5.5", "mock==5.2.0", diff --git a/src/nypl_py_utils/classes/azure_client.py b/src/nypl_py_utils/classes/azure_client.py new file mode 100644 index 0000000..fc1262a --- /dev/null +++ b/src/nypl_py_utils/classes/azure_client.py @@ -0,0 +1,132 @@ +import mssql_python +import pandas as pd +import time + +from contextlib import closing +from nypl_py_utils.functions.log_helper import create_log + + +class AzureClient: + """Class for managing connections to a Microsoft Azure SQL database""" + + def __init__(self, server, database, user, password): + self.logger = create_log("azure_client") + self.server = server + self.database = database + self.user = user + self.password = password + self.conn = None + + def connect(self, retry_count=0, backoff_factor=5): + """ + Connects to an Azure database using the given credentials. + + Parameters + ---------- + retry_count: int, optional + The number of times to retry connecting before throwing an error. + By default no retry occurs. + backoff_factor: int, optional + The backoff factor when retrying. The amount of time to wait before + retrying is backoff_factor ** number_of_retries_made. + """ + self.logger.info(f"Connecting to {self.database} database...") + + # Close any existing connection first so reconnecting doesn't leak it + self.close_connection() + + attempt_count = 0 + while attempt_count <= retry_count: + try: + try: + connection_string = ( + f"Server={self.server};" + f"Database={self.database};" + f"UID={self.user};" + f"PWD={self.password};" + f"Encrypt=yes;" + ) + self.conn = mssql_python.connect( + connection_str=connection_string, + timeout=30, + ) + self.conn.setencoding(encoding="utf-8") + self.conn.setdecoding( + sqltype=mssql_python.SQL_WCHAR, encoding="utf-8" + ) + return + except (mssql_python.InterfaceError, + mssql_python.OperationalError): + if attempt_count < retry_count: + self.logger.info("Failed to connect — retrying") + time.sleep(backoff_factor**attempt_count) + attempt_count += 1 + else: + raise + except Exception as e: + msg = f"Error connecting to {self.database} database: {e}" + self.logger.error(msg) + raise AzureClientError(msg) from e + + def execute_query(self, query: str, params=None, dataframe=False): + """ + Executes an arbitrary SQL read query against the database. + + Parameters + ---------- + query: str + The query to execute, assumed to be a read query + params: tuple or list, optional + The parameters to pass into the query, if any. Defaults to None. + dataframe: bool, optional + Whether the data will be returned as a pandas DataFrame. Defaults + to False, which means the data is returned as a list of tuples. + + Returns + ------- + None or sequence + A list of tuples or a pandas DataFrame (based on the `dataframe` + input) + """ + if not self.conn: + msg = "No active database connection" + self.logger.error(msg) + raise AzureClientError(msg) + + try: + # Automatically closes cursor when done, even if there's an error + with closing(self.conn.cursor()) as cursor: + if params is not None: + cursor.execute(query, params) + else: + cursor.execute(query) + if dataframe: + columns = [col[0] for col in cursor.description] + return pd.DataFrame.from_records( + cursor.fetchall(), columns=columns) + return cursor.fetchall() + except Exception as e: + self.close_connection() + msg = f"Error executing {self.database} query '{query}': {e}" + self.logger.error(msg) + raise AzureClientError(msg) from e + + def close_connection(self): + """Rolls back any open transaction and closes the connection""" + if self.conn: + # A rollback failure is logged but doesn't prevent the close + try: + self.conn.rollback() + except Exception: + self.logger.error("Error rolling back open transaction") + self.conn.close() + self.conn = None + self.logger.info(f"Connection to {self.database} closed.") + + +class AzureClientError(Exception): + """Custom exception for AzureClient errors""" + + def __init__(self, message=None): + super().__init__(message) + self.message = message diff --git a/tests/test_azure_client.py b/tests/test_azure_client.py new file mode 100644 index 0000000..64998e5 --- /dev/null +++ b/tests/test_azure_client.py @@ -0,0 +1,207 @@ +import mssql_python +import pandas as pd +import pytest + +from nypl_py_utils.classes.azure_client import AzureClient, AzureClientError +from pandas.testing import assert_frame_equal + + +class TestAzureClient: + @pytest.fixture + def mock_azure_conn(self, mocker): + return mocker.patch( + "nypl_py_utils.classes.azure_client.mssql_python.connect") + + @pytest.fixture + def test_instance(self): + return AzureClient( + server="test_server", + database="test_database", + user="test_user", + password="test_password", + ) + + def test_connect_success(self, mock_azure_conn, test_instance): + test_instance.connect() + + assert test_instance.conn == mock_azure_conn.return_value + mock_azure_conn.return_value.setencoding.assert_called_once_with( + encoding="utf-8" + ) + mock_azure_conn.return_value.setdecoding.assert_called_once_with( + sqltype=mssql_python.SQL_WCHAR, encoding="utf-8" + ) + # credentials are interpolated into connection string + connection_str = mock_azure_conn.call_args.kwargs["connection_str"] + assert connection_str == ( + "Server=test_server;Database=test_database;" + "UID=test_user;PWD=test_password;Encrypt=yes;" + ) + + def test_connect_retry_success( + self, mock_azure_conn, test_instance, mocker, caplog + ): + mock_sleep = mocker.patch( + "nypl_py_utils.classes.azure_client.time.sleep") + success_conn = mocker.MagicMock() + mock_azure_conn.side_effect = [ + mssql_python.OperationalError("busy", "ddbc busy"), + success_conn, + ] + with caplog.at_level("ERROR"): + test_instance.connect(retry_count=2, backoff_factor=2) + + assert test_instance.conn == success_conn + assert mock_azure_conn.call_count == 2 + mock_sleep.assert_called_once_with(2**0) + assert caplog.text == "" + + def test_connect_retry_fail( + self, mock_azure_conn, test_instance, mocker, caplog + ): + mocker.patch("nypl_py_utils.classes.azure_client.time.sleep") + mock_azure_conn.side_effect = mssql_python.OperationalError( + "still busy", "ddbc busy" + ) + + with pytest.raises(AzureClientError): + test_instance.connect(retry_count=2, backoff_factor=2) + + # retry_count=2 -> three attempts total before giving up + assert mock_azure_conn.call_count == 3 + assert "Error connecting to test_database database" in caplog.text + + def test_connect_unexpected_error( + self, mock_azure_conn, test_instance, caplog + ): + mock_azure_conn.side_effect = ValueError("uh oh") + + with pytest.raises(AzureClientError): + test_instance.connect(retry_count=3) + + assert mock_azure_conn.call_count == 1 + assert ( + "Error connecting to test_database database: uh oh" in caplog.text + ) + + def test_execute_query_no_params_success( + self, mock_azure_conn, test_instance, mocker + ): + test_instance.connect() + mock_cursor = mocker.MagicMock() + mock_cursor.fetchall.return_value = [(1, 2), (3, 4)] + test_instance.conn.cursor.return_value = mock_cursor + + result = test_instance.execute_query("SELECT * FROM t") + + assert result == [(1, 2), (3, 4)] + mock_cursor.execute.assert_called_once_with("SELECT * FROM t") + mock_cursor.close.assert_called_once() + + def test_execute_query_with_params_success( + self, mock_azure_conn, test_instance, mocker + ): + test_instance.connect() + mock_cursor = mocker.MagicMock() + mock_cursor.fetchall.return_value = [] + test_instance.conn.cursor.return_value = mock_cursor + + test_instance.execute_query("SELECT ?", params=("a",)) + + mock_cursor.execute.assert_called_once_with("SELECT ?", ("a",)) + + def test_execute_query_no_params_returns_dataframe_success( + self, mock_azure_conn, test_instance, mocker + ): + test_instance.connect() + mock_cursor = mocker.MagicMock() + mock_cursor.description = [("col1",), ("col2",)] + mock_cursor.fetchall.return_value = [(1, 2), (3, 4)] + test_instance.conn.cursor.return_value = mock_cursor + + result = test_instance.execute_query("SELECT * FROM t", dataframe=True) + + expected = pd.DataFrame({"col1": [1, 3], "col2": [2, 4]}) + assert_frame_equal(result, expected) + + def test_execute_query_with_params_returns_dataframe_success( + self, mock_azure_conn, test_instance, mocker + ): + test_instance.connect() + mock_cursor = mocker.MagicMock() + mock_cursor.description = [("col1",), ("col2",)] + mock_cursor.fetchall.return_value = [(1, 2), (3, 4)] + test_instance.conn.cursor.return_value = mock_cursor + expected = pd.DataFrame({"col1": [1, 3], "col2": [2, 4]}) + + result = test_instance.execute_query( + "SELECT * FROM t WHERE col1 = ?", params=("a",), dataframe=True + ) + + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM t WHERE col1 = ?", ("a",) + ) + assert_frame_equal(result, expected) + + def test_execute_query_fail( + self, mock_azure_conn, test_instance, mocker, caplog + ): + test_instance.connect() + mock_conn = test_instance.conn + mock_cursor = mocker.MagicMock() + mock_cursor.execute.side_effect = Exception("bad query") + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(AzureClientError): + test_instance.execute_query("SELECT bad") + + mock_conn.rollback.assert_called_once() + mock_conn.close.assert_called_once() + assert test_instance.conn is None + mock_cursor.close.assert_called_once() + assert ( + "Error executing test_database query 'SELECT bad'" in caplog.text + ) + + def test_execute_query_fail_with_rollback_error( + self, mock_azure_conn, test_instance, mocker, caplog + ): + test_instance.connect() + mock_conn = test_instance.conn + mock_cursor = mocker.MagicMock() + mock_cursor.execute.side_effect = Exception("bad query") + mock_conn.cursor.return_value = mock_cursor + mock_conn.rollback.side_effect = Exception("rollback issue") + + with pytest.raises(AzureClientError): + test_instance.execute_query("SELECT bad") + + mock_conn.close.assert_called_once() + assert test_instance.conn is None + assert "Error rolling back open transaction" in caplog.text + assert ( + "Error executing test_database query 'SELECT bad'" in caplog.text + ) + + def test_execute_query_without_connection(self, test_instance, caplog): + assert test_instance.conn is None + + with pytest.raises(AzureClientError): + test_instance.execute_query("SELECT 1") + + assert "No active database connection" in caplog.text + + def test_close_connection_success(self, mock_azure_conn, test_instance): + test_instance.connect() + mock_conn = test_instance.conn + + test_instance.close_connection() + + mock_conn.close.assert_called_once() + assert test_instance.conn is None + + def test_close_connection_when_already_closed(self, test_instance): + # no connection -> nothing to close, so nothing happens & no error + assert test_instance.conn is None + test_instance.close_connection() + assert test_instance.conn is None