Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ with connect(
print(results)
```

`runtime` and `region` are optional and accept either the provided enums (handy
for autocomplete) or a plain string — strings are passed to the API as-is, so
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: The new paragraph is correctly placed, but the existing README still has a > [!WARNING] block (in the "Runtime and region selection" section) that warns users the region parameter will become mandatory in a future SDK version. With this PR, the behavior is now the opposite — omitting region is intentional and org-default-driven. That stale warning will confuse anyone who reads past the basic usage section. Consider removing or updating it in the same PR.

new or BYOC regions (e.g. `region="byoc-acme-us-east-1"`) work without an SDK
upgrade. When omitted, your organization's configured defaults are used:

```python
with connect(api_key='...') as conn: # uses your org's default runtime + region
...
```

The `Cursor` supports the context manager protocol, so you can use it
within a `with` statement when needed:

Expand Down
60 changes: 60 additions & 0 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,66 @@
connect_direct,
)
from wherobots.db.errors import InterfaceError
from wherobots.db.region import Region
from wherobots.db.runtime import Runtime


def _run_connect(mock_post, mock_get, **connect_kwargs):
"""Drive a successful connect() and return the kwargs passed to requests.post."""
post_resp = MagicMock()
post_resp.status_code = 200
post_resp.url = "https://api.example.com/sql/session/test-id"
post_resp.raise_for_status = MagicMock()
mock_post.return_value = post_resp

get_resp = MagicMock()
get_resp.status_code = 200
get_resp.raise_for_status = MagicMock()
get_resp.json.return_value = {
"status": "READY",
"appMeta": {"url": "https://compute.example.com/sql/org/session-id"},
}
mock_get.return_value = get_resp

with patch("wherobots.db.driver.connect_direct") as mock_cd:
mock_cd.return_value = MagicMock()
connect(api_key="test-key", **connect_kwargs)

_, kwargs = mock_post.call_args
return kwargs


class TestConnectRegionRuntime:
"""region/runtime accept enum|str and are omitted when not provided."""

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_omitted_region_runtime_not_sent(self, mock_post, mock_get):
"""Omitting region/runtime sends no value so the API applies the org default."""
kwargs = _run_connect(mock_post, mock_get)
# `requests` drops query params that are None, so region is not sent.
assert kwargs["params"]["region"] is None
assert kwargs["json"]["runtimeId"] is None

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_enum_region_runtime_serialized(self, mock_post, mock_get):
"""Enum values serialize to their string form."""
kwargs = _run_connect(
mock_post, mock_get, region=Region.AWS_US_WEST_2, runtime=Runtime.TINY
)
assert kwargs["params"]["region"] == "aws-us-west-2"
assert kwargs["json"]["runtimeId"] == "tiny"

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_string_region_runtime_passthrough(self, mock_post, mock_get):
"""Raw strings (e.g. BYOC regions) are passed through untouched."""
kwargs = _run_connect(
mock_post, mock_get, region="byoc-acme-us-east-1", runtime="x-large"
)
assert kwargs["params"]["region"] == "byoc-acme-us-east-1"
assert kwargs["json"]["runtimeId"] == "x-large"


class TestCheckCancelled:
Expand Down
38 changes: 28 additions & 10 deletions wherobots/db/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from .connection import Connection
from .constants import (
DEFAULT_ENDPOINT,
DEFAULT_REGION,
DEFAULT_RUNTIME,
DEFAULT_READ_TIMEOUT_SECONDS,
DEFAULT_SESSION_TYPE,
DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
Expand Down Expand Up @@ -72,8 +70,8 @@ def connect(
host: str = DEFAULT_ENDPOINT,
token: Union[str, None] = None,
api_key: Union[str, None] = None,
runtime: Union[Runtime, None] = None,
region: Union[Region, None] = None,
runtime: Union[str, Runtime, None] = None,
region: Union[str, Region, None] = None,
version: Union[str, None] = None,
wait_timeout: float = DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS,
Expand All @@ -85,6 +83,20 @@ def connect(
geometry_representation: Union[GeometryRepresentation, None] = None,
cancel_event: Union[threading.Event, None] = None,
) -> Connection:
"""Create a connection to a Wherobots SQL session.

:param runtime: The compute runtime to use. Accepts a ``Runtime`` enum value
or a raw string; strings are passed to the API as-is. Override the
default runtime set for your organization — only set this if you need a
specific runtime instead of the one your administrator has configured.
When omitted, your organization's default runtime is used.
:param region: The compute region to run in. Accepts a ``Region`` enum value
or a raw string (e.g. a BYOC region such as ``byoc-acme-us-east-1``);
strings are passed to the API as-is. Override the default region set for
your organization — only set this if you intend to use a specific region
instead of the one your administrator has configured. When omitted, your
organization's default region is used.
"""
if not token and not api_key:
raise ValueError("At least one of `token` or `api_key` is required")
if token and api_key:
Expand All @@ -97,16 +109,20 @@ def connect(
headers["X-API-Key"] = api_key

host = host or DEFAULT_ENDPOINT
runtime = runtime or DEFAULT_RUNTIME
region = region or DEFAULT_REGION
session_type = session_type or DEFAULT_SESSION_TYPE

# Normalize enum values to their string form and pass raw strings through
# untouched. When omitted (None) the field is dropped from the request so
# the API applies the organization's configured default.
runtime_id = runtime.value if isinstance(runtime, Runtime) else runtime
region_name = region.value if isinstance(region, Region) else region

logging.info(
"Requesting %s%s runtime %sin %s from %s ...",
"new " if force_new else "",
runtime.value,
runtime_id or "org-default",
f"running {version} " if version else "",
region.value,
region_name or "org-default",
host,
)

Expand All @@ -119,9 +135,11 @@ def connect(
try:
resp = requests.post(
url=f"{host}/sql/session",
params={"region": region.value, "force_new": force_new},
# `requests` omits query params whose value is None, so an omitted
# region is simply not sent and the API applies the org default.
params={"region": region_name, "force_new": force_new},
json={
"runtimeId": runtime.value,
"runtimeId": runtime_id,
"shutdownAfterInactiveSeconds": shutdown_after_inactive_seconds,
"version": version,
"sessionType": session_type.value,
Expand Down
Loading