ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

reuse factorization in rule for \

Open oxinabox opened this issue 5 years ago • 7 comments

Closes JuliaLang/julia#250 Its pretty cute I am not 100% sure this is correct, but CI will tell me

Also stops thunking ∂B since we want to use that for ∂A which means if we need ∂B then we need to compute it, and if we need ∂A we needed to compute it, so either way we need it.

julia> @btime $(x)\$y
  1.032 μs (3 allocations: 1.19 KiB)
10-element Vector{Float64}:
  0.49018802032025616
 -0.33578706090563665
 -0.4615411572036624
  0.24664997800385907
 -0.5310612794797791
  0.3537748064584685
  1.470168551590106
 -0.4108242996218167
  0.11217305102151084
 -0.4276420520809933

julia> @btime $(factorize(x))\$y
  252.892 ns (1 allocation: 160 bytes)
10-element Vector{Float64}:
  0.49018802032025616
 -0.33578706090563665
 -0.4615411572036624
  0.24664997800385907
 -0.5310612794797791
  0.3537748064584685
  1.470168551590106
 -0.4108242996218167
  0.11217305102151084
 -0.4276420520809933

oxinabox avatar Nov 03 '20 00:11 oxinabox

This fails for some inputs because https://github.com/JuliaLang/julia/issues/38293

oxinabox avatar Nov 03 '20 13:11 oxinabox

@willtebbutt gave me a workaround:

julia> @btime $A'\$B;
  31.750 ms (5 allocations: 15.27 MiB)
julia> @btime $Af'\$B;  # best but not always permitted
  19.356 ms (2 allocations: 7.63 MiB)
julia> @btime ($B'/ $Af)';
  20.160 ms (2 allocations: 7.63 MiB)

oxinabox avatar Nov 03 '20 14:11 oxinabox

bother: : no method matching rdiv!(::Array{Float64,2}, ::QRPivoted{Float64,Array{Float64,2}})

oxinabox avatar Nov 03 '20 16:11 oxinabox

This should be unblocked by https://github.com/JuliaLang/julia/pull/40899 and it needs to be tested against that branch of julia

oxinabox avatar May 21 '21 14:05 oxinabox

testing this against JuliaLang/julia#40899 (cc @andreasnoack for interest) I think it has resolved the problems. Still errors but the erros are now:

  • inference failing (probably need to look into how the rules are written there)
  • MethodError: no method matching factorize(::Vector{Float64}) (probably should define a _factorize_for_ldiv that returns the input if it is a vector, or if VERSION<v"1.7") - DimensionMismatch("overdetermined systems are not supported") (need to look into that one. It was introduced by JuliaLang/julia#40899 but it might be we shouldn't be using factorize for this case? but / does support over-determined systems)

So I think probably JuliaLang/julia#40899 works, not 100% sure re the last point

oxinabox avatar May 28 '21 11:05 oxinabox

The error suggests that you end up doing something like factorize(A)'\b for a wide A. For wide A, factorize gives a QR and least squares isn't defined for Adjoint{QR}. Is this a real problem or are you just trying to exercise all possible combinations?

andreasnoack avatar May 28 '21 13:05 andreasnoack

Worth asking @DhairyaLGandhi who originally openned this issue in https://github.com/FluxML/Zygote.jl/issues/773 / https://github.com/JuliaDiff/ChainRules.jl/issues/250

I never want to actually AD \ at all. but I think the use of factorizations here is a very neat application of https://juliadiff.org/ChainRulesCore.jl/dev/design/changing_the_primal.html

In general ChainRules exercises all paths because it does.

oxinabox avatar May 28 '21 14:05 oxinabox