From 9533b25cf745db131cf726937242f314d046d783 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 2 Jul 2026 16:26:35 +0000 Subject: [PATCH 1/4] Dictionary-encode coordinate columns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A dense grid repeats every coordinate value across the whole partition (a chunk of shape (time, lat, lon) carries each latitude time×lon times). The reader materialized coordinate columns as dense, fully-repeated Arrow arrays, so every GROUP BY / JOIN on a coordinate re-hashed a hugely redundant column and the pivot moved far more bytes than the data itself. Encode coordinate columns as Arrow dictionaries: the distinct coordinate values are the dictionary and the strided per-row indices we already compute are the dictionary indices — no broadcast of repeated values. The index type is sized to the dimension length (int8/int16/int32), so a 6-step time chunk uses 1 byte/row and a 721/1440-point lat/lon uses 2. On an ERA5-shaped chunk the coordinate columns shrink ~4.8x (and equality GROUP BY / JOIN keys become small integers). - df.py: _parse_schema declares dimension coordinates as dictionary(index, value) with the value type/metadata preserved; iter_record_batches and dataset_to_record_batch emit DictionaryArrays. - src/lib.rs: keep partition pruning and exact statistics working on the new encoding — DataFusion coerces a coordinate filter to either a Dictionary literal (timestamp) or Cast(col AS value_type) (float), so compare_to_scalar unwraps Dictionary scalars, the pruning matchers strip decode casts, and bound_to_scalar / total_byte_size unwrap the dictionary value type. - tests: pin the encoding contract and the group-by round-trip; update schema and memory-characterization expectations to the (smaller) encoded columns. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N --- src/lib.rs | 65 +++++++++++++++++++++++++++---- tests/test_cft.py | 11 ++++-- tests/test_df.py | 41 +++++++++++--------- tests/test_dict_coords.py | 74 ++++++++++++++++++++++++++++++++++++ xarray_sql/df.py | 80 +++++++++++++++++++++++++++++++++++---- 5 files changed, 234 insertions(+), 37 deletions(-) create mode 100644 tests/test_dict_coords.py diff --git a/src/lib.rs b/src/lib.rs index 63a5a6b..c09e004 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,13 @@ impl ScalarBound { /// Compare this bound with a DataFusion ScalarValue. /// Returns None if types are incompatible. fn compare_to_scalar(&self, scalar: &ScalarValue) -> Option { + // Coordinate columns are dictionary-encoded, so DataFusion coerces the + // comparison literal to a `Dictionary` scalar (the column stays a bare + // Column). Unwrap it to the underlying value and compare against that, + // so partition pruning keeps working on dictionary coordinates. + if let ScalarValue::Dictionary(_, value) = scalar { + return self.compare_to_scalar(value); + } match (self, scalar) { // Integer comparisons (ScalarBound::Int64(a), ScalarValue::Int64(Some(b))) => Some(a.cmp(b)), @@ -265,8 +272,11 @@ impl PrunableStreamingTable { right: &Expr, meta: &PartitionMetadata, ) -> bool { - // Try to extract column and literal from either side - let (col_name, scalar, flipped) = match (left, right) { + // Try to extract column and literal from either side. Coordinate + // columns are dictionary-encoded, so a filter may arrive as + // `Cast(col AS value_type) op literal`; strip the (lossless) decode cast + // to reach the column. + let (col_name, scalar, flipped) = match (strip_cast(left), strip_cast(right)) { (Expr::Column(c), Expr::Literal(s, _)) => (c.name.clone(), s, false), (Expr::Literal(s, _), Expr::Column(c)) => (c.name.clone(), s, true), _ => return false, // Not a simple column-literal comparison @@ -348,8 +358,8 @@ impl PrunableStreamingTable { return false; } - // Extract column name - let col_name = match between.expr.as_ref() { + // Extract column name (peeling any dictionary-decode cast) + let col_name = match strip_cast(between.expr.as_ref()) { Expr::Column(c) => c.name.clone(), _ => return false, }; @@ -387,8 +397,8 @@ impl PrunableStreamingTable { return false; } - // Extract column name - let col_name = match in_list.expr.as_ref() { + // Extract column name (peeling any dictionary-decode cast) + let col_name = match strip_cast(in_list.expr.as_ref()) { Expr::Column(c) => c.name.clone(), _ => return false, }; @@ -440,13 +450,33 @@ impl PrunableStreamingTable { /// Check if an expression references a dimension column. fn expr_references_dimension(&self, expr: &Expr) -> bool { - match expr { + match strip_cast(expr) { Expr::Column(c) => self.dimension_columns.contains(&c.name), _ => false, } } } +/// Peel `Cast`/`TryCast` wrappers to reach the inner expression. +/// +/// Dictionary-encoded coordinate columns are decoded with a lossless +/// `CAST(col AS value_type)` before a comparison, so a filter on a coordinate +/// can arrive as `Cast(Column) op literal`. Stripping the cast lets pruning see +/// the underlying column; the cast targets the dictionary's own value type, so +/// it never changes the compared value. If a cast were ever lossy, the bound +/// comparison would simply fail to match and the partition is conservatively +/// kept — pruning never becomes wrong, only (at worst) less effective. +fn strip_cast(expr: &Expr) -> &Expr { + let mut current = expr; + loop { + match current { + Expr::Cast(cast) => current = &cast.expr, + Expr::TryCast(cast) => current = &cast.expr, + _ => return current, + } + } +} + /// Extension trait for partition streams that support column projection. /// /// Implemented by `PyArrowStreamPartition` so that `PrunableStreamingTable` @@ -702,6 +732,12 @@ fn fold_bound(a: &ScalarBound, b: &ScalarBound, keep_min: bool) -> Option Option { + // Coordinate columns are dictionary-encoded; their min/max is the min/max of + // the underlying values, so report a value-type scalar (what DataFusion uses + // for dictionary-column statistics). + if let DataType::Dictionary(_, value_type) = dtype { + return bound_to_scalar(bound, value_type); + } match (bound, dtype) { (ScalarBound::Int64(v), DataType::Int64) => Some(ScalarValue::Int64(Some(*v))), (ScalarBound::Int64(v), DataType::Int32) => { @@ -736,6 +772,19 @@ fn bound_to_scalar(bound: &ScalarBound, dtype: &DataType) -> Option } } +/// Fixed byte width of one value of `dtype`, or `None` if it is variable-width. +/// +/// Dictionary-encoded coordinate columns are sized by their *value* type: the +/// statistic estimates the logical data volume (a conservative upper bound on +/// the encoded column's real footprint), and keeps the reported size stable +/// whether or not a column is dictionary-encoded. +fn fixed_width(dtype: &DataType) -> Option { + match dtype { + DataType::Dictionary(_, value_type) => fixed_width(value_type), + other => other.primitive_width(), + } +} + /// Exact in-memory byte size of `num_rows` rows of `schema`, or `Absent` if any /// column is variable-width (e.g. Utf8) and cannot be sized from the row count /// alone. Our data model is dense fixed-width grids, so this is normally exact. @@ -745,7 +794,7 @@ fn total_byte_size(schema: &Schema, num_rows: &Precision) -> Precision row_width += w, None => return Precision::Absent, } diff --git a/tests/test_cft.py b/tests/test_cft.py index 6560d72..baa4290 100644 --- a/tests/test_cft.py +++ b/tests/test_cft.py @@ -153,18 +153,23 @@ class TestParseSchemaIntegration: def test_noleap_produces_timestamp_us(self, rasm_ds): schema = _parse_schema(rasm_ds[["Tair"]]) time_field = schema.field("time") - assert time_field.type == pa.timestamp("us") + # Coordinate columns are dictionary-encoded; the cftime encoding lives + # in the dictionary's value type (and the field metadata is preserved). + assert pa.types.is_dictionary(time_field.type) + assert time_field.type.value_type == pa.timestamp("us") assert time_field.metadata[b"xarray:calendar"] == b"noleap" def test_360day_produces_int64(self, ds_360day): schema = _parse_schema(ds_360day) time_field = schema.field("time") - assert time_field.type == pa.int64() + assert pa.types.is_dictionary(time_field.type) + assert time_field.type.value_type == pa.int64() assert time_field.metadata[b"xarray:calendar"] == b"360_day" def test_datetime64_unchanged(self): ds = xr.tutorial.open_dataset("air_temperature") schema = _parse_schema(ds) time_field = schema.field("time") - assert pa.types.is_timestamp(time_field.type) + assert pa.types.is_dictionary(time_field.type) + assert pa.types.is_timestamp(time_field.type.value_type) assert time_field.metadata is None # no xarray: metadata for native diff --git a/tests/test_df.py b/tests/test_df.py index 2c81144..de57eba 100644 --- a/tests/test_df.py +++ b/tests/test_df.py @@ -198,6 +198,7 @@ def test_dataset_to_record_batch_matches_pivot(air_small): we sort by the coordinate columns before comparing. """ schema = _parse_schema(air_small) + schema_cols = [f.name for f in schema] dim_cols = [f.name for f in schema if f.name in air_small.dims] blocks = list( block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}) @@ -205,18 +206,20 @@ def test_dataset_to_record_batch_matches_pivot(air_small): for block in blocks: ds_block = air_small.isel(block) - actual_df = ( - dataset_to_record_batch(ds_block, schema) - .to_pandas() - .sort_values(dim_cols) - .reset_index(drop=True) - ) expected_df = ( - pa.RecordBatch.from_pandas(pivot(ds_block), schema=schema) - .to_pandas() + pivot(ds_block)[schema_cols] .sort_values(dim_cols) .reset_index(drop=True) ) + actual_df = dataset_to_record_batch(ds_block, schema).to_pandas()[ + schema_cols + ] + # Coordinate columns are dictionary-encoded, so they come back from + # Arrow as pandas Categorical; decode them to the plain dtype pivot uses + # before comparing values. + for col in dim_cols: + actual_df[col] = actual_df[col].astype(expected_df[col].dtype) + actual_df = actual_df.sort_values(dim_cols).reset_index(drop=True) pd.testing.assert_frame_equal(actual_df, expected_df, check_like=False) @@ -403,19 +406,21 @@ def test_read_xarray_loads_one_chunk_at_a_time(large_ds): peaks.append(cur_peak) for size in sizes: - # Observed range: 1.59–1.83× on macOS, up to ~2.7× on Linux - # (glibc + Arrow allocate more intermediate buffers). - # iter_record_batches holds data-variable arrays (≈1× chunk) while - # yielding sub-batches, plus the current Arrow batch (≈0.65× chunk). - assert chunk_size * 1.3 < size, f"size {size} unexpectedly low" + # iter_record_batches holds the data-variable arrays (≈1× chunk) + # while yielding sub-batches, plus the current Arrow batch. The + # batch's coordinate columns are dictionary-encoded (only the + # distinct values plus small int32 indices), so they add far less + # than a broadcast column would — steady state sits just above 1×. + assert chunk_size * 1.1 < size, f"size {size} unexpectedly low" assert chunk_size * 3.5 > size, f"size {size} unexpectedly high" for peak in peaks: - # Observed range: 1.84–3.28× on macOS, up to ~4.15× on Linux - # (glibc + Arrow hold more intermediate buffers at peak). - # Peak includes data arrays + Arrow batch + temporary coordinate index - # arrays; the first batch of each chunk is highest (Dask compute overhead). - assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low" + # Peak includes the data arrays + the current Arrow batch + + # temporary coordinate index arrays; the first batch of each chunk + # is highest (Dask compute overhead). Dictionary-encoded coordinate + # columns keep the batch (and so the peak) smaller than a broadcast + # column would. + assert chunk_size * 1.3 < peak, f"peak {peak} unexpectedly low" assert chunk_size * 5.0 > peak, f"peak {peak} unexpectedly high" assert max(peaks) < large_ds.nbytes diff --git a/tests/test_dict_coords.py b/tests/test_dict_coords.py new file mode 100644 index 0000000..5c2d49c --- /dev/null +++ b/tests/test_dict_coords.py @@ -0,0 +1,74 @@ +"""Coordinate columns are dictionary-encoded. + +A dense grid repeats every coordinate value across the whole partition (a chunk +of shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times). +Encoding coordinate columns as Arrow dictionaries keeps only the distinct values +plus small integer indices, which shrinks the bytes the engine moves and lets +``GROUP BY`` / equality ``JOIN`` on coordinates compare integer keys. These +tests pin that the coordinates are dictionary-encoded end to end and that the +values still round-trip correctly. +""" + +import numpy as np +import pyarrow as pa +import xarray as xr + +from xarray_sql import XarrayContext +from xarray_sql.df import _parse_schema, block_slices, iter_record_batches + + +def _grid() -> xr.Dataset: + return xr.Dataset( + {"v": (("time", "lat", "lon"), np.random.rand(6, 4, 5))}, + coords={ + "time": xr.date_range("2020-01-01", periods=6, freq="D"), + "lat": np.array([-90.0, -30.0, 30.0, 90.0]), + "lon": np.array([0.0, 72.0, 144.0, 216.0, 288.0]), + }, + ) + + +def test_coordinate_fields_are_dictionary_encoded(): + """Dimension coordinates are dictionary-typed; data variables are not.""" + schema = _parse_schema(_grid()) + for dim in ("time", "lat", "lon"): + assert pa.types.is_dictionary(schema.field(dim).type), dim + assert not pa.types.is_dictionary(schema.field("v").type) + + +def test_iter_record_batches_emits_dictionary_coords(): + """A coordinate column arrives as a DictionaryArray with correct values.""" + ds = _grid() + schema = _parse_schema(ds) + block = next(block_slices(ds, chunks={"time": 6})) + batch = next(iter_record_batches(ds.isel(block), schema, batch_size=1024)) + + lat = batch.column(batch.schema.names.index("lat")) + assert isinstance(lat, pa.DictionaryArray) + # The dictionary holds only the distinct latitudes; decoding reproduces the + # per-row values in row-major order. + assert lat.dictionary.to_pylist() == [-90.0, -30.0, 30.0, 90.0] + decoded = lat.to_numpy(zero_copy_only=False) + assert decoded.shape == (batch.num_rows,) + # First lon×... rows share lat=-90 (lat varies slower than lon in C order). + assert decoded[0] == -90.0 + + # Data variables stay plain (not dictionary-encoded). + v = batch.column(batch.schema.names.index("v")) + assert not isinstance(v, pa.DictionaryArray) + + +def test_groupby_coordinate_roundtrips_through_dictionary(): + """GROUP BY on a dictionary coordinate returns the same numbers as xarray.""" + ds = _grid() + ctx = XarrayContext() + ctx.from_dataset("grid", ds, chunks={"time": 3}) + + got = ctx.sql( + "SELECT lat, AVG(v) AS m FROM grid GROUP BY lat ORDER BY lat" + ).to_dataset(dims=["lat"]) + ref = ds["v"].mean(["time", "lon"]) + + xr.testing.assert_allclose( + got.m, ref.reindex_like(got.m), rtol=1e-6, atol=1e-9 + ) diff --git a/xarray_sql/df.py b/xarray_sql/df.py index ab80056..0d091ea 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -271,15 +271,29 @@ def dataset_to_record_batch( for field in schema: name = field.name if name in ds.coords and name in ds.dims: - # Broadcast 1-D coordinate to the full N-D partition shape, then ravel. axis = dim_names.index(name) coord = ds.coords[name].values if cft.is_cftime(coord): coord = cft.convert_for_field(coord, field) reshape = [1] * len(shape) reshape[axis] = coord.shape[0] - arr = np.broadcast_to(coord.reshape(reshape), shape).ravel() - arrays.append(pa.array(arr, type=field.type)) + if pa.types.is_dictionary(field.type): + # Dictionary-encode: the distinct coordinate values are the + # dictionary, and the per-row indices are the coordinate's own + # axis index broadcast across the partition — no dense array of + # repeated values is ever built. + idx = np.broadcast_to( + np.arange(coord.shape[0]).reshape(reshape), shape + ).ravel() + dictionary = pa.array(coord, type=field.type.value_type) + indices = pa.array(idx, type=field.type.index_type) + arrays.append( + pa.DictionaryArray.from_arrays(indices, dictionary) + ) + else: + # Broadcast 1-D coordinate to the full partition shape, ravel. + arr = np.broadcast_to(coord.reshape(reshape), shape).ravel() + arrays.append(pa.array(arr, type=field.type)) else: # Data variable: ravel to 1-D (zero-copy for C-contiguous arrays). raw = ds[name].values.ravel() @@ -371,9 +385,22 @@ def iter_record_batches( if name in ds.coords and name in ds.dims: k = dim_names.index(name) coord_idx = (row_idx // strides[k]) % shape[k] - arrays.append( - pa.array(coord_values[name][coord_idx], type=field.type) - ) + if pa.types.is_dictionary(field.type): + # Emit the coordinate as a dictionary array: the distinct + # coordinate values are the dictionary, and the strided + # per-row indices we just computed are exactly the + # dictionary indices — no broadcast of repeated values. + dictionary = pa.array( + coord_values[name], type=field.type.value_type + ) + indices = pa.array(coord_idx, type=field.type.index_type) + arrays.append( + pa.DictionaryArray.from_arrays(indices, dictionary) + ) + else: + arrays.append( + pa.array(coord_values[name][coord_idx], type=field.type) + ) else: arrays.append( pa.array( @@ -386,6 +413,40 @@ def iter_record_batches( yield pa.RecordBatch.from_arrays(arrays, schema=schema) +def _coord_index_type(n_values: int) -> pa.DataType: + """Smallest signed Arrow int type that indexes ``n_values`` dictionary keys. + + A coordinate's dictionary indices range over its dimension length, so the + index width is chosen from the full dimension size (a safe upper bound on any + partition's cardinality). Narrower indices mean fewer bytes moved and cheaper + group/join hashing: a 721-point latitude or 1440-point longitude fits int16, + while a multi-million-step time axis needs int32. + """ + if n_values <= np.iinfo(np.int8).max: + return pa.int8() + if n_values <= np.iinfo(np.int16).max: + return pa.int16() + return pa.int32() + + +def _as_dictionary_field(field: pa.Field, n_values: int) -> pa.Field: + """Wrap a coordinate field's value type in a dictionary encoding. + + Coordinate columns repeat each value across the whole grid — a chunk of + shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times. A + dictionary array stores only the distinct values plus small integer indices, + so this shrinks the bytes the query engine moves for coordinate columns and + lets ``GROUP BY`` / equality ``JOIN`` on coordinates compare integer keys + instead of rehashing repeated floats. The index type is sized to the + dimension length (see :func:`_coord_index_type`). The dictionary's value type + and the field's name, nullability, and metadata (e.g. cftime units/calendar) + are preserved, so downstream schema handling is unchanged apart from the + encoding. + """ + index_type = _coord_index_type(n_values) + return field.with_type(pa.dictionary(index_type, field.type)) + + def _parse_schema(ds: xr.Dataset) -> pa.Schema: """Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns. @@ -413,10 +474,13 @@ def _parse_schema(ds: xr.Dataset) -> pa.Schema: if coord_name in ds.dims: if cft.is_cftime_index(ds, coord_name): units, calendar = cft.encoding(ds, coord_name) - columns.append(cft.arrow_field(coord_name, units, calendar)) + field = cft.arrow_field(coord_name, units, calendar) else: pa_type = pa.from_numpy_dtype(coord_var.dtype) - columns.append(pa.field(coord_name, pa_type)) + field = pa.field(coord_name, pa_type) + # Dictionary-encode the coordinate column (see _as_dictionary_field), + # sizing the index type to the dimension length. + columns.append(_as_dictionary_field(field, ds.sizes[coord_name])) for var_name, var in ds.data_vars.items(): # Data variables are virtually never cftime, but check dtype as a From f08d9605bb7a6bcc810711cbcb2980780319e22d Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 2 Jul 2026 17:07:18 +0000 Subject: [PATCH 2/4] Add int64 fallback for dictionary index width MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cap coordinate dictionary indices at the correct width for any cardinality: the selection now tiers int8 → int16 → int32 → int64 at exact `MAX + 1` boundaries (indices run 0..n-1, so a signed max M holds n = M+1). The int64 fallback keeps astronomically large coordinate axes representable instead of silently overflowing a 32-bit index. Follows the adaptive-key-width note in jayendra13's zarr-datafusion. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N --- tests/test_dict_coords.py | 24 +++++++++++++++++++++++- xarray_sql/df.py | 13 ++++++++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/test_dict_coords.py b/tests/test_dict_coords.py index 5c2d49c..7523056 100644 --- a/tests/test_dict_coords.py +++ b/tests/test_dict_coords.py @@ -11,10 +11,32 @@ import numpy as np import pyarrow as pa +import pytest import xarray as xr from xarray_sql import XarrayContext -from xarray_sql.df import _parse_schema, block_slices, iter_record_batches +from xarray_sql.df import ( + _coord_index_type, + _parse_schema, + block_slices, + iter_record_batches, +) + + +@pytest.mark.parametrize( + "n_values, expected", + [ + (1, pa.int8()), + (128, pa.int8()), # int8 holds indices 0..127 + (129, pa.int16()), + (32768, pa.int16()), # int16 holds indices 0..32767 + (32769, pa.int32()), + (2**31, pa.int32()), # int32 holds indices 0..2**31-1 + (2**31 + 1, pa.int64()), # fallback keeps huge axes representable + ], +) +def test_coord_index_type_boundaries(n_values, expected): + assert _coord_index_type(n_values) == expected def _grid() -> xr.Dataset: diff --git a/xarray_sql/df.py b/xarray_sql/df.py index 0d091ea..dd3dd49 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -421,12 +421,19 @@ def _coord_index_type(n_values: int) -> pa.DataType: partition's cardinality). Narrower indices mean fewer bytes moved and cheaper group/join hashing: a 721-point latitude or 1440-point longitude fits int16, while a multi-million-step time axis needs int32. + + Indices run ``0 .. n_values - 1``, so a signed type with maximum ``M`` holds + a cardinality of ``M + 1``. The ``int64`` fallback keeps astronomically large + coordinate axes representable rather than silently overflowing a 32-bit index + (cf. the adaptive-key-width note in jayendra13's zarr-datafusion). """ - if n_values <= np.iinfo(np.int8).max: + if n_values <= np.iinfo(np.int8).max + 1: return pa.int8() - if n_values <= np.iinfo(np.int16).max: + if n_values <= np.iinfo(np.int16).max + 1: return pa.int16() - return pa.int32() + if n_values <= np.iinfo(np.int32).max + 1: + return pa.int32() + return pa.int64() def _as_dictionary_field(field: pa.Field, n_values: int) -> pa.Field: From 13823c975315b5b6ef8beec4d247c9b3e4ccf5b5 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 2 Jul 2026 20:43:09 +0000 Subject: [PATCH 3/4] Only dictionary-encode coordinates where it is safe and worthwhile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A narrow dictionary key (int8/int16) can overflow under DataFusion streaming aggregation: it concatenates the per-batch coordinate dictionaries across the aggregate and does not always unify them (arrow merges dictionary values on a size heuristic, not a guarantee), so the combined index for an unchunked coordinate repeated across N partitions can reach card × N and blow past the key type — "Dictionary key bigger than the key type" (reported by @ghostiee-11 on a float32 GROUP BY lat, lon). Make the encoding overflow-proof and only apply it where it pays off: - `_coord_index_type` floors the key at int32 (~2.1B combined entries covers any realistic grid; int64 backstops the rest). int8/int16 are gone. - `_as_dictionary_field` encodes a coordinate only when the int32 key is strictly narrower than the value type: 8-byte float64/int64/timestamp coordinates (a safe 2x) and variable-width strings, leaving 4-byte float32/int32 coordinates dense (where a dictionary is pure overhead and the only way to win would be an overflow-prone narrow key). - `total_byte_size` reports Inexact when any column is dictionary-encoded, since it sizes those by the value type (a safe upper bound) not the narrower index — honest now that the index is smaller (addresses the "not exact" note). Regression test: float32 GROUP BY lat, lon over 100 partitions matches xarray and no longer overflows. Stats byte-size expectation updated to Inexact. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N --- src/lib.rs | 28 ++++++++++---- tests/test_dict_coords.py | 79 +++++++++++++++++++++++++++++++++------ tests/test_stats.py | 6 ++- xarray_sql/df.py | 56 ++++++++++++++------------- 4 files changed, 122 insertions(+), 47 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c09e004..0bce6c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -774,10 +774,10 @@ fn bound_to_scalar(bound: &ScalarBound, dtype: &DataType) -> Option /// Fixed byte width of one value of `dtype`, or `None` if it is variable-width. /// -/// Dictionary-encoded coordinate columns are sized by their *value* type: the -/// statistic estimates the logical data volume (a conservative upper bound on -/// the encoded column's real footprint), and keeps the reported size stable -/// whether or not a column is dictionary-encoded. +/// A dictionary-encoded column is sized by its *value* type — a safe upper bound +/// on the encoded column's real footprint (the actual indices are narrower). We +/// deliberately over- rather than under-count so a memory-based rule never +/// under-provisions; the caller marks the total `Inexact` to reflect it. fn fixed_width(dtype: &DataType) -> Option { match dtype { DataType::Dictionary(_, value_type) => fixed_width(value_type), @@ -785,21 +785,33 @@ fn fixed_width(dtype: &DataType) -> Option { } } -/// Exact in-memory byte size of `num_rows` rows of `schema`, or `Absent` if any -/// column is variable-width (e.g. Utf8) and cannot be sized from the row count -/// alone. Our data model is dense fixed-width grids, so this is normally exact. +/// In-memory byte size of `num_rows` rows of `schema`. +/// +/// `Absent` if any column is variable-width (e.g. Utf8). `Exact` for a plain +/// fixed-width grid; `Inexact` when any column is dictionary-encoded, since we +/// size those by the value type (a safe upper bound) rather than the narrower +/// index — an honest label now that coordinates can be dictionary-encoded. fn total_byte_size(schema: &Schema, num_rows: &Precision) -> Precision { let Precision::Exact(rows) = num_rows else { return Precision::Absent; }; let mut row_width = 0usize; + let mut any_dictionary = false; for field in schema.fields() { + if matches!(field.data_type(), DataType::Dictionary(_, _)) { + any_dictionary = true; + } match fixed_width(field.data_type()) { Some(w) => row_width += w, None => return Precision::Absent, } } - Precision::Exact(rows.saturating_mul(row_width)) + let total = rows.saturating_mul(row_width); + if any_dictionary { + Precision::Inexact(total) + } else { + Precision::Exact(total) + } } /// Build `Statistics` for a scan over the given partitions. diff --git a/tests/test_dict_coords.py b/tests/test_dict_coords.py index 7523056..4687183 100644 --- a/tests/test_dict_coords.py +++ b/tests/test_dict_coords.py @@ -1,12 +1,15 @@ -"""Coordinate columns are dictionary-encoded. +"""Coordinate columns are dictionary-encoded when it is worthwhile and safe. A dense grid repeats every coordinate value across the whole partition (a chunk of shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times). -Encoding coordinate columns as Arrow dictionaries keeps only the distinct values -plus small integer indices, which shrinks the bytes the engine moves and lets -``GROUP BY`` / equality ``JOIN`` on coordinates compare integer keys. These -tests pin that the coordinates are dictionary-encoded end to end and that the -values still round-trip correctly. +Encoding a coordinate as an Arrow dictionary keeps only the distinct values plus +an int32 index, shrinking the bytes the engine moves. We encode only when the +int32 key is strictly narrower than the value (8-byte float64/int64/timestamp +coordinates, and variable-width strings) and leave 4-byte float32/int32 +coordinates dense — a narrower key would be needed to win there but overflows +under DataFusion's cross-batch dictionary concatenation. These tests pin which +coordinates get encoded, that the values round-trip correctly, and that the +overflow case no longer crashes. """ import numpy as np @@ -26,16 +29,15 @@ @pytest.mark.parametrize( "n_values, expected", [ - (1, pa.int8()), - (128, pa.int8()), # int8 holds indices 0..127 - (129, pa.int16()), - (32768, pa.int16()), # int16 holds indices 0..32767 - (32769, pa.int32()), + (1, pa.int32()), (2**31, pa.int32()), # int32 holds indices 0..2**31-1 (2**31 + 1, pa.int64()), # fallback keeps huge axes representable ], ) def test_coord_index_type_boundaries(n_values, expected): + # int32 is the narrowest key we use: a narrower one can overflow when + # DataFusion concatenates per-batch dictionaries across a streaming + # aggregate (see test_float32_groupby_many_partitions_no_overflow). assert _coord_index_type(n_values) == expected @@ -51,13 +53,66 @@ def _grid() -> xr.Dataset: def test_coordinate_fields_are_dictionary_encoded(): - """Dimension coordinates are dictionary-typed; data variables are not.""" + """8-byte dimension coordinates are dictionary-typed; data variables are not. + + ``_grid`` uses float64/datetime coordinates (8 bytes), which are wider than + the int32 key, so they are encoded. + """ schema = _parse_schema(_grid()) for dim in ("time", "lat", "lon"): assert pa.types.is_dictionary(schema.field(dim).type), dim assert not pa.types.is_dictionary(schema.field("v").type) +def test_float32_coordinates_stay_dense(): + """4-byte coordinates are left dense: an int32 key is no narrower than the + value, so a dictionary would be pure overhead (and a narrower key is unsafe). + """ + ds = xr.Dataset( + {"v": (("lat", "lon"), np.zeros((3, 4), dtype="float32"))}, + coords={ + "lat": np.array([-90.0, 0.0, 90.0], dtype="float32"), + "lon": np.arange(4, dtype="int32"), + }, + ) + schema = _parse_schema(ds) + assert schema.field("lat").type == pa.float32() + assert schema.field("lon").type == pa.int32() + + +def test_float32_groupby_many_partitions_no_overflow(): + """Regression: GROUP BY on float32 coordinates over many partitions must not + overflow a dictionary key. + + With narrow keys, DataFusion concatenating the per-partition coordinate + dictionaries across the aggregate overflowed the key type ("Dictionary key + bigger than the key type"). float32 coordinates are now dense, so there is no + coordinate dictionary to concatenate, and the aggregate matches xarray. + """ + ds = xr.Dataset( + { + "air": ( + ("time", "lat", "lon"), + np.random.rand(600, 20, 30).astype("float32"), + ) + }, + coords={ + "time": np.arange(600), + "lat": np.linspace(-90, 90, 20).astype("float32"), + "lon": np.linspace(0, 359, 30).astype("float32"), + }, + ) + ctx = XarrayContext() + ctx.from_dataset("air", ds, chunks={"time": 6}) # 100 partitions + got = ctx.sql( + "SELECT lat, lon, AVG(air) AS m FROM air GROUP BY lat, lon" + ).to_dataset(dims=["lat", "lon"]) + ref = ds["air"].mean("time") + xr.testing.assert_allclose( + got.m, ref.reindex_like(got.m), rtol=1e-5, atol=1e-6 + ) + + def test_iter_record_batches_emits_dictionary_coords(): """A coordinate column arrives as a DictionaryArray with correct values.""" ds = _grid() diff --git a/tests/test_stats.py b/tests/test_stats.py index a1f199e..e51fe49 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -48,8 +48,10 @@ def test_exact_byte_size_in_scan_statistics(): ctx = XarrayContext() ctx.from_dataset("air", ds, chunks={"time": 50}) plan = _explain(ctx, "SELECT lat, lon, air FROM air") - # 2000 rows x (lat int64 + lon int64 + air float64) = 2000 x 24 bytes. - assert f"Bytes=Exact({100 * 4 * 5 * 24})" in plan + # 2000 rows x (lat int64 + lon int64 + air float64) = 2000 x 24 bytes, sized + # by value type. lat/lon are dictionary-encoded (int32 keys), so the value + # sizing is a safe upper bound and the total is reported Inexact. + assert f"Bytes=Inexact({100 * 4 * 5 * 24})" in plan def test_dimension_column_min_max_in_scan_statistics(): diff --git a/xarray_sql/df.py b/xarray_sql/df.py index dd3dd49..df6a227 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -414,43 +414,49 @@ def iter_record_batches( def _coord_index_type(n_values: int) -> pa.DataType: - """Smallest signed Arrow int type that indexes ``n_values`` dictionary keys. - - A coordinate's dictionary indices range over its dimension length, so the - index width is chosen from the full dimension size (a safe upper bound on any - partition's cardinality). Narrower indices mean fewer bytes moved and cheaper - group/join hashing: a 721-point latitude or 1440-point longitude fits int16, - while a multi-million-step time axis needs int32. - - Indices run ``0 .. n_values - 1``, so a signed type with maximum ``M`` holds - a cardinality of ``M + 1``. The ``int64`` fallback keeps astronomically large - coordinate axes representable rather than silently overflowing a 32-bit index - (cf. the adaptive-key-width note in jayendra13's zarr-datafusion). + """Signed Arrow int type for a coordinate dictionary of ``n_values`` keys. + + The index must survive DataFusion *concatenating* per-batch dictionaries + across a scan. Those concatenations are not always unified (arrow merges + dictionary values on a size heuristic, not a guarantee), so the *combined* + dictionary an operator sees can be many times a single partition's + cardinality — an unchunked coordinate repeated across N partitions can reach + ``card × N``. A narrow key (int8/int16) therefore overflows under streaming + aggregation ("Dictionary key bigger than the key type"), so we do not use + one. ``int32`` covers ~2.1B combined entries (any realistic grid), and + ``int64`` backstops the astronomically-large case. Indices run + ``0 .. n_values - 1``, so a signed max ``M`` holds a cardinality of ``M + 1``. """ - if n_values <= np.iinfo(np.int8).max + 1: - return pa.int8() - if n_values <= np.iinfo(np.int16).max + 1: - return pa.int16() if n_values <= np.iinfo(np.int32).max + 1: return pa.int32() return pa.int64() def _as_dictionary_field(field: pa.Field, n_values: int) -> pa.Field: - """Wrap a coordinate field's value type in a dictionary encoding. + """Dictionary-encode a coordinate field when it is worthwhile and safe. Coordinate columns repeat each value across the whole grid — a chunk of shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times. A - dictionary array stores only the distinct values plus small integer indices, - so this shrinks the bytes the query engine moves for coordinate columns and - lets ``GROUP BY`` / equality ``JOIN`` on coordinates compare integer keys - instead of rehashing repeated floats. The index type is sized to the - dimension length (see :func:`_coord_index_type`). The dictionary's value type - and the field's name, nullability, and metadata (e.g. cftime units/calendar) - are preserved, so downstream schema handling is unchanged apart from the - encoding. + dictionary stores only the distinct values plus integer indices, shrinking + the bytes the engine moves and letting ``GROUP BY`` / equality ``JOIN`` + compare integer keys instead of rehashing repeated floats. + + We only encode when the index type (see :func:`_coord_index_type`) is + *strictly narrower* than the value type. That keeps 8-byte coordinates + (``float64``, ``int64``, timestamps) — a 2× win with an overflow-safe int32 + key — while leaving 4-byte ``float32`` / ``int32`` coordinates dense, where a + dictionary would be pure overhead and a narrower (overflow-prone) key would + be the only way to win. Variable-width value types (e.g. strings) are always + encoded — the classic dictionary win. The field's name, nullability, and + metadata (e.g. cftime units/calendar) are preserved. """ index_type = _coord_index_type(n_values) + try: + value_bit_width = field.type.bit_width + except (ValueError, NotImplementedError): + value_bit_width = None # variable-width (e.g. string): always encode + if value_bit_width is not None and index_type.bit_width >= value_bit_width: + return field return field.with_type(pa.dictionary(index_type, field.type)) From 44087edf11f8b03badd65084b0328e996eb67c10 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 2 Jul 2026 20:49:50 +0000 Subject: [PATCH 4/4] Decode dictionary columns to plain values in to_pandas Coordinate columns are dictionary-encoded internally; DataFusion surfaces those as pandas Categorical, which sorts by category order and trips dtype checks. Decode them back to their value dtype at the to_pandas boundary so callers see the same plain columns as before the encoding. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N --- tests/test_ds.py | 18 ++++++++++++++++-- xarray_sql/ds.py | 15 +++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/tests/test_ds.py b/tests/test_ds.py index aa3deb2..43627d8 100644 --- a/tests/test_ds.py +++ b/tests/test_ds.py @@ -37,14 +37,28 @@ def test_ctx_sql_returns_xarray_dataframe(air_dataset_small): assert isinstance(result, XarrayDataFrame) -def test_to_pandas_unchanged_behavior(air_dataset_small): - """Wrapped ``.to_pandas()`` is bit-for-bit equal to the un-wrapped path.""" +def test_to_pandas_decodes_dictionary_columns(air_dataset_small): + """Wrapped ``.to_pandas()`` surfaces coordinates as plain values. + + Coordinate columns are dictionary-encoded internally, which the raw + DataFusion ``.to_pandas()`` returns as pandas ``Categorical``. The wrapper + decodes those back to their value dtype so callers see the same plain + columns as before the encoding — no ``Categorical`` leaks, and the values + match the raw path once its categoricals are decoded. + """ from datafusion import SessionContext ctx = XarrayContext() ctx.from_dataset("air", air_dataset_small) wrapped = ctx.sql("SELECT * FROM air LIMIT 7").to_pandas() raw = SessionContext.sql(ctx, "SELECT * FROM air LIMIT 7").to_pandas() + + assert not any( + isinstance(wrapped[c].dtype, pd.CategoricalDtype) for c in wrapped + ) + for c in raw.columns: + if isinstance(raw[c].dtype, pd.CategoricalDtype): + raw[c] = raw[c].astype(raw[c].cat.categories.dtype) pd.testing.assert_frame_equal(wrapped, raw) diff --git a/xarray_sql/ds.py b/xarray_sql/ds.py index 5dfdf42..e8b3a09 100644 --- a/xarray_sql/ds.py +++ b/xarray_sql/ds.py @@ -742,8 +742,19 @@ def __init__( object.__setattr__(self, "_templates", dict(templates or {})) def to_pandas(self) -> pd.DataFrame: - """Materialize the result as a ``pd.DataFrame`` (DataFusion API).""" - return self._inner.to_pandas() + """Materialize the result as a ``pd.DataFrame`` (DataFusion API). + + Coordinate columns are dictionary-encoded internally, which DataFusion + surfaces as pandas ``Categorical``. Decode those back to their value + dtype so callers see the same plain columns as before the encoding — + results sort and compare by value, not by category order. + """ + df = self._inner.to_pandas() + for name in df.columns: + dtype = df[name].dtype + if isinstance(dtype, pd.CategoricalDtype): + df[name] = df[name].astype(dtype.categories.dtype) + return df def to_dataset( self,