Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Typed Matrix literal is not differentiable.

Open Sebastian-Dawid opened this issue 7 months ago • 1 comments

Trying to differentiate a typed array literal fails with the error that mutation is not supported.

Julia Version: 1.11.5 and Zygote v0.7.7

Minimal workable Example:

using Zygote
T(x) = Float32[1 0 0 x[1]][4]
gradient(T, [42.0f0])

Stacktrace:

ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/array.jl:70
  [3] (::Zygote.var"#553#554"{Matrix{Float32}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/array.jl:82
  [4] (::Zygote.var"#2643#back#555"{Zygote.var"#553#554"{Matrix{Float32}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [5] hvcat_fill!
    @ ./abstractarray.jl:2238 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(Base.hvcat_fill!), Matrix{…}, Tuple{…}}, Any})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#672#674"{…}}, ChainRules.var"#671#673"{Matrix{…}, Float32, Tuple{…}}})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [7] typed_hcat
    @ ./abstractarray.jl:1645 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#671#673"{…}})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [9] T
    @ ./REPL[3]:1 [inlined]
 [10] (::Zygote.Pullback{Tuple{typeof(T), Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{typeof(T), Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
 [12] gradient(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
 [13] top-level scope
    @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.

Sebastian-Dawid avatar May 12 '25 18:05 Sebastian-Dawid

This syntax calls typed_hcat, and ideally there would be a rule for that in ChainRules.jl, probably a mild adaptation of the existing rule for hcat:

julia> Meta.@lower Float32[1 0 0 x[1]]
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = Float32
│   %2 = x
│   %3 = Base.getindex(%2, 1)
│   %4 = Base.typed_hcat(%1, 1, 0, 0, %3)
└──      return %4
))))

julia> Zygote.gradient(x -> sum(abs2, hcat(1, 2, x)), 3.0)
(6.0,)

julia> Zygote.gradient(x -> sum(abs2, Base.typed_hcat(Float32, 1, 2, x)), 3.0)
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
...
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/array.jl:70
...
  [5] hvcat_fill!
    @ ./abstractarray.jl:2238 [inlined]
...
  [7] typed_hcat
    @ ./abstractarray.jl:1645 [inlined]

Note that there are similar typed_hvcat and typed_hvncat functions too:

julia> Meta.@lower Float32[1 2; 3 4]
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = Float32
│   %2 = Core.tuple(2, 2)
│   %3 = Base.typed_hvcat(%1, %2, 1, 2, 3, 4)
└──      return %3
))))

julia> Meta.@lower Float32[1;;;]
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = Base.typed_hvncat(Float32, 3, 1)
└──      return %1
))))

Relevant issue: https://github.com/JuliaDiff/ChainRules.jl/issues/743

mcabbott avatar May 12 '25 22:05 mcabbott