ITensors.jl
ITensors.jl copied to clipboard
[ITensors] [BUG] Automatic differentiation
Description of bug
Weird behavior of automatic differentiation: error is thrown for a function, but not for a slightly rewritten (though analogous) function or for a bit more complicated one.
Minimal code demonstrating the bug or unexpected behavior
Minimal runnable code
using ChainRulesCore
using Zygote
using ITensors
@non_differentiable onehot(::Any...)
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
The function below leads to an error.
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens
end
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")
However, if I just add two extra lines to that function, differentiation works well.
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens
end
phys_idx = commonind(mps[i], state[i])
L -= sum([(partial_mps * onehot(phys_idx => j))[] for j in 1:1])
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
# works well
Finally, if I change just one line of that function, which does not affect its output, differentiation results in another error.
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens
end
phys_idx = commonind(mps[i], state[i])
L -= (partial_mps * onehot(phys_idx => 1))[]
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")
Expected output or behavior
I would expect to get no errors when setting up automatic differentiation with the functions above.
Version information
- Output from
versioninfo():
julia> versioninfo()
Julia Version 1.7.0
Commit 3bf9d17731* (2021-11-30 12:12 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: AMD EPYC 7742 64-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, znver2)
Environment:
JULIA_NUM_THREADS = 128
- Output from
using Pkg; Pkg.status("ITensors"):
julia> using Pkg; Pkg.status("ITensors")
Status `~/.julia/environments/v1.7/Project.toml`
[9136182c] ITensors v0.2.16
Very strange bug, thanks @ArtemStrashko. By the way, I would recommend updating to the latest version of ITensors (0.3.14), though I see this bug in the latest version as well so that won't fix this issue.
This seems to fix the initial error:
using Zygote
using ITensors
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
proj = if i < N
idx = commonind(mps[i], mps[i+1])
onehot(idx => 1)
else
ITensor(1.0)
end
partial_mps = tens * proj
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
It seems like the issue stems from putting something differentiable (like tens) inside of the if-statement. I'm not sure why that's the case, perhaps there's an issue with one of our ChainRules definitions, or it is a Zygote issue.
Actually, there's a more minimal fix. These both run without errors for me:
using Zygote
using ITensors
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens * ITensor(1.0)
end
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
and
using Zygote
using ITensors
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens * ITensor(1.0)
end
phys_idx = commonind(mps[i], state[i])
L -= (partial_mps * onehot(phys_idx => 1))[]
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
The only change I made is rewriting:
partial_mps = tens
to:
partial_mps = tens * ITensor(1.0)
I can't say I understand why this is happening. It would be helpful to concoct a more minimal example to investigate this better.
So oddly enough this seems to be a Zygote bug:
using FiniteDifferences
using Zygote
function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
else
# Error
proj_o = o
# Fix
# proj_o = o * 1.0
end
r += proj_o
end
return r
end
x = 1.2
@show f(x)
@show central_fdm(5, 1)(f, x)
@show f'(x)
which throws the error:
f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
ERROR: LoadError: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
+(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
+(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
...
Stacktrace:
[1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
[...]
while with the fix it outputs:
f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
(f')(x) = 3.4000000000000004