DifferentiationInterface.jl icon indicating copy to clipboard operation
DifferentiationInterface.jl copied to clipboard

Option to return auxiliary data from the primal

Open niklasschmitz opened this issue 9 months ago • 7 comments
trafficstars

A very common use case is that one wants to not only differentiate an objective, but also get some auxiliary output (intermediate results, the predictions of an ML model, data structures of a PDE solver, etc.)

For example, in JAX there is the has_aux keyword option in jax.value_and_grad, which is actually the most common usage pattern of AD in JAX I have seen. The pattern looks like this (See e.g. the flax docs for a full example in context)

def loss_fn(params):
     ...
     return loss, extra_data

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, extra_data), grads = grad_fn(params)

I typically use some hacky workarounds to achieve similar behavior in Julia, but maybe it is common enough to solve it at the interface level?

niklasschmitz avatar Feb 07 '25 13:02 niklasschmitz

The issue I see is that most backends want a single-output function. For instance, to get the gradient of loss_fn with ForwardDiff, I'd have to call ForwardDiff.gradient(p -> loss_fn(p)[1], params), and then we lose the benefit of "side-effect computations". Which backends can actually discard-but-return this extra_data without calling the function twice?

gdalle avatar Feb 07 '25 13:02 gdalle

Fair point! I guess the only way to get the extra_data out of a single call without purely returning it is to extract it by a side effect indeed (i.e. global state / save-to-file / ...). That does seem tricky to do in a nice way for all AD backends.

niklasschmitz avatar Feb 07 '25 14:02 niklasschmitz

Idea in passing: maybe we could define a new type of Context which, unlike Cache, would guarantee that it is not overwritten, and allow returning auxiliary data

gdalle avatar Feb 10 '25 14:02 gdalle

Which backends can actually discard-but-return this extra_data without calling the function twice?

FWIW, Zygote.withgradient supports this. And Flux.withgradient supports it with Enzyme.jl too, using this code which at present uses Enzyme's split mode, which may not be optimal.

mcabbott avatar Mar 27 '25 22:03 mcabbott

I don't think @niklasschmitz meant returning the primal value? If so, we do of course return it efficiently, making use of Zygote.withgradient or Enzyme.ReverseWithPrimal when necessary:

https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/13f18590e6914daed496c5ed6162cb3587ee1375/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl#L111-L117

I think the question was more about side-products of the primal computation, like solver timing statistics, that sort of thing, which is separate from the value of the function itself

gdalle avatar Mar 27 '25 23:03 gdalle

That's right? return loss, extra_data is clearly about auxiliary data, as in the title. And the docstring I linked to says "Allows you to capture auxillary outputs, in addition to the scalar used by gradient".

mcabbott avatar Mar 27 '25 23:03 mcabbott

Oh right, my bad, I thought your mention of withgradient only referred to its straightforward use. But yes, I was aware that Zygote supports this, I didn't know about Enzyme but I suspect they are pretty much the only backends where this is possible (probably with Mooncake). If we write the auxiliary data to a cache as suggested above, more possibilities might open up though

gdalle avatar Mar 27 '25 23:03 gdalle