Andreas Steiner

Results 36 comments of Andreas Steiner

+1 for *both* - adding broad statement that there are numerical differences, and linking to #3128 as an example - adding detailed information about known differences (like weight initializers for...

How exactly did you set up your VM? I tried the following: ``` PROJECT_ID=... TPU_NAME=isssue_3364 ZONE=us-central1-a gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --version v2-alpha --accelerator-type v2-32...

@gkroiz what was the exact command you used to start the TPU VMs? Ah, that's probably also why @cgarciae could not reproduce the issue. @cgarciae can you try with that...

I can confirm I can reproduce the `AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'` on a fresh Colab with a CPU runtime (Python 3.10.12). Setup: ``` !pip install jax==0.4.16 jaxlib==0.4.16...

@gkroiz So the problem seems to be that the `optax` version pinned in the example's `requiements.txt` is not compatible: https://github.com/google/flax/blob/242f84cac883108eb1e945221c5c544bef6cbd21/examples/imagenet/requirements.txt#L9 If you do a `pip install flax==0.7.4 optax==0.1.7` *after* installing...

@cgarciae wdyt should we update (some of) our examples to newer flax/jax versions?

@gkroiz Can specify the system version, Python version, Flax repo commit hash, and output of `pip freeze` from one of the machines?

Glad to see we got the dependency problems sorted out. The remaining problem seems to be due to a configuration problem: https://github.com/google/flax/blob/d059ba8aadfe839a7bb0ce7b2c47afb5d91fdf0a/examples/imagenet/configs/fake_data_benchmark.py#L22-L36 Since we're using `ACCELERATOR_TYPE=v5e-16`, we would have `config.batch_size...

can you try `config.steps_per_eval = 1` and see if that runs?

Is the `with mesh:` required with `nn.with_logical_constraints()` ? (Under the hood this is calling `lax.with_sharding_constraint()`)