flax icon indicating copy to clipboard operation
flax copied to clipboard

PJit example of ImageNet

Open mattiasmar opened this issue 3 years ago • 2 comments

Description of the model to be implemented

ResNet18

Dataset the model could be trained on

tensorflow_datasets mock_data, e.g.

with tfds.testing.mock_data(num_examples=128):
    ds = tfds.load('imagenette', split='train')

Specific points to consider

  • The existing ImageNet example uses pmap. It is not trivial to replace pmap with pjit in the existing example. A pure pjit example would be great to learn from.
  • The example should be functional also on a GPU host for the benefit of the broader community.

Reference implementations in other frameworks

https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb

mattiasmar avatar Jan 06 '22 13:01 mattiasmar

Hi @mattiasmar, we are considering of adding the WMT example with PJIT instead, since many people working on large language models are interested in it (and the models in language are usually much bigger than in vision). Would that work for you as well?

marcvanzee avatar Apr 27 '22 20:04 marcvanzee

Yes, WMT would be a good example. A request only: Could you enter the pjit as high up in the code as possible? As a user I want to pjit as a large part as possible of my model/program. Ideally I would like to use pjit only once in my program.

mattiasmar avatar Apr 28 '22 18:04 mattiasmar