flax
flax copied to clipboard
PJit example of ImageNet
Description of the model to be implemented
ResNet18
Dataset the model could be trained on
tensorflow_datasets mock_data, e.g.
with tfds.testing.mock_data(num_examples=128):
ds = tfds.load('imagenette', split='train')
Specific points to consider
- The existing ImageNet example uses pmap. It is not trivial to replace pmap with pjit in the existing example. A pure pjit example would be great to learn from.
- The example should be functional also on a GPU host for the benefit of the broader community.
Reference implementations in other frameworks
https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb
Hi @mattiasmar, we are considering of adding the WMT example with PJIT instead, since many people working on large language models are interested in it (and the models in language are usually much bigger than in vision). Would that work for you as well?
Yes, WMT would be a good example. A request only: Could you enter the pjit as high up in the code as possible? As a user I want to pjit as a large part as possible of my model/program. Ideally I would like to use pjit only once in my program.