fsdp=16 model=16 gbs=16 should work on 256 chips
fsdp=16 model=16 global_batch_size=16 should work on 256 chips
The use case is being able to use a global batch size smaller than total jax processes.
This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291
Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:
I1022 20:32:33.715831 139189201369088 trainer.py:323] gpt_trainer process 19 step -1] Global mesh: Mesh('pipeline': 1, 'data': 1, 'expert': 1, 'fsdp': 16, 'seq': 1, 'model': 16)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
app.run(main)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
launch_trainer.run_trainer(trainer_config)
File "/root/axlearn/common/launch_trainer.py", line 129, in run_trainer
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
File "/root/axlearn/common/config.py", line 734, in instantiate
return self.klass(self, **kwargs)
File "/root/axlearn/common/module.py", line 520, in __call__
instance = super().__call__(*args, **kwds)
File "/root/axlearn/common/trainer.py", line 244, in __init__
self._add_child("input", cfg.input.set(is_training=True))
File "/root/axlearn/common/module.py", line 760, in _add_child
module = child_config.instantiate(parent=self, **kwargs)
File "/root/axlearn/common/config.py", line 734, in instantiate
return self.klass(self, **kwargs)
File "/root/axlearn/common/module.py", line 520, in __call__
instance = super().__call__(*args, **kwds)
File "/root/axlearn/common/input_tf_data.py", line 1185, in __init__
self._batcher = maybe_set_config(cfg.batcher, is_training=cfg.is_training).instantiate()
File "/root/axlearn/common/config.py", line 801, in instantiate
return self.fn(*args, **kwargs)
File "/root/axlearn/common/input_tf_data.py", line 799, in batch
raise ValueError(
ValueError: global_batch_size (16.0) must be divisible by number of JAX processes (data feeds) (64).
Hi @samos123 , you can use the input dispatcher: https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_tf_data.py#L1165-L1167 https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_dispatch.py#L17-L33
Some hosts will produce padding feeds which will be dropped during input dispatch. I have some ideas to make this a bit simpler soon, but this should unblock you for now.
Maybe I'm misunderstanding the code... but input_dispatcher is None for the fuji models, so wouldn't it default to InputDispatcher already? @markblee
Are you saying I should create a custom InputDispatcher and pass that instead? That may make sense. Looking into that further.
Would these be the right settings for fsdp=16 and model=16 gbs=16 on v6e-256?
# Usually left unset. Defaults to
# max(feed_logical_batch_size * num_physical_feeds, jax.device_count()).
global_physical_batch_size = 16
# The total number of physical feeds across all hosts. Defaults to jax.process_count().
num_physical_feeds = 64
# The local physical feed index. Must be in [0, num_physical_feeds).
# Defaults to jax.process_index().
physical_feed_index: Optional[int] = None
Currently in Fuji models this is set:
cfg.input = input_tf_data.Input.default_config().set(
is_training=True,
source=train_input_source,
processor=config_for_function(input_tf_data.identity),
batcher=config_for_function(input_tf_data.batch).set(
global_batch_size=train_batch_size,
prefetch_buffer_size=tf.data.AUTOTUNE,
pad_example_fn=input_tf_data.default_pad_example_fn,
),
)
This is what's inside input_tf_data.batch function:
num_data_feeds = jax.process_count()
if global_batch_size % num_data_feeds != 0:
raise ValueError(
f"global_batch_size ({global_batch_size}) must be divisible by "
f"number of JAX processes (data feeds) ({num_data_feeds})."
)
So I suspect I need to modify the batch function directly.
@samos123 please read through the input logic, there is a logical batch, and physical batch, please understand the two key logics there and you should be good to go.
@samos123
global_logical_batch_size = 16
global_physical_batch_size = 256 # You can also try 64 here to see if it works.
logical_feed_indices = list(range(0, 64, 4)) # 1 in 4 host read a single "real" batch.
You don't need to specify other fields because the batcher can infer them. Also like Kelvin said, be sure to understand the meaning of these fields.
Just sharing for now since it's related. I also hit this error when trying fsdp=16, mdoel=16 and gbs=128 on 256 chips:
Stack Summary (most recent call last): File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals)
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
app.run(main)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 330, in run raise File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
launch_trainer.run_trainer(trainer_config)
File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer output = trainer.run(prng_key)
Wrapped call axlearn.common.trainer.SpmdTrainer.run(jaxlib.xla_extension.ArrayImpl) File "/root/axlearn/common/trainer.py", line 501, in run
utils.host_to_global_device_array(input_batch), File "/root/axlearn/common/utils.py", line 653, in host_to_global_device_array
device_arrays = jax.tree.map(put_to_devices, host_arrays) File "/opt/venv/lib/python3.10/site-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
File "/root/axlearn/common/utils.py", line 637, in put_to_devices_fully_partitioned
raise ValueError(f"({x.shape}) cannot be sharded across {len_local_devices} devices.")
ValueError: ((2, 8192)) cannot be sharded across 4 devices.
@samos123 please understand the physical branch logic, you shouldn't do global bs=128 over 256 chips, it is always 256. Logical batch is something you care more about, and physical batch can be auto configured when we use logical, but we didn't have that in place yet.
fwiw, you also cannot do model==16 is the num_heads, which is 8 for kv in GQA model. But that's a separate issue and you haven't hit that yet.