Skip to content

Some basic svd forward rules and tests#247

Merged
kshyatt merged 23 commits into
mainfrom
ksh/svd_fwd
Jun 23, 2026
Merged

Some basic svd forward rules and tests#247
kshyatt merged 23 commits into
mainfrom
ksh/svd_fwd

Conversation

@kshyatt

@kshyatt kshyatt commented Jun 8, 2026

Copy link
Copy Markdown
Member

Definitely not optimized...

Comment thread src/pushforwards/svd.jl Outdated
@kshyatt

kshyatt commented Jun 11, 2026

Copy link
Copy Markdown
Member Author

Took another look at this. The Enzyme tests seem to be failing because of finite differences, for svd_full for example when size(A) == (17, 19), the finite differences result for dU is all over the place compared to that from the rule for the section of U which is only present in the full case. I'll try to think of a nice way to handle this, I think it is not occurring for Mooncake because Mooncake uses a different technique for checking against FD.

Comment thread src/pushforwards/svd.jl Outdated
@github-actions

github-actions Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Your PR no longer requires formatting changes. Thank you for your contribution!

Comment thread src/common/initialization.jl Outdated
Comment thread src/pushforwards/svd.jl Outdated
Comment thread src/pushforwards/svd.jl Outdated
@Jutho

Jutho commented Jun 15, 2026

Copy link
Copy Markdown
Member

In the gauge fixing parts, we could use more views for the slicing ΔU₁[I] etc, but I did not want to push my luck on the GPU with views into views using CartesianIndices.

@Jutho

Jutho commented Jun 15, 2026

Copy link
Copy Markdown
Member

Anyway, I will leave this aside now for a while and look at some other PRs. In principle, tests can now be expanded to cover the complex case and svd_compact of rectangular matrices as well (but as always, I would only consider double precision for finite difference comparisons).

@lkdvos lkdvos left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some final small comments, otherwise good to go?

Comment thread src/pushforwards/svd.jl Outdated
Comment on lines +23 to +24
hUᴴΔAV₁ = inv_safe.(transpose(S₁) .- S₁) .* project_hermitian(UᴴΔAV₁)
aUᴴΔAV₁ = inv_safe.(transpose(S₁) .+ S₁) .* project_antihermitian(UᴴΔAV₁)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think below only the sum and difference are actually used, we could use a kernel like

function _avgdiff!(A::AbstractArray, B::AbstractArray)
axes(A) == axes(B) || throw(DimensionMismatch())
@simd for I in eachindex(A, B)
@inbounds begin
a = A[I]
b = B[I]
A[I] = (a + b) / 2
B[I] = b - a
end
end
return A, B
end
to avoid the two extra allocations, but I'm also happy to just leave them as-is, it's hard to imagine this really making that huge of a difference

Comment thread src/pushforwards/svd.jl Outdated
@kshyatt

kshyatt commented Jun 15, 2026

Copy link
Copy Markdown
Member Author

The GPU tests will probably fail until a new version is tagged at GPUArrays (JuliaGPU/GPUArrays.jl#738)

Comment thread test/testsuite/mooncake/svd.jl Outdated
Comment thread test/testsuite/mooncake/svd.jl Outdated
@kshyatt

kshyatt commented Jun 17, 2026

Copy link
Copy Markdown
Member Author

Ran it locally and it worked, let's hope!

S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
$f!(A, USVᴴ, Mooncake.primal(alg_dalg))
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This again works because svd_pushforward! doesn't actually need A, since A is destroyed at this point. Not sure if it is worth to add a comment.

Comment thread ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl Outdated
Comment on lines +216 to +217
# have to override this as methods are missing in GPUArrays for the various
# views of Diagonal of ΔA

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a confusing comment: what exactly is missing?
Might be useful to keep track of this, in case it gets fixed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's again a situation of mul!(::Diagonal{T, CuVector{T}}, [horrific view of adjoint of view], CuArray) which GPUArrays cannot dispatch onto at all.

Comment thread ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl Outdated
Comment thread ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl Outdated
@kshyatt

kshyatt commented Jun 19, 2026

Copy link
Copy Markdown
Member Author

Of course when I ran this locally (5 times!!!) it passed every single time.

@kshyatt

kshyatt commented Jun 19, 2026

Copy link
Copy Markdown
Member Author

Some of these look like tolerance issues to me at least

Comment thread src/pushforwards/svd.jl
if eltype(U) <: Complex && !iszerotangent(ΔU) && !iszerotangent(ΔVᴴ) # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, U₁; dims = 1)
infinitesimal_phases = imag.(ΔU₁[I] ./ U₁[I])
infinitesimal_phases = imag.(ΔU₁[I] .* inv_safe.(U₁[I]))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be suprised if that is needed or makes a difference. U₁[I] is the maximum element of every column of U₁. As U₁ is an isometric matrix, all of its columns have norm 1, and therefore, the largest element needs to be at least 1/sqrt(m), in magnitude, with m the number of rows. Typically, it will be larger.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be quite surprised too but I'm otherwise very confused where the NaN are emerging from (the ./transpose(S1) line has the same objection). I'll try stepping through the pushfoward to see if I can find the culprit.

@Jutho

Jutho commented Jun 19, 2026

Copy link
Copy Markdown
Member

I am also happy to take a look at the test failures.

@kshyatt

kshyatt commented Jun 19, 2026

Copy link
Copy Markdown
Member Author

I am also happy to take a look at the test failures.

If you like! I'm about to get on a plane and unsure if it will have WiFi with which to push if I can find the problem

@Jutho

Jutho commented Jun 21, 2026

Copy link
Copy Markdown
Member

Ok, I located the problem, but also started changing various other things so this turned out a bit bigger than anticipated. The problem was that the initialization of the tangent in the frule is random introducing NaNs, subnormals and other random stuff. But svd_pushforward only touches the diagonal of S, and did not set the off-diagonal entries (in the case of svd_full) to zero.

@Jutho

Jutho commented Jun 21, 2026

Copy link
Copy Markdown
Member

I also tried to make some steps in the various implementations more consistent, and to simplify them a bit here and there. I introduced a has_equal_storage for simplifying and generalizing the detection of the input matrix A overlapping with some of the output arguments.

Comment thread ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl Outdated
Comment thread src/common/utility.jl
Comment thread src/pushforwards/eig.jl
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
mul!(ΔV, V, ∂K, 1, 0)
mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we enforcing that the primal computation is always gauge fixing, and is that relevant for this?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we are not enforcing this, but it is the default. This is relevant to the extent that if gaugefix = false, there is no way to compute the pushforward at all. So we could throw a warning. There is some asymmetry between the pullbacks and the pushforwards in that the latter don't throw warnings for cases where there is gauge freedom.

But the way to deal with this is also very different between the two.

@lkdvos lkdvos left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor detail, but otherwise this looks ready to me?

Comment thread ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl Outdated
Co-authored-by: Lukas Devos <ldevos98@gmail.com>
@kshyatt kshyatt enabled auto-merge (squash) June 23, 2026 23:55
@kshyatt kshyatt disabled auto-merge June 23, 2026 23:55
@kshyatt kshyatt merged commit 0dfcb52 into main Jun 23, 2026
36 checks passed
@kshyatt kshyatt deleted the ksh/svd_fwd branch June 23, 2026 23:55
Comment thread src/common/utility.jl
end
has_equal_storage(A::AbstractMatrix, B::AbstractVector) = false
has_equal_storage(A::AbstractVector, B::AbstractMatrix) = false No newline at end of file
has_equal_storage(A::AbstractVector, B::AbstractMatrix) = false

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this was my formatting issue? I don't actually see any change? Is it some invisible character?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's that we were missing a \n after the false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants