Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions ext/TensorOperationsBumperExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,14 @@ function TensorOperations._butensor(src, ex...)
buf_sym = gensym("buffer")

# TODO: there is no check for doubled tensor kwargs
newex = quote
$buf_sym = $(Expr(:call, GlobalRef(Bumper, :default_buffer)))
$(
Expr(
:macrocall, GlobalRef(TensorOperations, Symbol("@tensor")),
src, :(allocator = $buf_sym), ex...
)
return Expr(
:block,
Expr(:(=), buf_sym, Expr(:call, GlobalRef(Bumper, :default_buffer))),
Expr(
:macrocall, GlobalRef(TensorOperations, Symbol("@tensor")),
src, :(allocator = $buf_sym), ex...
)
end
return Base.remove_linenums!(newex)
)
end

end
2 changes: 1 addition & 1 deletion src/indexnotation/contractiontrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function insertcontractiontrees!(
end
)
end
push!(postexprs, removelinenumbernode(costcompareex))
push!(postexprs, removeinternallinenumbernodes(costcompareex))
return treeex
end

Expand Down
3 changes: 2 additions & 1 deletion src/indexnotation/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mutable struct TensorParser
contractiontreebuilder = defaulttreebuilder
contractiontreesorter = defaulttreesorter
contractioncostcheck = nothing
postprocessors = [_flatten, removelinenumbernode, addtensoroperations]
postprocessors = [_flatten, addtensoroperations]
return new(
preprocessors,
contractiontreebuilder, contractiontreesorter, contractioncostcheck,
Expand All @@ -34,6 +34,7 @@ function (parser::TensorParser)(ex::Expr)
for p in parser.postprocessors
ex = p(ex)::Expr
end
ex = removeinternallinenumbernodes(ex)::Expr

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is not just added to the list of postprocessors?

return ex
end

Expand Down
35 changes: 35 additions & 0 deletions src/indexnotation/postprocessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,45 @@ function _flatten(ex)
end
end

# package source directory (with trailing separator), used to recognize `LineNumberNode`s that
# point into the parser's own `quote` blocks rather than into user code.
const _PARSER_SRCDIR = joinpath(dirname(@__DIR__), "")

_isinternallinenumber(@nospecialize(x)) =
x isa LineNumberNode && startswith(String(x.file), _PARSER_SRCDIR)

"""
removeinternallinenumbernodes(ex)

Remove all `LineNumberNode`s that point into the TensorOperations source tree, i.e. the ones
introduced by the parser's own `quote` blocks. `LineNumberNode`s originating from user code are
kept, so that the generated code remains attributable to the user's source lines (e.g. for code
coverage).
"""
function removeinternallinenumbernodes(ex)
if isexpr(ex, :block)
# within a block, `LineNumberNode`s are statement markers: drop the internal ones
args = Any[removeinternallinenumbernodes(e) for e in ex.args
if !_isinternallinenumber(e)]
return Expr(:block, args...)
elseif isa(ex, Expr)
# elsewhere (e.g. the mandatory 2nd argument of a `:macrocall`) a `LineNumberNode` may
# be structurally required, so keep all positions and only recurse into nested blocks
return Expr(ex.head, Any[removeinternallinenumbernodes(e) for e in ex.args]...)
else
return ex
end
end

"""
removelinenumbernode(ex)

Remove all `LineNumberNode`s from an expression.

!!! note
Kept for backwards compatibility. The parser now uses
[`removeinternallinenumbernodes`](@ref), which preserves user `LineNumberNode`s so that
generated code stays attributable to the user's source lines (e.g. for code coverage).
"""
function removelinenumbernode(ex)
if isexpr(ex, :block)
Expand Down
15 changes: 15 additions & 0 deletions test/butensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
end

using Bumper
@testset "@butensor preserves user line numbers (issue #280)" begin
# `@butensor` wraps the block in an inner `@tensor`; make sure it does not strip the user's
# line numbers, and does not leak TensorOperations-internal ones.
pkgsrc = dirname(pathof(TensorOperations))
lnns = LineNumberNode[]
collect_lnns(x) = x isa LineNumberNode ? push!(lnns, x) :
x isa Expr && foreach(collect_lnns, x.args)
collect_lnns(@macroexpand @butensor begin
T[a, b] := X[a, c] * Y[c, b]
Z[a, b] := T[a, c] * W[c, b]
end)
@test !any(l -> startswith(String(l.file), pkgsrc), lnns)
@test count(l -> String(l.file) == @__FILE__, lnns) >= 2
end

@testset "Bumper tests with eltype $T" for T in (Float32, ComplexF64)
D1, D2, D3 = 30, 40, 20
d1, d2 = 2, 3
Expand Down
48 changes: 48 additions & 0 deletions test/macro_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,54 @@ end
end
end

# https://github.com/QuantumKitHub/TensorOperations.jl/issues/280: the generated code must keep
# the user's `LineNumberNode`s (so `@tensor` lines show up in code coverage) while dropping the
# parser's own internal ones (which would otherwise pollute the package's coverage).
@testset "line numbers (issue #280)" begin
collectlinenumbernodes(ex, acc = LineNumberNode[]) =
(ex isa LineNumberNode ? push!(acc, ex) :
ex isa Expr && foreach(e -> collectlinenumbernodes(e, acc), ex.args); acc)
pkgsrc = dirname(pathof(TensorOperations))
pkglnns(lnns) = filter(l -> startswith(String(l.file), pkgsrc), lnns)
userlines(lnns) = sort!(unique!([l.line for l in lnns if String(l.file) == @__FILE__]))

@testset "no internal LineNumberNodes leak into generated code" begin
# covers the scalar, dst-reuse and checkpoint `quote` paths in the parser
exprs = [
@macroexpand(@tensor T[a, b] := A[a, c] * B[c, b]),
@macroexpand(@tensor R[a, b] := A[a, c] * B[c, d] * C[d, e] * E[e, f] * F[f, b]),
@macroexpand(@tensor s = X[a, b] * Y[a, b]),
@macroexpand(@tensoropt R[a, b] := A[a, c] * B[c, d] * C[d, e] * E[e, b]),
@macroexpand(@tensor allocator = alloc R[a, b] := A[a, c] * B[c, d] * C[d, b]),
@macroexpand(@tensor costcheck = warn R[a, b] := A[a, c] * B[c, d] * C[d, b]),
@macroexpand(@tensor contractcheck = true R[a, b] := A[a, c] * B[c, b]),
]
for ex in exprs
@test isempty(pkglnns(collectlinenumbernodes(ex)))
end
end

@testset "user LineNumberNodes are preserved per statement" begin
# multi-statement block, including a nested contraction whose intermediate is reused
block = @macroexpand @tensor begin
T[a, e] := A[a, c] * B[c, d] * C[d, e]
D[a, b] := T[a, e] * E[e, b]
s = D[a, b] * F[a, b]
end
lnns = collectlinenumbernodes(block)
@test isempty(pkglnns(lnns))
@test length(userlines(lnns)) >= 3 # one distinct user line per statement

optblock = @macroexpand @tensoropt begin
T[a, e] := A[a, c] * B[c, d] * C[d, e]
D[a, b] := T[a, e] * E[e, b]
end
optlnns = collectlinenumbernodes(optblock)
@test isempty(pkglnns(optlnns))
@test length(userlines(optlnns)) >= 2
end
end

@testset "opt" begin
A = randn(5, 5, 5, 5)
B = randn(5, 5, 5)
Expand Down
Loading