From 214399bb2ea9c386012d6be5bb2cd80f0218b029 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 30 Jun 2026 16:34:30 +0300 Subject: [PATCH] Add differentiable-SQL demos: ARCO-ERA5 and gradient descent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stacked demo branch (on the autograd feature) holding the runnable benchmark scripts, kept out of the core branch so it stays reviewable. * grad_era5.py: symbolic grad over real ARCO-ERA5 data (wind-speed sensitivity checked exactly; saturation vapour pressure checked against the closed-form Clausius-Clapeyron slope). The queries ORDER BY latitude DESC, longitude to match ERA5's native order, so results line up with the xarray reference with no sorting on either side (single partition, so the order survives to_dataset). * grad_descent.py: gradient descent as ONE declarative recursive-CTE query, with grad() inside the recursion. The loss is written once and differentiated in place; AVG(grad(loss, a)) descends the gradient each step. No Python loop and no precompiled rule — grad() is rewritten to SQL before planning, so it works inside the recursive CTE. Fit matches numpy least-squares. Co-Authored-By: Claude Opus 4.8 --- benchmarks/README.md | 63 ++++++++++++++ benchmarks/grad_descent.py | 113 ++++++++++++++++++++++++ benchmarks/grad_era5.py | 171 +++++++++++++++++++++++++++++++++++++ 3 files changed, 347 insertions(+) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/grad_descent.py create mode 100644 benchmarks/grad_era5.py diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..e7348da --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,63 @@ +# Benchmarks & demos + +Standalone scripts that exercise xarray-sql against real data. Each declares its +own dependencies inline (PEP 723) and points `xarray_sql` at this checkout, so +they run with no setup: + +```bash +uv run benchmarks/grad_era5.py +``` + +## `grad_era5.py` — differentiable SQL over ARCO-ERA5 + +Demonstrates the autograd feature on a real climate archive +([ARCO-ERA5](https://github.com/google-research/arco-era5), read anonymously +from GCS — needs `gcsfs` and network access). + +The key idea: a physical quantity is written as an **analytic SQL formula** over +ERA5 variables, and `grad(...)` differentiates that formula **symbolically**, +evaluated at every grid cell. Because each row is an independent point, this is +the relational equivalent of `jax.vmap(jax.grad(f))`. It is *not* a finite- +difference spatial gradient — `grad(f(u, v), u)` is the exact partial derivative +of `f`. + +Two worked cases, each checked against an analytic reference: + +| Quantity | SQL | Derivative | Check | +| --- | --- | --- | --- | +| Wind speed | `sqrt(power(u,2) + power(v,2))` | `grad(speed, u) = u/speed` | exact | +| Saturation vapour pressure | `A*exp(B*tc/(tc+C))` | `grad(e_s, T)` | closed-form Clausius-Clapeyron slope | + +Each query round-trips back to an `xarray.Dataset` via `.to_dataset(...)`. + +## `grad_descent.py` — gradient descent as one declarative SQL query + +Fits a line `y ~= a*x + b` by minimising the mean squared error, with the +**entire training loop expressed as a single recursive CTE** — no Python +iteration and no precompiled update rule. `grad(...)` lives *inside* the +recursion: `params(step, a, b)` starts at one row and each recursion appends the +next generation, descending along the gradient that `grad` computes from the +loss formula directly (`AVG(grad(loss, a))` is the relational `d/da (Σ loss) / +N` — differentiation through the aggregate is just linearity): + +```sql +WITH RECURSIVE params(step, a, b) AS ( + SELECT 0, 0.0, 0.0 + UNION ALL + SELECT params.step + 1, + params.a - lr*AVG(grad(loss, a)), + params.b - lr*AVG(grad(loss, b)) + FROM params CROSS JOIN d WHERE params.step < STEPS + GROUP BY params.step, params.a, params.b) +SELECT * FROM params ORDER BY step +``` + +So gradient, update, and iteration are all declarative SQL; the trajectory is +the rows of one query. The fit matches numpy's least-squares solution. +Self-contained (no network). + +`grad` works inside the recursive CTE because it is differentiated as a SQL +source-to-source rewrite *before* the query is planned — no Substrait round-trip, +so no plan-shape restrictions. (If you instead want the derivative as a string to +embed yourself, `xql.differentiate_sql(loss, "a", cols)` compiles a single +expression to SQL text.) diff --git a/benchmarks/grad_descent.py b/benchmarks/grad_descent.py new file mode 100644 index 0000000..1d4babf --- /dev/null +++ b/benchmarks/grad_descent.py @@ -0,0 +1,113 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Gradient descent as a single declarative SQL query. + +Fits a line ``y ~= a*x + b`` by minimising the mean squared error — with the +**entire training loop expressed as one recursive CTE**, no Python iteration and +no precompiled update rule. ``grad(...)`` lives *inside* the recursion: + + WITH RECURSIVE params(step, a, b) AS ( + SELECT 0, 0.0, 0.0 + UNION ALL + SELECT params.step + 1, + params.a - lr * AVG(grad(loss, a)), + params.b - lr * AVG(grad(loss, b)) + FROM params CROSS JOIN d + WHERE params.step < STEPS + GROUP BY params.step, params.a, params.b) + SELECT step, a, b FROM params ORDER BY step + +Each recursion appends the next generation, descending along the gradient that +``grad`` computes from the loss formula directly. ``AVG(grad(loss, a))`` is the +relational ``d/da (Σ loss) / N`` — differentiation through the aggregate is just +linearity. So gradient, update, and iteration are all one declarative query; the +optimisation trajectory is the rows of that query. + +``grad`` is differentiated as a SQL source-to-source rewrite *before* the query +is planned, so the marker works inside the recursive CTE (and any other query +shape) with no Substrait round-trip. The loss is written once, as ordinary SQL, +and the engine differentiates it symbolically — the relational equivalent of +``jax.vmap(jax.grad(f))``, since each row is an independent evaluation point. + +Run standalone: + + uv run benchmarks/grad_descent.py +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +# Per-row loss r^2 with residual r = y - (a*x + b). The columns a, b come from +# the recursive `params` relation; x, y come from the data table `d`. +RESIDUAL = "(y - (a * x + b))" +LOSS = f"{RESIDUAL} * {RESIDUAL}" +LR = 0.4 +STEPS = 200 + + +def main() -> None: + rng = np.random.default_rng(0) + n = 500 + x = rng.uniform(0.0, 1.0, n) + a_true, b_true = 2.0, -1.0 + y = a_true * x + b_true + rng.normal(0.0, 0.01, n) + + ctx = xql.XarrayContext() + ctx.from_dataset( + "d", + xr.Dataset( + {"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)} + ), + chunks={"i": n}, + ) + + # The entire training loop is one declarative recursive query: each step + # appends the next generation, descending along the gradient that grad() + # computes from the loss — differentiated inside the recursion itself. + trajectory = ctx.sql( + f""" + WITH RECURSIVE params(step, a, b) AS ( + SELECT 0 AS step, CAST(0.0 AS DOUBLE) AS a, CAST(0.0 AS DOUBLE) AS b + UNION ALL + SELECT params.step + 1 AS step, + params.a - {LR} * AVG(grad({LOSS}, a)) AS a, + params.b - {LR} * AVG(grad({LOSS}, b)) AS b + FROM params CROSS JOIN d + WHERE params.step < {STEPS} + GROUP BY params.step, params.a, params.b + ) + SELECT step, a, b FROM params ORDER BY step + """ + ).to_pandas() + + print("trajectory (every 40th generation):") + print(trajectory.iloc[::40].to_string(index=False)) + + a, b = float(trajectory["a"].iloc[-1]), float(trajectory["b"].iloc[-1]) + a_ols, b_ols = np.polyfit(x, y, 1) + print( + f"\nSQL gradient descent: a={a:.4f} b={b:.4f} ({len(trajectory)} generations)" + ) + print(f"least-squares (numpy): a={a_ols:.4f} b={b_ols:.4f}") + assert abs(a - a_ols) < 1e-2 and abs(b - b_ols) < 1e-2 + print( + "\nOK: a single recursive-CTE query with grad() inside fit the line " + "to the OLS solution." + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/grad_era5.py b/benchmarks/grad_era5.py new file mode 100644 index 0000000..866f066 --- /dev/null +++ b/benchmarks/grad_era5.py @@ -0,0 +1,171 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "xarray_sql", +# "xarray[io]", +# "gcsfs", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray_sql = { path = "..", editable = true } +# /// +"""Differentiable SQL over ARCO-ERA5. + +A minimal demonstration of xarray-sql's autograd: take a real climate archive +(ARCO-ERA5, read anonymously from GCS), express a physical quantity as an +*analytic* SQL formula over its variables, and let ``grad(...)`` differentiate +that formula symbolically — evaluated per grid cell, which is the relational +equivalent of ``jax.vmap(jax.grad(f))`` (each row is an independent point). + +Note this is *symbolic* differentiation of an expression, not a finite- +difference spatial gradient: ``grad(f(u, v), u)`` is the exact partial +derivative of the formula ``f``, evaluated at every cell's values. + +Two cases: + +1. Wind-speed magnitude ``speed = sqrt(u^2 + v^2)``. Its sensitivity to the + eastward wind is ``d(speed)/du = u / speed`` — checked exactly. + +2. Saturation vapour pressure ``e_s(T)`` (August-Roche-Magnus form of the + Clausius-Clapeyron relation). ``d(e_s)/dT`` governs how fast the atmosphere's + moisture capacity grows with temperature — checked against the closed-form + slope. + +Run standalone (builds the local extension on first use): + + uv run benchmarks/grad_era5.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +ARCO_ERA5 = ( + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +) + +# ERA5 variable names start with a digit, so they must be double-quoted in SQL. +U = '"10m_u_component_of_wind"' +V = '"10m_v_component_of_wind"' +T = '"2m_temperature"' + + +def load_era5_block() -> xr.Dataset: + """Open ARCO-ERA5 and pull one timestamp over a small region. + + Lazy open of the whole archive; only the requested block is read. We keep + it to a few thousand cells so the demo runs in seconds. + """ + full = xr.open_zarr( + ARCO_ERA5, chunks=None, storage_options={"token": "anon"} + ) + block = ( + full[ + [ + "10m_u_component_of_wind", + "10m_v_component_of_wind", + "2m_temperature", + ] + ] + .sel(time="2020-01-01T00") + # A ~North-America box (index-based to avoid lat-orientation pitfalls). + .isel(latitude=slice(120, 200), longitude=slice(900, 1000)) + .load() + ) + # One partition, so a SQL `ORDER BY latitude DESC` survives the round-trip + # back to xarray (across multiple partitions, to_dataset reconstructs + # coordinates in ascending order regardless of ORDER BY). + return block.chunk() + + +def wind_speed_sensitivity(ctx: xql.XarrayContext, ref: xr.Dataset) -> None: + """grad(sqrt(u^2 + v^2)) checked against the exact u / speed, v / speed.""" + speed = f"sqrt(power({U}, 2) + power({V}, 2))" + out = ctx.sql( + f""" + SELECT + latitude, + longitude, + {speed} AS wind_speed, + grad({speed}, {U}) AS d_speed_d_u, + grad({speed}, {V}) AS d_speed_d_v + FROM era5 + ORDER BY latitude DESC, longitude + """ + ).to_dataset(dims=["latitude", "longitude"]) + + u = ref["10m_u_component_of_wind"] + v = ref["10m_v_component_of_wind"] + speed_ref = np.sqrt(u**2 + v**2) + + xr.testing.assert_allclose( + out["wind_speed"], speed_ref.rename("wind_speed") + ) + xr.testing.assert_allclose( + out["d_speed_d_u"], (u / speed_ref).rename("d_speed_d_u") + ) + xr.testing.assert_allclose( + out["d_speed_d_v"], (v / speed_ref).rename("d_speed_d_v") + ) + print(" wind-speed sensitivity matches u/|w|, v/|w| exactly") + print(out) + + +def clausius_clapeyron(ctx: xql.XarrayContext, ref: xr.Dataset) -> None: + """grad(e_s(T)) checked against the closed-form Clausius-Clapeyron slope.""" + # August-Roche-Magnus: e_s(T) = A * exp(B * tc / (tc + C)), tc = T - 273.15. + a, b, c = 6.1094, 17.625, 243.04 + tc = f"({T} - 273.15)" + es = f"{a} * exp({b} * {tc} / ({tc} + {c}))" + out = ctx.sql( + f""" + SELECT + latitude, + longitude, + {es} AS e_s, + grad({es}, {T}) AS de_s_dt + FROM era5 + ORDER BY latitude DESC, longitude + """ + ).to_dataset(dims=["latitude", "longitude"]) + + # Reference in float64 (the columns are float32): the exact derivative is + # d(e_s)/dT = e_s * B*C / (tc + C)^2. + temp = ref["2m_temperature"].astype("float64") + tc_ref = temp - 273.15 + es_ref = a * np.exp(b * tc_ref / (tc_ref + c)) + des_dt_ref = es_ref * (b * c) / (tc_ref + c) ** 2 + + xr.testing.assert_allclose(out["e_s"], es_ref.rename("e_s"), rtol=1e-5) + xr.testing.assert_allclose( + out["de_s_dt"], des_dt_ref.rename("de_s_dt"), rtol=1e-5 + ) + print(" d(e_s)/dT matches the closed-form Clausius-Clapeyron slope") + print(out) + + +def main() -> None: + t0 = time.time() + ds = load_era5_block() + print(f"loaded ERA5 block {dict(ds.sizes)} in {time.time() - t0:.1f}s") + + ctx = xql.XarrayContext() + ctx.from_dataset("era5", ds) + + print("\n== wind-speed sensitivity: grad(sqrt(u^2 + v^2)) ==") + wind_speed_sensitivity(ctx, ds) + + print("\n== Clausius-Clapeyron: grad(e_s(T)) ==") + clausius_clapeyron(ctx, ds) + + print("\nOK: symbolic SQL gradients match the analytic references.") + + +if __name__ == "__main__": + main()