ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
CachingThunks
It came up in discussion of JuliaDiff/ChainRules.jl#21 that it would be handing to have an object which when evaluted caches its value (probably a Differentiable, but potentially not?) that will let us handle things that have computations that might be shared by some of the multiple returned differntiables.
A possibly alternative is to just give this behavour to all Thunks
I will play and make sure this is do able with type inferrence.
I think this is as good as it can get
mutable struct CachingCallable{R, F}
f::F
val::Any
CachingCallable{R, F}(f) where {R, F} = new(f)
end
function CachingCallable(f::F) where F
R = Core.Compiler.return_type(f,tuple())
return CachingCallable{R, F}(f)
end
function (self::CachingCallable{R})()::R where R
if !isdefined(self, :val)
self.val = self.f()
end
return self.val
end
It is type-inferred. Because we manually trigger the type-inference at the start. Which we do pay for. Running-type inference manually is expensive.
So it does become a bit of a trade off.
julia> using BenchmarkTools
julia> @btime CachingCallable(()->(sleep(1); 100))
2.139 μs (13 allocations: 752 bytes)
CachingCallable{Int64,getfield(Main, Symbol("##25#27"))}(getfield(Main, Symbol("##25#27"))(), #undef)
julia> @btime ()->(sleep(1); 100)
0.029 ns (0 allocations: 0 bytes)
JuliaDiff/ChainRules.jl#28 (generic function with 1 method)
julia> x() = (sleep(1); 100);
julia> @time x(); @time x();
1.012006 seconds (1.79 k allocations: 85.381 KiB)
1.001536 seconds (9 allocations: 352 bytes)
julia> cc = CachingCallable(x)
CachingCallable{Int64,getfield(Main, Symbol("##23#24"))}(getfield(Main, Symbol("##23#24"))(), #undef)
julia> @time cc(); @time cc();
1.016891 seconds (10.70 k allocations: 557.840 KiB)
0.000008 seconds (4 allocations: 160 bytes)
How bad does this get if the return type of the function isn't inferrable at all and we get Any?
idk need to work out some nontrival benchmarks.
ChainRules actually had something like this a while ago, I got rid of it because the usage of return_type made me nervous and it ended up being the case that not many of the rule we had at the time benefitted from it
@willtebbutt thinks will need it for things.
I think ldiv was one?
I think ldiv would actually be okay. I suspect that we would at least require a tenary function in which the sensitivities w.r.t. two of the arguments share some computation, while the third does not. Perhaps we should put this on the back-burner until we come across such an example (possibly in BLAS sensitivities, as many of these have large numbers of arguments)?
Thanks for checking this out though @oxinabox.
another thing to compair is dropping the type inferability requirement, and using a Any typed store.
ChainRules actually had something like this a while ago, I got rid of it because the usage of
return_typemade me nervous and it ended up being the case that not many of the rule we had at the time benefitted from it
Isn't it actually an OK use case of the inference API? Quickly searching discourse, I find a nice guideline summary by Jeff Bezanson:
- Best: don’t ever call promote_op or return_type.
- Mostly OK: call return_type, but only use it as a performance hint, making sure to give the same result no matter what it returns.
- Bad: program result depends on return_type.
--- https://discourse.julialang.org/t/missing-data-and-namedtuple-compatibility/8136/34
I think Memoize is the second "Mostly OK" case. It relies on the inference API but the result of (::Memoize)() does not depend on the inference result. This is very different from, e.g., using the inference result as an eltype of an array which is the third "Bad" case.
Of course, the type of the Memoize object itself depends on the inference result so this is somewhat delicate if you consider it as a result. But I guess thunks are usually treated as an opaque object so I don't think it is a problem.
Following up on the discussion in FluxML/Zygote.jl#966: One worry with caching thunks is additional memory allocation. In case where the result is computationally expensive and/or allocates substantial memory itself, a caching thunk would be preferable - in other cases, we would prefer a non-caching thunk. However, there's not reason we can't have both, right? The author of the rrule will usually have enough information to decide between caching and non-caching - we could offer @thunk and @cached_thunk (using double-checked locking or so internally).