From 26d620470003ef9b3c6c7896741799c1b59b8e0a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 26 May 2026 14:54:21 -0400 Subject: [PATCH 1/2] feat: expose lambda and higher-order array functions Add a Pythonic API for DataFusion's higher-order array functions and the lambda expressions they consume. - Rust: lambda_, lambda_var, array_transform, and array_any_match pyfunctions, plus a ResolveLambdaVariables analyzer rule so expression-builder plans (which emit unresolved lambda variables) resolve before optimization. - Python: array_transform / array_any_match (with list_transform, any_match, list_any_match aliases) accept either a Python callable or an explicit lambda built with lambda_ / lambda_var. Callables are introspected so their parameter names become the lambda parameters. - Tests and docs (expressions guide + agent skill), noting v1 limits: lambda expressions are not serializable, and SQL arrow syntax needs the DuckDB dialect. --- crates/core/src/analyzer.rs | 59 ++++++ crates/core/src/context.rs | 1 + crates/core/src/functions.rs | 37 ++++ crates/core/src/lib.rs | 1 + .../common-operations/expressions.rst | 42 +++++ python/datafusion/functions.py | 175 +++++++++++++++++- python/tests/test_lambda.py | 110 +++++++++++ skills/datafusion_python/SKILL.md | 25 +++ 8 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 crates/core/src/analyzer.rs create mode 100644 python/tests/test_lambda.py diff --git a/crates/core/src/analyzer.rs b/crates/core/src/analyzer.rs new file mode 100644 index 000000000..3a77e08a3 --- /dev/null +++ b/crates/core/src/analyzer.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzer rules layered on top of DataFusion's defaults. + +use datafusion::common::Result; +use datafusion::common::config::ConfigOptions; +use datafusion::logical_expr::LogicalPlan; +use datafusion::optimizer::AnalyzerRule; + +/// Resolve [`LambdaVariable`] references into bound lambda parameters. +/// +/// DataFusion's SQL planner resolves lambda variables inline as it plans a +/// higher-order function call, so SQL-built plans never carry unresolved +/// variables. Plans assembled programmatically through the Python expression +/// builder (e.g. `array_transform(col("xs"), lambda_(["v"], lambda_var("v")))`) +/// do carry them, and nothing in the default analyzer resolves them. This rule +/// runs [`LogicalPlan::resolve_lambda_variables`] so both construction paths +/// reach the optimizer with bound lambdas. +/// +/// [`LambdaVariable`]: datafusion::logical_expr::expr::LambdaVariable +#[derive(Debug)] +pub struct ResolveLambdaVariables {} + +impl ResolveLambdaVariables { + pub fn new() -> Self { + Self {} + } +} + +impl Default for ResolveLambdaVariables { + fn default() -> Self { + Self::new() + } +} + +impl AnalyzerRule for ResolveLambdaVariables { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + plan.resolve_lambda_variables().map(|t| t.data) + } + + fn name(&self) -> &str { + "resolve_lambda_variables" + } +} diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index bca8e30f5..cd212a49f 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -396,6 +396,7 @@ impl PySessionContext { .with_config(config) .with_runtime_env(runtime) .with_default_features() + .with_analyzer_rule(Arc::new(crate::analyzer::ResolveLambdaVariables::new())) .build(); let ctx = Arc::new(SessionContext::new_with_state(session_state)); Ok(PySessionContext { diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 5f47d123b..41286a746 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -159,6 +159,37 @@ fn array_slice(array: PyExpr, begin: PyExpr, end: PyExpr, stride: Option .into() } +/// Create a lambda expression from a list of parameter names and a body +/// expression. The body should reference the parameters via [`lambda_var`]. +/// Exposed to Python as `lambda_` because `lambda` is a reserved keyword. +#[pyfunction] +#[pyo3(name = "lambda_")] +fn py_lambda(params: Vec, body: PyExpr) -> PyExpr { + datafusion::logical_expr::lambda(params, body.into()).into() +} + +/// Create an unresolved lambda variable reference by name. The owning +/// higher-order function resolves it against its lambda parameters during +/// planning. +#[pyfunction] +fn lambda_var(name: String) -> PyExpr { + datafusion::logical_expr::lambda_var(name).into() +} + +/// Higher-order function: apply `transform` (a lambda) to each element of +/// `array`, returning a new array of the results. +#[pyfunction] +fn array_transform(array: PyExpr, transform: PyExpr) -> PyExpr { + datafusion::functions_nested::expr_fn::array_transform(array.into(), transform.into()).into() +} + +/// Higher-order function: return true if any element of `array` satisfies +/// `predicate` (a lambda returning a boolean). +#[pyfunction] +fn array_any_match(array: PyExpr, predicate: PyExpr) -> PyExpr { + datafusion::functions_nested::expr_fn::array_any_match(array.into(), predicate.into()).into() +} + /// Computes a binary hash of the given data. type is the algorithm to use. /// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. // #[pyfunction(value, method)] @@ -1082,6 +1113,12 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(encode))?; m.add_wrapped(wrap_pyfunction!(decode))?; + // Lambda / higher-order functions + m.add_wrapped(wrap_pyfunction!(py_lambda))?; + m.add_wrapped(wrap_pyfunction!(lambda_var))?; + m.add_wrapped(wrap_pyfunction!(array_transform))?; + m.add_wrapped(wrap_pyfunction!(array_any_match))?; + // Array Functions m.add_wrapped(wrap_pyfunction!(array_append))?; m.add_wrapped(wrap_pyfunction!(array_concat))?; diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 8b622d344..fb67b5703 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -27,6 +27,7 @@ use mimalloc::MiMalloc; use pyo3::prelude::*; #[allow(clippy::borrow_deref_ref)] +pub mod analyzer; pub mod catalog; pub mod codec; pub mod common; diff --git a/docs/source/user-guide/common-operations/expressions.rst b/docs/source/user-guide/common-operations/expressions.rst index ae1ccc0dc..f7b93496e 100644 --- a/docs/source/user-guide/common-operations/expressions.rst +++ b/docs/source/user-guide/common-operations/expressions.rst @@ -145,6 +145,48 @@ This function returns a new array with the elements repeated. In this example, the `repeated_array` column will contain `[[1, 2, 3], [1, 2, 3]]`. +Higher-order functions and lambdas +---------------------------------- + +Some array functions are *higher-order*: they take a lambda that runs once per +element. :py:func:`~datafusion.functions.array_transform` maps a lambda over +every element, and :py:func:`~datafusion.functions.array_any_match` returns +whether any element satisfies a predicate lambda. + +The simplest way to supply a lambda is a Python ``lambda``. Its parameter names +become the lambda parameters, and its return value becomes the body. + +.. ipython:: python + + from datafusion import SessionContext, col + from datafusion import functions as f + + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5]]}) + df.select(f.array_transform(col("a"), lambda v: v * 2).alias("doubled")) + df.select(f.array_any_match(col("a"), lambda v: v > 3).alias("has_big")) + +If you need explicit control over parameter names, build the lambda with +:py:func:`~datafusion.functions.lambda_` and reference its parameters with +:py:func:`~datafusion.functions.lambda_var`. The following is equivalent to the +``array_transform`` call above. + +.. ipython:: python + + from datafusion import lit + + double_fn = f.lambda_(["v"], f.lambda_var("v") * lit(2)) + df.select(f.array_transform(col("a"), double_fn).alias("doubled")) + +.. note:: + + Lambda expressions cannot yet be serialized: calling + :py:meth:`~datafusion.expr.Expr.to_bytes` or pickling an expression that + contains a lambda raises ``Lambda not implemented``. SQL lambda syntax + (``x -> x * 2``) is only parsed by dialects that support lambdas; set + ``datafusion.sql_parser.dialect`` to ``DuckDB`` to use it. The Python + expression builder shown above works regardless of dialect. + Testing membership in a list ---------------------------- diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 28c10e005..afbbdf218 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -38,10 +38,14 @@ from __future__ import annotations -from typing import Any +import inspect +from typing import TYPE_CHECKING, Any import pyarrow as pa +if TYPE_CHECKING: + from collections.abc import Callable + from datafusion._internal import functions as f from datafusion.common import NullTreatment from datafusion.expr import ( @@ -61,12 +65,14 @@ "acos", "acosh", "alias", + "any_match", "approx_distinct", "approx_median", "approx_percentile_cont", "approx_percentile_cont_with_weight", "array", "array_agg", + "array_any_match", "array_any_value", "array_append", "array_cat", @@ -108,6 +114,7 @@ "array_slice", "array_sort", "array_to_string", + "array_transform", "array_union", "arrays_overlap", "arrays_zip", @@ -188,6 +195,8 @@ "isnan", "iszero", "lag", + "lambda_", + "lambda_var", "last_value", "lcm", "lead", @@ -195,6 +204,7 @@ "left", "length", "levenshtein", + "list_any_match", "list_any_value", "list_append", "list_cat", @@ -237,6 +247,7 @@ "list_slice", "list_sort", "list_to_string", + "list_transform", "list_union", "list_zip", "ln", @@ -459,6 +470,168 @@ def list_join(expr: Expr, delimiter: Expr | str) -> Expr: return array_to_string(expr, delimiter) +def lambda_var(name: str) -> Expr: + """Create an unresolved reference to a lambda parameter by ``name``. + + Use this inside the body passed to :py:func:`lambda_` to refer to one of the + lambda's parameters. The owning higher-order function (such as + :py:func:`array_transform`) binds the variable to a concrete element type + during query planning. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> double_fn = F.lambda_(["v"], F.lambda_var("v") * lit(2)) + >>> df.select( + ... F.array_transform(col("a"), double_fn).alias("d") + ... ).collect_column("d")[0].as_py() + [2, 4, 6] + + See Also: + :py:func:`lambda_`, :py:func:`array_transform`, :py:func:`array_any_match`. + """ + return Expr(f.lambda_var(name)) + + +def lambda_(params: list[str], body: Expr) -> Expr: + """Create a lambda expression from parameter names and a body expression. + + This is the explicit form of building a lambda. Most callers can instead + pass a Python callable directly to a higher-order function such as + :py:func:`array_transform`, which builds the lambda automatically. Reach for + ``lambda_`` when you want explicit control over the parameter names. + + Args: + params: Ordered lambda parameter names. + body: Body expression that references the parameters via + :py:func:`lambda_var`. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> double_fn = F.lambda_(["v"], F.lambda_var("v") * lit(2)) + >>> df.select( + ... F.array_transform(col("a"), double_fn).alias("d") + ... ).collect_column("d")[0].as_py() + [2, 4, 6] + + See Also: + :py:func:`lambda_var`, :py:func:`array_transform`, :py:func:`array_any_match`. + """ + return Expr(f.lambda_(params, body.expr)) + + +def _to_lambda(fn: Expr | Callable[..., Any]) -> Expr: + """Coerce ``fn`` to a lambda ``Expr``. + + Accepts either an ``Expr`` produced by :py:func:`lambda_` (returned + unchanged) or a Python callable. A callable is introspected for its + parameter names; those names become :py:func:`lambda_var` references passed + positionally into the callable, and its return value (coerced to an + ``Expr``) becomes the lambda body. + """ + if isinstance(fn, Expr): + return fn + if not callable(fn): + msg = f"expected an Expr or callable, got {type(fn).__name__}" + raise TypeError(msg) + params = list(inspect.signature(fn).parameters) + if not params: + msg = "lambda callable must accept at least one parameter" + raise ValueError(msg) + body = coerce_to_expr(fn(*[lambda_var(p) for p in params])) + return lambda_(params, body) + + +def array_transform(array: Expr, transform: Expr | Callable[..., Any]) -> Expr: + """Transform each element of ``array`` with a lambda. + + ``transform`` may be a Python callable, which is converted to a lambda + automatically (its parameter names become the lambda parameters), or an + explicit lambda built with :py:func:`lambda_`. + + Examples: + Using a Python callable: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> df.select( + ... F.array_transform(col("a"), lambda v: v * 2).alias("d") + ... ).collect_column("d")[0].as_py() + [2, 4, 6] + + Using an explicit lambda built with :py:func:`lambda_`: + + >>> double_fn = F.lambda_(["v"], F.lambda_var("v") * lit(2)) + >>> df.select( + ... F.array_transform(col("a"), double_fn).alias("d") + ... ).collect_column("d")[0].as_py() + [2, 4, 6] + + See Also: + :py:func:`array_any_match`, :py:func:`lambda_`. + """ + return Expr(f.array_transform(array.expr, _to_lambda(transform).expr)) + + +def list_transform(array: Expr, transform: Expr | Callable[..., Any]) -> Expr: + """Transform each element of a list with a lambda. + + See Also: + This is an alias for :py:func:`array_transform`. + """ + return array_transform(array, transform) + + +def array_any_match(array: Expr, predicate: Expr | Callable[..., Any]) -> Expr: + """Return ``True`` if any element of ``array`` satisfies ``predicate``. + + ``predicate`` may be a Python callable, converted to a lambda + automatically, or an explicit lambda built with :py:func:`lambda_`. It must + return a boolean expression. + + Examples: + Using a Python callable: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> df.select( + ... F.array_any_match(col("a"), lambda v: v > 2).alias("m") + ... ).collect_column("m")[0].as_py() + True + + Using an explicit lambda built with :py:func:`lambda_`: + + >>> predicate = F.lambda_(["v"], F.lambda_var("v") > lit(2)) + >>> df.select( + ... F.array_any_match(col("a"), predicate).alias("m") + ... ).collect_column("m")[0].as_py() + True + + See Also: + :py:func:`array_transform`, :py:func:`lambda_`. + """ + return Expr(f.array_any_match(array.expr, _to_lambda(predicate).expr)) + + +def any_match(array: Expr, predicate: Expr | Callable[..., Any]) -> Expr: + """Return ``True`` if any element of an array satisfies a predicate. + + See Also: + This is an alias for :py:func:`array_any_match`. + """ + return array_any_match(array, predicate) + + +def list_any_match(array: Expr, predicate: Expr | Callable[..., Any]) -> Expr: + """Return ``True`` if any element of a list satisfies a predicate. + + See Also: + This is an alias for :py:func:`array_any_match`. + """ + return array_any_match(array, predicate) + + def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: """Returns whether the argument is contained within the list ``values``. diff --git a/python/tests/test_lambda.py b/python/tests/test_lambda.py new file mode 100644 index 000000000..e8546ad3c --- /dev/null +++ b/python/tests/test_lambda.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for lambda expressions and higher-order array functions.""" + +import pytest +from datafusion import SessionConfig, SessionContext, col, lit +from datafusion import functions as f + + +@pytest.fixture +def df(): + ctx = SessionContext() + return ctx.from_pydict({"a": [[1, 2, 3], [4, 5]]}) + + +def _column(df, expr, name): + return df.select(expr.alias(name)).collect_column(name).to_pylist() + + +def test_array_transform_callable(df): + expr = f.array_transform(col("a"), lambda v: v * 2) + assert _column(df, expr, "d") == [[2, 4, 6], [8, 10]] + + +def test_array_transform_explicit_lambda(df): + transform = f.lambda_(["v"], f.lambda_var("v") * lit(2)) + expr = f.array_transform(col("a"), transform) + assert _column(df, expr, "d") == [[2, 4, 6], [8, 10]] + + +def test_array_transform_literal_body_is_coerced(df): + expr = f.array_transform(col("a"), lambda v: 0) + assert _column(df, expr, "z") == [[0, 0, 0], [0, 0]] + + +def test_list_transform_alias(df): + expr = f.list_transform(col("a"), lambda v: v + 1) + assert _column(df, expr, "d") == [[2, 3, 4], [5, 6]] + + +def test_array_any_match_callable(df): + expr = f.array_any_match(col("a"), lambda v: v > 3) + assert _column(df, expr, "m") == [False, True] + + +def test_array_any_match_explicit_lambda(df): + predicate = f.lambda_(["v"], f.lambda_var("v") > lit(2)) + expr = f.array_any_match(col("a"), predicate) + assert _column(df, expr, "m") == [True, True] + + +@pytest.mark.parametrize("alias", [f.any_match, f.list_any_match]) +def test_any_match_aliases(df, alias): + expr = alias(col("a"), lambda v: v > 4) + assert _column(df, expr, "m") == [False, True] + + +def test_lambda_param_name_appears_in_plan(df): + # The user-chosen parameter name should survive into the displayed plan + # rather than a synthetic placeholder. + expr = f.array_transform(col("a"), lambda value: value * 2) + assert "value" in expr.canonical_name() + + +def test_to_lambda_rejects_non_callable(): + with pytest.raises(TypeError, match="expected an Expr or callable"): + f.array_transform(col("a"), 42) + + +def test_to_lambda_rejects_zero_arg_callable(): + with pytest.raises(ValueError, match="at least one parameter"): + f.array_transform(col("a"), lambda: lit(1)) + + +def test_sql_lambda_requires_duckdb_dialect(): + # Lambda arrow syntax (``x -> ...``) is only parsed by dialects that + # support lambda functions. The default Generic dialect treats ``->`` as + # the JSON arrow operator, so ``x`` is read as a column reference. + ctx = SessionContext() + with pytest.raises(Exception, match="No field named x"): + ctx.sql("select array_transform([1, 2, 3], x -> x * 2) as d").collect() + + duckdb_ctx = SessionContext( + SessionConfig().set("datafusion.sql_parser.dialect", "DuckDB") + ) + result = duckdb_ctx.sql( + "select array_transform([1, 2, 3], x -> x * 2) as d" + ).collect_column("d") + assert result.to_pylist() == [[2, 4, 6]] + + +def test_pickle_lambda_expr_not_supported(): + # v1 limitation: upstream proto serialization rejects lambda expressions. + expr = f.array_transform(col("a"), lambda v: v * 2) + with pytest.raises(Exception, match="Lambda not implemented"): + expr.to_bytes() diff --git a/skills/datafusion_python/SKILL.md b/skills/datafusion_python/SKILL.md index 98fa2c7aa..49838adcd 100644 --- a/skills/datafusion_python/SKILL.md +++ b/skills/datafusion_python/SKILL.md @@ -488,6 +488,31 @@ col("array_col")[0] # access array element (0-indexed) col("array_col")[1:3] # array slice (0-indexed) ``` +### Higher-Order Array Functions (Lambdas) + +Some array functions take a lambda that runs once per element. Pass a Python +`lambda` directly — its parameter names become the lambda parameters and its +return value becomes the body: + +```python +F.array_transform(col("a"), lambda v: v * 2) # map: [1,2,3] -> [2,4,6] +F.array_any_match(col("a"), lambda v: v > 3) # predicate: any element > 3 +``` + +Aliases: `list_transform` for `array_transform`; `any_match` / `list_any_match` +for `array_any_match`. + +For explicit parameter names, build the lambda by hand: + +```python +F.array_transform(col("a"), F.lambda_(["v"], F.lambda_var("v") * lit(2))) +``` + +Limitations: lambda expressions cannot be serialized (`Expr.to_bytes` / pickle +raise `Lambda not implemented`). SQL lambda syntax (`x -> x * 2`) needs the +DuckDB dialect (`SessionConfig().set("datafusion.sql_parser.dialect", "DuckDB")`); +the Python builder above is dialect-independent. + ## SQL-to-DataFrame Reference | SQL | DataFrame API | From 78df1c029a0bd12883b25f9c11fd04e9d6383d52 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 27 May 2026 09:54:45 -0400 Subject: [PATCH 2/2] test: fold lambda tests into pytest parameterization Combine the eight higher-order function result tests into a single parametrized test_higher_order_function_results, and the two to_lambda rejection tests into test_to_lambda_rejects_invalid_arg. Each case keeps a readable id via pytest.param. Co-Authored-By: Claude --- python/tests/test_lambda.py | 110 +++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 44 deletions(-) diff --git a/python/tests/test_lambda.py b/python/tests/test_lambda.py index e8546ad3c..c3eeb9a0c 100644 --- a/python/tests/test_lambda.py +++ b/python/tests/test_lambda.py @@ -31,42 +31,57 @@ def _column(df, expr, name): return df.select(expr.alias(name)).collect_column(name).to_pylist() -def test_array_transform_callable(df): - expr = f.array_transform(col("a"), lambda v: v * 2) - assert _column(df, expr, "d") == [[2, 4, 6], [8, 10]] - - -def test_array_transform_explicit_lambda(df): - transform = f.lambda_(["v"], f.lambda_var("v") * lit(2)) - expr = f.array_transform(col("a"), transform) - assert _column(df, expr, "d") == [[2, 4, 6], [8, 10]] - - -def test_array_transform_literal_body_is_coerced(df): - expr = f.array_transform(col("a"), lambda v: 0) - assert _column(df, expr, "z") == [[0, 0, 0], [0, 0]] - - -def test_list_transform_alias(df): - expr = f.list_transform(col("a"), lambda v: v + 1) - assert _column(df, expr, "d") == [[2, 3, 4], [5, 6]] - - -def test_array_any_match_callable(df): - expr = f.array_any_match(col("a"), lambda v: v > 3) - assert _column(df, expr, "m") == [False, True] - - -def test_array_any_match_explicit_lambda(df): - predicate = f.lambda_(["v"], f.lambda_var("v") > lit(2)) - expr = f.array_any_match(col("a"), predicate) - assert _column(df, expr, "m") == [True, True] - - -@pytest.mark.parametrize("alias", [f.any_match, f.list_any_match]) -def test_any_match_aliases(df, alias): - expr = alias(col("a"), lambda v: v > 4) - assert _column(df, expr, "m") == [False, True] +@pytest.mark.parametrize( + ("build_expr", "expected"), + [ + pytest.param( + lambda: f.array_transform(col("a"), lambda v: v * 2), + [[2, 4, 6], [8, 10]], + id="array_transform_callable", + ), + pytest.param( + lambda: f.array_transform( + col("a"), f.lambda_(["v"], f.lambda_var("v") * lit(2)) + ), + [[2, 4, 6], [8, 10]], + id="array_transform_explicit_lambda", + ), + pytest.param( + lambda: f.array_transform(col("a"), lambda v: 0), + [[0, 0, 0], [0, 0]], + id="array_transform_literal_body_is_coerced", + ), + pytest.param( + lambda: f.list_transform(col("a"), lambda v: v + 1), + [[2, 3, 4], [5, 6]], + id="list_transform_alias", + ), + pytest.param( + lambda: f.array_any_match(col("a"), lambda v: v > 3), + [False, True], + id="array_any_match_callable", + ), + pytest.param( + lambda: f.array_any_match( + col("a"), f.lambda_(["v"], f.lambda_var("v") > lit(2)) + ), + [True, True], + id="array_any_match_explicit_lambda", + ), + pytest.param( + lambda: f.any_match(col("a"), lambda v: v > 4), + [False, True], + id="any_match_alias", + ), + pytest.param( + lambda: f.list_any_match(col("a"), lambda v: v > 4), + [False, True], + id="list_any_match_alias", + ), + ], +) +def test_higher_order_function_results(df, build_expr, expected): + assert _column(df, build_expr(), "r") == expected def test_lambda_param_name_appears_in_plan(df): @@ -76,14 +91,21 @@ def test_lambda_param_name_appears_in_plan(df): assert "value" in expr.canonical_name() -def test_to_lambda_rejects_non_callable(): - with pytest.raises(TypeError, match="expected an Expr or callable"): - f.array_transform(col("a"), 42) - - -def test_to_lambda_rejects_zero_arg_callable(): - with pytest.raises(ValueError, match="at least one parameter"): - f.array_transform(col("a"), lambda: lit(1)) +@pytest.mark.parametrize( + ("arg", "exc_type", "match"), + [ + pytest.param(42, TypeError, "expected an Expr or callable", id="non_callable"), + pytest.param( + lambda: lit(1), + ValueError, + "at least one parameter", + id="zero_arg_callable", + ), + ], +) +def test_to_lambda_rejects_invalid_arg(arg, exc_type, match): + with pytest.raises(exc_type, match=match): + f.array_transform(col("a"), arg) def test_sql_lambda_requires_duckdb_dialect():