trax icon indicating copy to clipboard operation
trax copied to clipboard

Effective train/eval batch_size is always 1 due to batcher default arg "variable_shapes=True"

Open syzymon opened this issue 4 years ago • 0 comments

Description

When providing inputs with a constant shape - for instance imagenet32 where examples are always of length 3072, but it also applies e.g. to this config: https://github.com/google/trax/blob/master/trax/supervised/configs/reformer_imagenet64.gin

and not specifying variable_shapes=False, as it isn't in the config above, the effective training and evaluation batch_size is always equal to 1. The reason for that is default argument variable_shapes in this function set to True, which enables the bucketer to do some magic so that the effective train/eval batch becomes 1, regardless of what was specified by the user: https://github.com/google/trax/blob/master/trax/data/inputs.py#L791

That seems like an annoying bug that causes huge unexplained variance among eval batches and makes training on a batch bigger than 1 per device possible only virtually (without even being aware of this variable_shapes arg and using constant shape data).

My repro confirming that has been done using the latest trax dev (>=1.3.7), but the problem probably exists also in 1.3.6 and earlier.

Setting variable_shapes=False in the gin config explicitly solves the problem, however needing to specify it there doesn't seem like a good default behaviour and can lead many further people to this bug.

...

Environment information

environment independent problem (the issue is in the logic)

For bugs: reproduction and error logs

Steps to reproduce:

To make this repro work, variable_shapes shouldn't be specified in config.gin (the default value is True and it causes the issue), and the input should be of constant shape

from trax.data.inputs import batcher
import gin

gin.parse_config_file('config.gin')

b = batcher()
ev = b.eval_stream(1)
print(next(ev)[0].shape) # That prints (1, ...) regardless of the train/eval bs specified in config.gin

Error logs:

None - I have noticed this by printing eval batch shapes in debugger

syzymon avatar Feb 05 '21 11:02 syzymon