diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..1320cfd --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,39 @@ +# AGENTS.md + +Guidance for contributors (including AI assistants) working on `xarray-sql`. It +summarizes recurring maintainer review feedback so changes land clean. + +## Documentation and comments + +- Keep docstrings and comments self-contained. Do **not** put GitHub issue or PR + numbers in docstrings or code comments; a reader should not need the issue + tracker to understand the code. Issue references belong in the commit message + and PR description (e.g. `Closes #189`), not in the source. +- Do not reference the review conversation, chat, or "the reporter" in comments. + Describe the behavior, not how it came up. + +## API surface + +- Mark internal helpers private with a leading underscore when they are not part + of the public API. +- Prefer doing setup work in the functional entry point (e.g. `read_xarray`) + rather than in a class constructor; keep constructors minimal. + +## Tests + +- Test the public contract (values, dims, coords, attrs), not internal call + counts or private classes, so the suite survives refactors. +- Avoid redundant tests: if a public-path test already covers a behavior, do not + add a second lower-level test for the same thing. +- Make query results deterministic with `ORDER BY` so assertions do not have to + re-sort the output. +- Do not pass `dims=` to `to_dataset()` when inference already resolves them. + Reserve explicit `dims=` / `template=` for genuinely ambiguous cases (multiple + registered Datasets, or a test that is specifically exercising those + arguments). + +## Commits + +- Use conventional commit prefixes: `fix:`, `feat:`, `refactor:`, `chore:`, + `docs:`, `test:`. +- Keep imports at the top of the file. diff --git a/tests/test_ds.py b/tests/test_ds.py index aa3deb2..3682dac 100644 --- a/tests/test_ds.py +++ b/tests/test_ds.py @@ -125,7 +125,7 @@ def test_aggregation_drops_dim(air_dataset_small): ctx.from_dataset("air", air_dataset_small) out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() assert set(out.dims) == {"lat", "lon"} assert "air_avg" in out.data_vars assert "air" not in out.data_vars @@ -139,6 +139,23 @@ def test_aggregation_drops_dim(air_dataset_small): np.testing.assert_allclose(actual, expected) +def test_aggregation_infers_dims(air_dataset_small): + """to_dataset() infers the surviving GROUP BY dim when dims is omitted.""" + ctx = XarrayContext() + ctx.from_dataset("air", air_dataset_small) + + # Grouping by the time coordinate keeps time as the sole dimension; the + # ORDER BY makes the result order deterministic so no sort is needed below. + out = ctx.sql( + 'SELECT "time", AVG("air") AS air FROM "air" ' + 'GROUP BY "time" ORDER BY "time"' + ).to_dataset() + assert set(out.dims) == {"time"} + assert "air" in out.data_vars + expected = air_dataset_small.compute().mean(dim=["lat", "lon"])["air"] + np.testing.assert_allclose(out["air"].values, expected.values) + + def test_barrier_query_scans_source_once(air_dataset_small): """A barrier plan (aggregation) executes the source exactly once. @@ -166,7 +183,7 @@ def test_barrier_query_scans_source_once(air_dataset_small): out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() reads_after_construct = len(reads) out.compute() reads_after_compute = len(reads) @@ -188,7 +205,7 @@ def test_order_by_direction_sets_dim_order(air_dataset_small): ctx.from_dataset("air", air_dataset_small) out = ctx.sql( "SELECT lat, AVG(air) AS air_avg FROM air GROUP BY lat ORDER BY lat DESC" - ).to_dataset(dims=["lat"]) + ).to_dataset() lat = out["lat"].values assert (np.diff(lat) < 0).all(), f"expected descending lat, got {lat}" @@ -289,7 +306,7 @@ def test_fast_path_uses_scanned_tables_coords_not_user_template( def test_round_trip_preserves_descending_lat_on_lazy_path(air_dataset_small): - """Lazy round-trip preserves source dim order (xarray-sql#171). + """Lazy round-trip preserves source dim order. NCEP ``air_temperature`` ships descending lat (75.0 -> 15.0). The discovery path's ``.distinct().sort()`` previously flipped lat to @@ -383,14 +400,12 @@ def test_to_dataset_multi_registered_requires_explicit_template( assert set(out.dims) == {"time", "lat", "lon"} -def test_to_dataset_infer_fails_when_no_template_fits(air_dataset_small): - """If no registered Dataset's dims fit the result -> clear error.""" +def test_to_dataset_infer_fails_when_no_dim_survives(air_dataset_small): + """A global aggregation leaves no registered dim in the result -> clear error.""" ctx = XarrayContext() ctx.from_dataset("air", air_dataset_small) with pytest.raises(ValueError, match="dims cannot be inferred"): - ctx.sql( - "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset() + ctx.sql("SELECT AVG(air) AS air_avg FROM air").to_dataset() def test_template_accepts_name_or_dataset(air_dataset_small): @@ -447,7 +462,7 @@ def test_template_aggregation_alias_no_attrs(air_dataset_small): ctx.from_dataset("air", ds) out = ctx.sql( "SELECT lat, lon, AVG(air) AS air_avg FROM air GROUP BY lat, lon" - ).to_dataset(dims=["lat", "lon"]) + ).to_dataset() assert "air_avg" in out.data_vars assert out["air_avg"].attrs == {} diff --git a/xarray_sql/ds.py b/xarray_sql/ds.py index 5dfdf42..cd1db39 100644 --- a/xarray_sql/ds.py +++ b/xarray_sql/ds.py @@ -757,10 +757,12 @@ def to_dataset( Args: dims: Result columns to use as Dataset dimensions. When - ``None``, defaults to the dims of the registered Dataset - referenced by the SQL ``FROM`` clause (if exactly one - matches), or any single registered Dataset whose dims are - all present in the result columns. + ``None``, defaults to a registered Dataset's dimensions that + survive into the result columns, so an aggregation that drops + dims (e.g. ``GROUP BY time`` over a ``(time, lat, lon)`` grid) + round-trips on the remaining dim. Raises when no dimension + survives, or when several registered Datasets imply different + dims (pass ``dims`` explicitly then). template: Source to recover metadata (attrs, encoding, non-dim coordinates, dim-coord dtype) from. Either an ``xr.Dataset`` used directly, or the name of a registered table (e.g. @@ -879,33 +881,37 @@ def _infer_dimension_columns( ) -> list[str]: """Pick a default ``dimension_columns`` from the registry, or raise. - Uses the data variable's dim order (via :func:`_ds_var_dims`) so - the round-trip preserves the original axis order. + A registered Dataset's dims that survive into the result columns + become the dimensions, so aggregations that drop dims (e.g. + ``GROUP BY time`` over a ``(time, lat, lon)`` grid) round-trip on the + surviving dim(s). Uses the data variable's dim order (via + :func:`_ds_var_dims`) so the original axis order is preserved. """ result_cols = set(self._result_columns()) - if ( - preferred_template is not None - and set(preferred_template.dims) <= result_cols - ): - return _ds_var_dims(preferred_template) + + def surviving(template: xr.Dataset) -> list[str]: + # Template dims still present in the result, in var axis order. + return [d for d in _ds_var_dims(template) if d in result_cols] + + if preferred_template is not None: + preferred = surviving(preferred_template) + if preferred: + return preferred if not self._templates: raise ValueError( "dims cannot be inferred (no registered " "Dataset on this result); pass dims=[...] " "explicitly." ) - candidates = [ - _ds_var_dims(t) - for t in self._templates.values() - if set(t.dims) <= result_cols - ] + candidates = {tuple(surviving(t)) for t in self._templates.values()} + candidates.discard(()) # templates with no surviving dim if len(candidates) == 1: - return candidates[0] + return list(next(iter(candidates))) if not candidates: raise ValueError( - "dims cannot be inferred: no registered " - "Dataset has all of its dims present in the result " - "columns. Pass dims=[...] explicitly." + "dims cannot be inferred: no registered Dataset " + "dimension survives in the result columns. Pass " + "dims=[...] explicitly." ) raise ValueError( "dims cannot be inferred unambiguously: multiple "