mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Removes the `retain_graph` flag

Open angeloskath opened this issue 1 year ago • 2 comments

I am setting it as draft to have a chat about it but it can probably be PR-ed as is.

The main change is, it removes the retain_graph flag. It introduces a global currently_tracing() method which returns true if we are under a function transformation. The array::is_tracer method combines the tracer flag with the global flag, meaning that if we are not under a function transformation we don't care about the array's flag.

During eval we detach arrays that are not tracers.

In high order functions one simply instantiates an InTracing object that automatically sets and unsets the tracing flag in the constructor and destructor respectively. Main caveat, the flag is not thread safe.

angeloskath avatar Jan 06 '24 05:01 angeloskath

Verified this fixes cifar example + Batch Norm

awni avatar Jan 06 '24 15:01 awni

From a design / perf / testing standpoint I think this is all exactly right and beautifully done. Let's PR it.

I want to think a bit more about removing retain_graph and doing per array detachment. The only slight downside I see is if we want to do a global graph retention for the future but that's not sufficient rationale for keeping it now.

awni avatar Jan 06 '24 16:01 awni

Rebased, applied the comments and checked speeds and tests. I think it is ready to land.

angeloskath avatar Jan 07 '24 07:01 angeloskath

Closes https://github.com/ml-explore/mlx-examples/issues/233

awni avatar Jan 07 '24 13:01 awni