Zygote.jl
Zygote.jl copied to clipboard
Add an error for broadcasting with CUDA + complex numbers, etc
Xref https://github.com/FluxML/Zygote.jl/issues/1215
I think making this an error is a good idea. Ideally there would be a way to not error for custom types when you know it is okay to not track the gradients. I guess you can define _dual_safearg for your custom type but this might be worth describing in the error message, or exposing in a different way.
Yes. I guess the thing you overload should eventually be something like https://github.com/JuliaDiff/ChainRulesCore.jl/pull/528
However, at the moment I believe you get errors from unbroadcast not having appropriate methods, if you try to use some weird type (even just a Symbol, IIRC). So it ought to be safe to make these deliberate errors now, and adjustable later.