From 9b606230aa9175c9eb732b9dffcd6a8dbde83393 Mon Sep 17 00:00:00 2001 From: Clay McGinnis Date: Wed, 27 May 2026 16:38:50 -0700 Subject: [PATCH] feat: accept str|enum for region/runtime and make them optional MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `connect()` now accepts a plain string for `runtime` and `region` in addition to the `Runtime`/`Region` enums; strings are passed to the API untouched, so new or BYOC regions (e.g. "byoc-acme-us-east-1") work without an SDK release. When either is omitted, the SDK no longer injects the hardcoded `aws-us-west-2` / `tiny` defaults — the field is dropped from the request (`requests` omits None query params; `runtimeId` is sent null) so the API applies the organization's configured default. Enum values are normalized to their string form. Adds a docstring documenting the new behavior. Requires the studio-backend API change that makes region/runtime optional on POST /sql/session. --- README.md | 10 +++++++ tests/test_driver.py | 60 ++++++++++++++++++++++++++++++++++++++++++ wherobots/db/driver.py | 38 +++++++++++++++++++------- 3 files changed, 98 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ce968f6..07edb56 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,16 @@ with connect( print(results) ``` +`runtime` and `region` are optional and accept either the provided enums (handy +for autocomplete) or a plain string — strings are passed to the API as-is, so +new or BYOC regions (e.g. `region="byoc-acme-us-east-1"`) work without an SDK +upgrade. When omitted, your organization's configured defaults are used: + +```python +with connect(api_key='...') as conn: # uses your org's default runtime + region + ... +``` + The `Cursor` supports the context manager protocol, so you can use it within a `with` statement when needed: diff --git a/tests/test_driver.py b/tests/test_driver.py index ef75cf2..b9cf20a 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -14,6 +14,66 @@ connect_direct, ) from wherobots.db.errors import InterfaceError +from wherobots.db.region import Region +from wherobots.db.runtime import Runtime + + +def _run_connect(mock_post, mock_get, **connect_kwargs): + """Drive a successful connect() and return the kwargs passed to requests.post.""" + post_resp = MagicMock() + post_resp.status_code = 200 + post_resp.url = "https://api.example.com/sql/session/test-id" + post_resp.raise_for_status = MagicMock() + mock_post.return_value = post_resp + + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.raise_for_status = MagicMock() + get_resp.json.return_value = { + "status": "READY", + "appMeta": {"url": "https://compute.example.com/sql/org/session-id"}, + } + mock_get.return_value = get_resp + + with patch("wherobots.db.driver.connect_direct") as mock_cd: + mock_cd.return_value = MagicMock() + connect(api_key="test-key", **connect_kwargs) + + _, kwargs = mock_post.call_args + return kwargs + + +class TestConnectRegionRuntime: + """region/runtime accept enum|str and are omitted when not provided.""" + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_omitted_region_runtime_not_sent(self, mock_post, mock_get): + """Omitting region/runtime sends no value so the API applies the org default.""" + kwargs = _run_connect(mock_post, mock_get) + # `requests` drops query params that are None, so region is not sent. + assert kwargs["params"]["region"] is None + assert kwargs["json"]["runtimeId"] is None + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_enum_region_runtime_serialized(self, mock_post, mock_get): + """Enum values serialize to their string form.""" + kwargs = _run_connect( + mock_post, mock_get, region=Region.AWS_US_WEST_2, runtime=Runtime.TINY + ) + assert kwargs["params"]["region"] == "aws-us-west-2" + assert kwargs["json"]["runtimeId"] == "tiny" + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_string_region_runtime_passthrough(self, mock_post, mock_get): + """Raw strings (e.g. BYOC regions) are passed through untouched.""" + kwargs = _run_connect( + mock_post, mock_get, region="byoc-acme-us-east-1", runtime="x-large" + ) + assert kwargs["params"]["region"] == "byoc-acme-us-east-1" + assert kwargs["json"]["runtimeId"] == "x-large" class TestCheckCancelled: diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index a85c313..ae2d0d1 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -21,8 +21,6 @@ from .connection import Connection from .constants import ( DEFAULT_ENDPOINT, - DEFAULT_REGION, - DEFAULT_RUNTIME, DEFAULT_READ_TIMEOUT_SECONDS, DEFAULT_SESSION_TYPE, DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS, @@ -72,8 +70,8 @@ def connect( host: str = DEFAULT_ENDPOINT, token: Union[str, None] = None, api_key: Union[str, None] = None, - runtime: Union[Runtime, None] = None, - region: Union[Region, None] = None, + runtime: Union[str, Runtime, None] = None, + region: Union[str, Region, None] = None, version: Union[str, None] = None, wait_timeout: float = DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS, read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS, @@ -85,6 +83,20 @@ def connect( geometry_representation: Union[GeometryRepresentation, None] = None, cancel_event: Union[threading.Event, None] = None, ) -> Connection: + """Create a connection to a Wherobots SQL session. + + :param runtime: The compute runtime to use. Accepts a ``Runtime`` enum value + or a raw string; strings are passed to the API as-is. Override the + default runtime set for your organization — only set this if you need a + specific runtime instead of the one your administrator has configured. + When omitted, your organization's default runtime is used. + :param region: The compute region to run in. Accepts a ``Region`` enum value + or a raw string (e.g. a BYOC region such as ``byoc-acme-us-east-1``); + strings are passed to the API as-is. Override the default region set for + your organization — only set this if you intend to use a specific region + instead of the one your administrator has configured. When omitted, your + organization's default region is used. + """ if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") if token and api_key: @@ -97,16 +109,20 @@ def connect( headers["X-API-Key"] = api_key host = host or DEFAULT_ENDPOINT - runtime = runtime or DEFAULT_RUNTIME - region = region or DEFAULT_REGION session_type = session_type or DEFAULT_SESSION_TYPE + # Normalize enum values to their string form and pass raw strings through + # untouched. When omitted (None) the field is dropped from the request so + # the API applies the organization's configured default. + runtime_id = runtime.value if isinstance(runtime, Runtime) else runtime + region_name = region.value if isinstance(region, Region) else region + logging.info( "Requesting %s%s runtime %sin %s from %s ...", "new " if force_new else "", - runtime.value, + runtime_id or "org-default", f"running {version} " if version else "", - region.value, + region_name or "org-default", host, ) @@ -119,9 +135,11 @@ def connect( try: resp = requests.post( url=f"{host}/sql/session", - params={"region": region.value, "force_new": force_new}, + # `requests` omits query params whose value is None, so an omitted + # region is simply not sent and the API applies the org default. + params={"region": region_name, "force_new": force_new}, json={ - "runtimeId": runtime.value, + "runtimeId": runtime_id, "shutdownAfterInactiveSeconds": shutdown_after_inactive_seconds, "version": version, "sessionType": session_type.value,