Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
822 changes: 407 additions & 415 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ exclude = [
]

[dependencies]
arrow = { version = "57.2.0", features = ["pyarrow"] }
arrow = { version = "58", features = ["pyarrow"] }
async-stream = "0.3"
async-trait = "0.1"
datafusion = { version = "52.0.0" }
datafusion-ffi = { version = "52.0.0" }
datafusion = { version = "54.0.0" }
datafusion-ffi = { version = "54.0.0" }
futures = { version = "0.3" }
# `abi3-py310` builds against CPython's stable ABI, so a single wheel per
# platform works on all CPython >= 3.10 (matching `requires-python`). This
# lets the release workflow ship pre-built wheels for every interpreter
# without compiling per-version, avoiding local rebuilds on install.
pyo3 = { version = "0.26.0", features = ["extension-module", "abi3-py310"] }
pyo3 = { version = "0.28.0", features = ["extension-module", "abi3-py310"] }
tokio = { version = "1.46.1", features = ["rt"] }


[build-dependencies]
pyo3-build-config = "0.26"
pyo3-build-config = "0.28"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dependencies = [
"dask>=2024.8.0",
"datafusion==52.0.0", # This needs to match the cargo datafusion version!!
"datafusion==54.0.0", # This needs to match the cargo datafusion version!!
"xarray>=2024.7.0",
]

Expand Down
370 changes: 328 additions & 42 deletions src/lib.rs

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_full_query_iterates_all_blocks(self, small_ds):
ctx.register_table("test_table", table)

# Run a query that needs to scan all data
ctx.sql("SELECT COUNT(*) FROM test_table").collect()
ctx.sql("SELECT * FROM test_table").collect()

# With time=100 and chunks=25, we expect 4 blocks
expected_blocks = 100 // 25
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_multiple_queries_on_same_table(self, small_ds):
ctx.register_table("test_table", table)

# First query
ctx.sql("SELECT COUNT(*) FROM test_table").collect()
ctx.sql("SELECT * FROM test_table").collect()
first_query_iterations = tracker.iteration_count
assert first_query_iterations > 0, "First query should iterate"

Expand Down Expand Up @@ -358,8 +358,8 @@ def test_query_results_are_correct(self, small_ds):
ctx.register_table("test_table", table)

# Get count
result = ctx.sql("SELECT COUNT(*) as cnt FROM test_table").collect()
count = result[0].to_pandas()["cnt"].iloc[0]
result = ctx.sql("SELECT * FROM test_table").collect()
count = sum(b.num_rows for b in result)

# Expected: 100 time steps * 10 lat * 10 lon = 10,000 rows
expected_count = 100 * 10 * 10
Expand Down Expand Up @@ -512,7 +512,7 @@ def test_batches_processed_incrementally(self, small_ds):
ctx.register_table("test_table", table)

# Run query that scans all data
ctx.sql("SELECT COUNT(*) FROM test_table").collect()
ctx.sql("SELECT * FROM test_table").collect()

# All 4 batches should have been processed
assert tracker.batch_count == 4, (
Expand Down Expand Up @@ -580,8 +580,8 @@ def test_large_dataset_streams_correctly(self):
ctx.register_table("test_table", table)

# Run a query that needs all data
result = ctx.sql("SELECT COUNT(*) as cnt FROM test_table").collect()
count = result[0].to_pandas()["cnt"].iloc[0]
result = ctx.sql("SELECT * FROM test_table").collect()
count = sum(b.num_rows for b in result)

# Verify all blocks were processed
assert tracker.batch_count == 20, (
Expand Down Expand Up @@ -639,8 +639,8 @@ def test_many_batches_stream_successfully(self):
ctx = SessionContext()
ctx.register_table("test_table", table)

result = ctx.sql("SELECT COUNT(*) as cnt FROM test_table").collect()
count = result[0].to_pandas()["cnt"].iloc[0]
result = ctx.sql("SELECT * FROM test_table").collect()
count = sum(b.num_rows for b in result)

# All 16 batches should have been processed
assert tracker.batch_count == 16, (
Expand Down Expand Up @@ -722,8 +722,8 @@ def test_large_batch_count_completes(self):
ctx = SessionContext()
ctx.register_table("test_table", table)

result = ctx.sql("SELECT COUNT(*) as cnt FROM test_table").collect()
count = result[0].to_pandas()["cnt"].iloc[0]
result = ctx.sql("SELECT * FROM test_table").collect()
count = sum(b.num_rows for b in result)

# All 50 batches processed
assert tracker.batch_count == 50, (
Expand Down Expand Up @@ -792,8 +792,8 @@ def failing_factory():
raise ValueError("Factory intentionally failed")

schema = pa.schema([("value", pa.int64())])
# partitions is an iterable of (factory, metadata_dict) pairs
table = LazyArrowStreamTable([(failing_factory, {})], schema)
# partitions is an iterable of (factory, metadata_dict, num_rows) tuples
table = LazyArrowStreamTable([(failing_factory, {}, 1)], schema)

ctx = SessionContext()
ctx.register_table("test_table", table)
Expand Down Expand Up @@ -860,8 +860,8 @@ def test_empty_dataset_handled_gracefully(self):
ctx = SessionContext()
ctx.register_table("test_table", table)

result = ctx.sql("SELECT COUNT(*) as cnt FROM test_table").collect()
count = result[0].to_pandas()["cnt"].iloc[0]
result = ctx.sql("SELECT * FROM test_table").collect()
count = sum(b.num_rows for b in result)

assert count == 0, f"Expected 0 rows for empty dataset, got {count}"

Expand Down Expand Up @@ -889,7 +889,7 @@ def counting_callback(block, projection_names=None):
ctx.register_table("test_table", table)

# First query
ctx.sql("SELECT COUNT(*) FROM test_table").collect()
ctx.sql("SELECT * FROM test_table").collect()
first_query_count = call_count["value"]
assert first_query_count == 2, (
f"First query: expected 2, got {first_query_count}"
Expand Down Expand Up @@ -933,8 +933,8 @@ def test_parallel_queries_independent(self, small_ds):
ctx2.register_table("test_table", table2)

# Execute queries
ctx1.sql("SELECT COUNT(*) FROM test_table").collect()
ctx2.sql("SELECT COUNT(*) FROM test_table").collect()
ctx1.sql("SELECT * FROM test_table").collect()
ctx2.sql("SELECT * FROM test_table").collect()

# Each should have its own iteration count
assert tracker1.iteration_count == 4, (
Expand Down
124 changes: 124 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Exact table statistics reach the optimizer through the FFI boundary.

DataFusion 54 forwards ``Statistics`` across the ``datafusion-ffi`` boundary,
so the exact statistics ``XarrayScanExec`` reports are visible to the query
optimizer: num_rows (product of a chunk's dimension sizes), total byte size,
and per dimension-column min/max bounds. These tests pin that behaviour.
"""

import numpy as np
import xarray as xr

from xarray_sql import XarrayContext


def _explain(ctx: XarrayContext, query: str) -> str:
ctx.sql("SET datafusion.explain.show_statistics = true").collect()
rows = ctx.sql(f"EXPLAIN {query}").to_pandas()
return "\n".join(rows["plan"].tolist())


def test_exact_rows_in_scan_statistics():
"""The scan reports exact row counts (forwarded across FFI)."""
ds = xr.Dataset(
{"air": (("time", "lat", "lon"), np.random.rand(100, 4, 5))},
coords={
"time": np.arange(100),
"lat": np.arange(4),
"lon": np.arange(5),
},
)
ctx = XarrayContext()
ctx.from_dataset("air", ds, chunks={"time": 50})
plan = _explain(ctx, "SELECT lat, lon, air FROM air")
total = 100 * 4 * 5
assert f"Rows=Exact({total})" in plan


def test_exact_byte_size_in_scan_statistics():
"""The scan reports exact byte size (num_rows x fixed row width)."""
ds = xr.Dataset(
{"air": (("time", "lat", "lon"), np.random.rand(100, 4, 5))},
coords={
"time": np.arange(100),
"lat": np.arange(4),
"lon": np.arange(5),
},
)
ctx = XarrayContext()
ctx.from_dataset("air", ds, chunks={"time": 50})
plan = _explain(ctx, "SELECT lat, lon, air FROM air")
# 2000 rows x (lat int64 + lon int64 + air float64) = 2000 x 24 bytes.
assert f"Bytes=Exact({100 * 4 * 5 * 24})" in plan


def test_dimension_column_min_max_in_scan_statistics():
"""Dimension columns carry exact min/max and a zero null count.

These are the join/filter key columns; the bounds come from the same
coordinate metadata used for partition pruning (no data scan), and grid
axes are always fully populated so the null count is exactly zero.
"""
ds = xr.Dataset(
{"air": (("time", "lat", "lon"), np.random.rand(100, 4, 5))},
coords={
"time": np.arange(100),
"lat": np.arange(4),
"lon": np.arange(5),
},
)
ctx = XarrayContext()
ctx.from_dataset("air", ds, chunks={"time": 50})
plan = _explain(ctx, "SELECT lat, lon, air FROM air")
# lat spans 0..3, lon spans 0..4, both never null.
assert "Min=Exact(Int64(0)) Max=Exact(Int64(3)) Null=Exact(0)" in plan
assert "Min=Exact(Int64(0)) Max=Exact(Int64(4)) Null=Exact(0)" in plan


def test_count_star_answered_from_statistics():
"""COUNT(*) returns the exact count from statistics (metadata only)."""
ds = xr.Dataset(
{"air": (("time", "lat", "lon"), np.random.rand(100, 4, 5))},
coords={
"time": np.arange(100),
"lat": np.arange(4),
"lon": np.arange(5),
},
)
ctx = XarrayContext()
ctx.from_dataset("air", ds, chunks={"time": 50})
n = ctx.sql("SELECT COUNT(*) AS n FROM air").to_pandas()["n"][0]
assert int(n) == 100 * 4 * 5


def test_join_picks_small_build_side():
"""With exact stats the optimizer broadcasts the smaller table (CollectLeft).

Without statistics (the pre-54 FFI path) the optimizer could not know which
side was smaller and fell back to a Partitioned hash join.
"""
rng = np.random.default_rng(0)
big = xr.Dataset(
{"t": (("time", "lat", "lon"), rng.standard_normal((200, 8, 8)))},
coords={
"time": np.arange(200),
"lat": np.arange(8),
"lon": np.arange(8),
},
)
small = xr.Dataset(
{"w": (("lat", "lon"), rng.standard_normal((8, 8)))},
coords={"lat": np.arange(8), "lon": np.arange(8)},
)
ctx = XarrayContext()
ctx.from_dataset("big", big, chunks={"time": 50})
ctx.from_dataset("small", small, chunks={"lat": 8})

plan = _explain(
ctx,
"SELECT b.time, SUM(b.t * s.w) AS x FROM big b "
"JOIN small s ON b.lat=s.lat AND b.lon=s.lon GROUP BY b.time",
)
assert "HashJoinExec: mode=CollectLeft" in plan
# The small (build) side's exact cardinality crossed the FFI boundary.
assert "Rows=Exact(64)" in plan
22 changes: 14 additions & 8 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion xarray_sql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Block,
Chunks,
DEFAULT_BATCH_SIZE,
_block_len,
_block_metadata,
_block_slices_from_resolved,
_parse_schema,
Expand Down Expand Up @@ -315,7 +316,7 @@ def make_stream(
)

def partition_pairs():
"""Lazily yield (factory, metadata) for each partition.
"""Lazily yield (factory, metadata, num_rows) for each partition.

Consuming this generator one item at a time means Python never holds
all N block dicts, metadata dicts, and factory closures simultaneously.
Expand All @@ -327,6 +328,10 @@ def partition_pairs():
yield (
make_partition_factory(block),
{**static_ranges, **dynamic},
# Exact row count for this partition (product of the chunk's
# per-dimension sizes), so the scan can report exact
# Statistics::num_rows to the optimizer.
_block_len(block),
)

return LazyArrowStreamTable(partition_pairs(), schema)
Loading