Skip to content
56 changes: 56 additions & 0 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using Enzyme
Expand Down Expand Up @@ -264,6 +265,34 @@ for f in (:svd_compact!, :svd_full!)
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{TA},
USVᴴ::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
$f(A.val, USVᴴ.val, alg.val)
if !isa(A, Const)
if $(f == svd_compact!)
make_zero!(USVᴴ.dval[2].diag)
else
make_zero!(USVᴴ.dval[2])
end
!isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
make_zero!(A.dval)
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return USVᴴ
elseif EnzymeRules.needs_primal(config)
return USVᴴ.val
elseif EnzymeRules.needs_shadow(config)
return USVᴴ.dval
else
return nothing
end
end
end
end

Expand Down Expand Up @@ -502,5 +531,32 @@ function EnzymeRules.reverse(
!isa(S, Const) && !A_is_arg && make_zero!(S.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(svd_vals!)},
::Type{RT},
A::Annotation{TA},
S::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
U, S_, Vᴴ = svd_compact!(A.val, alg.val)
if !isa(A, Const) && !isa(S, Const)
ΔS = A_is_arg ? make_zero(S.dval) : S.dval
svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS)
A_is_arg && (S.dval .= ΔS)
end
!A_is_arg && make_zero!(A.dval)
copyto!(S.val, diagview(S_))
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return S
elseif EnzymeRules.needs_primal(config)
return S.val
elseif EnzymeRules.needs_shadow(config)
return S.dval
else
return nothing
end
end

end
52 changes: 48 additions & 4 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: svd_pushforward!, svd_trunc_pushforward!, svd_vals_pushforward!
using MatrixAlgebraKit: TruncatedAlgorithm
using LinearAlgebra

Expand Down Expand Up @@ -538,7 +539,7 @@ for (f!, f) in (
(:svd_compact!, :svd_compact),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
Expand All @@ -562,7 +563,18 @@ for (f!, f) in (
end
return USVᴴ_dUSVᴴ, svd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
$f!(A, USVᴴ, Mooncake.primal(alg_dalg))
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
return USVᴴ_dUSVᴴ
end
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
Expand All @@ -585,10 +597,23 @@ for (f!, f) in (
end
return USVᴴ_codual, svd_adjoint
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
dUSVᴴ = Mooncake.zero_tangent(USVᴴ)
USVᴴ_dual = Dual(USVᴴ, dUSVᴴ)
U, S, Vᴴ = Mooncake.primal(USVᴴ_dual)
dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual)
U, dU = arrayify(U, dU_)
S, dS = arrayify(S, dS_)
Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_)
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
return USVᴴ_dual
end
end
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -604,8 +629,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
end
return S_dS, svd_vals_adjoint
end
function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
S, dS = arrayify(S_dS)
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
copy!(S, diagview(USVᴴ[2]))
svd_vals_pushforward!(dA, A, USVᴴ, dS)
return S_dS
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -624,6 +658,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
end
return S_codual, svd_vals_adjoint
end
function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
S = diagview(USVᴴ[2])
S_dual = Dual(S, Mooncake.zero_tangent(S))
S_, dS = arrayify(S_dual)
svd_vals_pushforward!(dA, A, USVᴴ, dS)
return S_dual
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
Expand Down
3 changes: 2 additions & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester, lu!, diagm
using LinearAlgebra: isposdef, issymmetric
using LinearAlgebra: Diagonal, Hermitian, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: UpperTriangular, LowerTriangular, UniformScaling
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

export isisometric, isunitary, ishermitian, isantihermitian
Expand Down Expand Up @@ -132,6 +132,7 @@ include("pullbacks/polar.jl")
include("pushforwards/polar.jl")
include("pushforwards/eig.jl")
include("pushforwards/eigh.jl")
include("pushforwards/svd.jl")

include("precompile.jl")

Expand Down
7 changes: 6 additions & 1 deletion src/pushforwards/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ function eig_pushforward!(
end
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
mul!(ΔV, V, ∂K, 1, 0)
mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, V; dims = 1)
infinitesimal_phases = imag.(ΔV[I] ./ V[I])
ΔV .-= im .* V .* infinitesimal_phases
end
end
return ΔDV
end
Expand Down
5 changes: 5 additions & 0 deletions src/pushforwards/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ function eigh_pushforward!(
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
ΔV = mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, V; dims = 1)
infinitesimal_phases = imag.(ΔV[I] ./ V[I])
ΔV .-= im .* V .* infinitesimal_phases
end
end
return (ΔD, ΔV)
end
Expand Down
77 changes: 77 additions & 0 deletions src/pushforwards/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = default_pullback_rank_atol(A), kwargs...)
U, Smat, Vᴴ = USVᴴ
m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
S = diagview(Smat)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
r = svd_rank(S; rank_atol)

U₁ = view(U, :, 1:r)
S₁ = view(S, 1:r)
V₁ᴴ = view(Vᴴ, 1:r, :)

# compact region
V₁ = adjoint(V₁ᴴ)
ΔAV₁ = ΔA * V₁
UᴴΔAV₁ = U₁' * ΔAV₁
if !iszerotangent(ΔS)
ΔS₁ = view(diagview(ΔS), 1:r)
ΔS₁ .= real.(diagview(UᴴΔAV₁))
end
if !iszerotangent(ΔU) || !iszerotangent(ΔVᴴ)
hUᴴΔAV₁ = inv_safe.(transpose(S₁) .- S₁) .* project_hermitian(UᴴΔAV₁)
Comment thread
kshyatt marked this conversation as resolved.
aUᴴΔAV₁ = inv_safe.(transpose(S₁) .+ S₁) .* project_antihermitian(UᴴΔAV₁)
Comment on lines +23 to +24

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think below only the sum and difference are actually used, we could use a kernel like

function _avgdiff!(A::AbstractArray, B::AbstractArray)
axes(A) == axes(B) || throw(DimensionMismatch())
@simd for I in eachindex(A, B)
@inbounds begin
a = A[I]
b = B[I]
A[I] = (a + b) / 2
B[I] = b - a
end
end
return A, B
end
to avoid the two extra allocations, but I'm also happy to just leave them as-is, it's hard to imagine this really making that huge of a difference

if !iszerotangent(ΔU)
ΔU₁ = view(ΔU, :, 1:r)
K̇ = hUᴴΔAV₁ + aUᴴΔAV₁
mul!(ΔU₁, U₁, K̇)
if m > r
ΔAV₁ = mul!(ΔAV₁, U₁, UᴴΔAV₁, -1, 1)
ΔU₁ .+= ΔAV₁ ./ transpose(S₁)
end
if size(U, 2) > r # these columns of U are undetermined, but U' * U̇ should be antihermitian
U₂ = view(U, :, (r + 1):size(U, 2))
ΔU₁ᴴU₂ = ΔU₁' * U₂
ΔU₂ = view(ΔU, :, (r + 1):size(U, 2))
mul!(ΔU₂, U₁, ΔU₁ᴴU₂, -1, 0)
end
end
if !iszerotangent(ΔVᴴ)
ΔV₁ᴴ = view(ΔVᴴ, 1:r, :)
Ṁ = hUᴴΔAV₁ - aUᴴΔAV₁
mul!(ΔV₁ᴴ, Ṁ', V₁ᴴ)
if n > r
UᴴΔA₁ = U₁' * ΔA
UᴴΔA₁ = mul!(UᴴΔA₁, UᴴΔAV₁, V₁ᴴ, -1, 1)
ΔV₁ᴴ .+= S₁ .\ UᴴΔA₁
end
if size(Vᴴ, 1) > r # these rows of Vᴴ are undetermined, but V * V̇ should be antihermitian
V₂ᴴ = view(Vᴴ, (r + 1):size(Vᴴ, 1), :)
V₂ᴴΔV₁ = V₂ᴴ * ΔV₁ᴴ'
ΔV₂ᴴ = view(ΔVᴴ, (r + 1):size(Vᴴ, 1), :)
mul!(ΔV₂ᴴ, V₂ᴴΔV₁, V₁ᴴ, -1, 0)
end
end
if eltype(U) <: Complex && !iszerotangent(ΔU) && !iszerotangent(ΔVᴴ) # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, U₁; dims = 1)
infinitesimal_phases = imag.(ΔU₁[I] ./ U₁[I])
ΔU₁ .-= im .* U₁ .* infinitesimal_phases
ΔV₁ᴴ .+= im .* transpose(infinitesimal_phases) .* V₁ᴴ
end
end
return (ΔU, ΔS, ΔVᴴ)
end

function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
# TODO

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to either not define this or explicitly error()?

end

function svd_vals_pushforward!(
ΔA, A, USVᴴ, ΔS, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
)
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end
4 changes: 2 additions & 2 deletions test/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
if !is_buildkite
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
Loading
Loading