maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Training more than one epoch

Open peregilk opened this issue 5 months ago • 4 comments

@aireenmei Referring you here, because I think this issue is touched in #571 where you write:

I did not implement the auto restart because some users may not want their model to see repetitive data. I can add the multi-epoch support to our backlog. Meanwhile it should be straightforward to change the shard update logic here: https://github.com/google/maxtext/blob/main/MaxText/input_pipeline/_input_pipeline_utils.py#L105

The behaviour now seems to have changed a bit, and it might even be more confusing. I am a bit uncertain what has changed in the code here.

What I am trying to do is switching dataset during training. Here from step 160k. This is a fairly small special task dataset, and I am studying the effect. The dataset has 256 shards, and one epoch is roughly 350 steps.

Here is what is happening with comments:

# Perfectly normal. Switching to next shard. Weights and loss are fine
Updating host 3 dataset 0, was on shard 3
New shard is 67
completed step: 160086, seconds: 4.090, TFLOP/s/device: 116.004, Tokens/s/device: 2003.075, total_weights: 2080798, loss: 1.113

# Still normal
Updating host 3 dataset 0, was on shard 67
New shard is 131
completed step: 160177, seconds: 4.090, TFLOP/s/device: 115.995, Tokens/s/device: 2002.925, total_weights: 2078579, loss: 1.072

# Still normal
Updating host 3 dataset 0, was on shard 131
New shard is 195
completed step: 160268, seconds: 4.090, TFLOP/s/device: 115.989, Tokens/s/device: 2002.811, total_weights: 2079952, loss: 1.049

# Here things are starting to go south. The host starts generating all-0 paddings
completed step: 160359, seconds: 4.090, TFLOP/s/device: 116.001, Tokens/s/device: 2003.031, total_weights: 2077782, loss: 1.036

# Runs for a while, but then the total_weights start dropping, and the loss starts to drop
completed step: 160367, seconds: 4.091, TFLOP/s/device: 115.971, Tokens/s/device: 2002.507, total_weights: 2034296, loss: 1.030
completed step: 160368, seconds: 4.090, TFLOP/s/device: 116.002, Tokens/s/device: 2003.040, total_weights: 1860858, loss: 1.028
completed step: 160369, seconds: 4.090, TFLOP/s/device: 115.995, Tokens/s/device: 2002.928, total_weights: 1207504, loss: 1.038
completed step: 160370, seconds: 4.090, TFLOP/s/device: 115.991, Tokens/s/device: 2002.854, total_weights: 616193, loss: 1.038
completed step: 160371, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.734, total_weights: 184994, loss: 1.037
completed step: 160372, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.739, total_weights: 46490, loss: 1.058
completed step: 160373, seconds: 4.091, TFLOP/s/device: 115.976, Tokens/s/device: 2002.600, total_weights: 32596, loss: 0.989
completed step: 160374, seconds: 4.091, TFLOP/s/device: 115.978, Tokens/s/device: 2002.634, total_weights: 32491, loss: 1.041

# A bit later
completed step: 160460, seconds: 4.090, TFLOP/s/device: 115.987, Tokens/s/device: 2002.787, total_weights: 32673, loss: 0.980
completed step: 160461, seconds: 4.091, TFLOP/s/device: 115.970, Tokens/s/device: 2002.484, total_weights: 32503, loss: 1.043
completed step: 160462, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.736, total_weights: 1904, loss: 1.068
completed step: 160463, seconds: 4.091, TFLOP/s/device: 115.966, Tokens/s/device: 2002.420, total_weights: 0, loss: 0.000
completed step: 160464, seconds: 4.090, TFLOP/s/device: 115.990, Tokens/s/device: 2002.845, total_weights: 0, loss: 0.000

```

This behaviour is a bit unpredictable. Especially since some shards here can be smaller, and it is hard to know when the first host runs out of shards. Running out of shards seems to hurt the model.

What is your advice here?

peregilk avatar Sep 24 '24 11:09 peregilk