Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3925,23 +3925,32 @@ def _try_connect(self, endpoint):
sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)
local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
(peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout, fail_on_error=False)

if not local_success:
raise local_result

if not peers_success:
# error with the peers v2 query, fallback to peers v1
self._uses_peers_v2 = False
sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
peers_result = connection.wait_for_response(
peers_query, timeout=self._timeout)
consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)

with ThreadPoolExecutor(max_workers=2) as executor:
local_future = executor.submit(
connection.fetch_all_pages, local_query, self._timeout, False)
peers_future = executor.submit(
connection.fetch_all_pages, peers_query, self._timeout, False)

local_success, local_result = local_future.result()

if not local_success:
raise local_result

peers_success, peers_result = peers_future.result()

if not peers_success:
self._uses_peers_v2 = False
sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE,
fetch_size=self._schema_meta_page_size)
peers_result = connection.fetch_all_pages(peers_query, self._timeout)

shared_results = (peers_result, local_result)
self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results)
Expand Down Expand Up @@ -4084,11 +4093,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
log.debug("[control connection] Refreshing node list and token map")
sel_local = self._SELECT_LOCAL
peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout),
consistency_level=cl)
consistency_level=cl,
fetch_size=self._schema_meta_page_size)
local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout),
consistency_level=cl)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout)
consistency_level=cl,
fetch_size=self._schema_meta_page_size)

with ThreadPoolExecutor(max_workers=2) as executor:
peers_future = executor.submit(
connection.fetch_all_pages, peers_query, self._timeout)
local_future = executor.submit(
connection.fetch_all_pages, local_query, self._timeout)

peers_result = peers_future.result()
local_result = local_future.result()

peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows)

Expand Down
57 changes: 57 additions & 0 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,63 @@ def wait_for_responses(self, *msgs, **kwargs):
self.defunct(exc)
raise

def fetch_all_pages(self, query_msg, timeout, fail_on_error=True):
"""Fetch all pages for a query, following paging_state until exhausted.

Runs the given query and, if the response has a paging_state,
continues fetching subsequent pages until no paging_state remains.
Concatenates all parsed_rows into a single result.

Args:
query_msg: QueryMessage to execute (paging_state may be pre-set).
timeout: Per-request timeout passed to wait_for_response.
fail_on_error: If True (default), raises on error. If False,
returns (success, result_or_error).

Returns:
When fail_on_error=True: the fully accumulated ResultMessage.
When fail_on_error=False: (True, ResultMessage) or
(False, Exception).
"""
response = self.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error)

if not fail_on_error:
success, result = response
if not success:
return response
else:
result = response

if not result or not result.paging_state:
return response if not fail_on_error else result

all_rows = result.parsed_rows
if all_rows is None:
all_rows = []
original_paging_state = query_msg.paging_state

try:
while result and result.paging_state:
query_msg.paging_state = result.paging_state
page_response = self.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error)

if not fail_on_error:
page_success, page_result = page_response
if not page_success:
return page_response
result = page_result
else:
result = page_response

if result and result.parsed_rows:
all_rows.extend(result.parsed_rows)
finally:
query_msg.paging_state = original_paging_state

result.parsed_rows = all_rows

return (True, result) if not fail_on_error else result

def register_watcher(self, event_type, callback, register_timeout=None):
"""
Register a callback for a given event type.
Expand Down
112 changes: 106 additions & 6 deletions tests/unit/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import unittest

from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, ANY, call, patch
from unittest.mock import Mock, ANY, call, patch, MagicMock

from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType, ConsistencyLevel
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS, QueryMessage
from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.pool import Host
from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory
Expand Down Expand Up @@ -168,6 +168,21 @@ def __init__(self):
]
self.wait_for_responses = Mock(return_value=_node_meta_results(self.local_results, self.peer_results))

def wait_for_response_side_effect(query_msg, timeout=None, fail_on_error=True):
result = ResultMessage(kind=RESULT_KIND_ROWS)
if "peers" in query_msg.query.lower():
result.column_names = self.peer_results[0]
result.parsed_rows = self.peer_results[1]
else:
result.column_names = self.local_results[0]
result.parsed_rows = self.local_results[1]
result.paging_state = None
return result
self.wait_for_response = Mock(side_effect=wait_for_response_side_effect)

def fetch_all_pages(self, query_msg, timeout, fail_on_error=True):
return self.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error)


class FakeTime(object):

Expand Down Expand Up @@ -312,6 +327,91 @@ def test_wait_for_schema_agreement_none_timeout(self):
cc._time = self.time
assert cc._wait_for_schema_agreement()

def test_topology_queries_use_paging(self):
self.control_connection.refresh_node_list_and_token_map()
assert self.connection.wait_for_response.called
calls = self.connection.wait_for_response.call_args_list
for call in calls:
query_msg = call[0][0]
assert isinstance(query_msg, QueryMessage)
assert query_msg.fetch_size == self.control_connection._schema_meta_page_size

def test_topology_queries_fetch_all_pages(self):
from cassandra.connection import Connection as RealConnection
mock_connection = MagicMock()
mock_connection.endpoint = DefaultEndPoint("192.168.1.0")
mock_connection.original_endpoint = mock_connection.endpoint
mock_connection.fetch_all_pages = RealConnection.fetch_all_pages.__get__(mock_connection, RealConnection)
first_page = ResultMessage(kind=RESULT_KIND_ROWS)
first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]]
first_page.paging_state = b"has_more_pages"
second_page = ResultMessage(kind=RESULT_KIND_ROWS)
second_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
second_page.parsed_rows = [["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]]
second_page.paging_state = None
mock_connection.wait_for_response.side_effect = [first_page, second_page]
self.control_connection._connection = mock_connection
query_msg = QueryMessage(query="SELECT * FROM system.peers",
consistency_level=ConsistencyLevel.ONE,
fetch_size=self.control_connection._schema_meta_page_size)
result = mock_connection.fetch_all_pages(query_msg, timeout=5)
assert len(result.parsed_rows) == 2
assert result.parsed_rows[0][0] == "192.168.1.1"
assert result.parsed_rows[1][0] == "192.168.1.2"
assert result.paging_state is None
assert mock_connection.wait_for_response.call_count == 2

def test_topology_queries_fetch_all_pages_fail_on_error_false(self):
from cassandra.connection import Connection as RealConnection
mock_connection = MagicMock()
mock_connection.endpoint = DefaultEndPoint("192.168.1.0")
mock_connection.original_endpoint = mock_connection.endpoint
mock_connection.fetch_all_pages = RealConnection.fetch_all_pages.__get__(mock_connection, RealConnection)
first_page = ResultMessage(kind=RESULT_KIND_ROWS)
first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]]
first_page.paging_state = b"has_more_pages"
second_page = ResultMessage(kind=RESULT_KIND_ROWS)
second_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
second_page.parsed_rows = [["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]]
second_page.paging_state = None
mock_connection.wait_for_response.side_effect = [
(True, first_page),
(True, second_page),
]
query_msg = QueryMessage(query="SELECT * FROM system.peers",
consistency_level=ConsistencyLevel.ONE,
fetch_size=self.control_connection._schema_meta_page_size)
success, result = mock_connection.fetch_all_pages(query_msg, timeout=5, fail_on_error=False)
assert success
assert len(result.parsed_rows) == 2
assert result.parsed_rows[0][0] == "192.168.1.1"
assert result.parsed_rows[1][0] == "192.168.1.2"
assert mock_connection.wait_for_response.call_count == 2

def test_topology_queries_fetch_all_pages_page_failure(self):
from cassandra.connection import Connection as RealConnection
mock_connection = MagicMock()
mock_connection.endpoint = DefaultEndPoint("192.168.1.0")
mock_connection.original_endpoint = mock_connection.endpoint
mock_connection.fetch_all_pages = RealConnection.fetch_all_pages.__get__(mock_connection, RealConnection)
first_page = ResultMessage(kind=RESULT_KIND_ROWS)
first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"]
first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]]
first_page.paging_state = b"has_more_pages"
mock_connection.wait_for_response.side_effect = [
(True, first_page),
(False, OperationTimedOut()),
]
query_msg = QueryMessage(query="SELECT * FROM system.peers",
consistency_level=ConsistencyLevel.ONE,
fetch_size=self.control_connection._schema_meta_page_size)
success, error = mock_connection.fetch_all_pages(query_msg, timeout=5, fail_on_error=False)
assert not success
assert isinstance(error, OperationTimedOut)
assert mock_connection.wait_for_response.call_count == 2

def test_refresh_nodes_and_tokens(self):
self.control_connection.refresh_node_list_and_token_map()
meta = self.cluster.metadata
Expand All @@ -328,7 +428,7 @@ def test_refresh_nodes_and_tokens(self):
assert host.datacenter == "dc1"
assert host.rack == "rack1"

assert self.connection.wait_for_responses.call_count == 1
assert self.connection.wait_for_response.call_count == 2

def test_refresh_nodes_and_tokens_with_invalid_peers(self):
def refresh_and_validate_added_hosts():
Expand Down Expand Up @@ -444,11 +544,11 @@ def test_refresh_nodes_and_tokens_remove_host(self):

def test_refresh_nodes_and_tokens_timeout(self):

def bad_wait_for_responses(*args, **kwargs):
def bad_wait_for_response(*args, **kwargs):
assert kwargs['timeout'] == self.control_connection._timeout
raise OperationTimedOut()

self.connection.wait_for_responses = bad_wait_for_responses
self.connection.wait_for_response = Mock(side_effect=bad_wait_for_response)
self.control_connection.refresh_node_list_and_token_map()
self.cluster.executor.submit.assert_called_with(self.control_connection._reconnect)

Expand Down
Loading