From d790a4cdaf58bd37245af2c04c7d386b2eba42c2 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 3 Jul 2026 09:20:16 +0000 Subject: [PATCH] Add opt-in integer grid-position columns for exact grid joins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Regridding and forecast alignment join a source grid to a table on the source coordinate. Joining on the floating-point coordinate value is fragile: any sub-ULP drift — e.g. a reproject/interp UDF that computes in float32 — makes the equality join silently drop rows, producing a wrong answer with no error. At realistic scale a float32 round-trip of the weight-table coordinates drops ~99.9% of destination cells. Add `from_dataset(..., index_columns=True)`: for every dimension, emit an `int32` `_idx` column carrying each row's *absolute* integer position on that axis. Joining grids on these integer keys is exact (no float-equality mismatch), a bit faster (integer hashing), and half the key bytes — and, unlike dictionary-encoded coordinates, they are plain Int32 columns, so nothing in DataFusion's join/aggregate/scalar-function paths is stressed. The indices are global, not per-partition: the reader adds each block's start offset so keys line up across chunks (a local index would restart at 0 in every partition and mis-join). Coordinate columns stay dense and available for value predicates (`WHERE lat > 45`, `date_part`) and display; the index columns are opt-in and off by default. - df.py: `_parse_schema(index_columns=)` appends the `_idx` fields (with a collision guard); `iter_record_batches` / `dataset_to_record_batch` emit them from the strided position plus a block offset. - reader.py / sql.py: thread `index_columns` through `read_xarray_table` and `from_dataset`, computing per-block offsets so indices are global. - tests: global-across-chunks, exact index-keyed regrid, and the float32-drift case where the float join drops cells but the index join stays exact. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N --- tests/test_grid_index.py | 119 +++++++++++++++++++++++++++++++++++++++ xarray_sql/df.py | 70 +++++++++++++++++++++-- xarray_sql/reader.py | 20 +++++-- xarray_sql/sql.py | 23 +++++++- 4 files changed, 220 insertions(+), 12 deletions(-) create mode 100644 tests/test_grid_index.py diff --git a/tests/test_grid_index.py b/tests/test_grid_index.py new file mode 100644 index 0000000..84f6688 --- /dev/null +++ b/tests/test_grid_index.py @@ -0,0 +1,119 @@ +"""Integer grid-position columns (``_idx``) for exact grid joins. + +Regridding and forecast alignment join a source grid to a table on the source +*coordinate*. Joining on the floating-point coordinate value is fragile — any +sub-ULP drift (e.g. a reproject/interp computed in float32) makes the equality +join silently drop rows. The opt-in ``_idx`` columns give exact integer +grid keys instead. These tests pin that the indices are global (partition +independent) and that an index-keyed regrid stays exact where a float join does +not. +""" + +import numpy as np +import pyarrow as pa +import xarray as xr + +from xarray_sql import XarrayContext +from xarray_sql.df import _ensure_default_indexes, _parse_schema + + +# Deliberately not float32-exact, so a float32 round-trip actually perturbs the +# bits (a float32-exact axis like 10.0/20.0 would round-trip unchanged). +_X = np.array([10.1, 20.2, 30.3]) +_Y = np.array([1.1, 2.2, 3.3, 4.4]) + + +def _src() -> xr.Dataset: + return xr.Dataset( + {"v": (("x", "y"), np.arange(12.0).reshape(3, 4))}, + coords={"x": _X, "y": _Y}, + ) + + +def test_index_columns_are_int32_in_schema(): + schema = _parse_schema(_ensure_default_indexes(_src()), index_columns=True) + assert schema.field("x_idx").type == pa.int32() + assert schema.field("y_idx").type == pa.int32() + # Not present unless requested. + plain = _parse_schema(_ensure_default_indexes(_src())) + assert "x_idx" not in plain.names + + +def test_index_columns_are_global_across_chunks(): + """Indices are absolute axis positions, not per-partition local ones.""" + ctx = XarrayContext() + # chunk x into 3 single-row partitions: a local index would restart at 0 + # in every partition; the global index must run 0,1,2. + ctx.from_dataset("g", _src(), chunks={"x": 1}, index_columns=True) + df = ctx.sql("SELECT x, y, x_idx, y_idx FROM g").to_pandas() + + assert df["x_idx"].dtype == np.int32 + xpos = {v: i for i, v in enumerate(_X)} + ypos = {v: i for i, v in enumerate(_Y)} + assert (df["x_idx"] == df["x"].map(xpos)).all() + assert (df["y_idx"] == df["y"].map(ypos)).all() + + +def _weights(perturb_f32: bool = False) -> xr.Dataset: + """A tiny gather 'regrid': dst cell k reads one src cell, weight 1.0.""" + # dst 0..4 map to source cells (x_idx, y_idx): + sx = np.array([0, 2, 1, 0, 2], dtype=np.int32) + sy = np.array([0, 3, 1, 3, 0], dtype=np.int32) + src_x = _X[sx] + src_y = _Y[sy] + if perturb_f32: # coords computed in single precision, as a reproject UDF + src_x = src_x.astype(np.float32).astype(np.float64) + src_y = src_y.astype(np.float32).astype(np.float64) + return xr.Dataset( + { + "dst_id": (("pair",), np.arange(5, dtype=np.int32)), + "src_x_idx": (("pair",), sx), + "src_y_idx": (("pair",), sy), + "src_x": (("pair",), src_x), + "src_y": (("pair",), src_y), + "weight": (("pair",), np.ones(5)), + } + ) + + +INDEX_JOIN = """ + SELECT w.dst_id, SUM(s.v * w.weight) AS out + FROM weights w JOIN src s + ON s.x_idx = w.src_x_idx AND s.y_idx = w.src_y_idx + GROUP BY w.dst_id ORDER BY w.dst_id +""" +FLOAT_JOIN = """ + SELECT w.dst_id, SUM(s.v * w.weight) AS out + FROM weights w JOIN src s + ON s.x = w.src_x AND s.y = w.src_y + GROUP BY w.dst_id ORDER BY w.dst_id +""" + + +def _ctx(weights: xr.Dataset) -> XarrayContext: + ctx = XarrayContext() + ctx.from_dataset("src", _src(), chunks={"x": 1}, index_columns=True) + ctx.from_dataset("weights", weights, chunks={"pair": 5}) + return ctx + + +def test_index_join_regrids_exactly(): + """Index-keyed regrid matches a direct numpy gather, across chunks.""" + ctx = _ctx(_weights()) + got = ctx.sql(INDEX_JOIN).to_pandas().sort_values("dst_id") + + v = np.arange(12.0).reshape(3, 4) + sx = np.array([0, 2, 1, 0, 2]) + sy = np.array([0, 3, 1, 3, 0]) + expected = v[sx, sy] + np.testing.assert_allclose(got["out"].to_numpy(), expected) + + +def test_index_join_survives_float32_drift_where_float_join_does_not(): + ctx = _ctx(_weights(perturb_f32=True)) + idx = ctx.sql(INDEX_JOIN).to_pandas() + flt = ctx.sql(FLOAT_JOIN).to_pandas() + # Index join keeps every destination cell; the float-equality join drops the + # cells whose float32-roundtripped coord no longer bit-matches the source. + assert len(idx) == 5 + assert len(flt) < 5 diff --git a/xarray_sql/df.py b/xarray_sql/df.py index ab80056..cdd7f22 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -231,7 +231,10 @@ def pivot(ds: xr.Dataset) -> pd.DataFrame: def dataset_to_record_batch( - ds: xr.Dataset, schema: pa.Schema + ds: xr.Dataset, + schema: pa.Schema, + *, + index_offsets: Mapping[str, int] | None = None, ) -> pa.RecordBatch: """Convert an xarray Dataset partition to an Arrow RecordBatch. @@ -267,10 +270,24 @@ def dataset_to_record_batch( dim_names = list(ds.sizes.keys()) shape = tuple(ds.sizes[d] for d in dim_names) + offsets = index_offsets or {} + index_fields = { + index_column_name(d): (k, int(offsets.get(d, 0))) + for k, d in enumerate(dim_names) + } + arrays = [] for field in schema: name = field.name - if name in ds.coords and name in ds.dims: + if name in index_fields: + # Absolute integer grid position for this dim, broadcast+ravelled. + axis, offset = index_fields[name] + idx = offset + np.arange(shape[axis], dtype=np.int64) + reshape = [1] * len(shape) + reshape[axis] = shape[axis] + arr = np.broadcast_to(idx.reshape(reshape), shape).ravel() + arrays.append(pa.array(arr, type=field.type)) + elif name in ds.coords and name in ds.dims: # Broadcast 1-D coordinate to the full N-D partition shape, then ravel. axis = dim_names.index(name) coord = ds.coords[name].values @@ -302,6 +319,8 @@ def iter_record_batches( ds: xr.Dataset, schema: pa.Schema, batch_size: int = DEFAULT_BATCH_SIZE, + *, + index_offsets: Mapping[str, int] | None = None, ) -> Iterator[pa.RecordBatch]: """Yield RecordBatches of at most *batch_size* rows from a partition Dataset. @@ -350,11 +369,21 @@ def iter_record_batches( # Flat row index i → coordinate index for dim k: (i // stride[k]) % shape[k]. strides = [int(np.prod(shape[k + 1 :])) for k in range(len(shape))] + # Integer grid-position (`_idx`) columns, if the schema requests them: + # map the column name to its dim position and absolute block offset so the + # emitted index is global (partition-independent), which is what a join key + # across partitions requires. + offsets = index_offsets or {} + index_fields = { + index_column_name(d): (k, int(offsets.get(d, 0))) + for k, d in enumerate(dim_names) + } + # Load data-variable arrays fully (triggers Dask/Zarr compute once). # ravel() is a zero-copy view for C-contiguous arrays. data_arrays = {} for field in schema: - if field.name not in ds.dims: + if field.name not in ds.dims and field.name not in index_fields: raw = ds[field.name].values if cft.is_cftime(raw): data_arrays[field.name] = cft.convert_for_field(raw, field) @@ -368,7 +397,11 @@ def iter_record_batches( arrays = [] for field in schema: name = field.name - if name in ds.coords and name in ds.dims: + if name in index_fields: + k, offset = index_fields[name] + idx = offset + (row_idx // strides[k]) % shape[k] + arrays.append(pa.array(idx, type=field.type)) + elif name in ds.coords and name in ds.dims: k = dim_names.index(name) coord_idx = (row_idx // strides[k]) % shape[k] arrays.append( @@ -386,7 +419,16 @@ def iter_record_batches( yield pa.RecordBatch.from_arrays(arrays, schema=schema) -def _parse_schema(ds: xr.Dataset) -> pa.Schema: +#: Suffix for the integer grid-position column of a dimension. +INDEX_COLUMN_SUFFIX = "_idx" + + +def index_column_name(dim: str) -> str: + """Name of the integer grid-position column for dimension ``dim``.""" + return f"{dim}{INDEX_COLUMN_SUFFIX}" + + +def _parse_schema(ds: xr.Dataset, *, index_columns: bool = False) -> pa.Schema: """Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns. Only *dimension coordinates* become dimension columns, so a dimension @@ -394,6 +436,13 @@ def _parse_schema(ds: xr.Dataset) -> pa.Schema: :func:`_ensure_default_indexes` first (the readers do) so every dimension has a coordinate and appears as a column. + When ``index_columns`` is set, an ``int32`` ``_idx`` column is appended + for every dimension, carrying each row's absolute integer position along that + axis. These are exact integer keys for grid joins (regridding weight tables, + forecast alignment) — faster than, and free of the float-equality fragility + of, joining on the floating-point coordinate values, while the coordinate + columns remain available for value predicates and display. + Uses the xarray index type to detect cftime coordinates without materializing their data — important for Dask/Zarr-backed datasets where .values would trigger eager computation. @@ -431,6 +480,17 @@ def _parse_schema(ds: xr.Dataset) -> pa.Schema: pa_type = pa.from_numpy_dtype(var.dtype) columns.append(pa.field(var_name, pa_type)) + if index_columns: + existing = {f.name for f in columns} + for dim in ds.dims: + name = index_column_name(str(dim)) + if name in existing: + raise ValueError( + f"cannot add index column {name!r}: a column with that " + f"name already exists (dimension {dim!r})" + ) + columns.append(pa.field(name, pa.int32())) + return pa.schema(columns) diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index 153ed8c..c85800c 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -196,6 +196,7 @@ def read_xarray_table( chunks: Chunks = None, *, batch_size: int = DEFAULT_BATCH_SIZE, + index_columns: bool = False, coord_arrays: dict[str, np.ndarray] | None = None, _iteration_callback: ( Callable[[Block, list[str] | None], None] | None @@ -257,7 +258,7 @@ def read_xarray_table( from ._native import LazyArrowStreamTable ds = _ensure_default_indexes(ds) - schema = _parse_schema(ds) + schema = _parse_schema(ds, index_columns=index_columns) # Hoist coordinate reads once; avoids N_partitions remote I/O calls for # Zarr-backed datasets (e.g. ARCO-ERA5 on GCS). When the caller supplies @@ -281,15 +282,16 @@ def make_stream( if projection_names is not None: # Restrict to the data variables mentioned in the projection. - # Dimension coordinates come along automatically via coords. + # Dimension coordinates come along automatically via coords; + # index columns are computed from position, not loaded. data_vars_needed = [ c for c in projection_names if c in data_var_names ] if data_vars_needed: ds_block = ds[data_vars_needed].isel(block) else: - # Only dimension coords requested — drop all data vars to avoid - # loading them unnecessarily (e.g. for queries like SELECT lat, lon). + # Only dimension coords / index columns requested — drop all + # data vars to avoid loading them (e.g. SELECT lat, lon_idx). ds_block = ds.drop_vars(list(ds.data_vars)).isel(block) batch_schema = pa.schema( [schema.field(name) for name in projection_names] @@ -298,9 +300,17 @@ def make_stream( ds_block = ds.isel(block) batch_schema = schema + # Absolute start of this block on each axis, so `_idx` columns + # carry global positions that line up across partitions. + index_offsets = {str(dim): (block[dim].start or 0) for dim in block} return pa.RecordBatchReader.from_batches( batch_schema, - iter_record_batches(ds_block, batch_schema, batch_size), + iter_record_batches( + ds_block, + batch_schema, + batch_size, + index_offsets=index_offsets, + ), ) return make_stream diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..486aa63 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -29,6 +29,7 @@ def from_dataset( *, table_names: dict[tuple[str, ...], str] | None = None, chunks: Chunks = None, + index_columns: bool = False, ): """Register an xarray Dataset as one or more queryable SQL tables. @@ -83,6 +84,13 @@ def from_dataset( variables with differing dimensions. chunks: Xarray-like chunks specification. If not provided, uses the Dataset's existing chunks. + index_columns: When True, add an ``int32`` ``_idx`` column for + every dimension, carrying each row's absolute integer position + on that axis. These are exact, overflow-free integer keys for + grid joins (regridding weight tables, forecast alignment) — join + on them instead of the floating-point coordinates to avoid + float-equality mismatches, while the coordinate columns stay + available for value predicates and display. Returns: self, to allow chaining. @@ -99,7 +107,11 @@ def from_dataset( if len(groups) <= 1: self._registered_datasets[name] = input_table return self._from_dataset( - name, input_table, chunks, coord_arrays=coord_arrays + name, + input_table, + chunks, + coord_arrays=coord_arrays, + index_columns=index_columns, ) table_names = table_names or {} @@ -117,6 +129,7 @@ def from_dataset( chunks, schema=schema, coord_arrays=coord_arrays, + index_columns=index_columns, ) # Track the fully-qualified name so XarrayDataFrame metadata # recovery can find this Dataset on round-trip. @@ -131,6 +144,7 @@ def _from_dataset( chunks: Chunks = None, schema: Schema | None = None, coord_arrays: dict | None = None, + index_columns: bool = False, ): """Register a Dataset as a single SQL table. @@ -142,7 +156,12 @@ def _from_dataset( ) register( table_name, - read_xarray_table(input_table, chunks, coord_arrays=coord_arrays), + read_xarray_table( + input_table, + chunks, + index_columns=index_columns, + coord_arrays=coord_arrays, + ), ) self._maybe_register_cftime_udf(input_table) return self