diff --git a/tests/test_sql.py b/tests/test_sql.py index 857a7bc..746dfb3 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -488,3 +488,116 @@ def test_table_names_is_keyword_only(self, mixed_ds): ctx = XarrayContext() with pytest.raises(TypeError): ctx.from_dataset("era5", mixed_ds, {("time",): "x"}) + + @pytest.fixture + def coordless_dims_ds(self): + """Mirror the fashion-mnist layout: a dimension + coordinate (``sample``) alongside dimensions without coordinates + (``channel``/``height``/``width``).""" + n_sample, n_channel, n_height, n_width = 4, 1, 3, 3 + return xr.Dataset( + { + "images": ( + ["sample", "channel", "height", "width"], + np.arange( + n_sample * n_channel * n_height * n_width, + dtype="float32", + ).reshape(n_sample, n_channel, n_height, n_width), + ), + "labels": (["sample"], np.arange(n_sample, dtype="int64")), + }, + coords={"sample": ("sample", np.arange(n_sample, dtype="int64"))}, + ).chunk({"sample": 1}) + + def test_coordless_dims_appear_as_columns(self, coordless_dims_ds): + """Dimensions without coordinates must still be emitted as columns, + not silently dropped from the schema.""" + ctx = XarrayContext() + ctx.from_dataset( + "mnist", + coordless_dims_ds, + table_names={ + ("sample", "channel", "height", "width"): "X", + ("sample",): "y", + }, + ) + result = ctx.sql('SELECT * FROM mnist."X"').to_pandas() + assert set(result.columns) == { + "sample", + "channel", + "height", + "width", + "images", + } + + def test_coordless_dims_values_match_xarray(self, coordless_dims_ds): + """The X table's rows must match xarray's own pivot exactly, including + the synthesized index values for the coordinate-less dimensions.""" + ctx = XarrayContext() + ctx.from_dataset( + "mnist", + coordless_dims_ds, + table_names={ + ("sample", "channel", "height", "width"): "X", + ("sample",): "y", + }, + ) + dim_cols = ["sample", "channel", "height", "width"] + result = ( + ctx.sql('SELECT * FROM mnist."X"') + .to_pandas() + .sort_values(dim_cols) + .reset_index(drop=True) + ) + expected = ( + coordless_dims_ds[["images"]] + .to_dataframe() + .reset_index() + .sort_values(dim_cols) + .reset_index(drop=True) + ) + pd.testing.assert_frame_equal( + result, expected, check_dtype=False, check_like=True + ) + + def test_coordless_dims_y_table_unaffected(self, coordless_dims_ds): + """The 1-D ``y`` group (sample coordinate + labels) is unchanged.""" + ctx = XarrayContext() + ctx.from_dataset( + "mnist", + coordless_dims_ds, + table_names={ + ("sample", "channel", "height", "width"): "X", + ("sample",): "y", + }, + ) + result = ctx.sql('SELECT * FROM mnist."y"').to_pandas() + assert set(result.columns) == {"sample", "labels"} + assert len(result) == coordless_dims_ds.sizes["sample"] + + def test_single_table_all_coordless_dims(self): + """A uniform-dim dataset whose dims lack coordinates registers as one + table with every dimension present as a column, and the coordinate-less + dimensions carry their ABSOLUTE index even when chunked.""" + ds = xr.Dataset( + {"a": (("x", "y"), np.arange(6, dtype="float32").reshape(3, 2))} + ).chunk({"x": 1}) # chunked along the coordinate-less 'x' dim + ctx = XarrayContext() + ctx.from_dataset("grid", ds) + result = ( + ctx.sql("SELECT * FROM grid") + .to_pandas() + .sort_values(["x", "y"]) + .reset_index(drop=True) + ) + assert set(result.columns) == {"x", "y", "a"} + # x must span 0..2 across the three chunks, not restart at 0 each block. + expected = ( + ds.to_dataframe() + .reset_index() + .sort_values(["x", "y"]) + .reset_index(drop=True) + ) + pd.testing.assert_frame_equal( + result, expected, check_dtype=False, check_like=True + ) diff --git a/xarray_sql/df.py b/xarray_sql/df.py index 99c8408..ab80056 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -73,6 +73,24 @@ def resolve_chunks( return {d: tuple(c) for d, c in ds.chunks.items()} +def _ensure_default_indexes(ds: xr.Dataset) -> xr.Dataset: + """Attach a default integer index coordinate to every dimension lacking one. + + xarray allows "dimensions without coordinates"; these are absent from + ``ds.coords``, so they are dropped from the SQL schema and, once a block is + sliced out with ``isel``, their position is synthesized *relative to the + block* (restarting at 0 in every partition). Materialising an explicit + ``arange`` index up front turns them into ordinary dimension coordinates, so + they appear as columns and carry their absolute position through chunked + reads. Datasets whose dimensions already have coordinates are returned + unchanged. + """ + missing = { + dim: np.arange(ds.sizes[dim]) for dim in ds.dims if dim not in ds.coords + } + return ds.assign_coords(missing) if missing else ds + + def _block_slices_from_resolved( ds: xr.Dataset, resolved: Mapping[Hashable, tuple[int, ...]] ) -> Iterator[Block]: @@ -371,6 +389,11 @@ def iter_record_batches( def _parse_schema(ds: xr.Dataset) -> pa.Schema: """Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns. + Only *dimension coordinates* become dimension columns, so a dimension + without a coordinate would be dropped. Callers must run the Dataset through + :func:`_ensure_default_indexes` first (the readers do) so every dimension + has a coordinate and appears as a column. + Uses the xarray index type to detect cftime coordinates without materializing their data — important for Dask/Zarr-backed datasets where .values would trigger eager computation. diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index bfb7be2..153ed8c 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -24,6 +24,7 @@ _block_len, _block_metadata, _block_slices_from_resolved, + _ensure_default_indexes, _parse_schema, block_slices, iter_record_batches, @@ -185,6 +186,7 @@ def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader: A PyArrow RecordBatchReader, which is a table representation of the input Dataset. """ + ds = _ensure_default_indexes(ds) reader = XarrayRecordBatchReader(ds, chunks=chunks) return pa.RecordBatchReader.from_stream(reader) @@ -254,6 +256,7 @@ def read_xarray_table( """ from ._native import LazyArrowStreamTable + ds = _ensure_default_indexes(ds) schema = _parse_schema(ds) # Hoist coordinate reads once; avoids N_partitions remote I/O calls for