Zygote.jl
Zygote.jl copied to clipboard
Modified error message
The error message provided when a pullback isn't available for new / __new__ for a particular cotangent type is a bit misleading / incomplete.
In particular, you can see here that it reads
(back::Jnew{T})(Δ) where T = error("Need an adjoint for constructor $T. Gradient is of type $(typeof(Δ))")
This suggests to the user that in fact there doesn't exist a pullback for the constructor for T, which is not true. A pullback exists here (and a slightly different version exists below).
What is actually going on is that a cotangent of a type that is different from the expected canonical cotangent type for T (i.e. a NamedTuple, or Nothing) has been passed into the constructor.
The current implementation suggests that a user has but one option when confronted with this error -- to add a new adjoint tailored specifically to the type in question. In reality, they can also modify other code to ensure that a tangent of the expected type (a NamedTuple) is provided.
I propose to modify the error message to read something like:
"Unexpected cotangent type $(typeof(Δ)) found in pullback for `new` for $T. Either implement a new adjoint for `__new__{T}`, or ensure that the cotangent passed in to this pullback is a `NamedTuple` with appropriate fields."
or something like that.
Note
@oxinabox @mzgubic I think we had been labouring under the false impression that Zygote can't handle inner constructors properly (at least, I had been labouring under that impression). In fact, it can handle inner constructors just fine. See, for example:
struct Baz
x::Float64
y::Float64
function Baz(x)
return new(x, x)
end
end
julia> out, pb = Zygote._pullback(Zygote.Context(), Baz, 5.0)
(Baz(5.0, 5.0), ∂(Baz))
julia> pb((x=3.0, y=1.0))
(nothing, 4.0)
Pleasingly, this is all type-stable.
Sounds reasonable to me. Is it possible to soften "cotangent" to something less precise but more recognizable? The current error message uses "gradient", but I'm not sure if there are better options. That, or linking to some documentation which explains this terminology like we do for mutation errors.