Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/test_grid_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Integer grid-position columns (``<dim>_idx``) for exact grid joins.

Regridding and forecast alignment join a source grid to a table on the source
*coordinate*. Joining on the floating-point coordinate value is fragile — any
sub-ULP drift (e.g. a reproject/interp computed in float32) makes the equality
join silently drop rows. The opt-in ``<dim>_idx`` columns give exact integer
grid keys instead. These tests pin that the indices are global (partition
independent) and that an index-keyed regrid stays exact where a float join does
not.
"""

import numpy as np
import pyarrow as pa
import xarray as xr

from xarray_sql import XarrayContext
from xarray_sql.df import _ensure_default_indexes, _parse_schema


# Deliberately not float32-exact, so a float32 round-trip actually perturbs the
# bits (a float32-exact axis like 10.0/20.0 would round-trip unchanged).
_X = np.array([10.1, 20.2, 30.3])
_Y = np.array([1.1, 2.2, 3.3, 4.4])


def _src() -> xr.Dataset:
return xr.Dataset(
{"v": (("x", "y"), np.arange(12.0).reshape(3, 4))},
coords={"x": _X, "y": _Y},
)


def test_index_columns_are_int32_in_schema():
schema = _parse_schema(_ensure_default_indexes(_src()), index_columns=True)
assert schema.field("x_idx").type == pa.int32()
assert schema.field("y_idx").type == pa.int32()
# Not present unless requested.
plain = _parse_schema(_ensure_default_indexes(_src()))
assert "x_idx" not in plain.names


def test_index_columns_are_global_across_chunks():
"""Indices are absolute axis positions, not per-partition local ones."""
ctx = XarrayContext()
# chunk x into 3 single-row partitions: a local index would restart at 0
# in every partition; the global index must run 0,1,2.
ctx.from_dataset("g", _src(), chunks={"x": 1}, index_columns=True)
df = ctx.sql("SELECT x, y, x_idx, y_idx FROM g").to_pandas()

assert df["x_idx"].dtype == np.int32
xpos = {v: i for i, v in enumerate(_X)}
ypos = {v: i for i, v in enumerate(_Y)}
assert (df["x_idx"] == df["x"].map(xpos)).all()
assert (df["y_idx"] == df["y"].map(ypos)).all()


def _weights(perturb_f32: bool = False) -> xr.Dataset:
"""A tiny gather 'regrid': dst cell k reads one src cell, weight 1.0."""
# dst 0..4 map to source cells (x_idx, y_idx):
sx = np.array([0, 2, 1, 0, 2], dtype=np.int32)
sy = np.array([0, 3, 1, 3, 0], dtype=np.int32)
src_x = _X[sx]
src_y = _Y[sy]
if perturb_f32: # coords computed in single precision, as a reproject UDF
src_x = src_x.astype(np.float32).astype(np.float64)
src_y = src_y.astype(np.float32).astype(np.float64)
return xr.Dataset(
{
"dst_id": (("pair",), np.arange(5, dtype=np.int32)),
"src_x_idx": (("pair",), sx),
"src_y_idx": (("pair",), sy),
"src_x": (("pair",), src_x),
"src_y": (("pair",), src_y),
"weight": (("pair",), np.ones(5)),
}
)


INDEX_JOIN = """
SELECT w.dst_id, SUM(s.v * w.weight) AS out
FROM weights w JOIN src s
ON s.x_idx = w.src_x_idx AND s.y_idx = w.src_y_idx
GROUP BY w.dst_id ORDER BY w.dst_id
"""
FLOAT_JOIN = """
SELECT w.dst_id, SUM(s.v * w.weight) AS out
FROM weights w JOIN src s
ON s.x = w.src_x AND s.y = w.src_y
GROUP BY w.dst_id ORDER BY w.dst_id
"""


def _ctx(weights: xr.Dataset) -> XarrayContext:
ctx = XarrayContext()
ctx.from_dataset("src", _src(), chunks={"x": 1}, index_columns=True)
ctx.from_dataset("weights", weights, chunks={"pair": 5})
return ctx


def test_index_join_regrids_exactly():
"""Index-keyed regrid matches a direct numpy gather, across chunks."""
ctx = _ctx(_weights())
got = ctx.sql(INDEX_JOIN).to_pandas().sort_values("dst_id")

v = np.arange(12.0).reshape(3, 4)
sx = np.array([0, 2, 1, 0, 2])
sy = np.array([0, 3, 1, 3, 0])
expected = v[sx, sy]
np.testing.assert_allclose(got["out"].to_numpy(), expected)


def test_index_join_survives_float32_drift_where_float_join_does_not():
ctx = _ctx(_weights(perturb_f32=True))
idx = ctx.sql(INDEX_JOIN).to_pandas()
flt = ctx.sql(FLOAT_JOIN).to_pandas()
# Index join keeps every destination cell; the float-equality join drops the
# cells whose float32-roundtripped coord no longer bit-matches the source.
assert len(idx) == 5
assert len(flt) < 5
70 changes: 65 additions & 5 deletions xarray_sql/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,10 @@ def pivot(ds: xr.Dataset) -> pd.DataFrame:


def dataset_to_record_batch(
ds: xr.Dataset, schema: pa.Schema
ds: xr.Dataset,
schema: pa.Schema,
*,
index_offsets: Mapping[str, int] | None = None,
) -> pa.RecordBatch:
"""Convert an xarray Dataset partition to an Arrow RecordBatch.

Expand Down Expand Up @@ -267,10 +270,24 @@ def dataset_to_record_batch(
dim_names = list(ds.sizes.keys())
shape = tuple(ds.sizes[d] for d in dim_names)

offsets = index_offsets or {}
index_fields = {
index_column_name(d): (k, int(offsets.get(d, 0)))
for k, d in enumerate(dim_names)
}

arrays = []
for field in schema:
name = field.name
if name in ds.coords and name in ds.dims:
if name in index_fields:
# Absolute integer grid position for this dim, broadcast+ravelled.
axis, offset = index_fields[name]
idx = offset + np.arange(shape[axis], dtype=np.int64)
reshape = [1] * len(shape)
reshape[axis] = shape[axis]
arr = np.broadcast_to(idx.reshape(reshape), shape).ravel()
arrays.append(pa.array(arr, type=field.type))
elif name in ds.coords and name in ds.dims:
# Broadcast 1-D coordinate to the full N-D partition shape, then ravel.
axis = dim_names.index(name)
coord = ds.coords[name].values
Expand Down Expand Up @@ -302,6 +319,8 @@ def iter_record_batches(
ds: xr.Dataset,
schema: pa.Schema,
batch_size: int = DEFAULT_BATCH_SIZE,
*,
index_offsets: Mapping[str, int] | None = None,
) -> Iterator[pa.RecordBatch]:
"""Yield RecordBatches of at most *batch_size* rows from a partition Dataset.

Expand Down Expand Up @@ -350,11 +369,21 @@ def iter_record_batches(
# Flat row index i → coordinate index for dim k: (i // stride[k]) % shape[k].
strides = [int(np.prod(shape[k + 1 :])) for k in range(len(shape))]

# Integer grid-position (`<dim>_idx`) columns, if the schema requests them:
# map the column name to its dim position and absolute block offset so the
# emitted index is global (partition-independent), which is what a join key
# across partitions requires.
offsets = index_offsets or {}
index_fields = {
index_column_name(d): (k, int(offsets.get(d, 0)))
for k, d in enumerate(dim_names)
}

# Load data-variable arrays fully (triggers Dask/Zarr compute once).
# ravel() is a zero-copy view for C-contiguous arrays.
data_arrays = {}
for field in schema:
if field.name not in ds.dims:
if field.name not in ds.dims and field.name not in index_fields:
raw = ds[field.name].values
if cft.is_cftime(raw):
data_arrays[field.name] = cft.convert_for_field(raw, field)
Expand All @@ -368,7 +397,11 @@ def iter_record_batches(
arrays = []
for field in schema:
name = field.name
if name in ds.coords and name in ds.dims:
if name in index_fields:
k, offset = index_fields[name]
idx = offset + (row_idx // strides[k]) % shape[k]
arrays.append(pa.array(idx, type=field.type))
elif name in ds.coords and name in ds.dims:
k = dim_names.index(name)
coord_idx = (row_idx // strides[k]) % shape[k]
arrays.append(
Expand All @@ -386,14 +419,30 @@ def iter_record_batches(
yield pa.RecordBatch.from_arrays(arrays, schema=schema)


def _parse_schema(ds: xr.Dataset) -> pa.Schema:
#: Suffix for the integer grid-position column of a dimension.
INDEX_COLUMN_SUFFIX = "_idx"


def index_column_name(dim: str) -> str:
"""Name of the integer grid-position column for dimension ``dim``."""
return f"{dim}{INDEX_COLUMN_SUFFIX}"


def _parse_schema(ds: xr.Dataset, *, index_columns: bool = False) -> pa.Schema:
"""Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns.

Only *dimension coordinates* become dimension columns, so a dimension
without a coordinate would be dropped. Callers must run the Dataset through
:func:`_ensure_default_indexes` first (the readers do) so every dimension
has a coordinate and appears as a column.

When ``index_columns`` is set, an ``int32`` ``<dim>_idx`` column is appended
for every dimension, carrying each row's absolute integer position along that
axis. These are exact integer keys for grid joins (regridding weight tables,
forecast alignment) — faster than, and free of the float-equality fragility
of, joining on the floating-point coordinate values, while the coordinate
columns remain available for value predicates and display.

Uses the xarray index type to detect cftime coordinates without
materializing their data — important for Dask/Zarr-backed datasets
where .values would trigger eager computation.
Expand Down Expand Up @@ -431,6 +480,17 @@ def _parse_schema(ds: xr.Dataset) -> pa.Schema:
pa_type = pa.from_numpy_dtype(var.dtype)
columns.append(pa.field(var_name, pa_type))

if index_columns:
existing = {f.name for f in columns}
for dim in ds.dims:
name = index_column_name(str(dim))
if name in existing:
raise ValueError(
f"cannot add index column {name!r}: a column with that "
f"name already exists (dimension {dim!r})"
)
columns.append(pa.field(name, pa.int32()))

return pa.schema(columns)


Expand Down
20 changes: 15 additions & 5 deletions xarray_sql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def read_xarray_table(
chunks: Chunks = None,
*,
batch_size: int = DEFAULT_BATCH_SIZE,
index_columns: bool = False,
coord_arrays: dict[str, np.ndarray] | None = None,
_iteration_callback: (
Callable[[Block, list[str] | None], None] | None
Expand Down Expand Up @@ -257,7 +258,7 @@ def read_xarray_table(
from ._native import LazyArrowStreamTable

ds = _ensure_default_indexes(ds)
schema = _parse_schema(ds)
schema = _parse_schema(ds, index_columns=index_columns)

# Hoist coordinate reads once; avoids N_partitions remote I/O calls for
# Zarr-backed datasets (e.g. ARCO-ERA5 on GCS). When the caller supplies
Expand All @@ -281,15 +282,16 @@ def make_stream(

if projection_names is not None:
# Restrict to the data variables mentioned in the projection.
# Dimension coordinates come along automatically via coords.
# Dimension coordinates come along automatically via coords;
# index columns are computed from position, not loaded.
data_vars_needed = [
c for c in projection_names if c in data_var_names
]
if data_vars_needed:
ds_block = ds[data_vars_needed].isel(block)
else:
# Only dimension coords requested — drop all data vars to avoid
# loading them unnecessarily (e.g. for queries like SELECT lat, lon).
# Only dimension coords / index columns requested — drop all
# data vars to avoid loading them (e.g. SELECT lat, lon_idx).
ds_block = ds.drop_vars(list(ds.data_vars)).isel(block)
batch_schema = pa.schema(
[schema.field(name) for name in projection_names]
Expand All @@ -298,9 +300,17 @@ def make_stream(
ds_block = ds.isel(block)
batch_schema = schema

# Absolute start of this block on each axis, so `<dim>_idx` columns
# carry global positions that line up across partitions.
index_offsets = {str(dim): (block[dim].start or 0) for dim in block}
return pa.RecordBatchReader.from_batches(
batch_schema,
iter_record_batches(ds_block, batch_schema, batch_size),
iter_record_batches(
ds_block,
batch_schema,
batch_size,
index_offsets=index_offsets,
),
)

return make_stream
Expand Down
23 changes: 21 additions & 2 deletions xarray_sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def from_dataset(
*,
table_names: dict[tuple[str, ...], str] | None = None,
chunks: Chunks = None,
index_columns: bool = False,
):
"""Register an xarray Dataset as one or more queryable SQL tables.

Expand Down Expand Up @@ -83,6 +84,13 @@ def from_dataset(
variables with differing dimensions.
chunks: Xarray-like chunks specification. If not provided, uses
the Dataset's existing chunks.
index_columns: When True, add an ``int32`` ``<dim>_idx`` column for
every dimension, carrying each row's absolute integer position
on that axis. These are exact, overflow-free integer keys for
grid joins (regridding weight tables, forecast alignment) — join
on them instead of the floating-point coordinates to avoid
float-equality mismatches, while the coordinate columns stay
available for value predicates and display.

Returns:
self, to allow chaining.
Expand All @@ -99,7 +107,11 @@ def from_dataset(
if len(groups) <= 1:
self._registered_datasets[name] = input_table
return self._from_dataset(
name, input_table, chunks, coord_arrays=coord_arrays
name,
input_table,
chunks,
coord_arrays=coord_arrays,
index_columns=index_columns,
)

table_names = table_names or {}
Expand All @@ -117,6 +129,7 @@ def from_dataset(
chunks,
schema=schema,
coord_arrays=coord_arrays,
index_columns=index_columns,
)
# Track the fully-qualified name so XarrayDataFrame metadata
# recovery can find this Dataset on round-trip.
Expand All @@ -131,6 +144,7 @@ def _from_dataset(
chunks: Chunks = None,
schema: Schema | None = None,
coord_arrays: dict | None = None,
index_columns: bool = False,
):
"""Register a Dataset as a single SQL table.

Expand All @@ -142,7 +156,12 @@ def _from_dataset(
)
register(
table_name,
read_xarray_table(input_table, chunks, coord_arrays=coord_arrays),
read_xarray_table(
input_table,
chunks,
index_columns=index_columns,
coord_arrays=coord_arrays,
),
)
self._maybe_register_cftime_udf(input_table)
return self
Expand Down
Loading