Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Benchmarks & demos

Standalone scripts that exercise xarray-sql against real data. Each declares its
own dependencies inline (PEP 723) and points `xarray_sql` at this checkout, so
they run with no setup:

```bash
uv run benchmarks/grad_era5.py
```

## `grad_era5.py` — differentiable SQL over ARCO-ERA5

Demonstrates the autograd feature on a real climate archive
([ARCO-ERA5](https://github.com/google-research/arco-era5), read anonymously
from GCS — needs `gcsfs` and network access).

The key idea: a physical quantity is written as an **analytic SQL formula** over
ERA5 variables, and `grad(...)` differentiates that formula **symbolically**,
evaluated at every grid cell. Because each row is an independent point, this is
the relational equivalent of `jax.vmap(jax.grad(f))`. It is *not* a finite-
difference spatial gradient — `grad(f(u, v), u)` is the exact partial derivative
of `f`.

Two worked cases, each checked against an analytic reference:

| Quantity | SQL | Derivative | Check |
| --- | --- | --- | --- |
| Wind speed | `sqrt(power(u,2) + power(v,2))` | `grad(speed, u) = u/speed` | exact |
| Saturation vapour pressure | `A*exp(B*tc/(tc+C))` | `grad(e_s, T)` | closed-form Clausius-Clapeyron slope |

Each query round-trips back to an `xarray.Dataset` via `.to_dataset(...)`.

## `grad_descent.py` — gradient descent as one declarative SQL query

Fits a line `y ~= a*x + b` by minimising the mean squared error, with the
**entire training loop expressed as a single recursive CTE** — no Python
iteration and no precompiled update rule. `grad(...)` lives *inside* the
recursion: `params(step, a, b)` starts at one row and each recursion appends the
next generation, descending along the gradient that `grad` computes from the
loss formula directly (`AVG(grad(loss, a))` is the relational `d/da (Σ loss) /
N` — differentiation through the aggregate is just linearity):

```sql
WITH RECURSIVE params(step, a, b) AS (
SELECT 0, 0.0, 0.0
UNION ALL
SELECT params.step + 1,
params.a - lr*AVG(grad(loss, a)),
params.b - lr*AVG(grad(loss, b))
FROM params CROSS JOIN d WHERE params.step < STEPS
GROUP BY params.step, params.a, params.b)
SELECT * FROM params ORDER BY step
```

So gradient, update, and iteration are all declarative SQL; the trajectory is
the rows of one query. The fit matches numpy's least-squares solution.
Self-contained (no network).

`grad` works inside the recursive CTE because it is differentiated as a SQL
source-to-source rewrite *before* the query is planned — no Substrait round-trip,
so no plan-shape restrictions. (If you instead want the derivative as a string to
embed yourself, `xql.differentiate_sql(loss, "a", cols)` compiles a single
expression to SQL text.)
113 changes: 113 additions & 0 deletions benchmarks/grad_descent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "xarray_sql",
# "xarray",
# "numpy",
# ]
#
# [tool.uv.sources]
# xarray_sql = { path = "..", editable = true }
# ///
"""Gradient descent as a single declarative SQL query.

Fits a line ``y ~= a*x + b`` by minimising the mean squared error — with the
**entire training loop expressed as one recursive CTE**, no Python iteration and
no precompiled update rule. ``grad(...)`` lives *inside* the recursion:

WITH RECURSIVE params(step, a, b) AS (
SELECT 0, 0.0, 0.0
UNION ALL
SELECT params.step + 1,
params.a - lr * AVG(grad(loss, a)),
params.b - lr * AVG(grad(loss, b))
FROM params CROSS JOIN d
WHERE params.step < STEPS
GROUP BY params.step, params.a, params.b)
SELECT step, a, b FROM params ORDER BY step

Each recursion appends the next generation, descending along the gradient that
``grad`` computes from the loss formula directly. ``AVG(grad(loss, a))`` is the
relational ``d/da (Σ loss) / N`` — differentiation through the aggregate is just
linearity. So gradient, update, and iteration are all one declarative query; the
optimisation trajectory is the rows of that query.

``grad`` is differentiated as a SQL source-to-source rewrite *before* the query
is planned, so the marker works inside the recursive CTE (and any other query
shape) with no Substrait round-trip. The loss is written once, as ordinary SQL,
and the engine differentiates it symbolically — the relational equivalent of
``jax.vmap(jax.grad(f))``, since each row is an independent evaluation point.

Run standalone:

uv run benchmarks/grad_descent.py
"""

from __future__ import annotations

import numpy as np
import xarray as xr

import xarray_sql as xql

# Per-row loss r^2 with residual r = y - (a*x + b). The columns a, b come from
# the recursive `params` relation; x, y come from the data table `d`.
RESIDUAL = "(y - (a * x + b))"
LOSS = f"{RESIDUAL} * {RESIDUAL}"
LR = 0.4
STEPS = 200


def main() -> None:
rng = np.random.default_rng(0)
n = 500
x = rng.uniform(0.0, 1.0, n)
a_true, b_true = 2.0, -1.0
y = a_true * x + b_true + rng.normal(0.0, 0.01, n)

ctx = xql.XarrayContext()
ctx.from_dataset(
"d",
xr.Dataset(
{"x": (("i",), x), "y": (("i",), y)}, coords={"i": np.arange(n)}
),
chunks={"i": n},
)

# The entire training loop is one declarative recursive query: each step
# appends the next generation, descending along the gradient that grad()
# computes from the loss — differentiated inside the recursion itself.
trajectory = ctx.sql(
f"""
WITH RECURSIVE params(step, a, b) AS (
SELECT 0 AS step, CAST(0.0 AS DOUBLE) AS a, CAST(0.0 AS DOUBLE) AS b
UNION ALL
SELECT params.step + 1 AS step,
params.a - {LR} * AVG(grad({LOSS}, a)) AS a,
params.b - {LR} * AVG(grad({LOSS}, b)) AS b
FROM params CROSS JOIN d
WHERE params.step < {STEPS}
GROUP BY params.step, params.a, params.b
)
SELECT step, a, b FROM params ORDER BY step
"""
).to_pandas()

print("trajectory (every 40th generation):")
print(trajectory.iloc[::40].to_string(index=False))

a, b = float(trajectory["a"].iloc[-1]), float(trajectory["b"].iloc[-1])
a_ols, b_ols = np.polyfit(x, y, 1)
print(
f"\nSQL gradient descent: a={a:.4f} b={b:.4f} ({len(trajectory)} generations)"
)
print(f"least-squares (numpy): a={a_ols:.4f} b={b_ols:.4f}")
assert abs(a - a_ols) < 1e-2 and abs(b - b_ols) < 1e-2
print(
"\nOK: a single recursive-CTE query with grad() inside fit the line "
"to the OLS solution."
)


if __name__ == "__main__":
main()
171 changes: 171 additions & 0 deletions benchmarks/grad_era5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "xarray_sql",
# "xarray[io]",
# "gcsfs",
# "numpy",
# ]
#
# [tool.uv.sources]
# xarray_sql = { path = "..", editable = true }
# ///
"""Differentiable SQL over ARCO-ERA5.

A minimal demonstration of xarray-sql's autograd: take a real climate archive
(ARCO-ERA5, read anonymously from GCS), express a physical quantity as an
*analytic* SQL formula over its variables, and let ``grad(...)`` differentiate
that formula symbolically — evaluated per grid cell, which is the relational
equivalent of ``jax.vmap(jax.grad(f))`` (each row is an independent point).

Note this is *symbolic* differentiation of an expression, not a finite-
difference spatial gradient: ``grad(f(u, v), u)`` is the exact partial
derivative of the formula ``f``, evaluated at every cell's values.

Two cases:

1. Wind-speed magnitude ``speed = sqrt(u^2 + v^2)``. Its sensitivity to the
eastward wind is ``d(speed)/du = u / speed`` — checked exactly.

2. Saturation vapour pressure ``e_s(T)`` (August-Roche-Magnus form of the
Clausius-Clapeyron relation). ``d(e_s)/dT`` governs how fast the atmosphere's
moisture capacity grows with temperature — checked against the closed-form
slope.

Run standalone (builds the local extension on first use):

uv run benchmarks/grad_era5.py
"""

from __future__ import annotations

import time

import numpy as np
import xarray as xr

import xarray_sql as xql

ARCO_ERA5 = (
"gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
)

# ERA5 variable names start with a digit, so they must be double-quoted in SQL.
U = '"10m_u_component_of_wind"'
V = '"10m_v_component_of_wind"'
T = '"2m_temperature"'


def load_era5_block() -> xr.Dataset:
"""Open ARCO-ERA5 and pull one timestamp over a small region.

Lazy open of the whole archive; only the requested block is read. We keep
it to a few thousand cells so the demo runs in seconds.
"""
full = xr.open_zarr(
ARCO_ERA5, chunks=None, storage_options={"token": "anon"}
)
block = (
full[
[
"10m_u_component_of_wind",
"10m_v_component_of_wind",
"2m_temperature",
]
]
.sel(time="2020-01-01T00")
# A ~North-America box (index-based to avoid lat-orientation pitfalls).
.isel(latitude=slice(120, 200), longitude=slice(900, 1000))
.load()
)
# One partition, so a SQL `ORDER BY latitude DESC` survives the round-trip
# back to xarray (across multiple partitions, to_dataset reconstructs
# coordinates in ascending order regardless of ORDER BY).
return block.chunk()


def wind_speed_sensitivity(ctx: xql.XarrayContext, ref: xr.Dataset) -> None:
"""grad(sqrt(u^2 + v^2)) checked against the exact u / speed, v / speed."""
speed = f"sqrt(power({U}, 2) + power({V}, 2))"
out = ctx.sql(
f"""
SELECT
latitude,
longitude,
{speed} AS wind_speed,
grad({speed}, {U}) AS d_speed_d_u,
grad({speed}, {V}) AS d_speed_d_v
FROM era5
ORDER BY latitude DESC, longitude
"""
).to_dataset(dims=["latitude", "longitude"])

u = ref["10m_u_component_of_wind"]
v = ref["10m_v_component_of_wind"]
speed_ref = np.sqrt(u**2 + v**2)

xr.testing.assert_allclose(
out["wind_speed"], speed_ref.rename("wind_speed")
)
xr.testing.assert_allclose(
out["d_speed_d_u"], (u / speed_ref).rename("d_speed_d_u")
)
xr.testing.assert_allclose(
out["d_speed_d_v"], (v / speed_ref).rename("d_speed_d_v")
)
print(" wind-speed sensitivity matches u/|w|, v/|w| exactly")
print(out)


def clausius_clapeyron(ctx: xql.XarrayContext, ref: xr.Dataset) -> None:
"""grad(e_s(T)) checked against the closed-form Clausius-Clapeyron slope."""
# August-Roche-Magnus: e_s(T) = A * exp(B * tc / (tc + C)), tc = T - 273.15.
a, b, c = 6.1094, 17.625, 243.04
tc = f"({T} - 273.15)"
es = f"{a} * exp({b} * {tc} / ({tc} + {c}))"
out = ctx.sql(
f"""
SELECT
latitude,
longitude,
{es} AS e_s,
grad({es}, {T}) AS de_s_dt
FROM era5
ORDER BY latitude DESC, longitude
"""
).to_dataset(dims=["latitude", "longitude"])

# Reference in float64 (the columns are float32): the exact derivative is
# d(e_s)/dT = e_s * B*C / (tc + C)^2.
temp = ref["2m_temperature"].astype("float64")
tc_ref = temp - 273.15
es_ref = a * np.exp(b * tc_ref / (tc_ref + c))
des_dt_ref = es_ref * (b * c) / (tc_ref + c) ** 2

xr.testing.assert_allclose(out["e_s"], es_ref.rename("e_s"), rtol=1e-5)
xr.testing.assert_allclose(
out["de_s_dt"], des_dt_ref.rename("de_s_dt"), rtol=1e-5
)
print(" d(e_s)/dT matches the closed-form Clausius-Clapeyron slope")
print(out)


def main() -> None:
t0 = time.time()
ds = load_era5_block()
print(f"loaded ERA5 block {dict(ds.sizes)} in {time.time() - t0:.1f}s")

ctx = xql.XarrayContext()
ctx.from_dataset("era5", ds)

print("\n== wind-speed sensitivity: grad(sqrt(u^2 + v^2)) ==")
wind_speed_sensitivity(ctx, ds)

print("\n== Clausius-Clapeyron: grad(e_s(T)) ==")
clausius_clapeyron(ctx, ds)

print("\nOK: symbolic SQL gradients match the analytic references.")


if __name__ == "__main__":
main()
Loading