nanodl
nanodl copied to clipboard
A Jax-based library for designing and training transformer models from scratch.
Results
2
nanodl issues
Sort by
recently updated
recently updated
newest added
With the latest `jax=='0.4.25'` and `jaxlib=='0.4.25'` I get: ```python import nanodl # AttributeError: module 'jax.random' has no attribute 'KeyArray' ```
Hey, great job with nanodl! I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here: https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565 Not...