ITensors.jl
ITensors.jl copied to clipboard
[ITensors] [BUG] AD numerical error when calling `apply` on MPS/MPO multiple times
Description of bug
Taking derivatives of functions involving applying gates to an MPS/MPO with apply where apply gets called multiple times gives the wrong result for the derivative.
Minimal code demonstrating the bug or unexpected behavior
Minimal runnable code
using ITensors
using Zygote
function main(; n, θ, p)
s = siteinds("S=1/2", n)
ψ₀ₘₚₛ = MPS(s, "↑")
ψ₀ = contract(ψ₀ₘₚₛ)
U(θ) = [θ * op("Z", s, 1)]
function f(θ, ψ)
ψθ = ψ
Uθ = U(θ)
for _ in 1:p
ψθ = apply(Uθ, ψθ)
end
return inner(ψ, ψθ)
end
function g(θ, ψ)
Uθ = U(θ)
Utot = Uθ
for _ in 2:p
Utot = [Utot; Uθ]
end
ψθ = apply(Utot, ψ)
return inner(ψ, ψθ)
end
f_itensor(θ) = f(θ, ψ₀)
f_mps(θ) = f(θ, ψ₀ₘₚₛ)
g_itensor(θ) = g(θ, ψ₀)
g_mps(θ) = g(θ, ψ₀ₘₚₛ)
@show f_itensor(θ), θ^p
@show f_mps(θ), θ^p
@show f_itensor'(θ), p * θ^(p - 1)
@show f_mps'(θ), p * θ^(p - 1)
@show g_itensor(θ), θ^p
@show g_mps(θ), θ^p
@show g_itensor'(θ), p * θ^(p - 1)
@show g_mps'(θ), p * θ^(p - 1)
return nothing
end
works fine when apply is called once:
julia> main(; n=1, θ=3.0, p=1)
(f_itensor(θ), θ ^ p) = (3.0, 3.0)
(f_mps(θ), θ ^ p) = (3.0, 3.0)
((f_itensor')(θ), p * θ ^ (p - 1)) = (1.0, 1.0)
((f_mps')(θ), p * θ ^ (p - 1)) = (1.0, 1.0)
(g_itensor(θ), θ ^ p) = (3.0, 3.0)
(g_mps(θ), θ ^ p) = (3.0, 3.0)
((g_itensor')(θ), p * θ ^ (p - 1)) = (1.0, 1.0)
((g_mps')(θ), p * θ ^ (p - 1)) = (1.0, 1.0)
but gives the wrong result when apply is called more than once:
julia> main(; n=1, θ=3.0, p=2)
(f_itensor(θ), θ ^ p) = (9.0, 9.0)
(f_mps(θ), θ ^ p) = (9.0, 9.0)
((f_itensor')(θ), p * θ ^ (p - 1)) = (6.0, 6.0)
((f_mps')(θ), p * θ ^ (p - 1)) = (4.0, 6.0)
(g_itensor(θ), θ ^ p) = (9.0, 9.0)
(g_mps(θ), θ ^ p) = (9.0, 9.0)
((g_itensor')(θ), p * θ ^ (p - 1)) = (6.0, 6.0)
((g_mps')(θ), p * θ ^ (p - 1)) = (6.0, 6.0)
Version information
- Output from
versioninfo():
julia> versioninfo()
Julia Version 1.7.2
Commit bf53498635 (2022-02-06 15:21 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Xeon(R) E-2176M CPU @ 2.70GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
JULIA_EDITOR = vim
- Output from
using Pkg; Pkg.status("ITensors"):
julia> using Pkg; Pkg.status("ITensors")
Status `~/.julia/environments/v1.7/Project.toml`
[9136182c] ITensors v0.3.12
@GTorlai this is a strange one, maybe we can test out an older version of the rrule for apply that was only written for MPS which should be easier to investigate.
I will investigate the rrule let's see if I can find the problem.
That would be great, thanks!
It was a small mistake.
Do you know if there is a way for Zygote to automatically test the derivatives with finite differences?
There is through the ChainRulesTestUtils.jl package, but we haven't set it up for MPS/MPO yet. We have it set up for rrules written just in terms of ITensors but it requires writing some overloads to make our custom types compatible with FiniteDifferences.jl: https://github.com/ITensor/ITensors.jl/blob/main/test/ITensorChainRules/utils/chainrulestestutils.jl. I haven't bothered to work out the proper definitions for MPS yet but it shouldn't be too hard (maybe we can map an MPS to a Vector{Array} which should hopefully work with FiniteDifferences.jl). For now we have just been testing by hand but obviously it would be ideal to have it more automated.
Closed by #981, with this minimal example added as a test in #1077.