Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,116 @@ def test_table_names_is_keyword_only(self, mixed_ds):
ctx = XarrayContext()
with pytest.raises(TypeError):
ctx.from_dataset("era5", mixed_ds, {("time",): "x"})

@pytest.fixture
def coordless_dims_ds(self):
"""Mirror the fashion-mnist layout: a dimension
coordinate (``sample``) alongside dimensions without coordinates
(``channel``/``height``/``width``)."""
n_sample, n_channel, n_height, n_width = 4, 1, 3, 3
return xr.Dataset(
{
"images": (
["sample", "channel", "height", "width"],
np.arange(
n_sample * n_channel * n_height * n_width,
dtype="float32",
).reshape(n_sample, n_channel, n_height, n_width),
),
"labels": (["sample"], np.arange(n_sample, dtype="int64")),
},
coords={"sample": ("sample", np.arange(n_sample, dtype="int64"))},
).chunk({"sample": 1})

def test_coordless_dims_appear_as_columns(self, coordless_dims_ds):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this test.

"""Dimensions without coordinates must still be emitted as columns,
not silently dropped from the schema."""
ctx = XarrayContext()
ctx.from_dataset(
"mnist",
coordless_dims_ds,
table_names={
("sample", "channel", "height", "width"): "X",
("sample",): "y",
},
)
result = ctx.sql('SELECT * FROM mnist."X"').to_pandas()
assert set(result.columns) == {
"sample",
"channel",
"height",
"width",
"images",
}

def test_coordless_dims_values_match_xarray(self, coordless_dims_ds):
"""The X table's rows must match xarray's own pivot exactly, including
the synthesized index values for the coordinate-less dimensions."""
ctx = XarrayContext()
ctx.from_dataset(
"mnist",
coordless_dims_ds,
table_names={
("sample", "channel", "height", "width"): "X",
("sample",): "y",
},
)
dim_cols = ["sample", "channel", "height", "width"]
result = (
ctx.sql('SELECT * FROM mnist."X"')
.to_pandas()
.sort_values(dim_cols)
.reset_index(drop=True)
)
expected = (
coordless_dims_ds[["images"]]
.to_dataframe()
.reset_index()
.sort_values(dim_cols)
.reset_index(drop=True)
)
pd.testing.assert_frame_equal(
result, expected, check_dtype=False, check_like=True
)

def test_coordless_dims_y_table_unaffected(self, coordless_dims_ds):
"""The 1-D ``y`` group (sample coordinate + labels) is unchanged."""
ctx = XarrayContext()
ctx.from_dataset(
"mnist",
coordless_dims_ds,
table_names={
("sample", "channel", "height", "width"): "X",
("sample",): "y",
},
)
result = ctx.sql('SELECT * FROM mnist."y"').to_pandas()
assert set(result.columns) == {"sample", "labels"}
assert len(result) == coordless_dims_ds.sizes["sample"]

def test_single_table_all_coordless_dims(self):
"""A uniform-dim dataset whose dims lack coordinates registers as one
table with every dimension present as a column, and the coordinate-less
dimensions carry their ABSOLUTE index even when chunked."""
ds = xr.Dataset(
{"a": (("x", "y"), np.arange(6, dtype="float32").reshape(3, 2))}
).chunk({"x": 1}) # chunked along the coordinate-less 'x' dim
ctx = XarrayContext()
ctx.from_dataset("grid", ds)
result = (
ctx.sql("SELECT * FROM grid")
.to_pandas()
.sort_values(["x", "y"])
.reset_index(drop=True)
)
assert set(result.columns) == {"x", "y", "a"}
# x must span 0..2 across the three chunks, not restart at 0 each block.
expected = (
ds.to_dataframe()
.reset_index()
.sort_values(["x", "y"])
.reset_index(drop=True)
)
pd.testing.assert_frame_equal(
result, expected, check_dtype=False, check_like=True
)
23 changes: 23 additions & 0 deletions xarray_sql/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ def resolve_chunks(
return {d: tuple(c) for d, c in ds.chunks.items()}


def _ensure_default_indexes(ds: xr.Dataset) -> xr.Dataset:
"""Attach a default integer index coordinate to every dimension lacking one.

xarray allows "dimensions without coordinates"; these are absent from
``ds.coords``, so they are dropped from the SQL schema and, once a block is
sliced out with ``isel``, their position is synthesized *relative to the
block* (restarting at 0 in every partition). Materialising an explicit
``arange`` index up front turns them into ordinary dimension coordinates, so
they appear as columns and carry their absolute position through chunked
reads. Datasets whose dimensions already have coordinates are returned
unchanged.
"""
missing = {
dim: np.arange(ds.sizes[dim]) for dim in ds.dims if dim not in ds.coords
}
return ds.assign_coords(missing) if missing else ds


def _block_slices_from_resolved(
ds: xr.Dataset, resolved: Mapping[Hashable, tuple[int, ...]]
) -> Iterator[Block]:
Expand Down Expand Up @@ -371,6 +389,11 @@ def iter_record_batches(
def _parse_schema(ds: xr.Dataset) -> pa.Schema:
"""Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns.

Only *dimension coordinates* become dimension columns, so a dimension
without a coordinate would be dropped. Callers must run the Dataset through
:func:`_ensure_default_indexes` first (the readers do) so every dimension
has a coordinate and appears as a column.

Uses the xarray index type to detect cftime coordinates without
materializing their data — important for Dask/Zarr-backed datasets
where .values would trigger eager computation.
Expand Down
3 changes: 3 additions & 0 deletions xarray_sql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_block_len,
_block_metadata,
_block_slices_from_resolved,
_ensure_default_indexes,
_parse_schema,
block_slices,
iter_record_batches,
Expand Down Expand Up @@ -185,6 +186,7 @@ def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader:
A PyArrow RecordBatchReader, which is a table representation of the input
Dataset.
"""
ds = _ensure_default_indexes(ds)
reader = XarrayRecordBatchReader(ds, chunks=chunks)
return pa.RecordBatchReader.from_stream(reader)

Expand Down Expand Up @@ -254,6 +256,7 @@ def read_xarray_table(
"""
from ._native import LazyArrowStreamTable

ds = _ensure_default_indexes(ds)
schema = _parse_schema(ds)

# Hoist coordinate reads once; avoids N_partitions remote I/O calls for
Expand Down
Loading