diff --git a/src/lib.rs b/src/lib.rs index 63a5a6b..0bce6c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,13 @@ impl ScalarBound { /// Compare this bound with a DataFusion ScalarValue. /// Returns None if types are incompatible. fn compare_to_scalar(&self, scalar: &ScalarValue) -> Option { + // Coordinate columns are dictionary-encoded, so DataFusion coerces the + // comparison literal to a `Dictionary` scalar (the column stays a bare + // Column). Unwrap it to the underlying value and compare against that, + // so partition pruning keeps working on dictionary coordinates. + if let ScalarValue::Dictionary(_, value) = scalar { + return self.compare_to_scalar(value); + } match (self, scalar) { // Integer comparisons (ScalarBound::Int64(a), ScalarValue::Int64(Some(b))) => Some(a.cmp(b)), @@ -265,8 +272,11 @@ impl PrunableStreamingTable { right: &Expr, meta: &PartitionMetadata, ) -> bool { - // Try to extract column and literal from either side - let (col_name, scalar, flipped) = match (left, right) { + // Try to extract column and literal from either side. Coordinate + // columns are dictionary-encoded, so a filter may arrive as + // `Cast(col AS value_type) op literal`; strip the (lossless) decode cast + // to reach the column. + let (col_name, scalar, flipped) = match (strip_cast(left), strip_cast(right)) { (Expr::Column(c), Expr::Literal(s, _)) => (c.name.clone(), s, false), (Expr::Literal(s, _), Expr::Column(c)) => (c.name.clone(), s, true), _ => return false, // Not a simple column-literal comparison @@ -348,8 +358,8 @@ impl PrunableStreamingTable { return false; } - // Extract column name - let col_name = match between.expr.as_ref() { + // Extract column name (peeling any dictionary-decode cast) + let col_name = match strip_cast(between.expr.as_ref()) { Expr::Column(c) => c.name.clone(), _ => return false, }; @@ -387,8 +397,8 @@ impl PrunableStreamingTable { return false; } - // Extract column name - let col_name = match in_list.expr.as_ref() { + // Extract column name (peeling any dictionary-decode cast) + let col_name = match strip_cast(in_list.expr.as_ref()) { Expr::Column(c) => c.name.clone(), _ => return false, }; @@ -440,13 +450,33 @@ impl PrunableStreamingTable { /// Check if an expression references a dimension column. fn expr_references_dimension(&self, expr: &Expr) -> bool { - match expr { + match strip_cast(expr) { Expr::Column(c) => self.dimension_columns.contains(&c.name), _ => false, } } } +/// Peel `Cast`/`TryCast` wrappers to reach the inner expression. +/// +/// Dictionary-encoded coordinate columns are decoded with a lossless +/// `CAST(col AS value_type)` before a comparison, so a filter on a coordinate +/// can arrive as `Cast(Column) op literal`. Stripping the cast lets pruning see +/// the underlying column; the cast targets the dictionary's own value type, so +/// it never changes the compared value. If a cast were ever lossy, the bound +/// comparison would simply fail to match and the partition is conservatively +/// kept — pruning never becomes wrong, only (at worst) less effective. +fn strip_cast(expr: &Expr) -> &Expr { + let mut current = expr; + loop { + match current { + Expr::Cast(cast) => current = &cast.expr, + Expr::TryCast(cast) => current = &cast.expr, + _ => return current, + } + } +} + /// Extension trait for partition streams that support column projection. /// /// Implemented by `PyArrowStreamPartition` so that `PrunableStreamingTable` @@ -702,6 +732,12 @@ fn fold_bound(a: &ScalarBound, b: &ScalarBound, keep_min: bool) -> Option Option { + // Coordinate columns are dictionary-encoded; their min/max is the min/max of + // the underlying values, so report a value-type scalar (what DataFusion uses + // for dictionary-column statistics). + if let DataType::Dictionary(_, value_type) = dtype { + return bound_to_scalar(bound, value_type); + } match (bound, dtype) { (ScalarBound::Int64(v), DataType::Int64) => Some(ScalarValue::Int64(Some(*v))), (ScalarBound::Int64(v), DataType::Int32) => { @@ -736,21 +772,46 @@ fn bound_to_scalar(bound: &ScalarBound, dtype: &DataType) -> Option } } -/// Exact in-memory byte size of `num_rows` rows of `schema`, or `Absent` if any -/// column is variable-width (e.g. Utf8) and cannot be sized from the row count -/// alone. Our data model is dense fixed-width grids, so this is normally exact. +/// Fixed byte width of one value of `dtype`, or `None` if it is variable-width. +/// +/// A dictionary-encoded column is sized by its *value* type — a safe upper bound +/// on the encoded column's real footprint (the actual indices are narrower). We +/// deliberately over- rather than under-count so a memory-based rule never +/// under-provisions; the caller marks the total `Inexact` to reflect it. +fn fixed_width(dtype: &DataType) -> Option { + match dtype { + DataType::Dictionary(_, value_type) => fixed_width(value_type), + other => other.primitive_width(), + } +} + +/// In-memory byte size of `num_rows` rows of `schema`. +/// +/// `Absent` if any column is variable-width (e.g. Utf8). `Exact` for a plain +/// fixed-width grid; `Inexact` when any column is dictionary-encoded, since we +/// size those by the value type (a safe upper bound) rather than the narrower +/// index — an honest label now that coordinates can be dictionary-encoded. fn total_byte_size(schema: &Schema, num_rows: &Precision) -> Precision { let Precision::Exact(rows) = num_rows else { return Precision::Absent; }; let mut row_width = 0usize; + let mut any_dictionary = false; for field in schema.fields() { - match field.data_type().primitive_width() { + if matches!(field.data_type(), DataType::Dictionary(_, _)) { + any_dictionary = true; + } + match fixed_width(field.data_type()) { Some(w) => row_width += w, None => return Precision::Absent, } } - Precision::Exact(rows.saturating_mul(row_width)) + let total = rows.saturating_mul(row_width); + if any_dictionary { + Precision::Inexact(total) + } else { + Precision::Exact(total) + } } /// Build `Statistics` for a scan over the given partitions. diff --git a/tests/test_cft.py b/tests/test_cft.py index 6560d72..baa4290 100644 --- a/tests/test_cft.py +++ b/tests/test_cft.py @@ -153,18 +153,23 @@ class TestParseSchemaIntegration: def test_noleap_produces_timestamp_us(self, rasm_ds): schema = _parse_schema(rasm_ds[["Tair"]]) time_field = schema.field("time") - assert time_field.type == pa.timestamp("us") + # Coordinate columns are dictionary-encoded; the cftime encoding lives + # in the dictionary's value type (and the field metadata is preserved). + assert pa.types.is_dictionary(time_field.type) + assert time_field.type.value_type == pa.timestamp("us") assert time_field.metadata[b"xarray:calendar"] == b"noleap" def test_360day_produces_int64(self, ds_360day): schema = _parse_schema(ds_360day) time_field = schema.field("time") - assert time_field.type == pa.int64() + assert pa.types.is_dictionary(time_field.type) + assert time_field.type.value_type == pa.int64() assert time_field.metadata[b"xarray:calendar"] == b"360_day" def test_datetime64_unchanged(self): ds = xr.tutorial.open_dataset("air_temperature") schema = _parse_schema(ds) time_field = schema.field("time") - assert pa.types.is_timestamp(time_field.type) + assert pa.types.is_dictionary(time_field.type) + assert pa.types.is_timestamp(time_field.type.value_type) assert time_field.metadata is None # no xarray: metadata for native diff --git a/tests/test_df.py b/tests/test_df.py index 2c81144..de57eba 100644 --- a/tests/test_df.py +++ b/tests/test_df.py @@ -198,6 +198,7 @@ def test_dataset_to_record_batch_matches_pivot(air_small): we sort by the coordinate columns before comparing. """ schema = _parse_schema(air_small) + schema_cols = [f.name for f in schema] dim_cols = [f.name for f in schema if f.name in air_small.dims] blocks = list( block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}) @@ -205,18 +206,20 @@ def test_dataset_to_record_batch_matches_pivot(air_small): for block in blocks: ds_block = air_small.isel(block) - actual_df = ( - dataset_to_record_batch(ds_block, schema) - .to_pandas() - .sort_values(dim_cols) - .reset_index(drop=True) - ) expected_df = ( - pa.RecordBatch.from_pandas(pivot(ds_block), schema=schema) - .to_pandas() + pivot(ds_block)[schema_cols] .sort_values(dim_cols) .reset_index(drop=True) ) + actual_df = dataset_to_record_batch(ds_block, schema).to_pandas()[ + schema_cols + ] + # Coordinate columns are dictionary-encoded, so they come back from + # Arrow as pandas Categorical; decode them to the plain dtype pivot uses + # before comparing values. + for col in dim_cols: + actual_df[col] = actual_df[col].astype(expected_df[col].dtype) + actual_df = actual_df.sort_values(dim_cols).reset_index(drop=True) pd.testing.assert_frame_equal(actual_df, expected_df, check_like=False) @@ -403,19 +406,21 @@ def test_read_xarray_loads_one_chunk_at_a_time(large_ds): peaks.append(cur_peak) for size in sizes: - # Observed range: 1.59–1.83× on macOS, up to ~2.7× on Linux - # (glibc + Arrow allocate more intermediate buffers). - # iter_record_batches holds data-variable arrays (≈1× chunk) while - # yielding sub-batches, plus the current Arrow batch (≈0.65× chunk). - assert chunk_size * 1.3 < size, f"size {size} unexpectedly low" + # iter_record_batches holds the data-variable arrays (≈1× chunk) + # while yielding sub-batches, plus the current Arrow batch. The + # batch's coordinate columns are dictionary-encoded (only the + # distinct values plus small int32 indices), so they add far less + # than a broadcast column would — steady state sits just above 1×. + assert chunk_size * 1.1 < size, f"size {size} unexpectedly low" assert chunk_size * 3.5 > size, f"size {size} unexpectedly high" for peak in peaks: - # Observed range: 1.84–3.28× on macOS, up to ~4.15× on Linux - # (glibc + Arrow hold more intermediate buffers at peak). - # Peak includes data arrays + Arrow batch + temporary coordinate index - # arrays; the first batch of each chunk is highest (Dask compute overhead). - assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low" + # Peak includes the data arrays + the current Arrow batch + + # temporary coordinate index arrays; the first batch of each chunk + # is highest (Dask compute overhead). Dictionary-encoded coordinate + # columns keep the batch (and so the peak) smaller than a broadcast + # column would. + assert chunk_size * 1.3 < peak, f"peak {peak} unexpectedly low" assert chunk_size * 5.0 > peak, f"peak {peak} unexpectedly high" assert max(peaks) < large_ds.nbytes diff --git a/tests/test_dict_coords.py b/tests/test_dict_coords.py new file mode 100644 index 0000000..4687183 --- /dev/null +++ b/tests/test_dict_coords.py @@ -0,0 +1,151 @@ +"""Coordinate columns are dictionary-encoded when it is worthwhile and safe. + +A dense grid repeats every coordinate value across the whole partition (a chunk +of shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times). +Encoding a coordinate as an Arrow dictionary keeps only the distinct values plus +an int32 index, shrinking the bytes the engine moves. We encode only when the +int32 key is strictly narrower than the value (8-byte float64/int64/timestamp +coordinates, and variable-width strings) and leave 4-byte float32/int32 +coordinates dense — a narrower key would be needed to win there but overflows +under DataFusion's cross-batch dictionary concatenation. These tests pin which +coordinates get encoded, that the values round-trip correctly, and that the +overflow case no longer crashes. +""" + +import numpy as np +import pyarrow as pa +import pytest +import xarray as xr + +from xarray_sql import XarrayContext +from xarray_sql.df import ( + _coord_index_type, + _parse_schema, + block_slices, + iter_record_batches, +) + + +@pytest.mark.parametrize( + "n_values, expected", + [ + (1, pa.int32()), + (2**31, pa.int32()), # int32 holds indices 0..2**31-1 + (2**31 + 1, pa.int64()), # fallback keeps huge axes representable + ], +) +def test_coord_index_type_boundaries(n_values, expected): + # int32 is the narrowest key we use: a narrower one can overflow when + # DataFusion concatenates per-batch dictionaries across a streaming + # aggregate (see test_float32_groupby_many_partitions_no_overflow). + assert _coord_index_type(n_values) == expected + + +def _grid() -> xr.Dataset: + return xr.Dataset( + {"v": (("time", "lat", "lon"), np.random.rand(6, 4, 5))}, + coords={ + "time": xr.date_range("2020-01-01", periods=6, freq="D"), + "lat": np.array([-90.0, -30.0, 30.0, 90.0]), + "lon": np.array([0.0, 72.0, 144.0, 216.0, 288.0]), + }, + ) + + +def test_coordinate_fields_are_dictionary_encoded(): + """8-byte dimension coordinates are dictionary-typed; data variables are not. + + ``_grid`` uses float64/datetime coordinates (8 bytes), which are wider than + the int32 key, so they are encoded. + """ + schema = _parse_schema(_grid()) + for dim in ("time", "lat", "lon"): + assert pa.types.is_dictionary(schema.field(dim).type), dim + assert not pa.types.is_dictionary(schema.field("v").type) + + +def test_float32_coordinates_stay_dense(): + """4-byte coordinates are left dense: an int32 key is no narrower than the + value, so a dictionary would be pure overhead (and a narrower key is unsafe). + """ + ds = xr.Dataset( + {"v": (("lat", "lon"), np.zeros((3, 4), dtype="float32"))}, + coords={ + "lat": np.array([-90.0, 0.0, 90.0], dtype="float32"), + "lon": np.arange(4, dtype="int32"), + }, + ) + schema = _parse_schema(ds) + assert schema.field("lat").type == pa.float32() + assert schema.field("lon").type == pa.int32() + + +def test_float32_groupby_many_partitions_no_overflow(): + """Regression: GROUP BY on float32 coordinates over many partitions must not + overflow a dictionary key. + + With narrow keys, DataFusion concatenating the per-partition coordinate + dictionaries across the aggregate overflowed the key type ("Dictionary key + bigger than the key type"). float32 coordinates are now dense, so there is no + coordinate dictionary to concatenate, and the aggregate matches xarray. + """ + ds = xr.Dataset( + { + "air": ( + ("time", "lat", "lon"), + np.random.rand(600, 20, 30).astype("float32"), + ) + }, + coords={ + "time": np.arange(600), + "lat": np.linspace(-90, 90, 20).astype("float32"), + "lon": np.linspace(0, 359, 30).astype("float32"), + }, + ) + ctx = XarrayContext() + ctx.from_dataset("air", ds, chunks={"time": 6}) # 100 partitions + got = ctx.sql( + "SELECT lat, lon, AVG(air) AS m FROM air GROUP BY lat, lon" + ).to_dataset(dims=["lat", "lon"]) + ref = ds["air"].mean("time") + xr.testing.assert_allclose( + got.m, ref.reindex_like(got.m), rtol=1e-5, atol=1e-6 + ) + + +def test_iter_record_batches_emits_dictionary_coords(): + """A coordinate column arrives as a DictionaryArray with correct values.""" + ds = _grid() + schema = _parse_schema(ds) + block = next(block_slices(ds, chunks={"time": 6})) + batch = next(iter_record_batches(ds.isel(block), schema, batch_size=1024)) + + lat = batch.column(batch.schema.names.index("lat")) + assert isinstance(lat, pa.DictionaryArray) + # The dictionary holds only the distinct latitudes; decoding reproduces the + # per-row values in row-major order. + assert lat.dictionary.to_pylist() == [-90.0, -30.0, 30.0, 90.0] + decoded = lat.to_numpy(zero_copy_only=False) + assert decoded.shape == (batch.num_rows,) + # First lon×... rows share lat=-90 (lat varies slower than lon in C order). + assert decoded[0] == -90.0 + + # Data variables stay plain (not dictionary-encoded). + v = batch.column(batch.schema.names.index("v")) + assert not isinstance(v, pa.DictionaryArray) + + +def test_groupby_coordinate_roundtrips_through_dictionary(): + """GROUP BY on a dictionary coordinate returns the same numbers as xarray.""" + ds = _grid() + ctx = XarrayContext() + ctx.from_dataset("grid", ds, chunks={"time": 3}) + + got = ctx.sql( + "SELECT lat, AVG(v) AS m FROM grid GROUP BY lat ORDER BY lat" + ).to_dataset(dims=["lat"]) + ref = ds["v"].mean(["time", "lon"]) + + xr.testing.assert_allclose( + got.m, ref.reindex_like(got.m), rtol=1e-6, atol=1e-9 + ) diff --git a/tests/test_ds.py b/tests/test_ds.py index aa3deb2..43627d8 100644 --- a/tests/test_ds.py +++ b/tests/test_ds.py @@ -37,14 +37,28 @@ def test_ctx_sql_returns_xarray_dataframe(air_dataset_small): assert isinstance(result, XarrayDataFrame) -def test_to_pandas_unchanged_behavior(air_dataset_small): - """Wrapped ``.to_pandas()`` is bit-for-bit equal to the un-wrapped path.""" +def test_to_pandas_decodes_dictionary_columns(air_dataset_small): + """Wrapped ``.to_pandas()`` surfaces coordinates as plain values. + + Coordinate columns are dictionary-encoded internally, which the raw + DataFusion ``.to_pandas()`` returns as pandas ``Categorical``. The wrapper + decodes those back to their value dtype so callers see the same plain + columns as before the encoding — no ``Categorical`` leaks, and the values + match the raw path once its categoricals are decoded. + """ from datafusion import SessionContext ctx = XarrayContext() ctx.from_dataset("air", air_dataset_small) wrapped = ctx.sql("SELECT * FROM air LIMIT 7").to_pandas() raw = SessionContext.sql(ctx, "SELECT * FROM air LIMIT 7").to_pandas() + + assert not any( + isinstance(wrapped[c].dtype, pd.CategoricalDtype) for c in wrapped + ) + for c in raw.columns: + if isinstance(raw[c].dtype, pd.CategoricalDtype): + raw[c] = raw[c].astype(raw[c].cat.categories.dtype) pd.testing.assert_frame_equal(wrapped, raw) diff --git a/tests/test_stats.py b/tests/test_stats.py index a1f199e..e51fe49 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -48,8 +48,10 @@ def test_exact_byte_size_in_scan_statistics(): ctx = XarrayContext() ctx.from_dataset("air", ds, chunks={"time": 50}) plan = _explain(ctx, "SELECT lat, lon, air FROM air") - # 2000 rows x (lat int64 + lon int64 + air float64) = 2000 x 24 bytes. - assert f"Bytes=Exact({100 * 4 * 5 * 24})" in plan + # 2000 rows x (lat int64 + lon int64 + air float64) = 2000 x 24 bytes, sized + # by value type. lat/lon are dictionary-encoded (int32 keys), so the value + # sizing is a safe upper bound and the total is reported Inexact. + assert f"Bytes=Inexact({100 * 4 * 5 * 24})" in plan def test_dimension_column_min_max_in_scan_statistics(): diff --git a/xarray_sql/df.py b/xarray_sql/df.py index ab80056..df6a227 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -271,15 +271,29 @@ def dataset_to_record_batch( for field in schema: name = field.name if name in ds.coords and name in ds.dims: - # Broadcast 1-D coordinate to the full N-D partition shape, then ravel. axis = dim_names.index(name) coord = ds.coords[name].values if cft.is_cftime(coord): coord = cft.convert_for_field(coord, field) reshape = [1] * len(shape) reshape[axis] = coord.shape[0] - arr = np.broadcast_to(coord.reshape(reshape), shape).ravel() - arrays.append(pa.array(arr, type=field.type)) + if pa.types.is_dictionary(field.type): + # Dictionary-encode: the distinct coordinate values are the + # dictionary, and the per-row indices are the coordinate's own + # axis index broadcast across the partition — no dense array of + # repeated values is ever built. + idx = np.broadcast_to( + np.arange(coord.shape[0]).reshape(reshape), shape + ).ravel() + dictionary = pa.array(coord, type=field.type.value_type) + indices = pa.array(idx, type=field.type.index_type) + arrays.append( + pa.DictionaryArray.from_arrays(indices, dictionary) + ) + else: + # Broadcast 1-D coordinate to the full partition shape, ravel. + arr = np.broadcast_to(coord.reshape(reshape), shape).ravel() + arrays.append(pa.array(arr, type=field.type)) else: # Data variable: ravel to 1-D (zero-copy for C-contiguous arrays). raw = ds[name].values.ravel() @@ -371,9 +385,22 @@ def iter_record_batches( if name in ds.coords and name in ds.dims: k = dim_names.index(name) coord_idx = (row_idx // strides[k]) % shape[k] - arrays.append( - pa.array(coord_values[name][coord_idx], type=field.type) - ) + if pa.types.is_dictionary(field.type): + # Emit the coordinate as a dictionary array: the distinct + # coordinate values are the dictionary, and the strided + # per-row indices we just computed are exactly the + # dictionary indices — no broadcast of repeated values. + dictionary = pa.array( + coord_values[name], type=field.type.value_type + ) + indices = pa.array(coord_idx, type=field.type.index_type) + arrays.append( + pa.DictionaryArray.from_arrays(indices, dictionary) + ) + else: + arrays.append( + pa.array(coord_values[name][coord_idx], type=field.type) + ) else: arrays.append( pa.array( @@ -386,6 +413,53 @@ def iter_record_batches( yield pa.RecordBatch.from_arrays(arrays, schema=schema) +def _coord_index_type(n_values: int) -> pa.DataType: + """Signed Arrow int type for a coordinate dictionary of ``n_values`` keys. + + The index must survive DataFusion *concatenating* per-batch dictionaries + across a scan. Those concatenations are not always unified (arrow merges + dictionary values on a size heuristic, not a guarantee), so the *combined* + dictionary an operator sees can be many times a single partition's + cardinality — an unchunked coordinate repeated across N partitions can reach + ``card × N``. A narrow key (int8/int16) therefore overflows under streaming + aggregation ("Dictionary key bigger than the key type"), so we do not use + one. ``int32`` covers ~2.1B combined entries (any realistic grid), and + ``int64`` backstops the astronomically-large case. Indices run + ``0 .. n_values - 1``, so a signed max ``M`` holds a cardinality of ``M + 1``. + """ + if n_values <= np.iinfo(np.int32).max + 1: + return pa.int32() + return pa.int64() + + +def _as_dictionary_field(field: pa.Field, n_values: int) -> pa.Field: + """Dictionary-encode a coordinate field when it is worthwhile and safe. + + Coordinate columns repeat each value across the whole grid — a chunk of + shape ``(time, lat, lon)`` carries each latitude ``time × lon`` times. A + dictionary stores only the distinct values plus integer indices, shrinking + the bytes the engine moves and letting ``GROUP BY`` / equality ``JOIN`` + compare integer keys instead of rehashing repeated floats. + + We only encode when the index type (see :func:`_coord_index_type`) is + *strictly narrower* than the value type. That keeps 8-byte coordinates + (``float64``, ``int64``, timestamps) — a 2× win with an overflow-safe int32 + key — while leaving 4-byte ``float32`` / ``int32`` coordinates dense, where a + dictionary would be pure overhead and a narrower (overflow-prone) key would + be the only way to win. Variable-width value types (e.g. strings) are always + encoded — the classic dictionary win. The field's name, nullability, and + metadata (e.g. cftime units/calendar) are preserved. + """ + index_type = _coord_index_type(n_values) + try: + value_bit_width = field.type.bit_width + except (ValueError, NotImplementedError): + value_bit_width = None # variable-width (e.g. string): always encode + if value_bit_width is not None and index_type.bit_width >= value_bit_width: + return field + return field.with_type(pa.dictionary(index_type, field.type)) + + def _parse_schema(ds: xr.Dataset) -> pa.Schema: """Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns. @@ -413,10 +487,13 @@ def _parse_schema(ds: xr.Dataset) -> pa.Schema: if coord_name in ds.dims: if cft.is_cftime_index(ds, coord_name): units, calendar = cft.encoding(ds, coord_name) - columns.append(cft.arrow_field(coord_name, units, calendar)) + field = cft.arrow_field(coord_name, units, calendar) else: pa_type = pa.from_numpy_dtype(coord_var.dtype) - columns.append(pa.field(coord_name, pa_type)) + field = pa.field(coord_name, pa_type) + # Dictionary-encode the coordinate column (see _as_dictionary_field), + # sizing the index type to the dimension length. + columns.append(_as_dictionary_field(field, ds.sizes[coord_name])) for var_name, var in ds.data_vars.items(): # Data variables are virtually never cftime, but check dtype as a diff --git a/xarray_sql/ds.py b/xarray_sql/ds.py index 5dfdf42..e8b3a09 100644 --- a/xarray_sql/ds.py +++ b/xarray_sql/ds.py @@ -742,8 +742,19 @@ def __init__( object.__setattr__(self, "_templates", dict(templates or {})) def to_pandas(self) -> pd.DataFrame: - """Materialize the result as a ``pd.DataFrame`` (DataFusion API).""" - return self._inner.to_pandas() + """Materialize the result as a ``pd.DataFrame`` (DataFusion API). + + Coordinate columns are dictionary-encoded internally, which DataFusion + surfaces as pandas ``Categorical``. Decode those back to their value + dtype so callers see the same plain columns as before the encoding — + results sort and compare by value, not by category order. + """ + df = self._inner.to_pandas() + for name in df.columns: + dtype = df[name].dtype + if isinstance(dtype, pd.CategoricalDtype): + df[name] = df[name].astype(dtype.categories.dtype) + return df def to_dataset( self,