mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

[BUG] cifar example fails with ValueError when replacing LayerNorm with BatchNorm with default params

Open menzHSE opened this issue 1 year ago • 6 comments

When replacing LayerNorm with BatchNorm (merged in https://github.com/ml-explore/mlx/pull/217) with default params, the cifar example fails with ValueError with mlx 0.0.7:

  File "../mlx-examples/cifar/main.py", line 48, in train_epoch
    mx.eval(model.parameters(), optimizer.state)
ValueError: [eval] Illegal to eval an array during function transform without graph retention.

Using BatchNorm with nn.BatchNorm(dims, track_running_stats=False) seems to work OK. Is that expected behavior?

Reproduce the issue The ResNet with BatchNorm is available on branch resnet_batch_norm at https://github.com/menzHSE/mlx-examples. Run python main.py in mlx-examples/cifar

menzHSE avatar Jan 05 '24 08:01 menzHSE

That is not expected, it sounds like a bug. Thanks for reporting, I will take a look.

awni avatar Jan 05 '24 14:01 awni

I was able to repro using the code @menzHSE provided. If you pass retain_graph=True to the eval() method it suppresses the error, but it trains much slower, as expected. Here are the culprit lines.

beverm2391 avatar Jan 05 '24 23:01 beverm2391

Yes we are very much aware of this issue. Working with @angeloskath on a fix.

awni avatar Jan 05 '24 23:01 awni

For now I recommend avoiding the running stats until we fix it

awni avatar Jan 05 '24 23:01 awni

I was able to repro using the code @menzHSE provided. If you pass retain_graph=True to the eval() method it suppresses the error, but it trains much slower, as expected. Here are the culprit lines.

Thanks! With retain_graph=True it runs out of memory on my 16 GB M1 though vs. approx. 840MB usage with retain_graph=False.

menzHSE avatar Jan 06 '24 08:01 menzHSE

This should fix it https://github.com/ml-explore/mlx/pull/385

awni avatar Jan 06 '24 13:01 awni