From c294260412639b4519bfab459907c1e9509a9a44 Mon Sep 17 00:00:00 2001 From: ghostiee-11 Date: Fri, 3 Jul 2026 01:26:34 +0530 Subject: [PATCH 1/4] fix: infer surviving dims in to_dataset for aggregations (closes #189) to_dataset() inferred dims only when a registered Dataset's entire dim set appeared in the result, so a GROUP BY that aggregates dims away (e.g. SELECT "time", AVG("air") ... GROUP BY "time" over a time/lat/lon grid) raised "dims cannot be inferred" instead of using the surviving group-by dim. _infer_dimension_columns now uses the registered dims that survive into the result columns, in the data-variable axis order, so such aggregations round-trip on the remaining dim(s). Several registered Datasets that imply different surviving dims still raise; a global aggregation with no surviving dim raises a clearer message. Repurposes the now-inferable test to a global-aggregation case and adds an auto-inference regression test. --- tests/test_ds.py | 44 +++++++++++++++++++++++++++++++++++++++----- xarray_sql/ds.py | 46 ++++++++++++++++++++++++++-------------------- 2 files changed, 65 insertions(+), 25 deletions(-) diff --git a/tests/test_ds.py b/tests/test_ds.py index aa3deb2..d530408 100644 --- a/tests/test_ds.py +++ b/tests/test_ds.py @@ -139,6 +139,42 @@ def test_aggregation_drops_dim(air_dataset_small): np.testing.assert_allclose(actual, expected) +def test_aggregation_infers_dims(air_dataset_small): + """#189: to_dataset() infers the surviving GROUP BY dims without dims=.""" + ctx = XarrayContext() + ctx.from_dataset("air", air_dataset_small) + + # GROUP BY lat, lon: time is aggregated away, lat/lon survive. + out = ctx.sql( + "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" + ).to_dataset() + assert set(out.dims) == {"lat", "lon"} + assert "air_avg" in out.data_vars + expected = ( + air_dataset_small.compute() + .sortby(["lat", "lon"]) + .mean(dim="time")["air"] + .values + ) + np.testing.assert_allclose( + out.sortby(["lat", "lon"])["air_avg"].values, expected + ) + + # The reporter's exact case: GROUP BY the time coordinate. + single = ctx.sql( + 'SELECT "time", AVG("air") AS air FROM "air" GROUP BY "time"' + ).to_dataset() + assert set(single.dims) == {"time"} + assert "air" in single.data_vars + expected_t = ( + air_dataset_small.compute() + .sortby("time") + .mean(dim=["lat", "lon"])["air"] + .values + ) + np.testing.assert_allclose(single.sortby("time")["air"].values, expected_t) + + def test_barrier_query_scans_source_once(air_dataset_small): """A barrier plan (aggregation) executes the source exactly once. @@ -383,14 +419,12 @@ def test_to_dataset_multi_registered_requires_explicit_template( assert set(out.dims) == {"time", "lat", "lon"} -def test_to_dataset_infer_fails_when_no_template_fits(air_dataset_small): - """If no registered Dataset's dims fit the result -> clear error.""" +def test_to_dataset_infer_fails_when_no_dim_survives(air_dataset_small): + """A global aggregation leaves no registered dim in the result -> clear error.""" ctx = XarrayContext() ctx.from_dataset("air", air_dataset_small) with pytest.raises(ValueError, match="dims cannot be inferred"): - ctx.sql( - "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset() + ctx.sql("SELECT AVG(air) AS air_avg FROM air").to_dataset() def test_template_accepts_name_or_dataset(air_dataset_small): diff --git a/xarray_sql/ds.py b/xarray_sql/ds.py index 5dfdf42..cd1db39 100644 --- a/xarray_sql/ds.py +++ b/xarray_sql/ds.py @@ -757,10 +757,12 @@ def to_dataset( Args: dims: Result columns to use as Dataset dimensions. When - ``None``, defaults to the dims of the registered Dataset - referenced by the SQL ``FROM`` clause (if exactly one - matches), or any single registered Dataset whose dims are - all present in the result columns. + ``None``, defaults to a registered Dataset's dimensions that + survive into the result columns, so an aggregation that drops + dims (e.g. ``GROUP BY time`` over a ``(time, lat, lon)`` grid) + round-trips on the remaining dim. Raises when no dimension + survives, or when several registered Datasets imply different + dims (pass ``dims`` explicitly then). template: Source to recover metadata (attrs, encoding, non-dim coordinates, dim-coord dtype) from. Either an ``xr.Dataset`` used directly, or the name of a registered table (e.g. @@ -879,33 +881,37 @@ def _infer_dimension_columns( ) -> list[str]: """Pick a default ``dimension_columns`` from the registry, or raise. - Uses the data variable's dim order (via :func:`_ds_var_dims`) so - the round-trip preserves the original axis order. + A registered Dataset's dims that survive into the result columns + become the dimensions, so aggregations that drop dims (e.g. + ``GROUP BY time`` over a ``(time, lat, lon)`` grid) round-trip on the + surviving dim(s). Uses the data variable's dim order (via + :func:`_ds_var_dims`) so the original axis order is preserved. """ result_cols = set(self._result_columns()) - if ( - preferred_template is not None - and set(preferred_template.dims) <= result_cols - ): - return _ds_var_dims(preferred_template) + + def surviving(template: xr.Dataset) -> list[str]: + # Template dims still present in the result, in var axis order. + return [d for d in _ds_var_dims(template) if d in result_cols] + + if preferred_template is not None: + preferred = surviving(preferred_template) + if preferred: + return preferred if not self._templates: raise ValueError( "dims cannot be inferred (no registered " "Dataset on this result); pass dims=[...] " "explicitly." ) - candidates = [ - _ds_var_dims(t) - for t in self._templates.values() - if set(t.dims) <= result_cols - ] + candidates = {tuple(surviving(t)) for t in self._templates.values()} + candidates.discard(()) # templates with no surviving dim if len(candidates) == 1: - return candidates[0] + return list(next(iter(candidates))) if not candidates: raise ValueError( - "dims cannot be inferred: no registered " - "Dataset has all of its dims present in the result " - "columns. Pass dims=[...] explicitly." + "dims cannot be inferred: no registered Dataset " + "dimension survives in the result columns. Pass " + "dims=[...] explicitly." ) raise ValueError( "dims cannot be inferred unambiguously: multiple " From dcdcc5c1eb3c5c663614388de04cbc3bb964bf16 Mon Sep 17 00:00:00 2001 From: ghostiee-11 Date: Sat, 4 Jul 2026 05:05:31 +0530 Subject: [PATCH 2/4] test: infer dims in aggregation tests; add AGENTS.md contributor guide Drop now-redundant explicit dims= from single-registration aggregation tests (inference resolves them), make the inference test self-contained and deterministic via ORDER BY, and remove issue-number and conversation references from test docstrings and comments. Add AGENTS.md summarizing the maintainer's review conventions: self-contained docs, private helpers, public-contract tests, ORDER BY in tests, and conventional commits. --- AGENTS.md | 39 +++++++++++++++++++++++++++++++++++++++ tests/test_ds.py | 47 ++++++++++++++--------------------------------- 2 files changed, 53 insertions(+), 33 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..1320cfd --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,39 @@ +# AGENTS.md + +Guidance for contributors (including AI assistants) working on `xarray-sql`. It +summarizes recurring maintainer review feedback so changes land clean. + +## Documentation and comments + +- Keep docstrings and comments self-contained. Do **not** put GitHub issue or PR + numbers in docstrings or code comments; a reader should not need the issue + tracker to understand the code. Issue references belong in the commit message + and PR description (e.g. `Closes #189`), not in the source. +- Do not reference the review conversation, chat, or "the reporter" in comments. + Describe the behavior, not how it came up. + +## API surface + +- Mark internal helpers private with a leading underscore when they are not part + of the public API. +- Prefer doing setup work in the functional entry point (e.g. `read_xarray`) + rather than in a class constructor; keep constructors minimal. + +## Tests + +- Test the public contract (values, dims, coords, attrs), not internal call + counts or private classes, so the suite survives refactors. +- Avoid redundant tests: if a public-path test already covers a behavior, do not + add a second lower-level test for the same thing. +- Make query results deterministic with `ORDER BY` so assertions do not have to + re-sort the output. +- Do not pass `dims=` to `to_dataset()` when inference already resolves them. + Reserve explicit `dims=` / `template=` for genuinely ambiguous cases (multiple + registered Datasets, or a test that is specifically exercising those + arguments). + +## Commits + +- Use conventional commit prefixes: `fix:`, `feat:`, `refactor:`, `chore:`, + `docs:`, `test:`. +- Keep imports at the top of the file. diff --git a/tests/test_ds.py b/tests/test_ds.py index d530408..3682dac 100644 --- a/tests/test_ds.py +++ b/tests/test_ds.py @@ -125,7 +125,7 @@ def test_aggregation_drops_dim(air_dataset_small): ctx.from_dataset("air", air_dataset_small) out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() assert set(out.dims) == {"lat", "lon"} assert "air_avg" in out.data_vars assert "air" not in out.data_vars @@ -140,39 +140,20 @@ def test_aggregation_drops_dim(air_dataset_small): def test_aggregation_infers_dims(air_dataset_small): - """#189: to_dataset() infers the surviving GROUP BY dims without dims=.""" + """to_dataset() infers the surviving GROUP BY dim when dims is omitted.""" ctx = XarrayContext() ctx.from_dataset("air", air_dataset_small) - # GROUP BY lat, lon: time is aggregated away, lat/lon survive. + # Grouping by the time coordinate keeps time as the sole dimension; the + # ORDER BY makes the result order deterministic so no sort is needed below. out = ctx.sql( - "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" + 'SELECT "time", AVG("air") AS air FROM "air" ' + 'GROUP BY "time" ORDER BY "time"' ).to_dataset() - assert set(out.dims) == {"lat", "lon"} - assert "air_avg" in out.data_vars - expected = ( - air_dataset_small.compute() - .sortby(["lat", "lon"]) - .mean(dim="time")["air"] - .values - ) - np.testing.assert_allclose( - out.sortby(["lat", "lon"])["air_avg"].values, expected - ) - - # The reporter's exact case: GROUP BY the time coordinate. - single = ctx.sql( - 'SELECT "time", AVG("air") AS air FROM "air" GROUP BY "time"' - ).to_dataset() - assert set(single.dims) == {"time"} - assert "air" in single.data_vars - expected_t = ( - air_dataset_small.compute() - .sortby("time") - .mean(dim=["lat", "lon"])["air"] - .values - ) - np.testing.assert_allclose(single.sortby("time")["air"].values, expected_t) + assert set(out.dims) == {"time"} + assert "air" in out.data_vars + expected = air_dataset_small.compute().mean(dim=["lat", "lon"])["air"] + np.testing.assert_allclose(out["air"].values, expected.values) def test_barrier_query_scans_source_once(air_dataset_small): @@ -202,7 +183,7 @@ def test_barrier_query_scans_source_once(air_dataset_small): out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() reads_after_construct = len(reads) out.compute() reads_after_compute = len(reads) @@ -224,7 +205,7 @@ def test_order_by_direction_sets_dim_order(air_dataset_small): ctx.from_dataset("air", air_dataset_small) out = ctx.sql( "SELECT lat, AVG(air) AS air_avg FROM air GROUP BY lat ORDER BY lat DESC" - ).to_dataset(dims=["lat"]) + ).to_dataset() lat = out["lat"].values assert (np.diff(lat) < 0).all(), f"expected descending lat, got {lat}" @@ -325,7 +306,7 @@ def test_fast_path_uses_scanned_tables_coords_not_user_template( def test_round_trip_preserves_descending_lat_on_lazy_path(air_dataset_small): - """Lazy round-trip preserves source dim order (xarray-sql#171). + """Lazy round-trip preserves source dim order. NCEP ``air_temperature`` ships descending lat (75.0 -> 15.0). The discovery path's ``.distinct().sort()`` previously flipped lat to @@ -481,7 +462,7 @@ def test_template_aggregation_alias_no_attrs(air_dataset_small): ctx.from_dataset("air", ds) out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() assert "air_avg" in out.data_vars assert out["air_avg"].attrs == {} From f905ad8fada9f5080205d9b640ccefb06eaed4f6 Mon Sep 17 00:00:00 2001 From: ghostiee-11 Date: Sat, 4 Jul 2026 15:56:00 +0530 Subject: [PATCH 3/4] docs: note to_dataset dim inference; refine AGENTS.md per review Document that to_dataset() infers dimensions from the registered table's surviving dims (README and examples): a GROUP BY on a real dimension needs no dims=, while a derived column like month is still named explicitly. Drop the now-redundant dims= from the era5 example. Refine AGENTS.md per maintainer feedback: drop the too-specific entry-point rule, split imports into its own section noting transitive dependencies are safe to import non-locally, and remove the non-essential commit-prefix rule. --- AGENTS.md | 10 ++++------ README.md | 13 +++++++++---- docs/examples.md | 6 ++++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 1320cfd..fc419a1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -16,8 +16,6 @@ summarizes recurring maintainer review feedback so changes land clean. - Mark internal helpers private with a leading underscore when they are not part of the public API. -- Prefer doing setup work in the functional entry point (e.g. `read_xarray`) - rather than in a class constructor; keep constructors minimal. ## Tests @@ -32,8 +30,8 @@ summarizes recurring maintainer review feedback so changes land clean. registered Datasets, or a test that is specifically exercising those arguments). -## Commits +## Imports -- Use conventional commit prefixes: `fix:`, `feat:`, `refactor:`, `chore:`, - `docs:`, `test:`. -- Keep imports at the top of the file. +- Keep imports at the top of the file. Assume transitive dependencies are safe + to import non-locally, rather than deferring imports into functions to avoid + a dependency. diff --git a/README.md b/README.md index 0356c6a..03d3b37 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,12 @@ clim = ctx.sql(''' ORDER BY month ''') -# Write the SQL result back to an Xarray Dataset. `month` is a derived -# column, so name it as the dimension; the variable's units are recovered -# from the registered table. The result is one value per month: air(month). +# Write the SQL result back to an Xarray Dataset. `to_dataset()` infers the +# dimensions from the registered table's dims that survive the query, so a +# GROUP BY on a real dimension (e.g. `time`) needs no `dims=`. Here `month` +# is a derived column, not a registered dim, so name it explicitly. The +# variable's units are recovered from the registered table. One value per +# month: air(month). clim_ds = clim.to_dataset(dims=["month"]) # Plot the annual cycle as a time series. @@ -138,7 +141,9 @@ ctx.sql(''' AND TIMESTAMP '2020-01-01 05:00:00' GROUP BY latitude, longitude ORDER BY latitude DESC, longitude -''').to_dataset(dims=['latitude', 'longitude'], template=ds) +# `latitude`/`longitude` are inferred from the registered table's surviving +# dims; `template` is kept only to recover metadata (attrs, encoding). +''').to_dataset(template=ds) # Size: 8MB # Dimensions: (latitude: 721, longitude: 1440) # Coordinates: diff --git a/docs/examples.md b/docs/examples.md index 01b9956..f80c5b5 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -34,8 +34,10 @@ clim = ctx.sql(''' clim.to_pandas().head() # Option 2: round-trip back to an Xarray Dataset and plot the annual cycle as -# a time series. `month` is a derived column, so name it as the dimension; the -# variable's units are recovered from the registered table. +# a time series. `to_dataset()` infers dimensions from the registered table's +# surviving dims, so a GROUP BY on a real dimension needs no `dims=`. Here +# `month` is a derived column, not a registered dim, so name it explicitly; +# the variable's units are recovered from the registered table. clim_ds = clim.to_dataset(dims=["month"]) clim_ds["air"].plot() ``` From 8c5553022e7ad28765c991338a874ff7fbe33622 Mon Sep 17 00:00:00 2001 From: ghostiee-11 Date: Sat, 4 Jul 2026 16:15:56 +0530 Subject: [PATCH 4/4] docs: trim the quickstart to_dataset comment Shorten the climatology comment in the README quickstart; the fuller explanation of dimension inference lives in docs/examples.md. --- README.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 03d3b37..a254d29 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,8 @@ clim = ctx.sql(''' ORDER BY month ''') -# Write the SQL result back to an Xarray Dataset. `to_dataset()` infers the -# dimensions from the registered table's dims that survive the query, so a -# GROUP BY on a real dimension (e.g. `time`) needs no `dims=`. Here `month` -# is a derived column, not a registered dim, so name it explicitly. The -# variable's units are recovered from the registered table. One value per -# month: air(month). +# Round-trip the result back to Xarray. `month` is a derived column, so name +# it as the dimension. clim_ds = clim.to_dataset(dims=["month"]) # Plot the annual cycle as a time series.