Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# AGENTS.md

Guidance for contributors (including AI assistants) working on `xarray-sql`. It
summarizes recurring maintainer review feedback so changes land clean.

## Documentation and comments

- Keep docstrings and comments self-contained. Do **not** put GitHub issue or PR
numbers in docstrings or code comments; a reader should not need the issue
tracker to understand the code. Issue references belong in the commit message
and PR description (e.g. `Closes #189`), not in the source.
- Do not reference the review conversation, chat, or "the reporter" in comments.
Describe the behavior, not how it came up.

## API surface

- Mark internal helpers private with a leading underscore when they are not part
of the public API.
- Prefer doing setup work in the functional entry point (e.g. `read_xarray`)
rather than in a class constructor; keep constructors minimal.

## Tests

- Test the public contract (values, dims, coords, attrs), not internal call
counts or private classes, so the suite survives refactors.
- Avoid redundant tests: if a public-path test already covers a behavior, do not
add a second lower-level test for the same thing.
- Make query results deterministic with `ORDER BY` so assertions do not have to
re-sort the output.
- Do not pass `dims=` to `to_dataset()` when inference already resolves them.
Reserve explicit `dims=` / `template=` for genuinely ambiguous cases (multiple
registered Datasets, or a test that is specifically exercising those
arguments).

## Commits

- Use conventional commit prefixes: `fix:`, `feat:`, `refactor:`, `chore:`,
`docs:`, `test:`.
- Keep imports at the top of the file.
35 changes: 25 additions & 10 deletions tests/test_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_aggregation_drops_dim(air_dataset_small):
ctx.from_dataset("air", air_dataset_small)
out = ctx.sql(
"SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon"
).to_dataset(dims=["lat", "lon"])
).to_dataset()
assert set(out.dims) == {"lat", "lon"}
assert "air_avg" in out.data_vars
assert "air" not in out.data_vars
Expand All @@ -139,6 +139,23 @@ def test_aggregation_drops_dim(air_dataset_small):
np.testing.assert_allclose(actual, expected)


def test_aggregation_infers_dims(air_dataset_small):
"""to_dataset() infers the surviving GROUP BY dim when dims is omitted."""
ctx = XarrayContext()
ctx.from_dataset("air", air_dataset_small)

# Grouping by the time coordinate keeps time as the sole dimension; the
# ORDER BY makes the result order deterministic so no sort is needed below.
out = ctx.sql(
'SELECT "time", AVG("air") AS air FROM "air" '
'GROUP BY "time" ORDER BY "time"'
).to_dataset()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

assert set(out.dims) == {"time"}
assert "air" in out.data_vars
expected = air_dataset_small.compute().mean(dim=["lat", "lon"])["air"]
np.testing.assert_allclose(out["air"].values, expected.values)


def test_barrier_query_scans_source_once(air_dataset_small):
"""A barrier plan (aggregation) executes the source exactly once.

Expand Down Expand Up @@ -166,7 +183,7 @@ def test_barrier_query_scans_source_once(air_dataset_small):

out = ctx.sql(
"SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon"
).to_dataset(dims=["lat", "lon"])
).to_dataset()
reads_after_construct = len(reads)
out.compute()
reads_after_compute = len(reads)
Expand All @@ -188,7 +205,7 @@ def test_order_by_direction_sets_dim_order(air_dataset_small):
ctx.from_dataset("air", air_dataset_small)
out = ctx.sql(
"SELECT lat, AVG(air) AS air_avg FROM air GROUP BY lat ORDER BY lat DESC"
).to_dataset(dims=["lat"])
).to_dataset()

lat = out["lat"].values
assert (np.diff(lat) < 0).all(), f"expected descending lat, got {lat}"
Expand Down Expand Up @@ -289,7 +306,7 @@ def test_fast_path_uses_scanned_tables_coords_not_user_template(


def test_round_trip_preserves_descending_lat_on_lazy_path(air_dataset_small):
"""Lazy round-trip preserves source dim order (xarray-sql#171).
"""Lazy round-trip preserves source dim order.

NCEP ``air_temperature`` ships descending lat (75.0 -> 15.0). The
discovery path's ``.distinct().sort()`` previously flipped lat to
Expand Down Expand Up @@ -383,14 +400,12 @@ def test_to_dataset_multi_registered_requires_explicit_template(
assert set(out.dims) == {"time", "lat", "lon"}


def test_to_dataset_infer_fails_when_no_template_fits(air_dataset_small):
"""If no registered Dataset's dims fit the result -> clear error."""
def test_to_dataset_infer_fails_when_no_dim_survives(air_dataset_small):
"""A global aggregation leaves no registered dim in the result -> clear error."""
ctx = XarrayContext()
ctx.from_dataset("air", air_dataset_small)
with pytest.raises(ValueError, match="dims cannot be inferred"):
ctx.sql(
"SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon"
).to_dataset()
ctx.sql("SELECT AVG(air) AS air_avg FROM air").to_dataset()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you also fix all the other tests that have specified dims when they now no longer need them?



def test_template_accepts_name_or_dataset(air_dataset_small):
Expand Down Expand Up @@ -447,7 +462,7 @@ def test_template_aggregation_alias_no_attrs(air_dataset_small):
ctx.from_dataset("air", ds)
out = ctx.sql(
"SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon"
).to_dataset(dims=["lat", "lon"])
).to_dataset()
assert "air_avg" in out.data_vars
assert out["air_avg"].attrs == {}

Expand Down
46 changes: 26 additions & 20 deletions xarray_sql/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,10 +757,12 @@ def to_dataset(

Args:
dims: Result columns to use as Dataset dimensions. When
``None``, defaults to the dims of the registered Dataset
referenced by the SQL ``FROM`` clause (if exactly one
matches), or any single registered Dataset whose dims are
all present in the result columns.
``None``, defaults to a registered Dataset's dimensions that
survive into the result columns, so an aggregation that drops
dims (e.g. ``GROUP BY time`` over a ``(time, lat, lon)`` grid)
round-trips on the remaining dim. Raises when no dimension
survives, or when several registered Datasets imply different
dims (pass ``dims`` explicitly then).
template: Source to recover metadata (attrs, encoding, non-dim
coordinates, dim-coord dtype) from. Either an ``xr.Dataset``
used directly, or the name of a registered table (e.g.
Expand Down Expand Up @@ -879,33 +881,37 @@ def _infer_dimension_columns(
) -> list[str]:
"""Pick a default ``dimension_columns`` from the registry, or raise.

Uses the data variable's dim order (via :func:`_ds_var_dims`) so
the round-trip preserves the original axis order.
A registered Dataset's dims that survive into the result columns
become the dimensions, so aggregations that drop dims (e.g.
``GROUP BY time`` over a ``(time, lat, lon)`` grid) round-trip on the
surviving dim(s). Uses the data variable's dim order (via
:func:`_ds_var_dims`) so the original axis order is preserved.
"""
result_cols = set(self._result_columns())
if (
preferred_template is not None
and set(preferred_template.dims) <= result_cols
):
return _ds_var_dims(preferred_template)

def surviving(template: xr.Dataset) -> list[str]:
# Template dims still present in the result, in var axis order.
return [d for d in _ds_var_dims(template) if d in result_cols]

if preferred_template is not None:
preferred = surviving(preferred_template)
if preferred:
return preferred
if not self._templates:
raise ValueError(
"dims cannot be inferred (no registered "
"Dataset on this result); pass dims=[...] "
"explicitly."
)
candidates = [
_ds_var_dims(t)
for t in self._templates.values()
if set(t.dims) <= result_cols
]
candidates = {tuple(surviving(t)) for t in self._templates.values()}
candidates.discard(()) # templates with no surviving dim
if len(candidates) == 1:
return candidates[0]
return list(next(iter(candidates)))
if not candidates:
raise ValueError(
"dims cannot be inferred: no registered "
"Dataset has all of its dims present in the result "
"columns. Pass dims=[...] explicitly."
"dims cannot be inferred: no registered Dataset "
"dimension survives in the result columns. Pass "
"dims=[...] explicitly."
)
raise ValueError(
"dims cannot be inferred unambiguously: multiple "
Expand Down
Loading