diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 1181c6f686..9da12f49ba 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3925,23 +3925,32 @@ def _try_connect(self, endpoint): sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout, fail_on_error=False) - - if not local_success: - raise local_result - - if not peers_success: - # error with the peers v2 query, fallback to peers v1 - self._uses_peers_v2 = False - sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) - peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - peers_result = connection.wait_for_response( - peers_query, timeout=self._timeout) + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) + + with ThreadPoolExecutor(max_workers=2) as executor: + local_future = executor.submit( + connection.fetch_all_pages, local_query, self._timeout, False) + peers_future = executor.submit( + connection.fetch_all_pages, peers_query, self._timeout, False) + + local_success, local_result = local_future.result() + + if not local_success: + raise local_result + + peers_success, peers_result = peers_future.result() + + if not peers_success: + self._uses_peers_v2 = False + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) + peers_result = connection.fetch_all_pages(peers_query, self._timeout) shared_results = (peers_result, local_result) self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) @@ -4084,11 +4093,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, log.debug("[control connection] Refreshing node list and token map") sel_local = self._SELECT_LOCAL peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=cl) + consistency_level=cl, + fetch_size=self._schema_meta_page_size) local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout), - consistency_level=cl) - peers_result, local_result = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout) + consistency_level=cl, + fetch_size=self._schema_meta_page_size) + + with ThreadPoolExecutor(max_workers=2) as executor: + peers_future = executor.submit( + connection.fetch_all_pages, peers_query, self._timeout) + local_future = executor.submit( + connection.fetch_all_pages, local_query, self._timeout) + + peers_result = peers_future.result() + local_result = local_future.result() peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) diff --git a/cassandra/connection.py b/cassandra/connection.py index eae018649b..ee49900231 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1287,6 +1287,63 @@ def wait_for_responses(self, *msgs, **kwargs): self.defunct(exc) raise + def fetch_all_pages(self, query_msg, timeout, fail_on_error=True): + """Fetch all pages for a query, following paging_state until exhausted. + + Runs the given query and, if the response has a paging_state, + continues fetching subsequent pages until no paging_state remains. + Concatenates all parsed_rows into a single result. + + Args: + query_msg: QueryMessage to execute (paging_state may be pre-set). + timeout: Per-request timeout passed to wait_for_response. + fail_on_error: If True (default), raises on error. If False, + returns (success, result_or_error). + + Returns: + When fail_on_error=True: the fully accumulated ResultMessage. + When fail_on_error=False: (True, ResultMessage) or + (False, Exception). + """ + response = self.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error) + + if not fail_on_error: + success, result = response + if not success: + return response + else: + result = response + + if not result or not result.paging_state: + return response if not fail_on_error else result + + all_rows = result.parsed_rows + if all_rows is None: + all_rows = [] + original_paging_state = query_msg.paging_state + + try: + while result and result.paging_state: + query_msg.paging_state = result.paging_state + page_response = self.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error) + + if not fail_on_error: + page_success, page_result = page_response + if not page_success: + return page_response + result = page_result + else: + result = page_response + + if result and result.parsed_rows: + all_rows.extend(result.parsed_rows) + finally: + query_msg.paging_state = original_paging_state + + result.parsed_rows = all_rows + + return (True, result) if not fail_on_error else result + def register_watcher(self, event_type, callback, register_timeout=None): """ Register a callback for a given event type. diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index fd62323f33..e8dc63e58c 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,10 +15,10 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, ANY, call, patch +from unittest.mock import Mock, ANY, call, patch, MagicMock -from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType -from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS +from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType, ConsistencyLevel +from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS, QueryMessage from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory @@ -168,6 +168,21 @@ def __init__(self): ] self.wait_for_responses = Mock(return_value=_node_meta_results(self.local_results, self.peer_results)) + def wait_for_response_side_effect(query_msg, timeout=None, fail_on_error=True): + result = ResultMessage(kind=RESULT_KIND_ROWS) + if "peers" in query_msg.query.lower(): + result.column_names = self.peer_results[0] + result.parsed_rows = self.peer_results[1] + else: + result.column_names = self.local_results[0] + result.parsed_rows = self.local_results[1] + result.paging_state = None + return result + self.wait_for_response = Mock(side_effect=wait_for_response_side_effect) + + def fetch_all_pages(self, query_msg, timeout, fail_on_error=True): + return self.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error) + class FakeTime(object): @@ -312,6 +327,91 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc._wait_for_schema_agreement() + def test_topology_queries_use_paging(self): + self.control_connection.refresh_node_list_and_token_map() + assert self.connection.wait_for_response.called + calls = self.connection.wait_for_response.call_args_list + for call in calls: + query_msg = call[0][0] + assert isinstance(query_msg, QueryMessage) + assert query_msg.fetch_size == self.control_connection._schema_meta_page_size + + def test_topology_queries_fetch_all_pages(self): + from cassandra.connection import Connection as RealConnection + mock_connection = MagicMock() + mock_connection.endpoint = DefaultEndPoint("192.168.1.0") + mock_connection.original_endpoint = mock_connection.endpoint + mock_connection.fetch_all_pages = RealConnection.fetch_all_pages.__get__(mock_connection, RealConnection) + first_page = ResultMessage(kind=RESULT_KIND_ROWS) + first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]] + first_page.paging_state = b"has_more_pages" + second_page = ResultMessage(kind=RESULT_KIND_ROWS) + second_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + second_page.parsed_rows = [["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + second_page.paging_state = None + mock_connection.wait_for_response.side_effect = [first_page, second_page] + self.control_connection._connection = mock_connection + query_msg = QueryMessage(query="SELECT * FROM system.peers", + consistency_level=ConsistencyLevel.ONE, + fetch_size=self.control_connection._schema_meta_page_size) + result = mock_connection.fetch_all_pages(query_msg, timeout=5) + assert len(result.parsed_rows) == 2 + assert result.parsed_rows[0][0] == "192.168.1.1" + assert result.parsed_rows[1][0] == "192.168.1.2" + assert result.paging_state is None + assert mock_connection.wait_for_response.call_count == 2 + + def test_topology_queries_fetch_all_pages_fail_on_error_false(self): + from cassandra.connection import Connection as RealConnection + mock_connection = MagicMock() + mock_connection.endpoint = DefaultEndPoint("192.168.1.0") + mock_connection.original_endpoint = mock_connection.endpoint + mock_connection.fetch_all_pages = RealConnection.fetch_all_pages.__get__(mock_connection, RealConnection) + first_page = ResultMessage(kind=RESULT_KIND_ROWS) + first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]] + first_page.paging_state = b"has_more_pages" + second_page = ResultMessage(kind=RESULT_KIND_ROWS) + second_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + second_page.parsed_rows = [["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + second_page.paging_state = None + mock_connection.wait_for_response.side_effect = [ + (True, first_page), + (True, second_page), + ] + query_msg = QueryMessage(query="SELECT * FROM system.peers", + consistency_level=ConsistencyLevel.ONE, + fetch_size=self.control_connection._schema_meta_page_size) + success, result = mock_connection.fetch_all_pages(query_msg, timeout=5, fail_on_error=False) + assert success + assert len(result.parsed_rows) == 2 + assert result.parsed_rows[0][0] == "192.168.1.1" + assert result.parsed_rows[1][0] == "192.168.1.2" + assert mock_connection.wait_for_response.call_count == 2 + + def test_topology_queries_fetch_all_pages_page_failure(self): + from cassandra.connection import Connection as RealConnection + mock_connection = MagicMock() + mock_connection.endpoint = DefaultEndPoint("192.168.1.0") + mock_connection.original_endpoint = mock_connection.endpoint + mock_connection.fetch_all_pages = RealConnection.fetch_all_pages.__get__(mock_connection, RealConnection) + first_page = ResultMessage(kind=RESULT_KIND_ROWS) + first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]] + first_page.paging_state = b"has_more_pages" + mock_connection.wait_for_response.side_effect = [ + (True, first_page), + (False, OperationTimedOut()), + ] + query_msg = QueryMessage(query="SELECT * FROM system.peers", + consistency_level=ConsistencyLevel.ONE, + fetch_size=self.control_connection._schema_meta_page_size) + success, error = mock_connection.fetch_all_pages(query_msg, timeout=5, fail_on_error=False) + assert not success + assert isinstance(error, OperationTimedOut) + assert mock_connection.wait_for_response.call_count == 2 + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata @@ -328,7 +428,7 @@ def test_refresh_nodes_and_tokens(self): assert host.datacenter == "dc1" assert host.rack == "rack1" - assert self.connection.wait_for_responses.call_count == 1 + assert self.connection.wait_for_response.call_count == 2 def test_refresh_nodes_and_tokens_with_invalid_peers(self): def refresh_and_validate_added_hosts(): @@ -444,11 +544,11 @@ def test_refresh_nodes_and_tokens_remove_host(self): def test_refresh_nodes_and_tokens_timeout(self): - def bad_wait_for_responses(*args, **kwargs): + def bad_wait_for_response(*args, **kwargs): assert kwargs['timeout'] == self.control_connection._timeout raise OperationTimedOut() - self.connection.wait_for_responses = bad_wait_for_responses + self.connection.wait_for_response = Mock(side_effect=bad_wait_for_response) self.control_connection.refresh_node_list_and_token_map() self.cluster.executor.submit.assert_called_with(self.control_connection._reconnect)