ITensors.jl icon indicating copy to clipboard operation
ITensors.jl copied to clipboard

[ITensors] [BUG] Automatic differentiation

Open ArtemStrashko opened this issue 3 years ago • 4 comments

Description of bug

Weird behavior of automatic differentiation: error is thrown for a function, but not for a slightly rewritten (though analogous) function or for a bit more complicated one.

Minimal code demonstrating the bug or unexpected behavior

Minimal runnable code

using ChainRulesCore
using Zygote
using ITensors
@non_differentiable onehot(::Any...)

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

The function below leads to an error.

function f(mps, state)
    N = length(mps)
    L = 0.
    tens = 1.
    for i in 1:N
        tens *= mps[i]
        if i < N
            idx = commonind(mps[i], mps[i+1])
            partial_mps = tens * onehot(idx => 1)
        else
            partial_mps = tens
        end
        L += (partial_mps * state[i])[]
        tens *= state[i]
    end
    return L
end

f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")

However, if I just add two extra lines to that function, differentiation works well.

function f(mps, state)
    N = length(mps)
    L = 0.
    tens = 1.
    for i in 1:N
        tens *= mps[i]
        if i < N
            idx = commonind(mps[i], mps[i+1])
            partial_mps = tens * onehot(idx => 1)
        else
            partial_mps = tens
        end
        phys_idx = commonind(mps[i], state[i])
        L -= sum([(partial_mps * onehot(phys_idx => j))[] for j in 1:1])
        L += (partial_mps * state[i])[]
        tens *= state[i]
    end
    return L
end

f(x) = f(x, state)
f'(mps)
# works well

Finally, if I change just one line of that function, which does not affect its output, differentiation results in another error.

function f(mps, state)
    N = length(mps)
    L = 0.
    tens = 1.
    for i in 1:N
        tens *= mps[i]
        if i < N
            idx = commonind(mps[i], mps[i+1])
            partial_mps = tens * onehot(idx => 1)
        else
            partial_mps = tens
        end
        phys_idx = commonind(mps[i], state[i])
        L -= (partial_mps * onehot(phys_idx => 1))[]
        L += (partial_mps * state[i])[]
        tens *= state[i]
    end
    return L
end

f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")

Expected output or behavior

I would expect to get no errors when setting up automatic differentiation with the functions above.

Version information

  • Output from versioninfo():
julia> versioninfo()
Julia Version 1.7.0
Commit 3bf9d17731* (2021-11-30 12:12 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD EPYC 7742 64-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, znver2)
Environment:
  JULIA_NUM_THREADS = 128

  • Output from using Pkg; Pkg.status("ITensors"):
julia> using Pkg; Pkg.status("ITensors")
Status `~/.julia/environments/v1.7/Project.toml`
  [9136182c] ITensors v0.2.16

ArtemStrashko avatar Jun 03 '22 14:06 ArtemStrashko

Very strange bug, thanks @ArtemStrashko. By the way, I would recommend updating to the latest version of ITensors (0.3.14), though I see this bug in the latest version as well so that won't fix this issue.

mtfishman avatar Jun 03 '22 14:06 mtfishman

This seems to fix the initial error:

using Zygote
using ITensors

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

function f(mps, state)
  N = length(mps)
  L = 0.
  tens = 1.
  for i in 1:N
    tens *= mps[i]
    proj = if i < N
      idx = commonind(mps[i], mps[i+1])
      onehot(idx => 1)
    else
      ITensor(1.0)
    end
    partial_mps = tens * proj
    L += (partial_mps * state[i])[]
    tens *= state[i]
  end
  return L
end

f(x) = f(x, state)
f'(mps)

It seems like the issue stems from putting something differentiable (like tens) inside of the if-statement. I'm not sure why that's the case, perhaps there's an issue with one of our ChainRules definitions, or it is a Zygote issue.

mtfishman avatar Jun 03 '22 19:06 mtfishman

Actually, there's a more minimal fix. These both run without errors for me:

using Zygote
using ITensors

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

function f(mps, state)
  N = length(mps)
  L = 0.
  tens = 1.
  for i in 1:N
    tens *= mps[i]
    if i < N
      idx = commonind(mps[i], mps[i+1])
      partial_mps = tens * onehot(idx => 1)
    else
      partial_mps = tens * ITensor(1.0)
    end
    L += (partial_mps * state[i])[]
    tens *= state[i]
  end
  return L
end

f(x) = f(x, state)
f'(mps)

and

using Zygote
using ITensors

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

function f(mps, state)
  N = length(mps)
  L = 0.
  tens = 1.
  for i in 1:N
    tens *= mps[i]
    if i < N
      idx = commonind(mps[i], mps[i+1])
      partial_mps = tens * onehot(idx => 1)
    else
      partial_mps = tens * ITensor(1.0)
    end
    phys_idx = commonind(mps[i], state[i])
    L -= (partial_mps * onehot(phys_idx => 1))[]
    L += (partial_mps * state[i])[]
    tens *= state[i]
  end
  return L 
end

f(x) = f(x, state)
f'(mps)

The only change I made is rewriting:

      partial_mps = tens

to:

      partial_mps = tens * ITensor(1.0)

I can't say I understand why this is happening. It would be helpful to concoct a more minimal example to investigate this better.

mtfishman avatar Jun 03 '22 19:06 mtfishman

So oddly enough this seems to be a Zygote bug:

using FiniteDifferences
using Zygote

function f(x)
  y = [[x]', [x]]
  r = 0.0
  o = 1.0
  for n in 1:2
    o *= y[n]
    if n < 2
      proj_o = o * [1.0]
    else
      # Error
      proj_o = o
      # Fix
      # proj_o = o * 1.0
    end
    r += proj_o
  end
  return r
end

x = 1.2
@show f(x)
@show central_fdm(5, 1)(f, x)
@show f'(x)

which throws the error:

f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
ERROR: LoadError: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
  +(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
  +(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
  ...
Stacktrace:
 [1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
[...]

while with the fix it outputs:

f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
(f')(x) = 3.4000000000000004

mtfishman avatar Jun 03 '22 19:06 mtfishman