Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 73 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ impl ScalarBound {
/// Compare this bound with a DataFusion ScalarValue.
/// Returns None if types are incompatible.
fn compare_to_scalar(&self, scalar: &ScalarValue) -> Option<std::cmp::Ordering> {
// Coordinate columns are dictionary-encoded, so DataFusion coerces the
// comparison literal to a `Dictionary` scalar (the column stays a bare
// Column). Unwrap it to the underlying value and compare against that,
// so partition pruning keeps working on dictionary coordinates.
if let ScalarValue::Dictionary(_, value) = scalar {
return self.compare_to_scalar(value);
}
match (self, scalar) {
// Integer comparisons
(ScalarBound::Int64(a), ScalarValue::Int64(Some(b))) => Some(a.cmp(b)),
Expand Down Expand Up @@ -265,8 +272,11 @@ impl PrunableStreamingTable {
right: &Expr,
meta: &PartitionMetadata,
) -> bool {
// Try to extract column and literal from either side
let (col_name, scalar, flipped) = match (left, right) {
// Try to extract column and literal from either side. Coordinate
// columns are dictionary-encoded, so a filter may arrive as
// `Cast(col AS value_type) op literal`; strip the (lossless) decode cast
// to reach the column.
let (col_name, scalar, flipped) = match (strip_cast(left), strip_cast(right)) {
(Expr::Column(c), Expr::Literal(s, _)) => (c.name.clone(), s, false),
(Expr::Literal(s, _), Expr::Column(c)) => (c.name.clone(), s, true),
_ => return false, // Not a simple column-literal comparison
Expand Down Expand Up @@ -348,8 +358,8 @@ impl PrunableStreamingTable {
return false;
}

// Extract column name
let col_name = match between.expr.as_ref() {
// Extract column name (peeling any dictionary-decode cast)
let col_name = match strip_cast(between.expr.as_ref()) {
Expr::Column(c) => c.name.clone(),
_ => return false,
};
Expand Down Expand Up @@ -387,8 +397,8 @@ impl PrunableStreamingTable {
return false;
}

// Extract column name
let col_name = match in_list.expr.as_ref() {
// Extract column name (peeling any dictionary-decode cast)
let col_name = match strip_cast(in_list.expr.as_ref()) {
Expr::Column(c) => c.name.clone(),
_ => return false,
};
Expand Down Expand Up @@ -440,13 +450,33 @@ impl PrunableStreamingTable {

/// Check if an expression references a dimension column.
fn expr_references_dimension(&self, expr: &Expr) -> bool {
match expr {
match strip_cast(expr) {
Expr::Column(c) => self.dimension_columns.contains(&c.name),
_ => false,
}
}
}

/// Peel `Cast`/`TryCast` wrappers to reach the inner expression.
///
/// Dictionary-encoded coordinate columns are decoded with a lossless
/// `CAST(col AS value_type)` before a comparison, so a filter on a coordinate
/// can arrive as `Cast(Column) op literal`. Stripping the cast lets pruning see
/// the underlying column; the cast targets the dictionary's own value type, so
/// it never changes the compared value. If a cast were ever lossy, the bound
/// comparison would simply fail to match and the partition is conservatively
/// kept — pruning never becomes wrong, only (at worst) less effective.
fn strip_cast(expr: &Expr) -> &Expr {
let mut current = expr;
loop {
match current {
Expr::Cast(cast) => current = &cast.expr,
Expr::TryCast(cast) => current = &cast.expr,
_ => return current,
}
}
}

/// Extension trait for partition streams that support column projection.
///
/// Implemented by `PyArrowStreamPartition` so that `PrunableStreamingTable`
Expand Down Expand Up @@ -702,6 +732,12 @@ fn fold_bound(a: &ScalarBound, b: &ScalarBound, keep_min: bool) -> Option<Scalar
/// unit we can't scale exactly), in which case the column is left without
/// min/max rather than risk a wrong value.
fn bound_to_scalar(bound: &ScalarBound, dtype: &DataType) -> Option<ScalarValue> {
// Coordinate columns are dictionary-encoded; their min/max is the min/max of
// the underlying values, so report a value-type scalar (what DataFusion uses
// for dictionary-column statistics).
if let DataType::Dictionary(_, value_type) = dtype {
return bound_to_scalar(bound, value_type);
}
match (bound, dtype) {
(ScalarBound::Int64(v), DataType::Int64) => Some(ScalarValue::Int64(Some(*v))),
(ScalarBound::Int64(v), DataType::Int32) => {
Expand Down Expand Up @@ -736,21 +772,46 @@ fn bound_to_scalar(bound: &ScalarBound, dtype: &DataType) -> Option<ScalarValue>
}
}

/// Exact in-memory byte size of `num_rows` rows of `schema`, or `Absent` if any
/// column is variable-width (e.g. Utf8) and cannot be sized from the row count
/// alone. Our data model is dense fixed-width grids, so this is normally exact.
/// Fixed byte width of one value of `dtype`, or `None` if it is variable-width.
///
/// A dictionary-encoded column is sized by its *value* type — a safe upper bound
/// on the encoded column's real footprint (the actual indices are narrower). We
/// deliberately over- rather than under-count so a memory-based rule never
/// under-provisions; the caller marks the total `Inexact` to reflect it.
fn fixed_width(dtype: &DataType) -> Option<usize> {
match dtype {
DataType::Dictionary(_, value_type) => fixed_width(value_type),
other => other.primitive_width(),
}
}

/// In-memory byte size of `num_rows` rows of `schema`.
///
/// `Absent` if any column is variable-width (e.g. Utf8). `Exact` for a plain
/// fixed-width grid; `Inexact` when any column is dictionary-encoded, since we
/// size those by the value type (a safe upper bound) rather than the narrower
/// index — an honest label now that coordinates can be dictionary-encoded.
fn total_byte_size(schema: &Schema, num_rows: &Precision<usize>) -> Precision<usize> {
let Precision::Exact(rows) = num_rows else {
return Precision::Absent;
};
let mut row_width = 0usize;
let mut any_dictionary = false;
for field in schema.fields() {
match field.data_type().primitive_width() {
if matches!(field.data_type(), DataType::Dictionary(_, _)) {
any_dictionary = true;
}
match fixed_width(field.data_type()) {
Some(w) => row_width += w,
None => return Precision::Absent,
}
}
Precision::Exact(rows.saturating_mul(row_width))
let total = rows.saturating_mul(row_width);
if any_dictionary {
Precision::Inexact(total)
} else {
Precision::Exact(total)
}
}

/// Build `Statistics` for a scan over the given partitions.
Expand Down
11 changes: 8 additions & 3 deletions tests/test_cft.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,23 @@ class TestParseSchemaIntegration:
def test_noleap_produces_timestamp_us(self, rasm_ds):
schema = _parse_schema(rasm_ds[["Tair"]])
time_field = schema.field("time")
assert time_field.type == pa.timestamp("us")
# Coordinate columns are dictionary-encoded; the cftime encoding lives
# in the dictionary's value type (and the field metadata is preserved).
assert pa.types.is_dictionary(time_field.type)
assert time_field.type.value_type == pa.timestamp("us")
assert time_field.metadata[b"xarray:calendar"] == b"noleap"

def test_360day_produces_int64(self, ds_360day):
schema = _parse_schema(ds_360day)
time_field = schema.field("time")
assert time_field.type == pa.int64()
assert pa.types.is_dictionary(time_field.type)
assert time_field.type.value_type == pa.int64()
assert time_field.metadata[b"xarray:calendar"] == b"360_day"

def test_datetime64_unchanged(self):
ds = xr.tutorial.open_dataset("air_temperature")
schema = _parse_schema(ds)
time_field = schema.field("time")
assert pa.types.is_timestamp(time_field.type)
assert pa.types.is_dictionary(time_field.type)
assert pa.types.is_timestamp(time_field.type.value_type)
assert time_field.metadata is None # no xarray: metadata for native
41 changes: 23 additions & 18 deletions tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,25 +198,28 @@ def test_dataset_to_record_batch_matches_pivot(air_small):
we sort by the coordinate columns before comparing.
"""
schema = _parse_schema(air_small)
schema_cols = [f.name for f in schema]
dim_cols = [f.name for f in schema if f.name in air_small.dims]
blocks = list(
block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4})
)

for block in blocks:
ds_block = air_small.isel(block)
actual_df = (
dataset_to_record_batch(ds_block, schema)
.to_pandas()
.sort_values(dim_cols)
.reset_index(drop=True)
)
expected_df = (
pa.RecordBatch.from_pandas(pivot(ds_block), schema=schema)
.to_pandas()
pivot(ds_block)[schema_cols]
.sort_values(dim_cols)
.reset_index(drop=True)
)
actual_df = dataset_to_record_batch(ds_block, schema).to_pandas()[
schema_cols
]
# Coordinate columns are dictionary-encoded, so they come back from
# Arrow as pandas Categorical; decode them to the plain dtype pivot uses
# before comparing values.
for col in dim_cols:
actual_df[col] = actual_df[col].astype(expected_df[col].dtype)
actual_df = actual_df.sort_values(dim_cols).reset_index(drop=True)

pd.testing.assert_frame_equal(actual_df, expected_df, check_like=False)

Expand Down Expand Up @@ -403,19 +406,21 @@ def test_read_xarray_loads_one_chunk_at_a_time(large_ds):
peaks.append(cur_peak)

for size in sizes:
# Observed range: 1.59–1.83× on macOS, up to ~2.7× on Linux
# (glibc + Arrow allocate more intermediate buffers).
# iter_record_batches holds data-variable arrays (≈1× chunk) while
# yielding sub-batches, plus the current Arrow batch (≈0.65× chunk).
assert chunk_size * 1.3 < size, f"size {size} unexpectedly low"
# iter_record_batches holds the data-variable arrays (≈1× chunk)
# while yielding sub-batches, plus the current Arrow batch. The
# batch's coordinate columns are dictionary-encoded (only the
# distinct values plus small int32 indices), so they add far less
# than a broadcast column would — steady state sits just above 1×.
assert chunk_size * 1.1 < size, f"size {size} unexpectedly low"
assert chunk_size * 3.5 > size, f"size {size} unexpectedly high"

for peak in peaks:
# Observed range: 1.84–3.28× on macOS, up to ~4.15× on Linux
# (glibc + Arrow hold more intermediate buffers at peak).
# Peak includes data arrays + Arrow batch + temporary coordinate index
# arrays; the first batch of each chunk is highest (Dask compute overhead).
assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low"
# Peak includes the data arrays + the current Arrow batch +
# temporary coordinate index arrays; the first batch of each chunk
# is highest (Dask compute overhead). Dictionary-encoded coordinate
# columns keep the batch (and so the peak) smaller than a broadcast
# column would.
assert chunk_size * 1.3 < peak, f"peak {peak} unexpectedly low"
assert chunk_size * 5.0 > peak, f"peak {peak} unexpectedly high"

assert max(peaks) < large_ds.nbytes
Expand Down
151 changes: 151 additions & 0 deletions tests/test_dict_coords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Coordinate columns are dictionary-encoded when it is worthwhile and safe.

A dense grid repeats every coordinate value across the whole partition (a chunk
of shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times).
Encoding a coordinate as an Arrow dictionary keeps only the distinct values plus
an int32 index, shrinking the bytes the engine moves. We encode only when the
int32 key is strictly narrower than the value (8-byte float64/int64/timestamp
coordinates, and variable-width strings) and leave 4-byte float32/int32
coordinates dense — a narrower key would be needed to win there but overflows
under DataFusion's cross-batch dictionary concatenation. These tests pin which
coordinates get encoded, that the values round-trip correctly, and that the
overflow case no longer crashes.
"""

import numpy as np
import pyarrow as pa
import pytest
import xarray as xr

from xarray_sql import XarrayContext
from xarray_sql.df import (
_coord_index_type,
_parse_schema,
block_slices,
iter_record_batches,
)


@pytest.mark.parametrize(
"n_values, expected",
[
(1, pa.int32()),
(2**31, pa.int32()), # int32 holds indices 0..2**31-1
(2**31 + 1, pa.int64()), # fallback keeps huge axes representable
],
)
def test_coord_index_type_boundaries(n_values, expected):
# int32 is the narrowest key we use: a narrower one can overflow when
# DataFusion concatenates per-batch dictionaries across a streaming
# aggregate (see test_float32_groupby_many_partitions_no_overflow).
assert _coord_index_type(n_values) == expected


def _grid() -> xr.Dataset:
return xr.Dataset(
{"v": (("time", "lat", "lon"), np.random.rand(6, 4, 5))},
coords={
"time": xr.date_range("2020-01-01", periods=6, freq="D"),
"lat": np.array([-90.0, -30.0, 30.0, 90.0]),
"lon": np.array([0.0, 72.0, 144.0, 216.0, 288.0]),
},
)


def test_coordinate_fields_are_dictionary_encoded():
"""8-byte dimension coordinates are dictionary-typed; data variables are not.

``_grid`` uses float64/datetime coordinates (8 bytes), which are wider than
the int32 key, so they are encoded.
"""
schema = _parse_schema(_grid())
for dim in ("time", "lat", "lon"):
assert pa.types.is_dictionary(schema.field(dim).type), dim
assert not pa.types.is_dictionary(schema.field("v").type)


def test_float32_coordinates_stay_dense():
"""4-byte coordinates are left dense: an int32 key is no narrower than the
value, so a dictionary would be pure overhead (and a narrower key is unsafe).
"""
ds = xr.Dataset(
{"v": (("lat", "lon"), np.zeros((3, 4), dtype="float32"))},
coords={
"lat": np.array([-90.0, 0.0, 90.0], dtype="float32"),
"lon": np.arange(4, dtype="int32"),
},
)
schema = _parse_schema(ds)
assert schema.field("lat").type == pa.float32()
assert schema.field("lon").type == pa.int32()


def test_float32_groupby_many_partitions_no_overflow():
"""Regression: GROUP BY on float32 coordinates over many partitions must not
overflow a dictionary key.

With narrow keys, DataFusion concatenating the per-partition coordinate
dictionaries across the aggregate overflowed the key type ("Dictionary key
bigger than the key type"). float32 coordinates are now dense, so there is no
coordinate dictionary to concatenate, and the aggregate matches xarray.
"""
ds = xr.Dataset(
{
"air": (
("time", "lat", "lon"),
np.random.rand(600, 20, 30).astype("float32"),
)
},
coords={
"time": np.arange(600),
"lat": np.linspace(-90, 90, 20).astype("float32"),
"lon": np.linspace(0, 359, 30).astype("float32"),
},
)
ctx = XarrayContext()
ctx.from_dataset("air", ds, chunks={"time": 6}) # 100 partitions
got = ctx.sql(
"SELECT lat, lon, AVG(air) AS m FROM air GROUP BY lat, lon"
).to_dataset(dims=["lat", "lon"])
ref = ds["air"].mean("time")
xr.testing.assert_allclose(
got.m, ref.reindex_like(got.m), rtol=1e-5, atol=1e-6
)


def test_iter_record_batches_emits_dictionary_coords():
"""A coordinate column arrives as a DictionaryArray with correct values."""
ds = _grid()
schema = _parse_schema(ds)
block = next(block_slices(ds, chunks={"time": 6}))
batch = next(iter_record_batches(ds.isel(block), schema, batch_size=1024))

lat = batch.column(batch.schema.names.index("lat"))
assert isinstance(lat, pa.DictionaryArray)
# The dictionary holds only the distinct latitudes; decoding reproduces the
# per-row values in row-major order.
assert lat.dictionary.to_pylist() == [-90.0, -30.0, 30.0, 90.0]
decoded = lat.to_numpy(zero_copy_only=False)
assert decoded.shape == (batch.num_rows,)
# First lon×... rows share lat=-90 (lat varies slower than lon in C order).
assert decoded[0] == -90.0

# Data variables stay plain (not dictionary-encoded).
v = batch.column(batch.schema.names.index("v"))
assert not isinstance(v, pa.DictionaryArray)


def test_groupby_coordinate_roundtrips_through_dictionary():
"""GROUP BY on a dictionary coordinate returns the same numbers as xarray."""
ds = _grid()
ctx = XarrayContext()
ctx.from_dataset("grid", ds, chunks={"time": 3})

got = ctx.sql(
"SELECT lat, AVG(v) AS m FROM grid GROUP BY lat ORDER BY lat"
).to_dataset(dims=["lat"])
ref = ds["v"].mean(["time", "lon"])

xr.testing.assert_allclose(
got.m, ref.reindex_like(got.m), rtol=1e-6, atol=1e-9
)
Loading
Loading