trax
trax copied to clipboard
Effective train/eval batch_size is always 1 due to batcher default arg "variable_shapes=True"
Description
When providing inputs with a constant shape - for instance imagenet32 where examples are always of length 3072, but it also applies e.g. to this config: https://github.com/google/trax/blob/master/trax/supervised/configs/reformer_imagenet64.gin
and not specifying variable_shapes=False
, as it isn't in the config above, the effective training and evaluation
batch_size
is always equal to 1. The reason for that is default argument variable_shapes
in this function set to True,
which enables the bucketer to do some magic so that the effective train/eval batch becomes 1, regardless of what was
specified by the user:
https://github.com/google/trax/blob/master/trax/data/inputs.py#L791
That seems like an annoying bug that causes huge unexplained variance
among eval batches and makes training on a batch bigger than 1 per device possible only virtually (without even being aware of this variable_shapes
arg and using constant shape data).
My repro confirming that has been done using the latest trax dev (>=1.3.7), but the problem probably exists also in 1.3.6 and earlier.
Setting variable_shapes=False
in the gin config explicitly solves the problem, however needing to specify it there doesn't seem like a good default behaviour and can lead many further people to this bug.
...
Environment information
environment independent problem (the issue is in the logic)
For bugs: reproduction and error logs
Steps to reproduce:
To make this repro work, variable_shapes shouldn't be specified in config.gin (the default value is True and it causes the issue), and the input should be of constant shape
from trax.data.inputs import batcher
import gin
gin.parse_config_file('config.gin')
b = batcher()
ev = b.eval_stream(1)
print(next(ev)[0].shape) # That prints (1, ...) regardless of the train/eval bs specified in config.gin
Error logs:
None - I have noticed this by printing eval batch shapes in debugger