mlx-examples
mlx-examples copied to clipboard
[BUG] cifar example fails with ValueError when replacing LayerNorm with BatchNorm with default params
When replacing LayerNorm with BatchNorm (merged in https://github.com/ml-explore/mlx/pull/217) with default params, the cifar example fails with ValueError with mlx 0.0.7:
File "../mlx-examples/cifar/main.py", line 48, in train_epoch
mx.eval(model.parameters(), optimizer.state)
ValueError: [eval] Illegal to eval an array during function transform without graph retention.
Using BatchNorm with nn.BatchNorm(dims, track_running_stats=False)
seems to work OK.
Is that expected behavior?
Reproduce the issue
The ResNet with BatchNorm is available on branch resnet_batch_norm
at https://github.com/menzHSE/mlx-examples. Run python main.py
in mlx-examples/cifar
That is not expected, it sounds like a bug. Thanks for reporting, I will take a look.
I was able to repro using the code @menzHSE provided. If you pass retain_graph=True
to the eval()
method it suppresses the error, but it trains much slower, as expected. Here are the culprit lines.
Yes we are very much aware of this issue. Working with @angeloskath on a fix.
For now I recommend avoiding the running stats until we fix it
I was able to repro using the code @menzHSE provided. If you pass
retain_graph=True
to theeval()
method it suppresses the error, but it trains much slower, as expected. Here are the culprit lines.
Thanks! With retain_graph=True
it runs out of memory on my 16 GB M1 though vs. approx. 840MB usage with retain_graph=False
.
This should fix it https://github.com/ml-explore/mlx/pull/385