diff --git a/Cargo.lock b/Cargo.lock index 8b58e75..b6cbcf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2656,9 +2656,9 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80ece43fc6fbed4eb5392ab50c07334d3e577cbf40997ee896fe7af40bba4245" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", "serde_derive", @@ -2666,18 +2666,18 @@ dependencies = [ [[package]] name = "serde_core" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a576275b607a2c86ea29e410193df32bc680303c82f31e275bbfcafe8b33be5" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.227" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51e694923b8824cf0e9b382adf0f60d4e05f348f357b38833a3fa5ed7c2ede04" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2851,18 +2851,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.16" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", @@ -3375,7 +3375,7 @@ checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "xarray_sql" -version = "0.2.3" +version = "0.3.0" dependencies = [ "arrow", "async-stream", @@ -3385,6 +3385,7 @@ dependencies = [ "futures", "pyo3", "pyo3-build-config", + "sqlparser", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 1dc95bd..21f5b43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ async-stream = "0.3" async-trait = "0.1" datafusion = { version = "52.0.0" } datafusion-ffi = { version = "52.0.0" } +sqlparser = { version = "0.59", features = ["visitor"] } futures = { version = "0.3" } pyo3 = { version = "0.26.0", features = ["extension-module"] } tokio = { version = "1.46.1", features = ["rt"] } diff --git a/src/autograd.rs b/src/autograd.rs new file mode 100644 index 0000000..c729ca3 --- /dev/null +++ b/src/autograd.rs @@ -0,0 +1,838 @@ +//! Symbolic differentiation of DataFusion logical [`Expr`] trees. +//! +//! This is the autograd engine for xarray-sql. Given an [`Expr`] and the name +//! of a column to differentiate with respect to, [`differentiate`] returns a +//! new [`Expr`] for the (symbolic) partial derivative, built entirely from +//! ordinary DataFusion expressions so the result can be planned and evaluated +//! by DataFusion like any other SQL expression. +//! +//! ## Design +//! +//! The approach mirrors JAX's per-primitive rule registry (`defjvp` and +//! friends in `jax/_src/interpreters/ad.py`): every expression node has a +//! differentiation rule, and the chain rule composes them as the tree is +//! walked. Because each row of a relational table is an independent evaluation +//! point, differentiating a column expression and letting DataFusion evaluate +//! it row-by-row is the moral equivalent of `jax.vmap(jax.grad(f))` — the rows +//! *are* the batch dimension. +//! +//! A small simplifier folds the `0`/`1` constants that differentiation +//! produces in abundance (e.g. `d/dx (c) = 0`, `d/dx (x) = 1`), keeping output +//! expressions compact. This plays the role of JAX's `Zero` tangents and +//! `add_tangents`: a `0` derivative short-circuits products and drops out of +//! sums, and a `1` factor drops out of products. +//! +//! ## Surface +//! +//! Three scalar operations, all rewritten away before execution: +//! +//! * `grad(expr, column)` — the partial derivative `d(expr)/d(column)`. +//! * `jvp(expr, column, tangent)` — forward-mode directional derivative, +//! `d(expr)/d(column) * tangent` (seed a tangent on an input). +//! * `vjp(expr, column, cotangent)` — reverse-mode pullback, +//! `cotangent * d(expr)/d(column)` (seed a cotangent on the output). +//! +//! All three return a scalar per row, staying in the long/tidy data model. A +//! full gradient or Jacobian is expressed as several scalar columns (e.g. +//! `grad(f, x) AS dfdx, grad(f, y) AS dfdy`) rather than a nested array, which +//! would break the one-value-per-coordinate model. +//! +//! Calls nest, giving higher-order derivatives for free: the rewrite walks +//! bottom-up, so the inner call in `grad(grad(f, x), x)` is differentiated +//! first and the outer call differentiates that result. +//! +//! Differentiation through an aggregate is just linearity and needs no special +//! handling: write the `grad` *inside* the aggregate, e.g. `SUM(grad(f, x))` or +//! `AVG(grad(loss, theta))`. Because the marker is rewritten to plain SQL +//! before the aggregate runs (and the column is in scope there), this is the +//! relational `d/dθ Σ f = Σ ∂f/∂θ` — enough to run gradient descent in SQL. +//! (The transposed form `grad(SUM(f), x)` is rejected by SQL's own scoping, +//! since `x` is gone after aggregation.) + +#![allow(dead_code)] + +use std::any::Any; +use std::collections::HashMap; +use std::f64::consts::{LN_10, LN_2}; +use std::ops::ControlFlow; +use std::sync::Arc; + +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::{DFSchema, DataFusionError, Result, ScalarValue, TableReference}; +use datafusion::functions::math::expr_fn; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{ + lit, BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; +use sqlparser::ast::{Expr as SqlExpr, Visit, VisitMut, Visitor, VisitorMut}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; + +// --------------------------------------------------------------------------- +// Constant helpers and the 0/1-folding builders +// --------------------------------------------------------------------------- + +/// The constant `0.0`, used as the derivative of anything not depending on the +/// differentiation variable. +fn zero() -> Expr { + lit(0.0_f64) +} + +/// The constant `1.0`, used as the derivative of the differentiation variable. +fn one() -> Expr { + lit(1.0_f64) +} + +/// Interpret a [`ScalarValue`] as `f64` if it is a (non-null) numeric scalar. +fn scalar_as_f64(sv: &ScalarValue) -> Option { + match sv { + ScalarValue::Float64(Some(v)) => Some(*v), + ScalarValue::Float32(Some(v)) => Some(*v as f64), + ScalarValue::Int64(Some(v)) => Some(*v as f64), + ScalarValue::Int32(Some(v)) => Some(*v as f64), + ScalarValue::Int16(Some(v)) => Some(*v as f64), + ScalarValue::Int8(Some(v)) => Some(*v as f64), + ScalarValue::UInt64(Some(v)) => Some(*v as f64), + ScalarValue::UInt32(Some(v)) => Some(*v as f64), + ScalarValue::UInt16(Some(v)) => Some(*v as f64), + ScalarValue::UInt8(Some(v)) => Some(*v as f64), + _ => None, + } +} + +/// Return the constant `f64` value of a literal expression, if it is one. +fn as_const(e: &Expr) -> Option { + match e { + Expr::Literal(sv, _) => scalar_as_f64(sv), + _ => None, + } +} + +/// True if the expression is a numeric literal exactly equal to zero. +fn is_zero(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 0.0) +} + +/// True if the expression is a numeric literal exactly equal to one. +fn is_one(e: &Expr) -> bool { + matches!(as_const(e), Some(v) if v == 1.0) +} + +fn binary(left: Expr, op: Operator, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) +} + +/// `a + b`, dropping a zero operand. +fn add(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + b + } else if is_zero(&b) { + a + } else { + binary(a, Operator::Plus, b) + } +} + +/// `a - b`, dropping a zero right operand and turning `0 - b` into `-b`. +fn sub(a: Expr, b: Expr) -> Expr { + if is_zero(&b) { + a + } else if is_zero(&a) { + neg(b) + } else { + binary(a, Operator::Minus, b) + } +} + +/// `a * b`, folding `0 * _ = 0` and `1 * b = b` (and the mirror cases). +fn mul(a: Expr, b: Expr) -> Expr { + if is_zero(&a) || is_zero(&b) { + zero() + } else if is_one(&a) { + b + } else if is_one(&b) { + a + } else { + binary(a, Operator::Multiply, b) + } +} + +/// `a / b`, folding `0 / _ = 0` and `a / 1 = a`. +fn div(a: Expr, b: Expr) -> Expr { + if is_zero(&a) { + zero() + } else if is_one(&b) { + a + } else { + binary(a, Operator::Divide, b) + } +} + +/// `-a`, folding `-0 = 0`. +fn neg(a: Expr) -> Expr { + if is_zero(&a) { + zero() + } else { + Expr::Negative(Box::new(a)) + } +} + +/// `e * e`. +fn square(e: Expr) -> Expr { + mul(e.clone(), e) +} + +// --------------------------------------------------------------------------- +// The differentiation engine (forward-mode linearization) +// --------------------------------------------------------------------------- + +/// A *leaf rule*: the tangent of a column, i.e. the seed assigned to each input +/// during forward-mode differentiation. +/// +/// `grad` uses a one-hot leaf (`1` for the differentiation variable, `0` +/// otherwise); `jvp` uses an arbitrary seed per input. Everything above the +/// leaves — the chain rule — is shared. +type Leaf<'a> = dyn Fn(&str) -> Expr + 'a; + +/// Linearize `expr`: push tangents from the leaves (per `leaf`) up through the +/// expression via the chain rule, returning the tangent of `expr`. +/// +/// This is forward-mode automatic differentiation. `differentiate` (a single +/// partial derivative) and `jvp` (a directional derivative) are both thin +/// wrappers that only differ in their leaf rule. Returns a +/// [`DataFusionError::NotImplemented`] for nodes or functions without a rule, +/// so callers surface a clear error rather than a silently-wrong derivative. +fn linearize(expr: &Expr, leaf: &Leaf) -> Result { + match expr { + // The leaf rule decides a column's tangent. + Expr::Column(c) => Ok(leaf(&c.name)), + + // Constants have zero tangent. + Expr::Literal(_, _) => Ok(zero()), + + // An alias is transparent; the surrounding query re-applies any naming. + Expr::Alias(a) => linearize(&a.expr, leaf), + + // A numeric cast is (locally) linear: tangent of cast(u) = cast(du). + Expr::Cast(c) => { + let du = linearize(&c.expr, leaf)?; + Ok(Expr::Cast(Cast::new(Box::new(du), c.data_type.clone()))) + } + + // tangent of -u = -(du). + Expr::Negative(inner) => Ok(neg(linearize(inner, leaf)?)), + + Expr::BinaryExpr(be) => linearize_binary(be, leaf), + + Expr::ScalarFunction(sf) => linearize_scalar_function(sf, leaf), + + other => Err(DataFusionError::NotImplemented(format!( + "grad: differentiation is not implemented for this expression: {other}" + ))), + } +} + +/// Differentiate `expr` with respect to the column named `wrt`. +/// +/// Forward-mode with a one-hot seed: `1` on `wrt`, `0` on every other column. +pub fn differentiate(expr: &Expr, wrt: &str) -> Result { + linearize(expr, &|name| if name == wrt { one() } else { zero() }) +} + +/// Forward-mode directional derivative: the tangent of `expr` given a tangent +/// (`seeds[col]`) for each seeded input column; unseeded columns are constant. +fn jvp(expr: &Expr, seeds: &HashMap) -> Result { + linearize(expr, &|name| seeds.get(name).cloned().unwrap_or_else(zero)) +} + +/// Linearize a binary arithmetic expression via the sum/product/quotient rules. +fn linearize_binary(be: &BinaryExpr, leaf: &Leaf) -> Result { + let a = be.left.as_ref(); + let b = be.right.as_ref(); + let da = linearize(a, leaf)?; + let db = linearize(b, leaf)?; + + match be.op { + // tangent of (a + b) = da + db + Operator::Plus => Ok(add(da, db)), + // tangent of (a - b) = da - db + Operator::Minus => Ok(sub(da, db)), + // tangent of (a * b) = da*b + a*db (product rule) + Operator::Multiply => Ok(add(mul(da, b.clone()), mul(a.clone(), db))), + // tangent of (a / b) = (da*b - a*db) / b^2 (quotient rule) + Operator::Divide => { + let numerator = sub(mul(da, b.clone()), mul(a.clone(), db)); + Ok(div(numerator, square(b.clone()))) + } + op => Err(DataFusionError::NotImplemented(format!( + "grad: operator '{op}' is not differentiable" + ))), + } +} + +/// Linearize a scalar-function call via the chain rule. +/// +/// For a unary primitive `f(u)`, the tangent is `f'(u) * du`. For `power`, +/// which is binary, we handle the constant-exponent and constant-base cases. +fn linearize_scalar_function(sf: &ScalarFunction, leaf: &Leaf) -> Result { + let name = sf.func.name(); + let args = &sf.args; + + // `power(base, exponent)` is the one binary primitive we linearize. + if name == "power" { + return linearize_power(args, leaf); + } + + if args.len() != 1 { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}' with {} arguments", + args.len() + ))); + } + + let u = &args[0]; + let du = linearize(u, leaf)?; + // Chain rule short-circuit: if du is 0, the whole tangent is 0 and we avoid + // emitting the (dead) outer derivative term entirely. + if is_zero(&du) { + return Ok(zero()); + } + + let outer = match name { + // Trigonometric. + "sin" => expr_fn::cos(u.clone()), + "cos" => neg(expr_fn::sin(u.clone())), + "tan" => div(one(), square(expr_fn::cos(u.clone()))), + // Inverse trigonometric. + "asin" => div(one(), expr_fn::sqrt(sub(one(), square(u.clone())))), + "acos" => neg(div(one(), expr_fn::sqrt(sub(one(), square(u.clone()))))), + "atan" => div(one(), add(one(), square(u.clone()))), + // Exponential / logarithmic. + "exp" => expr_fn::exp(u.clone()), + "ln" => div(one(), u.clone()), + "log2" => div(one(), mul(u.clone(), lit(LN_2))), + "log10" => div(one(), mul(u.clone(), lit(LN_10))), + "sqrt" => div(one(), mul(lit(2.0_f64), expr_fn::sqrt(u.clone()))), + // Hyperbolic. + "sinh" => expr_fn::cosh(u.clone()), + "cosh" => expr_fn::sinh(u.clone()), + "tanh" => sub(one(), square(expr_fn::tanh(u.clone()))), + // Piecewise-linear: derivative is the sign (undefined at 0, like JAX). + "abs" => expr_fn::signum(u.clone()), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "grad: no derivative rule for function '{name}'" + ))) + } + }; + + Ok(mul(outer, du)) +} + +/// Linearize `power(base, exponent)`. +/// +/// * Constant exponent `c`: tangent = `c * base^(c-1) * d(base)`. +/// * Constant base `a`: tangent = `a^u * ln(a) * d(u)`. +/// * Both variable (`u^v`): not supported yet. +fn linearize_power(args: &[Expr], leaf: &Leaf) -> Result { + if args.len() != 2 { + return Err(DataFusionError::NotImplemented( + "grad: power() expects exactly two arguments".to_string(), + )); + } + let base = &args[0]; + let exponent = &args[1]; + + match (as_const(base), as_const(exponent)) { + // Constant exponent (covers the common x^2, x^0.5, ... cases). + (_, Some(c)) => { + let dbase = linearize(base, leaf)?; + if is_zero(&dbase) { + return Ok(zero()); + } + let outer = mul(lit(c), expr_fn::power(base.clone(), lit(c - 1.0))); + Ok(mul(outer, dbase)) + } + // Constant base, variable exponent. + (Some(a), None) => { + let dexp = linearize(exponent, leaf)?; + if is_zero(&dexp) { + return Ok(zero()); + } + let outer = mul(expr_fn::power(base.clone(), exponent.clone()), lit(a.ln())); + Ok(mul(outer, dexp)) + } + // General u^v requires the exp/log trick; deferred for now. + (None, None) => Err(DataFusionError::NotImplemented( + "grad: power(base, exponent) where both depend on the \ + differentiation variable is not yet supported" + .to_string(), + )), + } +} + +// --------------------------------------------------------------------------- +// The `grad` / `jacobian` marker UDFs and the plan-level rewrite +// --------------------------------------------------------------------------- + +/// A no-op placeholder UDF for the autograd surface functions. +/// +/// `grad`, `jvp`, and `vjp` are *markers*: they carry the differentiation +/// request intact through SQL parsing, logical planning, and Substrait +/// serialization. They are always rewritten away by [`rewrite_grad_calls`] +/// before execution, so `invoke` is never reached in normal use (and +/// deliberately errors if it somehow is, rather than returning a wrong value). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MarkerUdf { + name: String, + signature: Signature, +} + +impl MarkerUdf { + fn new(name: &str, arity: usize) -> Self { + Self { + name: name.to_string(), + signature: Signature::any(arity, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MarkerUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // Every autograd marker rewrites to a scalar derivative expression. + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Err(DataFusionError::Execution(format!( + "{}() marker reached execution without being rewritten; this is \ + an internal xarray-sql autograd error", + self.name + ))) + } +} + +/// The `grad(expr, column)` marker: scalar partial derivative `d(expr)/dcolumn`. +pub fn grad_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("grad", 2)) +} + +/// The `jvp(expr, column, tangent)` marker: forward-mode directional derivative. +pub fn jvp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("jvp", 3)) +} + +/// The `vjp(expr, column, cotangent)` marker: reverse-mode pullback to an input. +pub fn vjp_marker() -> ScalarUDF { + ScalarUDF::from(MarkerUdf::new("vjp", 3)) +} + +/// Rewrite every `grad`/`jvp`/`vjp` call anywhere in a logical plan into its +/// symbolic derivative, leaving everything else untouched. The plan's schema is +/// recomputed afterwards because replacing a marker can change an expression's +/// name or type. +pub fn rewrite_grad_calls(plan: LogicalPlan) -> Result { + let rewritten = plan + .transform_up(|node| node.map_expressions(rewrite_grad_in_expr))? + .data; + rewritten.recompute_schema() +} + +/// Replace any `grad`/`jvp`/`vjp` calls nested anywhere inside a single +/// expression. +fn rewrite_grad_in_expr(expr: Expr) -> Result> { + expr.transform_up(|e| { + let Expr::ScalarFunction(sf) = &e else { + return Ok(Transformed::no(e)); + }; + match sf.func.name() { + "grad" => Ok(Transformed::yes(rewrite_grad(&sf.args)?)), + "jvp" => Ok(Transformed::yes(rewrite_jvp(&sf.args)?)), + "vjp" => Ok(Transformed::yes(rewrite_vjp(&sf.args)?)), + _ => Ok(Transformed::no(e)), + } + }) +} + +/// Read a bare column name from a marker argument, or report a clear error. +fn column_arg(func: &str, arg: &Expr) -> Result { + match arg { + Expr::Column(c) => Ok(c.name.clone()), + other => Err(DataFusionError::Plan(format!( + "{func}(): the column argument must be a bare column to \ + differentiate with respect to, got: {other}" + ))), + } +} + +/// `grad(expr, column)` -> `d(expr)/d(column)`. +fn rewrite_grad(args: &[Expr]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "grad() expects two arguments grad(expr, column), got {}", + args.len() + ))); + } + let wrt = column_arg("grad", &args[1])?; + differentiate(&args[0], &wrt) +} + +/// `jvp(expr, column, tangent)` -> forward-mode tangent: seed `tangent` on +/// `column` and push it through `expr`, yielding `d(expr)/d(column) * tangent`. +/// +/// A directional derivative over several inputs is the sum of per-input jvps, +/// e.g. `jvp(f, x, dx) + jvp(f, y, dy)`, since each treats the other inputs as +/// having zero tangent. +fn rewrite_jvp(args: &[Expr]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Plan(format!( + "jvp() expects three arguments jvp(expr, column, tangent), got {}", + args.len() + ))); + } + let wrt = column_arg("jvp", &args[1])?; + let seeds = HashMap::from([(wrt, args[2].clone())]); + jvp(&args[0], &seeds) +} + +/// `vjp(expr, column, cotangent)` -> reverse-mode pullback: the sensitivity that +/// an output cotangent induces on `column`, i.e. `cotangent * d(expr)/d(column)`. +/// +/// For a single scalar output this equals the matching `jvp` (both contract the +/// same partial derivative); the surfaces differ in where the seed lives — `jvp` +/// seeds an input tangent, `vjp` seeds an output cotangent. +fn rewrite_vjp(args: &[Expr]) -> Result { + if args.len() != 3 { + return Err(DataFusionError::Plan(format!( + "vjp() expects three arguments vjp(expr, column, cotangent), got {}", + args.len() + ))); + } + let wrt = column_arg("vjp", &args[1])?; + let derivative = differentiate(&args[0], &wrt)?; + Ok(mul(args[2].clone(), derivative)) +} + +// --------------------------------------------------------------------------- +// SQL source-to-source rewrite +// --------------------------------------------------------------------------- + +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL statement into its symbolic +/// derivative, returning the rewritten SQL text. +/// +/// Unlike a logical-plan rewrite, this is a pure source-to-source transform run +/// *before* the query is planned, so it works for any query shape the SQL parser +/// accepts — recursive CTEs, DML, and subqueries included. Each marker call is +/// parsed into a DataFusion [`Expr`], differentiated by the engine in this +/// module, and rendered back to SQL in place. Columns are taken from the call's +/// own identifiers (all treated as `Float64`; types don't affect the symbolic +/// result), so no catalog or table schema is needed. +pub fn rewrite_grad_in_sql(sql: &str) -> Result { + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| DataFusionError::Plan(format!("grad: failed to parse SQL: {e}")))?; + + // A throwaway context that only needs the marker UDFs registered so the + // calls parse into `ScalarFunction` nodes the engine can dispatch on. + let ctx = SessionContext::new(); + ctx.register_udf(grad_marker()); + ctx.register_udf(jvp_marker()); + ctx.register_udf(vjp_marker()); + + let mut rewriter = GradSqlRewriter { ctx: &ctx }; + for stmt in &mut statements { + if let ControlFlow::Break(msg) = stmt.visit(&mut rewriter) { + return Err(DataFusionError::Plan(msg)); + } + } + + Ok(statements + .iter() + .map(ToString::to_string) + .collect::>() + .join("; ")) +} + +/// True if `name` is one of the autograd marker functions (case-insensitive). +fn is_marker_name(name: &str) -> bool { + matches!(name.to_lowercase().as_str(), "grad" | "jvp" | "vjp") +} + +/// Walks a SQL AST and replaces each `grad`/`jvp`/`vjp` call with its derivative. +struct GradSqlRewriter<'a> { + ctx: &'a SessionContext, +} + +impl VisitorMut for GradSqlRewriter<'_> { + type Break = String; + + fn pre_visit_expr(&mut self, expr: &mut SqlExpr) -> ControlFlow { + let is_marker = matches!( + expr, + SqlExpr::Function(f) if is_marker_name(&f.name.to_string()) + ); + if !is_marker { + return ControlFlow::Continue(()); + } + match self.rewrite_call(expr) { + Ok(()) => ControlFlow::Continue(()), + Err(e) => ControlFlow::Break(e), + } + } +} + +impl GradSqlRewriter<'_> { + /// Differentiate a single marker call in place. The replacement is wrapped + /// in parentheses so it keeps the call's precedence in the surrounding SQL. + fn rewrite_call(&self, expr: &mut SqlExpr) -> std::result::Result<(), String> { + let schema = call_schema(expr)?; + let text = expr.to_string(); + let parsed = self + .ctx + .parse_sql_expr(&text, &schema) + .map_err(|e| format!("grad: failed to parse '{text}': {e}"))?; + let derivative = rewrite_grad_in_expr(parsed) + .map_err(|e| format!("grad: failed to differentiate '{text}': {e}"))? + .data; + let rendered = expr_to_sql(&derivative) + .map_err(|e| format!("grad: failed to render derivative for '{text}': {e}"))?; + *expr = SqlExpr::Nested(Box::new(rendered)); + Ok(()) + } +} + +/// Build a `Float64` schema covering every column identifier referenced inside a +/// marker call, so the call's argument expression can be parsed standalone. +fn call_schema(call: &SqlExpr) -> std::result::Result { + let mut collector = ColumnCollector::default(); + let _ = call.visit(&mut collector); + let fields = collector + .cols + .into_iter() + .map(|(qualifier, name)| { + let qualifier = qualifier.map(TableReference::bare); + ( + qualifier, + Arc::new(Field::new(name, DataType::Float64, true)), + ) + }) + .collect(); + DFSchema::new_with_metadata(fields, HashMap::new()) + .map_err(|e| format!("grad: failed to build schema for differentiation: {e}")) +} + +/// Collects the (optional qualifier, name) of every column identifier in a SQL +/// expression tree. +#[derive(Default)] +struct ColumnCollector { + cols: Vec<(Option, String)>, +} + +impl Visitor for ColumnCollector { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &SqlExpr) -> ControlFlow<()> { + let pair = match expr { + SqlExpr::Identifier(ident) => Some((None, ident.value.clone())), + SqlExpr::CompoundIdentifier(parts) => parts.last().map(|last| { + let qualifier = (parts.len() >= 2).then(|| parts[parts.len() - 2].value.clone()); + (qualifier, last.value.clone()) + }), + _ => None, + }; + if let Some(pair) = pair { + if !self.cols.contains(&pair) { + self.cols.push(pair); + } + } + ControlFlow::Continue(()) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use datafusion::logical_expr::col; + + use super::*; + + #[test] + fn constant_has_zero_derivative() { + assert_eq!(differentiate(&lit(3.0_f64), "x").unwrap(), zero()); + } + + #[test] + fn variable_has_unit_derivative() { + assert_eq!(differentiate(&col("x"), "x").unwrap(), one()); + } + + #[test] + fn other_variable_has_zero_derivative() { + assert_eq!(differentiate(&col("y"), "x").unwrap(), zero()); + } + + #[test] + fn sum_rule_folds_constants() { + // d/dx (x + y) = 1 + 0 = 1 + let e = add(col("x"), col("y")); + assert_eq!(differentiate(&e, "x").unwrap(), one()); + } + + #[test] + fn product_rule() { + // d/dx (x * x) = 1*x + x*1 = x + x + let e = binary(col("x"), Operator::Multiply, col("x")); + let expected = add(col("x"), col("x")); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn quotient_rule() { + // d/dx (x / y) = (1*y - x*0) / (y*y) = y / (y*y) + let e = binary(col("x"), Operator::Divide, col("y")); + let expected = div(col("y"), square(col("y"))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn chain_rule_sin() { + // d/dx sin(x) = cos(x) * 1 = cos(x) + let d = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + assert_eq!(d, expr_fn::cos(col("x"))); + // Readable, precedence-free rendering. + assert_eq!(d.to_string(), "cos(x)"); + } + + #[test] + fn composite_sin_times_x() { + // d/dx (sin(x) * x) = cos(x)*x + sin(x) + let e = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let d = differentiate(&e, "x").unwrap(); + assert_eq!(d.to_string(), "cos(x) * x + sin(x)"); + } + + #[test] + fn power_constant_exponent() { + // d/dx power(x, 2) = 2 * power(x, 1) * 1 = 2 * power(x, 1) + let e = expr_fn::power(col("x"), lit(2.0_f64)); + let expected = mul(lit(2.0_f64), expr_fn::power(col("x"), lit(1.0_f64))); + assert_eq!(differentiate(&e, "x").unwrap(), expected); + } + + #[test] + fn unsupported_operator_errors() { + let e = binary(col("x"), Operator::Modulo, col("y")); + assert!(differentiate(&e, "x").is_err()); + } + + #[test] + fn unsupported_function_errors() { + // atan2 is binary and has no rule yet. + let e = expr_fn::atan2(col("x"), col("y")); + assert!(differentiate(&e, "x").is_err()); + } + + #[test] + fn higher_order_derivative() { + // Differentiation composes: d2/dx2 sin(x) = -sin(x). + let d1 = differentiate(&expr_fn::sin(col("x")), "x").unwrap(); + let d2 = differentiate(&d1, "x").unwrap(); + assert_eq!(d2, neg(expr_fn::sin(col("x")))); + } + + #[test] + fn jvp_seeds_a_tangent_on_one_input() { + // jvp(x*y, {x: dx}) = product rule with tangent(x)=dx, tangent(y)=0 + // = dx*y + x*0 = dx*y + let f = binary(col("x"), Operator::Multiply, col("y")); + let seeds = HashMap::from([("x".to_string(), col("dx"))]); + let t = jvp(&f, &seeds).unwrap(); + assert_eq!(t, mul(col("dx"), col("y"))); + } + + #[test] + fn jvp_with_unit_seed_matches_grad() { + // A one-hot tangent reproduces the partial derivative. + let f = expr_fn::sin(col("x")); + let seeds = HashMap::from([("x".to_string(), one())]); + assert_eq!(jvp(&f, &seeds).unwrap(), differentiate(&f, "x").unwrap()); + } + + #[test] + fn vjp_equals_cotangent_times_grad() { + // rewrite_vjp(sin(x), x, w) = w * cos(x) + let f = expr_fn::sin(col("x")); + let got = rewrite_vjp(&[f.clone(), col("x"), col("w")]).unwrap(); + assert_eq!(got, mul(col("w"), expr_fn::cos(col("x")))); + } + + #[test] + fn jvp_and_vjp_agree_for_unit_seed() { + // With matching unit seed/cotangent, forward and reverse coincide. + let f = binary(expr_fn::sin(col("x")), Operator::Multiply, col("x")); + let fwd = rewrite_jvp(&[f.clone(), col("x"), one()]).unwrap(); + let rev = rewrite_vjp(&[f, col("x"), one()]).unwrap(); + assert_eq!(fwd, rev); + } + + #[test] + fn sql_rewrite_replaces_grad_call() { + // grad(sin(x), x) -> cos(x); the surrounding SELECT is preserved. + let out = rewrite_grad_in_sql("SELECT grad(sin(x), x) AS d FROM t").unwrap(); + assert_eq!(out, "SELECT (cos(x)) AS d FROM t"); + } + + #[test] + fn sql_rewrite_leaves_non_grad_queries_intact() { + // A query with no marker is still parsed and re-emitted unchanged in + // meaning (the caller only invokes the rewrite when a marker is present). + let out = rewrite_grad_in_sql("SELECT a + b FROM t").unwrap(); + assert_eq!(out, "SELECT a + b FROM t"); + } + + #[test] + fn sql_rewrite_fires_inside_recursive_cte() { + // The #197 capability: a marker inside a recursive term is rewritten, + // a query shape the Substrait bridge could never carry. d/dx(x*x) = x+x. + let out = rewrite_grad_in_sql( + "WITH RECURSIVE r AS (SELECT 1.0 AS x UNION ALL \ + SELECT x - grad(x * x, x) FROM r WHERE x < 10) SELECT x FROM r", + ) + .unwrap(); + assert!(out.contains("(x + x)"), "unexpected rewrite: {out}"); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } + + #[test] + fn sql_rewrite_handles_nested_higher_order_grad() { + // grad(grad(power(x, 3), x), x) -> d2/dx2 (x^3) = 6x; bottom-up so the + // inner call is differentiated before the outer one. + let out = rewrite_grad_in_sql("SELECT grad(grad(power(x, 3), x), x) AS d FROM t").unwrap(); + assert!( + !out.to_lowercase().contains("grad("), + "marker left behind: {out}" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index c489609..d157d72 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,8 @@ //! Will skip loading partitions whose time ranges are entirely before 2020-02-01. //! Supported operators: `=`, `<`, `>`, `<=`, `>=`, `BETWEEN`, `IN`, `AND`, `OR`. +mod autograd; + use std::any::Any; use std::collections::{HashMap, HashSet}; use std::ffi::CString; @@ -48,13 +50,13 @@ use std::fmt::Debug; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::pyarrow::FromPyArrow; use async_stream::try_stream; use async_trait::async_trait; use datafusion::catalog::streaming::StreamingTable; use datafusion::catalog::Session; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::common::{DFSchema, DataFusionError, Result as DFResult, ScalarValue}; use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::logical_expr::expr::InList; @@ -64,6 +66,8 @@ use datafusion::logical_expr::{ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion::prelude::SessionContext; +use datafusion::sql::unparser::expr_to_sql; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use pyo3::prelude::*; @@ -981,9 +985,91 @@ impl LazyArrowStreamTable { } } +// ============================================================================ +// Autograd: SQL-level grad() rewrite +// ============================================================================ + +/// Rewrite every `grad`/`jvp`/`vjp` call in a SQL query into its symbolic +/// derivative, returning the rewritten SQL text. +/// +/// The autograd engine operates on DataFusion logical `Expr` trees. Rather than +/// round-tripping a whole plan across the cdylib boundary, this rewrites the +/// query as **SQL text** before it is planned: each marker call is parsed, +/// differentiated, and rendered back to SQL in place. Because it runs before +/// planning, it works for any query shape the parser accepts — recursive CTEs, +/// DML, and subqueries — which the plan-level Substrait bridge could not carry. +/// +/// Args: +/// query: A SQL query string that may contain `grad`/`jvp`/`vjp` calls. +/// +/// Returns: +/// The rewritten SQL string, ready to pass to ``SessionContext.sql``. +#[pyfunction] +fn rewrite_grad_sql(query: &str) -> PyResult { + autograd::rewrite_grad_in_sql(query).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "rewrite_grad_sql: failed to rewrite grad() calls: {e}" + )) + }) +} + +/// Differentiate a SQL scalar expression symbolically and return the +/// derivative as SQL text. +/// +/// Where [`grad_rewrite`] rewrites `grad(...)` calls inside a whole plan, this +/// differentiates a single expression and hands back the result as SQL — the +/// autograd engine acting as a "calculus compiler". It lets a caller obtain an +/// update rule once and embed it in queries the Substrait round-trip can't +/// carry a `grad` marker through, such as a recursive-CTE training loop. +/// +/// Args: +/// expr: A SQL scalar expression over `columns` (e.g. `"sin(x) * x"`). +/// wrt: The column name to differentiate with respect to. +/// columns: The column names in scope; all treated as `Float64` (enough to +/// parse and differentiate — types don't affect the symbolic result). +/// +/// Returns: +/// The derivative as a SQL string (e.g. `"cos(x) * x + sin(x)"`). +#[pyfunction] +fn differentiate_sql(expr: &str, wrt: &str, columns: Vec) -> PyResult { + let ctx = SessionContext::new(); + + let fields: Vec = columns + .iter() + .map(|name| Field::new(name, DataType::Float64, true)) + .collect(); + let df_schema = DFSchema::try_from(Schema::new(fields)).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to build schema: {e}" + )) + })?; + + let parsed = ctx.parse_sql_expr(expr, &df_schema).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to parse expression '{expr}': {e}" + )) + })?; + + let derivative = autograd::differentiate(&parsed, wrt).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to differentiate: {e}" + )) + })?; + + let sql = expr_to_sql(&derivative).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "differentiate_sql: failed to render derivative to SQL: {e}" + )) + })?; + + Ok(sql.to_string()) +} + /// Python module initialization #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_function(wrap_pyfunction!(rewrite_grad_sql, m)?)?; + m.add_function(wrap_pyfunction!(differentiate_sql, m)?)?; Ok(()) } diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..794194e --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,309 @@ +"""Tests for the SQL autograd surface: ``SELECT grad(expr, column) ...``. + +These exercise the full path — XarrayContext.sql() differentiates every +``grad``/``jvp``/``vjp`` call as SQL text before planning, then DataFusion +executes the rewritten query — and compare results against analytic +derivatives computed with numpy. +""" + +import numpy as np +import pyarrow as pa +import pytest +import xarray as xr + +import xarray_sql as xql + + +@pytest.fixture +def ctx(): + val = np.linspace(0.1, 3.0, 16) + ds = xr.Dataset( + {"val": (("i",), val)}, + coords={"i": np.arange(16)}, + ) + context = xql.XarrayContext() + context.from_dataset("t", ds, chunks={"i": 5}) + return context + + +@pytest.fixture +def ctx_xy(): + rng = np.random.default_rng(0) + n = 16 + ds = xr.Dataset( + { + "x": (("i",), rng.uniform(0.5, 2.5, n)), + "y": (("i",), rng.uniform(0.5, 2.5, n)), + }, + coords={"i": np.arange(n)}, + ) + context = xql.XarrayContext() + context.from_dataset("g", ds, chunks={"i": 5}) + return context, ds + + +def _ordered(df, key="i"): + """Collect a result DataFrame into a dict of column -> numpy array, sorted + by the integer key column so comparisons are index-aligned.""" + pdf = df.to_pandas().sort_values(key) + return {c: pdf[c].to_numpy() for c in pdf.columns} + + +def test_grad_sin_is_cos(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val), val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val)) + + +def test_grad_product_rule(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql("SELECT i, grad(sin(val) * val, val) AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + +def test_grad_exp_equals_value(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql("SELECT i, exp(val) AS v, grad(exp(val), val) AS d FROM t") + ) + np.testing.assert_allclose(res["d"], np.exp(val)) + np.testing.assert_allclose(res["d"], res["v"]) + + +def test_grad_quotient_and_power(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(1.0 / val, val) AS dinv, " + "grad(power(val, 3), val) AS dcube FROM t" + ) + ) + np.testing.assert_allclose(res["dinv"], -1.0 / val**2) + np.testing.assert_allclose(res["dcube"], 3.0 * val**2) + + +def test_higher_order_grad(ctx): + # Nested grad() differentiates repeatedly: the inner call is rewritten + # first, then the outer differentiates its result. + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, " + "grad(grad(sin(val), val), val) AS d2_sin, " + "grad(grad(power(val, 3), val), val) AS d2_cube FROM t" + ) + ) + np.testing.assert_allclose(res["d2_sin"], -np.sin(val)) # -sin + np.testing.assert_allclose(res["d2_cube"], 6.0 * val) # d2/dx2 x^3 = 6x + + +def test_third_order_grad(ctx): + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(grad(grad(sin(val), val), val), val) AS d3 FROM t" + ) + ) + np.testing.assert_allclose(res["d3"], -np.cos(val)) # d3/dx3 sin = -cos + + +def test_non_grad_query_is_unaffected(ctx): + # Queries without grad() bypass the rewrite and behave normally. + res = _ordered(ctx.sql("SELECT i, val FROM t")) + np.testing.assert_allclose(res["val"], np.linspace(0.1, 3.0, 16)) + + +def test_unsupported_function_raises(ctx): + # atan2 has no derivative rule yet -> a clear error, not a wrong answer. + with pytest.raises(Exception): + ctx.sql("SELECT grad(atan2(val, val), val) AS d FROM t").to_pandas() + + +def test_grad_over_in_memory_table(ctx): + # grad works over plain DataFusion tables too (not just xarray-registered + # ones): here a coefficient lives in an in-memory MemTable cross-joined to + # the xarray data. d/dval (c * val^2) = c * 2*val, with c = 3. + ctx.register_record_batches( + "coef", [[pa.RecordBatch.from_pydict({"c": [3.0]})]] + ) + val = np.linspace(0.1, 3.0, 16) + res = _ordered( + ctx.sql( + "SELECT i, grad(c * val * val, val) AS d FROM t CROSS JOIN coef" + ) + ) + np.testing.assert_allclose(res["d"], 3.0 * 2.0 * val) + + +def test_differentiate_sql_round_trip(ctx): + # differentiate_sql returns the derivative as SQL text; evaluating it must + # match the analytic derivative. d/dval (sin(val)*val) = cos(val)*val + sin(val). + deriv = xql.differentiate_sql("sin(val) * val", "val", ["val"]) + val = np.linspace(0.1, 3.0, 16) + res = _ordered(ctx.sql(f"SELECT i, {deriv} AS d FROM t")) + np.testing.assert_allclose(res["d"], np.cos(val) * val + np.sin(val)) + + +def test_grad_inside_aggregate(ctx): + # Differentiation through an aggregate is just linearity: + # AGG(grad(f, x)) == d/dx AGG(f). grad rewrites to plain SQL before the + # aggregate runs, so this composes with no special machinery. + val = np.linspace(0.1, 3.0, 16) + res = ctx.sql( + "SELECT SUM(grad(val * val, val)) AS s, " + "AVG(grad(sin(val), val)) AS a FROM t" + ).to_pandas() + np.testing.assert_allclose(res["s"][0], np.sum(2 * val)) + np.testing.assert_allclose(res["a"][0], np.mean(np.cos(val))) + + +def test_gradient_descent_in_sql(): + # End to end: fit y ~= a*x + b by minimising MSE, with the gradients + # w.r.t. the parameters computed in SQL via AVG(grad(loss, param)). + rng = np.random.default_rng(0) + n = 200 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + data = xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ) + ctx = xql.XarrayContext() + ctx.from_dataset("d", data, chunks={"i": n}) + + resid = "(y - (a * x + b))" + loss = f"{resid} * {resid}" + a, b, lr = 0.0, 0.0, 0.4 + losses = [] + for _ in range(120): + if "params" in ctx._registered_datasets: + ctx.deregister_table("params") + del ctx._registered_datasets["params"] + params = xr.Dataset( + {"a": (("p",), [a]), "b": (("p",), [b])}, coords={"p": [0]} + ) + ctx.from_dataset("params", params, chunks={"p": 1}) + row = ctx.sql( + f"SELECT AVG({loss}) AS loss, " + f"AVG(grad({loss}, a)) AS dl_da, " + f"AVG(grad({loss}, b)) AS dl_db FROM d CROSS JOIN params" + ).to_pandas() + losses.append(float(row["loss"][0])) + a -= lr * float(row["dl_da"][0]) + b -= lr * float(row["dl_db"][0]) + + assert losses[-1] < losses[0] # loss decreased + np.testing.assert_allclose([a, b], [a_true, b_true], atol=0.05) + + +def test_grad_inside_recursive_cte(): + # The headline of #197: grad() *inside* a recursive CTE — a query shape the + # old Substrait bridge could not represent. Newton's method for sqrt(2) + # drives the step with grad(x*x - 2, x) computed in the recursive term: + # x <- x - (x*x - 2) / d/dx(x*x - 2) = x - (x*x - 2) / (2x). + ctx = xql.XarrayContext() + res = ctx.sql( + "WITH RECURSIVE newton AS (" + " SELECT 0 AS step, CAST(1.0 AS DOUBLE) AS x " + " UNION ALL " + " SELECT step + 1 AS step, " + " x - (x * x - 2.0) / grad(x * x - 2.0, x) AS x " + " FROM newton WHERE step < 20" + ") " + "SELECT x FROM newton ORDER BY step DESC LIMIT 1" + ).to_pandas() + np.testing.assert_allclose(res["x"][0], np.sqrt(2.0), atol=1e-9) + + +def test_multi_input_grad_columns(ctx_xy): + # A full Jacobian written as separate scalar grad() columns: + # f = x*y -> df/dx = y, df/dy = x. + context, ds = ctx_xy + res = _ordered( + context.sql( + "SELECT i, grad(x * y, x) AS dfdx, grad(x * y, y) AS dfdy FROM g" + ) + ) + np.testing.assert_allclose(res["dfdx"], ds["y"].values) + np.testing.assert_allclose(res["dfdy"], ds["x"].values) + + +def test_jvp_forward_directional_derivative(ctx_xy): + # jvp(f, x, dx) = df/dx * dx. With f = sin(x)*y and a constant tangent. + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, jvp(sin(x) * y, x, 2.0) AS t FROM g")) + np.testing.assert_allclose(res["t"], (np.cos(x) * y) * 2.0) + + +def test_jvp_multi_input_is_sum(ctx_xy): + # A full directional derivative is the sum of per-input jvp terms: + # df/dx*dx + df/dy*dy for f = x*y, with dx=1, dy=1 -> y + x. + context, ds = ctx_xy + res = _ordered( + context.sql( + "SELECT i, jvp(x * y, x, 1.0) + jvp(x * y, y, 1.0) AS t FROM g" + ) + ) + np.testing.assert_allclose(res["t"], ds["y"].values + ds["x"].values) + + +def test_vjp_reverse_pullback(ctx_xy): + # vjp(f, x, w) = w * df/dx. With f = sin(x)*y and cotangent w = 3.0. + context, ds = ctx_xy + x, y = ds["x"].values, ds["y"].values + res = _ordered(context.sql("SELECT i, vjp(sin(x) * y, x, 3.0) AS s FROM g")) + np.testing.assert_allclose(res["s"], 3.0 * (np.cos(x) * y)) + + +@pytest.fixture +def ctx_mixed(): + # A mixed-dimension dataset registers as schema-qualified tables: + # era5.time_x (surface, 2 dims) + # era5.time_x_level (atmosphere, 3 dims) + rng = np.random.default_rng(1) + ds = xr.Dataset( + { + "sfc": (("time", "x"), rng.uniform(0.5, 2.5, (3, 4))), + "atm": (("time", "x", "level"), rng.uniform(0.5, 2.5, (3, 4, 2))), + }, + coords={"time": [0, 1, 2], "x": np.arange(4.0), "level": [0, 1]}, + ) + context = xql.XarrayContext() + context.from_dataset("era5", ds, chunks={"time": 1}) + return context, ds + + +def test_grad_on_qualified_surface_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT time, x, sfc, grad(sin(sfc), sfc) AS d FROM era5.time_x" + ), + key="sfc", + ) + np.testing.assert_allclose(res["d"], np.cos(res["sfc"])) + + +def test_grad_on_qualified_atmosphere_table(ctx_mixed): + context, ds = ctx_mixed + res = _ordered( + context.sql( + "SELECT atm, grad(power(atm, 2), atm) AS d FROM era5.time_x_level" + ), + key="atm", + ) + np.testing.assert_allclose(res["d"], 2.0 * res["atm"]) + + +def test_jvp_and_vjp_agree_for_unit_seed(ctx_xy): + # Forward (unit tangent) and reverse (unit cotangent) coincide for a + # scalar output -- both contract the same partial derivative. + context, _ = ctx_xy + res = _ordered( + context.sql( + "SELECT i, jvp(sin(x) * y, x, 1.0) AS fwd, " + "vjp(sin(x) * y, x, 1.0) AS rev FROM g" + ) + ) + np.testing.assert_allclose(res["fwd"], res["rev"]) diff --git a/xarray_sql/__init__.py b/xarray_sql/__init__.py index d1e5984..c01f295 100644 --- a/xarray_sql/__init__.py +++ b/xarray_sql/__init__.py @@ -1,4 +1,5 @@ from . import cftime +from ._native import differentiate_sql from .df import from_map from .reader import read_xarray, read_xarray_table from .sql import XarrayContext @@ -6,6 +7,7 @@ __all__ = [ "cftime", "XarrayContext", + "differentiate_sql", "read_xarray_table", "read_xarray", "from_map", # deprecated diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..46fe8e6 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -1,13 +1,21 @@ +import re + import xarray as xr from datafusion import SessionContext from datafusion.catalog import Schema from collections import defaultdict +from . import _native from . import cftime as cft from .df import Chunks from .ds import XarrayDataFrame from .reader import read_xarray_table +# Matches a call to an autograd marker function (``grad(`` / ``jvp(`` / ``vjp(``, +# case-insensitive), used as a cheap gate so ordinary queries skip the grad +# source-to-source rewrite. +_GRAD_CALL = re.compile(r"\b(grad|jvp|vjp)\s*\(", re.IGNORECASE) + class XarrayContext(SessionContext): """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" @@ -166,6 +174,11 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: ``.to_dataset(dimension_columns=[...])`` for round-tripping the result back to an ``xr.Dataset``. + If the query contains ``grad`` / ``jvp`` / ``vjp`` calls, they are + differentiated and substituted as SQL text *before* planning (see + :meth:`_rewrite_autograd`), so the differentiation works inside any + query shape — recursive CTEs, DML, and subqueries included. + Args: query: A SQL query string. *args: Forwarded to ``SessionContext.sql``. @@ -174,9 +187,34 @@ def sql(self, query: str, *args, **kwargs) -> XarrayDataFrame: Returns: An :class:`XarrayDataFrame` wrapping the DataFusion DataFrame. """ + if _GRAD_CALL.search(query): + query = self._rewrite_autograd(query) inner = super().sql(query, *args, **kwargs) return XarrayDataFrame(inner, templates=self._registered_datasets) + def _rewrite_autograd(self, query: str) -> str: + """Differentiate ``grad`` / ``jvp`` / ``vjp`` calls into SQL text. + + The differentiation engine lives in the native (Rust) extension and + operates on DataFusion logical expressions. Rather than round-tripping a + whole plan across that extension's boundary, we hand it the query as SQL + text: it parses each marker call, differentiates it symbolically, and + renders the derivative back into the query in place. The result is an + ordinary SQL string this context can plan and execute directly. + + * ``grad(expr, column)`` -> ``d(expr)/d(column)``. + * ``jvp(expr, column, tangent)`` -> forward-mode directional derivative + ``d(expr)/d(column) * tangent`` (seed a tangent on an input). A + multi-input directional derivative is a sum of jvp terms. + * ``vjp(expr, column, cotangent)`` -> reverse-mode pullback + ``cotangent * d(expr)/d(column)`` (seed a cotangent on the output). + + A full gradient/Jacobian is expressed as several scalar columns, e.g. + ``grad(f, x) AS dfdx, grad(f, y) AS dfdy``. + """ + rewritten: str = _native.rewrite_grad_sql(query) + return rewritten + def _group_vars_by_dims(ds: xr.Dataset) -> dict[tuple[str, ...], list[str]]: """Group variables in the dataset based on shared dims.