Ivy Zheng

Results 29 comments of Ivy Zheng

https://github.com/google/flax/pull/3229 includes your change and fixed another issue that appeared with your change. Feel free to give it a try!

In `save_checkpoint_multiprocess`, we remove old checkpoints before writing new ones, which results in the total number of checkpoints to be `keep`+1. This is because Orbax `AsyncCheckpointer` does not support custom...

To restore an array with a certain sharding, you need not only the global `mesh` but also a `PartitionSpec` specifying how the array axes should be sharded. That means for...

I don't think they check shapes, but I am not sure it's a good idea to require that check. That means before loading large arrays, user need to pre-create a...

Hi - one of the Orbax team member replied your StackOverflow thread, please take a look. Regarding the `SaveArgs.aggregate is deprecated` warning, it seems that internally `aggregate=True` will happen whenever...

You'll need to shard the model params across chips - `replicate` forces them to be replicated on each chip, which is too small for 7b. Instead of `replicate`, use `jax.device_put`...

The error message you posted all looks like normal INFO printouts - is there a more specific error message or stack trace, or did the program just crashed after these...

From your description it sounds like the program is blocked, instead of fail and exit immediately? If blocked, it might be that the gpu devices (or their cpu hosts?) are...

Do you have any printout in `train_epoch` function to pinpoint the line of blockage? We would really benefit from a smaller code that can repro the problem and narrow down...

Hey, not sure if it helps but directly using `dataclasses.replace` would work: ``` from flax import struct import dataclasses @struct.dataclass class c123: variable1: int = 0 c123instance = dataclasses.replace(c123(variable1=0), variable1=1)...