nanodl icon indicating copy to clipboard operation
nanodl copied to clipboard

A Jax-based library for designing and training transformer models from scratch.

Results 2 nanodl issues
Sort by recently updated
recently updated
newest added

With the latest `jax=='0.4.25'` and `jaxlib=='0.4.25'` I get: ```python import nanodl # AttributeError: module 'jax.random' has no attribute 'KeyArray' ```

Hey, great job with nanodl! I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here: https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565 Not...