diff --git a/Cargo.lock b/Cargo.lock index 21dfa95..41a4c7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3375,13 +3375,14 @@ checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "xarray_sql" -version = "0.3.0" +version = "0.3.1" dependencies = [ "arrow", "async-stream", "async-trait", "datafusion", "datafusion-ffi", + "datafusion-physical-expr-common", "futures", "pyo3", "pyo3-build-config", diff --git a/Cargo.toml b/Cargo.toml index 147c68f..4b0374f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,10 @@ arrow = { version = "57.2.0", features = ["pyarrow"] } async-stream = "0.3" async-trait = "0.1" datafusion = { version = "52.0.0" } +# Provides `snapshot_physical_expr`, used to resolve a join's dynamic filter to +# its current bounds for partition pruning. Pinned to the same 52.x line as +# datafusion so the PhysicalExpr ABI matches. +datafusion-physical-expr-common = { version = "52.0.0" } datafusion-ffi = { version = "52.0.0" } futures = { version = "0.3" } # `abi3-py310` builds against CPython's stable ABI, so a single wheel per @@ -29,7 +33,7 @@ futures = { version = "0.3" } # lets the release workflow ship pre-built wheels for every interpreter # without compiling per-version, avoiding local rebuilds on install. pyo3 = { version = "0.26.0", features = ["extension-module", "abi3-py310"] } -tokio = { version = "1.46.1", features = ["rt"] } +tokio = { version = "1.46.1", features = ["rt", "rt-multi-thread"] } [build-dependencies] diff --git a/src/lib.rs b/src/lib.rs index c489609..136d1c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,27 +47,45 @@ use std::ffi::CString; use std::fmt::Debug; use std::sync::Arc; -use arrow::array::RecordBatch; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::array::{make_array, ArrayData, ArrayRef, BooleanArray, RecordBatch}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::pyarrow::FromPyArrow; +use arrow::pyarrow::ToPyArrow; use async_stream::try_stream; use async_trait::async_trait; +use datafusion::catalog::memory::MemorySchemaProvider; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::Session; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::common::stats::Precision; +use datafusion::common::{Column, ColumnStatistics, Statistics}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue, TableReference}; +use datafusion::config::ConfigOptions; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - BinaryExpr, Expr, Operator, TableProviderFilterPushDown, TableType, + BinaryExpr, ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, TableProviderFilterPushDown, TableType, Volatility, +}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown, }; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; -use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion::physical_plan::{ + displayable, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, +}; +use datafusion::prelude::{col, lit, DataFrame, SessionContext}; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; +use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; +use futures::StreamExt; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyList}; +use pyo3::types::{PyCapsule, PyList, PyTuple}; +use tokio::runtime::Runtime; // ============================================================================ // Partition Metadata Types for Filter Pushdown @@ -136,6 +154,14 @@ pub struct DimensionRange { pub struct PartitionMetadata { /// Dimension ranges for this partition, keyed by column name pub ranges: HashMap, + /// Exact number of rows in this partition (product of the chunk's + /// per-dimension sizes). `None` when the producer did not supply it. + /// Used to report exact `Statistics::num_rows` to the optimizer so + /// cost-based rules (join build-side selection, broadcast vs. shuffle) + /// have real cardinalities instead of guesses. xarray knows this + /// exactly — it is the product of the partition's dimension lengths — + /// so unlike most table providers these statistics are not estimates. + pub num_rows: Option, } impl PartitionMetadata { @@ -507,7 +533,10 @@ fn convert_python_metadata_from_bound(meta_obj: &Bound<'_, PyAny>) -> PyResult

= included_indices + .iter() + .map(|&idx| &self.partitions[idx].1) + .collect(); + let partition_rows: Vec> = included_metas + .iter() + .map(|meta| match meta.num_rows { + Some(n) => Precision::Exact(n), + None => Precision::Absent, + }) + .collect(); + // Owned copies of the surviving partitions' coordinate ranges, so the + // scan node can re-prune at execution time against a dynamic filter. + let owned_metas: Vec = + included_metas.iter().map(|&m| m.clone()).collect(); + // Handle empty case — all partitions pruned, return empty plan if included_indices.is_empty() { let empty_table = StreamingTable::try_new(Arc::clone(&self.schema), vec![])?; - return empty_table.scan(state, projection, filters, limit).await; + let inner = empty_table.scan(state, projection, filters, limit).await?; + let stats = build_scan_statistics(inner.schema().as_ref(), &included_metas); + return Ok(Arc::new(XarrayScanExec::new( + inner, + stats, + partition_rows, + owned_metas, + ))); } // Determine whether to push projection down to the Python factory. @@ -622,7 +677,14 @@ impl TableProvider for PrunableStreamingTable { // StreamingTable already has the projected schema — pass None for // projection so it doesn't wrap the stream in a redundant ProjectionExec. let streaming = StreamingTable::try_new(projected_schema, projected_partitions)?; - streaming.scan(state, None, filters, limit).await + let inner = streaming.scan(state, None, filters, limit).await?; + let stats = build_scan_statistics(inner.schema().as_ref(), &included_metas); + Ok(Arc::new(XarrayScanExec::new( + inner, + stats, + partition_rows, + owned_metas, + ))) } else { // No projection pushdown — factory is called with None (loads all // columns). StreamingTable applies projection via ProjectionExec. @@ -631,11 +693,414 @@ impl TableProvider for PrunableStreamingTable { .map(|&idx| self.partitions[idx].0.clone_as_stream()) .collect(); let streaming = StreamingTable::try_new(Arc::clone(&self.schema), included_partitions)?; - streaming.scan(state, projection, filters, limit).await + let inner = streaming.scan(state, projection, filters, limit).await?; + let stats = build_scan_statistics(inner.schema().as_ref(), &included_metas); + Ok(Arc::new(XarrayScanExec::new( + inner, + stats, + partition_rows, + owned_metas, + ))) + } + } +} + +// ============================================================================ +// Exact Statistics + Scan Wrapper +// ============================================================================ + +/// Sum a set of optional per-partition row counts into a `Precision`. +/// +/// Exact only when *every* partition reports a count; if any is missing we +/// return `Absent` rather than an under-count, so the optimizer never sees a +/// cardinality smaller than reality. +fn sum_row_counts<'a>(metas: impl Iterator) -> Precision { + let mut total: usize = 0; + for meta in metas { + match meta.num_rows { + Some(n) => total += n, + None => return Precision::Absent, + } + } + Precision::Exact(total) +} + +/// Fold two same-variant `ScalarBound`s, keeping the smaller (`keep_min`) or +/// larger one. Returns `None` if the variants differ (never expected within a +/// single dimension) so the caller can fall back to unknown. +fn fold_bound(a: &ScalarBound, b: &ScalarBound, keep_min: bool) -> Option { + let ord = match (a, b) { + (ScalarBound::Int64(x), ScalarBound::Int64(y)) => x.partial_cmp(y), + (ScalarBound::Float64(x), ScalarBound::Float64(y)) => x.partial_cmp(y), + (ScalarBound::TimestampNanos(x), ScalarBound::TimestampNanos(y)) => x.partial_cmp(y), + _ => return None, + }?; + let take_a = if keep_min { + ord != std::cmp::Ordering::Greater + } else { + ord != std::cmp::Ordering::Less + }; + Some(if take_a { a.clone() } else { b.clone() }) +} + +/// Convert a coordinate bound into a `ScalarValue` matching a column's Arrow +/// type, so column statistics line up with the schema. Returns `None` for +/// type combinations we don't convert exactly (e.g. timestamp unit scaling), +/// in which case the column is left without min/max rather than risk a wrong +/// value. +fn bound_to_scalar(bound: &ScalarBound, dtype: &DataType) -> Option { + match (bound, dtype) { + (ScalarBound::Int64(v), DataType::Int64) => Some(ScalarValue::Int64(Some(*v))), + (ScalarBound::Int64(v), DataType::Int32) => { + i32::try_from(*v).ok().map(|x| ScalarValue::Int32(Some(x))) + } + (ScalarBound::Float64(v), DataType::Float64) => Some(ScalarValue::Float64(Some(*v))), + (ScalarBound::Float64(v), DataType::Float32) => Some(ScalarValue::Float32(Some(*v as f32))), + // Timestamp columns are intentionally left without min/max for now: + // bounds are stored in nanoseconds but the column may be a different + // unit, and an unscaled value would be wrong. num_rows already covers + // the cost model; exact timestamp bounds can come with the dynamic + // filter work that actually consumes them. + _ => None, + } +} + +/// Build table-level `Statistics` for a scan over the given partitions. +/// +/// `num_rows` is exact (the summed product of chunk dimension sizes). For +/// numeric dimension columns we also surface exact min/max, folded across the +/// included partitions — these are the join/filter key columns, and unlike +/// most providers the bounds are exact coordinate values, not estimates. +fn build_scan_statistics(output_schema: &Schema, metas: &[&PartitionMetadata]) -> Statistics { + let mut stats = Statistics::new_unknown(output_schema); + stats.num_rows = sum_row_counts(metas.iter().copied()); + + for (col_idx, field) in output_schema.fields().iter().enumerate() { + // Fold this column's min/max across every partition that has a range + // for it. All partitions share the same bound variant per dimension. + let mut folded: Option<(ScalarBound, ScalarBound)> = None; + for meta in metas { + if let Some(range) = meta.ranges.get(field.name()) { + folded = Some(match folded { + None => (range.min.clone(), range.max.clone()), + Some((lo, hi)) => ( + fold_bound(&lo, &range.min, true).unwrap_or(lo), + fold_bound(&hi, &range.max, false).unwrap_or(hi), + ), + }); + } + } + + if let Some((lo, hi)) = folded { + let dtype = field.data_type(); + if let (Some(min), Some(max)) = + (bound_to_scalar(&lo, dtype), bound_to_scalar(&hi, dtype)) + { + stats.column_statistics[col_idx] = ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(max), + min_value: Precision::Exact(min), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }; + } + } + } + + stats +} + +/// A thin scan operator that wraps an inner `StreamingTableExec` and reports +/// exact `Statistics` to the query optimizer. +/// +/// Execution, schema, ordering, and partitioning are delegated verbatim to the +/// inner plan (so projection mechanics are reused unchanged); the only thing +/// this node adds is real cardinality. When consumed natively (not across the +/// FFI boundary, which drops statistics entirely), this is what lets +/// DataFusion's cost-based `JoinSelection` rule pick a sensible build side and +/// broadcast-vs-shuffle strategy. +struct XarrayScanExec { + inner: Arc, + /// Output schema (== `inner.schema()`), cached for pruning predicates. + schema: SchemaRef, + statistics: Statistics, + /// Exact row count per output partition (parallel to `inner` partitions), + /// so `partition_statistics(Some(i))` is exact too. + partition_rows: Vec>, + /// Coordinate-range metadata per output partition (parallel to `inner` + /// partitions), used to skip partitions a dynamic filter can't match. + metas: Vec, + /// Dynamic filters accepted from joins/TopK during the post-optimization + /// filter-pushdown phase. Empty until a parent pushes one in. + dynamic_filters: Vec>, +} + +impl XarrayScanExec { + fn new( + inner: Arc, + statistics: Statistics, + partition_rows: Vec>, + metas: Vec, + ) -> Self { + let schema = inner.schema(); + Self { + inner, + schema, + statistics, + partition_rows, + metas, + dynamic_filters: Vec::new(), + } + } + + /// Clone this node carrying an additional set of dynamic filters. + fn with_dynamic_filters(&self, filters: Vec>) -> Self { + Self { + inner: Arc::clone(&self.inner), + schema: Arc::clone(&self.schema), + statistics: self.statistics.clone(), + partition_rows: self.partition_rows.clone(), + metas: self.metas.clone(), + dynamic_filters: filters, + } + } +} + +impl Debug for XarrayScanExec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("XarrayScanExec") + .field("num_partitions", &self.metas.len()) + .field("num_dynamic_filters", &self.dynamic_filters.len()) + .finish() + } +} + +impl DisplayAs for XarrayScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let n = self.dynamic_filters.len(); + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "XarrayScanExec: rows={:?}", self.statistics.num_rows)?; + if n > 0 { + write!(f, ", dynamic_filters={n}")?; + } + Ok(()) + } + DisplayFormatType::TreeRender => { + write!(f, "rows={:?}", self.statistics.num_rows) + } } } } +/// `PruningStatistics` view over a single partition's coordinate bounds, so a +/// dynamic filter snapshot can be evaluated against it to decide skipping. +struct SinglePartitionStats<'a> { + meta: &'a PartitionMetadata, + schema: &'a SchemaRef, +} + +impl SinglePartitionStats<'_> { + /// One-element array of this partition's min (or max) for `column`, typed to + /// the column, or `None` if the column has no usable numeric bound. + fn bound_array(&self, column: &Column, want_min: bool) -> Option { + let field = self.schema.field_with_name(&column.name).ok()?; + let dtype = field.data_type(); + let range = self.meta.ranges.get(&column.name)?; + let bound = if want_min { &range.min } else { &range.max }; + let scalar = bound_to_scalar(bound, dtype)?; + ScalarValue::iter_to_array(std::iter::once(scalar)).ok() + } +} + +impl PruningStatistics for SinglePartitionStats<'_> { + fn min_values(&self, column: &Column) -> Option { + self.bound_array(column, true) + } + fn max_values(&self, column: &Column) -> Option { + self.bound_array(column, false) + } + fn num_containers(&self) -> usize { + 1 + } + fn null_counts(&self, _column: &Column) -> Option { + None + } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( + &self, + _column: &Column, + _values: &std::collections::HashSet, + ) -> Option { + None + } +} + +/// Returns true if every row of `meta`'s partition is provably excluded by at +/// least one of `filters` (evaluated at their current/snapshot value). +/// +/// Conservative: any uncertainty (a predicate `PruningPredicate` can't model, a +/// column without bounds, an error) keeps the partition. Correctness does not +/// depend on this — it only skips work a join/TopK would discard anyway — so a +/// missed prune costs time, never results. +fn partition_pruned( + filters: &[Arc], + meta: &PartitionMetadata, + schema: &SchemaRef, +) -> bool { + let stats = SinglePartitionStats { meta, schema }; + for filter in filters { + // Snapshot resolves any DynamicFilterPhysicalExpr to its current bounds. + let Ok(snapshot) = snapshot_physical_expr(Arc::clone(filter)) else { + continue; + }; + let Ok(predicate) = PruningPredicate::try_new(snapshot, Arc::clone(schema)) else { + continue; + }; + if let Ok(keep) = predicate.prune(&stats) { + // One container in, one bool out: false means "cannot match". + if keep.first() == Some(&false) { + return true; + } + } + } + false +} + +#[async_trait] +impl ExecutionPlan for XarrayScanExec { + fn name(&self) -> &str { + "XarrayScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + // Delegate partitioning + output ordering + boundedness to the inner + // StreamingTableExec. This is also how declared coordinate ordering + // (when present) reaches the optimizer. + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + // A scan is a leaf; the inner plan is an execution detail, not a child + // the optimizer should rewrite. + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + ctx: Arc, + ) -> DFResult { + // Fast path: no dynamic filters, just stream the inner partition. + if self.dynamic_filters.is_empty() { + return self.inner.execute(partition, ctx); + } + + // Otherwise defer the prune decision to first poll. By the time the + // hash join probes (and this stream is polled) the build side has run + // and updated the dynamic filter's bounds, so the snapshot is final. + // Skipping here means the partition's Python factory is never called — + // no remote read for a partition that cannot match. + let inner = Arc::clone(&self.inner); + let schema = Arc::clone(&self.schema); + let filters = self.dynamic_filters.clone(); + let meta = self.metas[partition].clone(); + + let stream = try_stream! { + if partition_pruned(&filters, &meta, &schema) { + return; + } + let mut s = inner.execute(partition, ctx)?; + while let Some(batch) = s.next().await { + yield batch?; + } + }; + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream, + ))) + } + + fn statistics(&self) -> DFResult { + Ok(self.statistics.clone()) + } + + fn partition_statistics(&self, partition: Option) -> DFResult { + match partition { + None => Ok(self.statistics.clone()), + Some(i) => { + let mut s = self.statistics.clone(); + s.num_rows = self + .partition_rows + .get(i) + .cloned() + .unwrap_or(Precision::Absent); + Ok(s) + } + } + } + + /// Accept dynamic filters (join/TopK) pushed in during the post phase. + /// + /// We mark them `Yes` so the producing join activates its bounds + /// accumulator and keeps updating the filter at runtime. This is safe even + /// though we only prune at partition granularity: a hash join still matches + /// every surviving row, so the filter is a pure optimization here. Static + /// (`Pre`-phase) filters are left to the existing logical pushdown. + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> DFResult>> { + if phase != FilterPushdownPhase::Post || child_pushdown_result.parent_filters.is_empty() { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + // Re-wrap each filter through `with_new_children`. For a + // `DynamicFilterPhysicalExpr` this clones the shared `inner` Arc, + // registering this scan as a *consumer* of the filter — which is what + // makes the producing hash join treat it as "used" (`is_used()`) and + // therefore actually compute and publish the build-side bounds at + // runtime. Without this the captured filter stays at its initial + // `true` value and never prunes. + let parent_filters: Vec> = child_pushdown_result + .parent_filters + .iter() + .map(|f| { + let original = Arc::clone(&f.filter); + let children: Vec> = + original.children().into_iter().cloned().collect(); + Arc::clone(&original) + .with_new_children(children) + .unwrap_or(original) + }) + .collect(); + + let supports = vec![PushedDown::Yes; parent_filters.len()]; + let updated = self.with_dynamic_filters(parent_filters); + Ok(FilterPushdownPropagation { + filters: supports, + updated_node: Some(Arc::new(updated)), + }) + } +} + /// A partition stream that wraps a Python factory function that creates streams. /// /// The factory is called lazily on each `execute()` invocation, allowing @@ -921,12 +1386,24 @@ impl LazyArrowStreamTable { let mut partition_list: Vec<(Arc, PartitionMetadata)> = Vec::new(); for item_result in partitions.try_iter()? { let item = item_result?; - let (factory_obj, meta_obj): (Py, Py) = item.extract().map_err(|e| { - pyo3::exceptions::PyTypeError::new_err(format!( - "each partition must be a (factory, metadata_dict) tuple: {e}" - )) - })?; - let meta = convert_python_metadata_from_bound(meta_obj.bind(partitions.py()))?; + // Accept either ``(factory, metadata)`` (legacy) or + // ``(factory, metadata, num_rows)`` (preferred — carries the exact + // partition row count for statistics). Try the 3-tuple first. + let (factory_obj, meta_obj, num_rows): (Py, Py, Option) = + match item.extract::<(Py, Py, usize)>() { + Ok((f, m, n)) => (f, m, Some(n)), + Err(_) => { + let (f, m): (Py, Py) = item.extract().map_err(|e| { + pyo3::exceptions::PyTypeError::new_err(format!( + "each partition must be a (factory, metadata_dict) or \ + (factory, metadata_dict, num_rows) tuple: {e}" + )) + })?; + (f, m, None) + } + }; + let mut meta = convert_python_metadata_from_bound(meta_obj.bind(partitions.py()))?; + meta.num_rows = num_rows; let partition: Arc = Arc::new(PyArrowStreamPartition::new(factory_obj, schema_ref.clone())); partition_list.push((partition, meta)); @@ -969,7 +1446,6 @@ impl LazyArrowStreamTable { /// Get the schema of the table as a PyArrow Schema. fn schema(&self, py: Python<'_>) -> PyResult> { - use arrow::pyarrow::ToPyArrow; self.table .schema() .to_pyarrow(py) @@ -981,9 +1457,405 @@ impl LazyArrowStreamTable { } } +// ============================================================================ +// Native Execution Context +// ============================================================================ + +/// Convert a DataFusion error into a Python exception. +fn df_err_to_py(e: DataFusionError) -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err(format!("DataFusion error: {e}")) +} + +/// A DataFusion `SessionContext` that owns its tables *natively* (in-process, +/// same compiled DataFusion), rather than across the FFI boundary. +/// +/// This matters because `datafusion-ffi` does not forward `Statistics` or +/// dynamic-filter pushdown across the boundary: a `ForeignExecutionPlan` +/// reports unknown statistics regardless of what the provider knows. Consuming +/// `PrunableStreamingTable` here, as a native `Arc`, lets the +/// optimizer see the exact cardinalities `XarrayScanExec` reports, accept +/// join-driven dynamic filters, and (later) custom physical rules. +/// +/// Queries are returned as a lazy [`NativeDataFrame`]: nothing executes until +/// the result is streamed, and it is streamed in batches rather than +/// materialised, so a reduction over a petabyte-scale store never holds the +/// whole input (or output) in memory at once. +#[pyclass(name = "NativeContext")] +struct NativeContext { + ctx: SessionContext, + rt: Arc, +} + +#[pymethods] +impl NativeContext { + #[new] + fn new() -> PyResult { + let rt = Runtime::new().map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "failed to create tokio runtime: {e}" + )) + })?; + Ok(Self { + ctx: SessionContext::new(), + rt: Arc::new(rt), + }) + } + + /// Register a `LazyArrowStreamTable` natively under `name`. + /// + /// Unlike the FFI path (`register_table` with the capsule), this hands the + /// provider directly to the in-process `SessionContext`, so its statistics + /// and physical-plan capabilities are visible to the optimizer. + /// + /// `name` may be schema-qualified (e.g. `"era5.surface"`), so a dataset + /// whose variables span several dimension groups registers as a SQL + /// namespace just like the FFI path. The schema is created on demand. + fn register_table(&self, name: &str, table: &LazyArrowStreamTable) -> PyResult<()> { + let provider: Arc = table.table.clone(); + let reference = TableReference::from(name); + + // For a qualified name, make sure the target schema exists first — + // DataFusion won't auto-create it. + if let Some(schema_name) = reference.schema() { + if let Some(catalog) = self.ctx.catalog("datafusion") { + if catalog.schema(schema_name).is_none() { + catalog + .register_schema(schema_name, Arc::new(MemorySchemaProvider::new())) + .map_err(df_err_to_py)?; + } + } + } + + self.ctx + .register_table(reference, provider) + .map_err(df_err_to_py)?; + Ok(()) + } + + /// Register a Python callable as a native scalar UDF. + /// + /// The callable receives one PyArrow `Array` per argument and returns a + /// PyArrow `Array`. This is how the `cftime()` filter helper — a Python + /// function backed by the `cftime` library — becomes available inside the + /// native engine, which cannot consume a `datafusion-python` UDF across the + /// FFI boundary. `input_types` and `return_type` are PyArrow `DataType`s. + fn register_scalar_udf( + &self, + name: &str, + func: Py, + input_types: Vec>, + return_type: Bound<'_, PyAny>, + ) -> PyResult<()> { + let in_types = input_types + .iter() + .map(DataType::from_pyarrow_bound) + .collect::>>()?; + let ret = DataType::from_pyarrow_bound(&return_type)?; + let udf = PyScalarUdf::new(name.to_string(), in_types, ret, func); + self.ctx.register_udf(ScalarUDF::new_from_impl(udf)); + Ok(()) + } + + /// Plan a SQL query and return a *lazy* `NativeDataFrame`. + /// + /// Planning runs now (so errors surface immediately) but no data is read + /// until the frame is streamed. + fn sql(&self, py: Python<'_>, query: &str) -> PyResult { + let df = py + .detach(|| self.rt.block_on(async { self.ctx.sql(query).await })) + .map_err(df_err_to_py)?; + Ok(NativeDataFrame { + df, + rt: Arc::clone(&self.rt), + }) + } + + /// Return the physical plan for `query` as a string, with statistics shown. + /// + /// Internal: used by tests to confirm that exact cardinalities reach the + /// optimizer (the statistics line is absent on the FFI path). + fn explain(&self, py: Python<'_>, query: &str) -> PyResult { + py.detach(|| { + self.rt.block_on(async { + let df = self.ctx.sql(query).await?; + let plan = df.create_physical_plan().await?; + let rendered = displayable(plan.as_ref()) + .set_show_statistics(true) + .indent(true) + .to_string(); + Ok::<_, DataFusionError>(rendered) + }) + }) + .map_err(df_err_to_py) + } +} + +/// A lazy handle to a planned query on a [`NativeContext`]. +/// +/// Mirrors the slice of the `datafusion-python` `DataFrame` API that the xarray +/// round-trip needs — `schema()`, `execute_stream()`, plus column projection +/// and coordinate filtering for the chunked (lazy) reconstruction path — but +/// every consumer is streaming, so it scales to stores larger than memory. +#[pyclass(name = "NativeDataFrame")] +struct NativeDataFrame { + df: DataFrame, + rt: Arc, +} + +#[pymethods] +impl NativeDataFrame { + /// The PyArrow schema of the result (no execution). + fn schema(&self, py: Python<'_>) -> PyResult> { + self.df + .schema() + .inner() + .as_ref() + .to_pyarrow(py) + .map(|b| b.unbind()) + } + + /// Begin streaming the result. Returns an iterator of PyArrow RecordBatches. + /// + /// The GIL is released while each batch is produced, so DataFusion's worker + /// threads can re-acquire it to pull from the Python-backed partition + /// streams. Batches are yielded one at a time — the full result is never + /// collected. + fn execute_stream(&self, py: Python<'_>) -> PyResult { + let df = self.df.clone(); + let stream = py + .detach(|| self.rt.block_on(async { df.execute_stream().await })) + .map_err(df_err_to_py)?; + Ok(NativeRecordBatchStream { + stream: Some(stream), + rt: Arc::clone(&self.rt), + }) + } + + /// Project to a subset of columns by name (lazy). + fn select_columns(&self, columns: Vec) -> PyResult { + let exprs: Vec = columns.iter().map(|c| col(c.as_str())).collect(); + let df = self.df.clone().select(exprs).map_err(df_err_to_py)?; + Ok(NativeDataFrame { + df, + rt: Arc::clone(&self.rt), + }) + } + + /// The distinct values of `column`, ascending (lazy). + /// + /// Used to discover a dimension's coordinate values for the chunked + /// round-trip; the scan projects to this single column and skips data + /// variables, so discovery reads coordinates only. + fn distinct_sorted(&self, column: String) -> PyResult { + let c = col(column.as_str()); + let df = self + .df + .clone() + .select(vec![c.clone()]) + .map_err(df_err_to_py)? + .distinct() + .map_err(df_err_to_py)? + .sort(vec![c.sort(true, false)]) + .map_err(df_err_to_py)?; + Ok(NativeDataFrame { + df, + rt: Arc::clone(&self.rt), + }) + } + + /// Keep only rows whose `column` is one of `values` (lazy). + /// + /// Used by the chunked round-trip to read a single output chunk: the + /// coordinate predicate pushes into the scan and prunes partitions, so each + /// chunk reads only the partitions it overlaps. `dtype_tag` matches the + /// tags used for partition bounds (`int64`, `float64`, `timestamp_ns`). + fn filter_in( + &self, + column: String, + values: &Bound<'_, PyAny>, + dtype_tag: &str, + ) -> PyResult { + let scalars = python_values_to_scalars(values, dtype_tag)?; + let list: Vec = scalars.into_iter().map(lit).collect(); + if list.is_empty() { + return Ok(NativeDataFrame { + df: self.df.clone(), + rt: Arc::clone(&self.rt), + }); + } + let predicate = col(column.as_str()).in_list(list, false); + let df = self.df.clone().filter(predicate).map_err(df_err_to_py)?; + Ok(NativeDataFrame { + df, + rt: Arc::clone(&self.rt), + }) + } +} + +/// Convert a Python sequence of coordinate values into typed `ScalarValue`s. +fn python_values_to_scalars( + values: &Bound<'_, PyAny>, + dtype_tag: &str, +) -> PyResult> { + let mut out = Vec::new(); + for item in values.try_iter()? { + let item = item?; + let scalar = match dtype_tag { + "timestamp_ns" => ScalarValue::TimestampNanosecond(Some(item.extract::()?), None), + "float64" => ScalarValue::Float64(Some(item.extract::()?)), + "int64" => ScalarValue::Int64(Some(item.extract::()?)), + _ => { + return Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Unsupported dtype tag for filter value: {dtype_tag}" + ))) + } + }; + out.push(scalar); + } + Ok(out) +} + +/// A synchronous Python iterator over a DataFusion record-batch stream. +/// +/// `unsendable` because a `SendableRecordBatchStream` is `Send` but not `Sync` +/// and is only ever advanced from the single Python thread that owns it. +#[pyclass(name = "NativeRecordBatchStream", unsendable)] +struct NativeRecordBatchStream { + stream: Option, + rt: Arc, +} + +#[pymethods] +impl NativeRecordBatchStream { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + /// Pull the next batch (or signal end of iteration). + fn __next__(&mut self, py: Python<'_>) -> PyResult>> { + let Some(stream) = self.stream.as_mut() else { + return Ok(None); + }; + let rt = Arc::clone(&self.rt); + let next = py.detach(|| rt.block_on(async { stream.next().await })); + match next { + Some(Ok(batch)) => Ok(Some(batch.to_pyarrow(py)?.unbind())), + Some(Err(e)) => { + self.stream = None; + Err(df_err_to_py(e)) + } + None => { + self.stream = None; + Ok(None) + } + } + } +} + +/// A native DataFusion scalar UDF backed by a Python callable. +/// +/// Each invocation converts the argument arrays to PyArrow, calls the Python +/// function (acquiring the GIL), and converts the PyArrow result back to Arrow. +/// Used to expose the `cftime()` filter helper to the native engine. +#[derive(Debug)] +struct PyScalarUdf { + name: String, + signature: Signature, + return_type: DataType, + /// `Arc` so the closure moved into `Python::attach` shares the callable + /// without acquiring the GIL just to clone it. + func: Arc>, +} + +impl PyScalarUdf { + fn new( + name: String, + input_types: Vec, + return_type: DataType, + func: Py, + ) -> Self { + Self { + name, + signature: Signature::exact(input_types, Volatility::Immutable), + return_type, + func: Arc::new(func), + } + } +} + +// `ScalarUDFImpl` requires `Eq`/`Hash` (via `DynEq`/`DynHash`), which we can't +// derive through the `Py` callable. Identity is its signature — name, argument +// types, and return type — which is all the optimizer needs to compare UDFs. +impl PartialEq for PyScalarUdf { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + } +} + +impl Eq for PyScalarUdf {} + +impl std::hash::Hash for PyScalarUdf { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.return_type.hash(state); + } +} + +impl ScalarUDFImpl for PyScalarUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(self.return_type.clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let num_rows = args.number_rows; + // Materialise every argument (scalars broadcast to `num_rows`) so the + // Python function always sees full arrays. + let arrays: Vec = args + .args + .into_iter() + .map(|cv| cv.into_array(num_rows)) + .collect::>()?; + + let func = Arc::clone(&self.func); + let result = Python::attach(|py| -> DFResult { + let py_args = arrays + .iter() + .map(|a| a.to_data().to_pyarrow(py)) + .collect::>>() + .and_then(|args| PyTuple::new(py, args)) + .map_err(|e| DataFusionError::Execution(format!("UDF arg conversion: {e}")))?; + let out = func.call1(py, py_args).map_err(|e| { + DataFusionError::Execution(format!("UDF '{}' failed: {e}", self.name)) + })?; + let data = ArrayData::from_pyarrow_bound(out.bind(py)) + .map_err(|e| DataFusionError::Execution(format!("UDF result conversion: {e}")))?; + Ok(make_array(data)) + })?; + + Ok(ColumnarValue::Array(result)) + } +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/tests/test_native.py b/tests/test_native.py new file mode 100644 index 0000000..27e6085 --- /dev/null +++ b/tests/test_native.py @@ -0,0 +1,329 @@ +"""Tests for the native (non-FFI) execution engine. + +The native engine registers tables into an in-process DataFusion +``SessionContext`` compiled into the extension, bypassing the FFI boundary +that drops table statistics. These tests check (a) result parity with the +default FFI engine and (b) that exact cardinalities now reach the optimizer, +which is the whole point of the native path. +""" + +import numpy as np +import pytest +import xarray as xr + +from xarray_sql import XarrayContext + + +@pytest.fixture +def grid_ds(): + """A small, fully in-memory gridded dataset (no network).""" + nt, nlat, nlon = 8, 6, 5 + rng = np.random.default_rng(0) + return xr.Dataset( + {"air": (("time", "lat", "lon"), rng.standard_normal((nt, nlat, nlon)))}, + coords={ + "time": np.arange(nt), + "lat": np.linspace(10, 50, nlat), + "lon": np.linspace(0, 30, nlon), + }, + ) + + +def test_engine_validation(): + with pytest.raises(ValueError, match="engine must be"): + XarrayContext(engine="bogus") + + +def test_native_select_parity(grid_ds): + ffi = XarrayContext() + ffi.from_dataset("air", grid_ds, chunks={"time": 3}) + nat = XarrayContext(engine="native") + nat.from_dataset("air", grid_ds, chunks={"time": 3}) + + q = 'SELECT lat, lon, "air" FROM air' + a = ffi.sql(q).to_pandas().sort_values(["lat", "lon", "air"]).reset_index(drop=True) + b = nat.sql(q).to_pandas().sort_values(["lat", "lon", "air"]).reset_index(drop=True) + assert a.shape == b.shape + np.testing.assert_allclose(a["air"].to_numpy(), b["air"].to_numpy()) + + +def test_native_groupby_parity(grid_ds): + nat = XarrayContext(engine="native") + nat.from_dataset("air", grid_ds, chunks={"time": 3}) + + got = ( + nat.sql("SELECT lat, lon, AVG(air) AS a FROM air GROUP BY lat, lon") + .to_pandas() + .sort_values(["lat", "lon"]) + .reset_index(drop=True) + ) + ref = ( + grid_ds["air"] + .mean("time") + .to_dataframe() + .reset_index() + .rename(columns={"air": "a"}) + .sort_values(["lat", "lon"]) + .reset_index(drop=True) + ) + assert len(got) == grid_ds.sizes["lat"] * grid_ds.sizes["lon"] + np.testing.assert_allclose(got["a"].to_numpy(), ref["a"].to_numpy()) + + +def test_native_to_dataset_roundtrip(grid_ds): + nat = XarrayContext(engine="native") + nat.from_dataset("air", grid_ds, chunks={"time": 3}) + + ds = nat.sql( + "SELECT lat, lon, AVG(air) AS air FROM air GROUP BY lat, lon" + ).to_dataset(dims=["lat", "lon"]) + assert isinstance(ds, xr.Dataset) + assert set(ds.dims) == {"lat", "lon"} + ref = grid_ds["air"].mean("time") + # GROUP BY returns rows in hash order, so sort both by coordinate label + # before comparing raw values. + got = ds["air"].sortby(["lat", "lon"]).transpose("lat", "lon") + ref = ref.sortby(["lat", "lon"]).transpose("lat", "lon") + np.testing.assert_allclose(got.to_numpy(), ref.to_numpy()) + + +def test_native_lazy_chunked_roundtrip(grid_ds): + """to_dataset(chunks=...) is lazy: dask-backed arrays, correct on compute.""" + dask = pytest.importorskip("dask") # noqa: F841 + nat = XarrayContext(engine="native") + nat.from_dataset("air", grid_ds, chunks={"time": 3}) + + out = nat.sql("SELECT time, lat, lon, air FROM air").to_dataset( + dims=["time", "lat", "lon"], chunks={"time": 3} + ) + # The result variable is lazy (a chunked dask array), not a dense ndarray. + assert out["air"].chunks is not None + assert type(out["air"].data).__module__.startswith("dask") + + # Computing a slice reads lazily and matches the reference. + got = out["air"].sel(time=slice(3, 5)).compute().transpose("time", "lat", "lon") + ref = grid_ds["air"].sel(time=slice(3, 5)).transpose("time", "lat", "lon") + np.testing.assert_allclose(got.to_numpy(), ref.to_numpy()) + + # Full materialisation also matches. + full = out["air"].compute().transpose("time", "lat", "lon") + np.testing.assert_allclose( + full.to_numpy(), grid_ds["air"].transpose("time", "lat", "lon").to_numpy() + ) + + +def test_native_sql_returns_lazy_frame(): + """NativeContext.sql plans lazily and streams — it does not collect.""" + from xarray_sql._native import NativeContext + from xarray_sql.reader import read_xarray_table + + ds = xr.Dataset( + {"v": (("x",), np.arange(10, dtype="float64"))}, + coords={"x": np.arange(10)}, + ) + nc = NativeContext() + nc.register_table("t", read_xarray_table(ds, chunks={"x": 5})) + frame = nc.sql("SELECT x, v FROM t") + # A lazy frame exposes schema without executing, and streams batches. + assert [f.name for f in frame.schema()] == ["x", "v"] + total = sum(b.num_rows for b in frame.execute_stream()) + assert total == 10 + + +def test_native_statistics_in_plan(grid_ds): + """Exact row counts must appear at the scan and propagate upward.""" + nat = XarrayContext(engine="native") + nat.from_dataset("air", grid_ds, chunks={"time": 3}) + plan = nat._explain_native("SELECT lat, lon, AVG(air) FROM air GROUP BY lat, lon") + total = grid_ds.sizes["time"] * grid_ds.sizes["lat"] * grid_ds.sizes["lon"] + assert f"rows=Exact({total})" in plan + # The exact count is not dropped to Absent above the scan. + assert "Rows=Absent" not in plan.splitlines()[-1] + + +def test_native_column_minmax_in_plan(grid_ds): + """Numeric dimension columns get exact min/max coordinate bounds.""" + nat = XarrayContext(engine="native") + nat.from_dataset("air", grid_ds, chunks={"time": 3}) + plan = nat._explain_native("SELECT lat, lon, air FROM air") + scan = next(l for l in plan.splitlines() if "XarrayScanExec" in l) + + def fmt(v: float) -> str: + # DataFusion's ScalarValue display drops a trailing ".0". + return str(int(v)) if v == int(v) else str(v) + + lat_min, lat_max = float(grid_ds.lat.min()), float(grid_ds.lat.max()) + lon_min, lon_max = float(grid_ds.lon.min()), float(grid_ds.lon.max()) + assert f"Min=Exact(Float64({fmt(lat_min)}))" in scan + assert f"Max=Exact(Float64({fmt(lat_max)}))" in scan + assert f"Min=Exact(Float64({fmt(lon_min)}))" in scan + assert f"Max=Exact(Float64({fmt(lon_max)}))" in scan + + +def test_native_join_picks_small_build_side(): + """With exact statistics the optimizer broadcasts the smaller table. + + A big (time x lat x lon) table joined to a small (lat x lon) weight table + should plan as a CollectLeft hash join with the small table on the build + side. Without statistics (the FFI path) the optimizer cannot know which + side is smaller and falls back to a Partitioned join. + """ + rng = np.random.default_rng(0) + big = xr.Dataset( + {"t": (("time", "lat", "lon"), rng.standard_normal((200, 8, 8)))}, + coords={"time": np.arange(200), "lat": np.arange(8), "lon": np.arange(8)}, + ) + small = xr.Dataset( + {"w": (("lat", "lon"), rng.standard_normal((8, 8)))}, + coords={"lat": np.arange(8), "lon": np.arange(8)}, + ) + nat = XarrayContext(engine="native") + nat.from_dataset("big", big, chunks={"time": 50}) + nat.from_dataset("small", small, chunks={"lat": 8}) + + plan = nat._explain_native( + "SELECT b.time, SUM(b.t * s.w) AS x FROM big b " + "JOIN small s ON b.lat=s.lat AND b.lon=s.lon GROUP BY b.time" + ) + assert "HashJoinExec: mode=CollectLeft" in plan + + +def _join_datasets(): + """Probe table chunked on time; build table covers a narrow time window.""" + nt = 300 + rng = np.random.default_rng(0) + big = xr.Dataset( + {"t": (("time", "lat", "lon"), rng.standard_normal((nt, 4, 4)))}, + coords={"time": np.arange(nt), "lat": np.arange(4), "lon": np.arange(4)}, + ) + sel = xr.Dataset( + {"w": (("time",), np.ones(10))}, coords={"time": np.arange(10, 20)} + ) + return big, sel + + +_JOIN_SQL = ( + "SELECT b.time, SUM(b.t) x FROM big b JOIN sel s ON b.time=s.time " + "GROUP BY b.time" +) + + +def test_native_join_dynamic_filter_in_plan(): + """The probe-side scan accepts the join's dynamic filter.""" + big, sel = _join_datasets() + nat = XarrayContext(engine="native") + nat.from_dataset("big", big, chunks={"time": 30}) + nat.from_dataset("sel", sel, chunks={"time": 10}) + plan = nat._explain_native(_JOIN_SQL) + # The big (probe) scan carries one dynamic filter from the join. + assert "dynamic_filters=1" in plan + + +def test_native_join_dynamic_filter_correctness(): + """Dynamic-filter pruning must not change results.""" + big, sel = _join_datasets() + nat = XarrayContext(engine="native") + nat.from_dataset("big", big, chunks={"time": 30}) + nat.from_dataset("sel", sel, chunks={"time": 10}) + + got = nat.sql(_JOIN_SQL).to_pandas().sort_values("time").reset_index(drop=True) + ref = ( + big.sel(time=slice(10, 19))["t"] + .sum(["lat", "lon"]) + .to_dataframe() + .reset_index() + .rename(columns={"t": "x"}) + .sort_values("time") + .reset_index(drop=True) + ) + assert len(got) == 10 + np.testing.assert_allclose(got["x"].to_numpy(), ref["x"].to_numpy()) + + +def test_native_join_dynamic_filter_prunes_partitions(): + """At runtime the probe scan skips partitions the join can't match. + + The big table is chunked into ten time-partitions; the build side only + covers times 10..19, which overlaps a single partition. A partition that is + skipped never has its Python factory invoked (no read at all). + + Dynamic-filter pushdown is *opportunistic* -- whether a given partition is + pruned depends on the runtime race between the build side publishing its + bounds and the probe side being polled -- so we don't assert an exact count. + We require that (a) the one overlapping partition is always read and + (b) pruning fires: at least one run reads strictly fewer than all ten. + """ + from xarray_sql.reader import read_xarray_table + from xarray_sql._native import NativeContext + + big, sel = _join_datasets() + + best = 10 + for _ in range(5): + materialized = [] + nc = NativeContext() + nc.register_table( + "big", + read_xarray_table( + big, + chunks={"time": 30}, + _iteration_callback=lambda block, proj: materialized.append( + (block["time"].start, block["time"].stop) + ), + ), + ) + nc.register_table("sel", read_xarray_table(sel, chunks={"time": 10})) + list(nc.sql(_JOIN_SQL).execute_stream()) # force execution + + slices = [(int(a), int(b)) for a, b in materialized] + # The overlapping partition must always be read for a correct result. + assert (0, 30) in slices + best = min(best, len(slices)) + + # Across the attempts, pruning must fire at least once. + assert best < 10, f"dynamic filter never pruned a partition (best={best})" + + +def test_native_multigroup_namespace(weather_dataset): + """A dataset spanning two dim groups registers as a SQL namespace.""" + ds = weather_dataset + two_group = ds.assign( + surface_pressure=( + ("lat", "lon"), + ds["temperature"].isel(time=0, level=0).data, + ) + ) + nat = XarrayContext(engine="native") + nat.from_dataset("wx", two_group, chunks={"time": 3}) + + # Both dim-group tables are queryable under the "wx" schema. + got = nat.sql( + 'SELECT AVG("surface_pressure") AS p FROM wx.lat_lon' + ).to_pandas() + ref = float(two_group["surface_pressure"].mean()) + np.testing.assert_allclose(got["p"].to_numpy()[0], ref, rtol=1e-6) + + n = nat.sql("SELECT COUNT(*) AS n FROM wx.lat_lon_time_level").to_pandas() + assert int(n["n"].to_numpy()[0]) == two_group["temperature"].size + + +def test_native_cftime_udf(rasm_ds): + """The cftime() filter UDF works on the native engine (non-Gregorian cal).""" + pytest.importorskip("cftime") + # rasm uses a noleap (Gregorian-like) calendar; build a 360_day dataset so + # the int64/UDF path is exercised. + import cftime + + times = [cftime.Datetime360Day(2000, m, 1) for m in range(1, 13)] + ds = xr.Dataset( + {"v": (("time",), np.arange(12, dtype="float64"))}, + coords={"time": times}, + ) + nat = XarrayContext(engine="native") + nat.from_dataset("cal", ds, chunks={"time": 6}) + + # cftime('2000-07-01') resolves to the July offset; expect months 7..12. + got = nat.sql( + "SELECT v FROM cal WHERE time >= cftime('2000-07-01') ORDER BY v" + ).to_pandas() + np.testing.assert_allclose(got["v"].to_numpy(), np.arange(6, 12)) diff --git a/xarray_sql/cftime.py b/xarray_sql/cftime.py index 6ce7b4e..d5a8aa1 100644 --- a/xarray_sql/cftime.py +++ b/xarray_sql/cftime.py @@ -214,19 +214,19 @@ def arrow_field(name: str, units: str, cal: str) -> pa.Field: # --------------------------------------------------------------------------- -def make_cftime_udf(units: str, calendar: str): - """Create a DataFusion scalar UDF that converts date strings to int64 offsets. - - This enables ergonomic SQL filtering on non-Gregorian cftime columns:: +def make_cftime_callable(units: str, calendar: str): + """Build the raw ``pa.Array -> pa.Array`` cftime conversion function. - SELECT * FROM ds360 WHERE time > cftime('0500-01-01') + Returned alongside its Arrow signature so it can be wrapped as either a + ``datafusion-python`` UDF (FFI engine) or a native scalar UDF (native + engine). The function parses each input string as a cftime datetime in the + given calendar and returns the int64 offset in the specified units. - The UDF parses the input string as a cftime datetime in the given - calendar system and returns the corresponding int64 offset in the - specified units. + Returns: + ``(func, name, input_types, return_type)`` where ``func`` maps a PyArrow + utf8 ``Array`` to a PyArrow int64 ``Array``. """ import cftime as _cftime - from datafusion import udf def _cftime_scalar(date_strings: pa.Array) -> pa.Array: results: list[int | None] = [] @@ -239,10 +239,17 @@ def _cftime_scalar(date_strings: pa.Array) -> pa.Array: results.append(int(val)) return pa.array(results, type=pa.int64()) - return udf( - _cftime_scalar, - [pa.utf8()], - pa.int64(), - "immutable", - "cftime", - ) + return _cftime_scalar, "cftime", [pa.utf8()], pa.int64() + + +def make_cftime_udf(units: str, calendar: str): + """Create a ``datafusion-python`` scalar UDF converting date strings to offsets. + + This enables ergonomic SQL filtering on non-Gregorian cftime columns:: + + SELECT * FROM ds360 WHERE time > cftime('0500-01-01') + """ + from datafusion import udf + + func, name, input_types, return_type = make_cftime_callable(units, calendar) + return udf(func, input_types, return_type, "immutable", name) diff --git a/xarray_sql/ds.py b/xarray_sql/ds.py index 5dfdf42..047f6e7 100644 --- a/xarray_sql/ds.py +++ b/xarray_sql/ds.py @@ -338,28 +338,38 @@ def _raw_getitem(self, key: tuple) -> np.ndarray: # OR-chain of equalities (DataFusion 52.0.0 does not expose a clean # ``Expr.in_list`` from Python; OR-chained equalities constant-fold # equivalently and stay typed). - predicates = [] - for dim in self._dimension_columns: - if dim in full_dims: - continue - vals = requested[dim] - if len(vals) == 1: - predicates.append(col(f'"{dim}"') == literal(vals[0])) - else: - eq = col(f'"{dim}"') == literal(vals[0]) - for v in vals[1:]: - eq = eq | (col(f'"{dim}"') == literal(v)) - predicates.append(eq) - - filtered = self._inner_df - if predicates: - combined = predicates[0] - for p in predicates[1:]: - combined = combined & p - filtered = filtered.filter(combined) - projected = filtered.select( - *(col(f'"{c}"') for c in self._dimension_columns + [self._var_name]) - ) + wanted = self._dimension_columns + [self._var_name] + if hasattr(self._inner_df, "filter_coord"): + # Native engine: express the per-dim predicate structurally. Each + # coordinate filter pushes into the scan and prunes source + # partitions, so a chunk reads only the partitions it overlaps. + frame = self._inner_df + for dim in self._dimension_columns: + if dim in full_dims: + continue + frame = frame.filter_coord(dim, requested[dim]) + projected = frame.select_columns(wanted) + else: + predicates = [] + for dim in self._dimension_columns: + if dim in full_dims: + continue + vals = requested[dim] + if len(vals) == 1: + predicates.append(col(f'"{dim}"') == literal(vals[0])) + else: + eq = col(f'"{dim}"') == literal(vals[0]) + for v in vals[1:]: + eq = eq | (col(f'"{dim}"') == literal(v)) + predicates.append(eq) + + filtered = self._inner_df + if predicates: + combined = predicates[0] + for p in predicates[1:]: + combined = combined & p + filtered = filtered.filter(combined) + projected = filtered.select(*(col(f'"{c}"') for c in wanted)) # Consume the projected DataFrame as Arrow RecordBatches. The # DataFusion wrapper exposes ``.to_pyarrow()`` to convert each @@ -528,7 +538,11 @@ def _build_lazy_scan( ) if coord_arrays is None: coord_arrays = {} + native = hasattr(inner_df, "distinct_sorted_values") for d in dimension_columns: + if native: + coord_arrays[d] = inner_df.distinct_sorted_values(d) + continue dim_only = ( inner_df.select(col(f'"{d}"')) .distinct() diff --git a/xarray_sql/native.py b/xarray_sql/native.py new file mode 100644 index 0000000..1817e9b --- /dev/null +++ b/xarray_sql/native.py @@ -0,0 +1,104 @@ +"""Lazy, streaming adapter over the native (non-FFI) engine's DataFrame. + +:class:`NativeFrame` wraps the Rust ``NativeDataFrame`` and exposes just the +slice of the ``datafusion-python`` ``DataFrame`` interface the xarray +round-trip consumes — ``schema()``, ``execute_stream()``, ``to_pandas()`` — plus +structured column projection and coordinate filtering for the chunked +reconstruction path. Every consumer streams: nothing is collected up front, so a +reduction (or a chunked scan) over a store larger than memory never holds the +whole input or output at once. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pyarrow as pa + + +def coord_dtype_tag(values: np.ndarray) -> str | None: + """Map a coordinate array's numpy dtype to a native filter dtype tag. + + Returns ``None`` for dtypes the native filter can't represent (e.g. + strings), so the caller can skip pushing a predicate on that dimension. + """ + kind = values.dtype.kind + if kind == "M": # datetime64 + return "timestamp_ns" + if kind == "f": + return "float64" + if kind in ("i", "u"): + return "int64" + return None + + +def _coord_values_for_filter(values: np.ndarray, dtype_tag: str) -> list: + """Convert coordinate values to the Python scalars the native filter wants.""" + if dtype_tag == "timestamp_ns": + as_ns = values.astype("datetime64[ns]").astype(np.int64) + return [int(v) for v in as_ns] + if dtype_tag == "float64": + return [float(v) for v in values] + return [int(v) for v in values] + + +class _Batch: + """Adapt a PyArrow RecordBatch to datafusion-python's stream-item API.""" + + __slots__ = ("_batch",) + + def __init__(self, batch: pa.RecordBatch) -> None: + self._batch = batch + + def to_pyarrow(self) -> pa.RecordBatch: + return self._batch + + +class NativeFrame: + """Streaming, lazy stand-in for a ``datafusion-python`` DataFrame.""" + + def __init__(self, native_df: Any) -> None: + self._df = native_df + + # -- interface the round-trip consumes ----------------------------------- + + def schema(self) -> pa.Schema: + return self._df.schema() + + def execute_stream(self): + """Yield result batches lazily, wrapped so ``.to_pyarrow()`` works.""" + return (_Batch(b) for b in self._df.execute_stream()) + + def to_pandas(self): + batches = list(self._df.execute_stream()) + if batches: + return pa.Table.from_batches(batches).to_pandas() + return pa.Table.from_batches([], schema=self._df.schema()).to_pandas() + + # -- chunked round-trip helpers ------------------------------------------ + + def select_columns(self, columns: list[str]) -> "NativeFrame": + return NativeFrame(self._df.select_columns(list(columns))) + + def filter_coord(self, column: str, values: np.ndarray) -> "NativeFrame": + """Keep rows whose ``column`` is one of ``values`` (pushes into the scan). + + Coordinate ranges pushed here prune source partitions, so a single + output chunk reads only the partitions it overlaps. + """ + tag = coord_dtype_tag(np.asarray(values)) + if tag is None: + return self + native_values = _coord_values_for_filter(np.asarray(values), tag) + return NativeFrame(self._df.filter_in(column, native_values, tag)) + + def distinct_sorted_values(self, column: str) -> np.ndarray: + """Ascending distinct values of ``column`` (coordinate discovery).""" + frame = self._df.distinct_sorted(column) + batches = list(frame.execute_stream()) + if not batches: + return np.asarray([]) + return np.concatenate( + [b.column(0).to_numpy(zero_copy_only=False) for b in batches] + ) diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index f8c5975..9b394f6 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -21,6 +21,7 @@ Block, Chunks, DEFAULT_BATCH_SIZE, + _block_len, _block_metadata, _block_slices_from_resolved, _parse_schema, @@ -327,6 +328,10 @@ def partition_pairs(): yield ( make_partition_factory(block), {**static_ranges, **dynamic}, + # Exact row count for this partition (product of the chunk's + # per-dimension sizes), so the native engine can report exact + # Statistics::num_rows to the optimizer. + _block_len(block), ) return LazyArrowStreamTable(partition_pairs(), schema) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..474d9d9 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -6,13 +6,34 @@ from . import cftime as cft from .df import Chunks from .ds import XarrayDataFrame +from .native import NativeFrame from .reader import read_xarray_table class XarrayContext(SessionContext): - """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" + """A datafusion `SessionContext` that also supports `xarray.Dataset`s. - def __init__(self, *args, **kwargs): + Two query engines are available, selected with ``engine``: + + * ``"ffi"`` (default): tables are registered with DataFusion through the + Arrow/FFI table-provider protocol. This is the mature path and supports + everything DataFusion's Python API does (scalar UDFs such as ``cftime``, + multi-dimension-group namespaces, lazy/chunked round-trips). + + * ``"native"``: tables are registered into an in-process DataFusion + ``SessionContext`` compiled into this extension, *bypassing the FFI + boundary*. The boundary drops table statistics and dynamic-filter + pushdown, so only the native engine can give the cost-based optimizer + exact cardinalities (for join build-side selection and broadcast-vs- + shuffle). Results stream lazily and round-trip through the same + ``to_pandas`` / ``to_dataset`` (including the chunked, partition-pruning + path) as the default engine, so a reduction or chunked scan over a store + larger than memory never materialises the whole result. Multi-dimension- + group datasets (SQL namespaces) and the ``cftime()`` filter UDF are + supported here too. + """ + + def __init__(self, *args, engine: str = "ffi", **kwargs): super().__init__(*args, **kwargs) # Track registered xarray Datasets so XarrayDataFrame can recover # defaults (dimension_columns) and metadata (var/dataset attrs, @@ -22,6 +43,17 @@ def __init__(self, *args, **kwargs): # ``"era5.surface"`` for one entry from a multi-dim-group split). self._registered_datasets: dict[str, xr.Dataset] = {} + if engine not in ("ffi", "native"): + raise ValueError( + f"engine must be 'ffi' or 'native', got {engine!r}" + ) + self._engine = engine + self._native = None + if engine == "native": + from ._native import NativeContext + + self._native = NativeContext() + def from_dataset( self, name: str, @@ -103,6 +135,21 @@ def from_dataset( ) table_names = table_names or {} + + if self._native is not None: + # Native engine: register each dim-group under a schema-qualified + # name (``name.sub_name``). The Rust side creates the SQL schema on + # demand, so ``SELECT ... FROM era5.surface`` works just like FFI. + for dims, var_names in groups.items(): + sub_name = table_names.get(dims, "_".join(dims) or "scalar") + sub_ds = input_table[var_names] + qualified = f"{name}.{sub_name}" + self._from_dataset( + qualified, sub_ds, chunks, coord_arrays=coord_arrays + ) + self._registered_datasets[qualified] = sub_ds + return self + schema = Schema.memory_schema(self) self.catalog().register_schema(name, schema) @@ -137,23 +184,44 @@ def _from_dataset( Registers a top-level table by default, or a table inside ``schema`` (a SQL namespace) when one is given. """ - register = ( - self.register_table if schema is None else schema.register_table - ) - register( - table_name, - read_xarray_table(input_table, chunks, coord_arrays=coord_arrays), + table = read_xarray_table( + input_table, chunks, coord_arrays=coord_arrays ) + if self._native is not None: + # Register the provider directly into the in-process context so + # its statistics and physical-plan capabilities reach the + # optimizer (the FFI path drops them). + self._native.register_table(table_name, table) + else: + register = ( + self.register_table + if schema is None + else schema.register_table + ) + register(table_name, table) self._maybe_register_cftime_udf(input_table) return self def _maybe_register_cftime_udf(self, ds: xr.Dataset) -> None: - """Auto-register a cftime() UDF for non-Gregorian cftime coordinates.""" + """Auto-register a cftime() UDF for non-Gregorian cftime coordinates. + + On the native engine the UDF is a Python callable registered directly + into the in-process context (a ``datafusion-python`` UDF cannot cross + the FFI boundary); on the FFI engine it is a ``datafusion-python`` UDF. + """ for coord_name in ds.dims: if cft.is_cftime_index(ds, coord_name): units, cal = cft.encoding(ds, coord_name) if not cft.is_gregorian_like(cal): - self.register_udf(cft.make_cftime_udf(units, cal)) + if self._native is not None: + func, name, in_types, ret = cft.make_cftime_callable( + units, cal + ) + self._native.register_scalar_udf( + name, func, in_types, ret + ) + else: + self.register_udf(cft.make_cftime_udf(units, cal)) break # One UDF per context is enough. def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: @@ -174,9 +242,39 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: Returns: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ + if self._native is not None: + return self._native_sql(query) inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) + def _native_sql(self, query: str) -> XarrayDataFrame: + """Plan *query* on the native engine and wrap the lazy result. + + The native context returns a lazy, streaming ``NativeDataFrame`` (no + data is read until the result is streamed). :class:`NativeFrame` adapts + it to the small DataFrame interface the :class:`XarrayDataFrame` + round-trip consumes, so ``to_pandas`` and ``to_dataset`` — including the + chunked, partition-pruning path — work without materialising the result. + """ + native_df = self._native.sql(query) + return XarrayDataFrame( + NativeFrame(native_df), templates=self._registered_datasets + ) + + def _explain_native(self, query: str) -> str: + """Return the native engine's physical plan (with statistics) for *query*. + + Internal helper (not public API): used by tests to confirm that exact + cardinalities reach the optimizer — e.g. that a join is planned as + ``HashJoinExec: mode=CollectLeft`` with the smaller table on the build + side, or that a probe scan carries a join dynamic filter. + """ + if self._native is None: + raise RuntimeError( + "explain_native is only available when engine='native'" + ) + return self._native.explain(query) + def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: """Group variables in the dataset based on shared dims.