axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

fsdp=16 model=16 gbs=16 should work on 256 chips

Open samos123 opened this issue 1 year ago • 9 comments

fsdp=16 model=16 global_batch_size=16 should work on 256 chips

The use case is being able to use a global batch size smaller than total jax processes.

This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291

Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:

I1022 20:32:33.715831 139189201369088 trainer.py:323] gpt_trainer process  19 step       -1] Global mesh: Mesh('pipeline': 1, 'data': 1, 'expert': 1, 'fsdp': 16, 'seq': 1, 'model': 16)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
    app.run(main)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
    launch_trainer.run_trainer(trainer_config)
  File "/root/axlearn/common/launch_trainer.py", line 129, in run_trainer
    trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/trainer.py", line 244, in __init__
    self._add_child("input", cfg.input.set(is_training=True))
  File "/root/axlearn/common/module.py", line 760, in _add_child
    module = child_config.instantiate(parent=self, **kwargs)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/input_tf_data.py", line 1185, in __init__
    self._batcher = maybe_set_config(cfg.batcher, is_training=cfg.is_training).instantiate()
  File "/root/axlearn/common/config.py", line 801, in instantiate
    return self.fn(*args, **kwargs)
  File "/root/axlearn/common/input_tf_data.py", line 799, in batch
    raise ValueError(
ValueError: global_batch_size (16.0) must be divisible by number of JAX processes (data feeds) (64).

samos123 avatar Oct 22 '24 20:10 samos123

Hi @samos123 , you can use the input dispatcher: https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_tf_data.py#L1165-L1167 https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_dispatch.py#L17-L33

Some hosts will produce padding feeds which will be dropped during input dispatch. I have some ideas to make this a bit simpler soon, but this should unblock you for now.

markblee avatar Oct 22 '24 21:10 markblee

Maybe I'm misunderstanding the code... but input_dispatcher is None for the fuji models, so wouldn't it default to InputDispatcher already? @markblee

samos123 avatar Oct 23 '24 17:10 samos123

Are you saying I should create a custom InputDispatcher and pass that instead? That may make sense. Looking into that further.

samos123 avatar Oct 23 '24 17:10 samos123

Would these be the right settings for fsdp=16 and model=16 gbs=16 on v6e-256?

        # Usually left unset. Defaults to
        # max(feed_logical_batch_size * num_physical_feeds, jax.device_count()).
        global_physical_batch_size = 16
  
        # The total number of physical feeds across all hosts. Defaults to jax.process_count().
        num_physical_feeds = 64

        # The local physical feed index. Must be in [0, num_physical_feeds).
        # Defaults to jax.process_index().
        physical_feed_index: Optional[int] = None

samos123 avatar Oct 23 '24 18:10 samos123

Currently in Fuji models this is set:

        cfg.input = input_tf_data.Input.default_config().set(
            is_training=True,
            source=train_input_source,
            processor=config_for_function(input_tf_data.identity),
            batcher=config_for_function(input_tf_data.batch).set(
                global_batch_size=train_batch_size,
                prefetch_buffer_size=tf.data.AUTOTUNE,
                pad_example_fn=input_tf_data.default_pad_example_fn,
            ),
        )

This is what's inside input_tf_data.batch function:

    num_data_feeds = jax.process_count()
    if global_batch_size % num_data_feeds != 0:
        raise ValueError(
            f"global_batch_size ({global_batch_size}) must be divisible by "
            f"number of JAX processes (data feeds) ({num_data_feeds})."
        )

So I suspect I need to modify the batch function directly.

samos123 avatar Oct 23 '24 18:10 samos123

@samos123 please read through the input logic, there is a logical batch, and physical batch, please understand the two key logics there and you should be good to go.

kelvin-zou avatar Oct 23 '24 19:10 kelvin-zou

@samos123

global_logical_batch_size = 16
global_physical_batch_size = 256 # You can also try 64 here to see if it works.
logical_feed_indices = list(range(0, 64, 4)) # 1 in 4 host read a single "real" batch.

You don't need to specify other fields because the batcher can infer them. Also like Kelvin said, be sure to understand the meaning of these fields.

hanzhi713 avatar Oct 23 '24 20:10 hanzhi713

Just sharing for now since it's related. I also hit this error when trying fsdp=16, mdoel=16 and gbs=128 on 256 chips:

Stack Summary (most recent call last):                                                                                     File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code                                                           exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
    app.run(main)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 330, in run                                                raise                                                                                                                  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
    launch_trainer.run_trainer(trainer_config)
  File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer                                                    output = trainer.run(prng_key)
  Wrapped call axlearn.common.trainer.SpmdTrainer.run(jaxlib.xla_extension.ArrayImpl)                                      File "/root/axlearn/common/trainer.py", line 501, in run
    utils.host_to_global_device_array(input_batch),                                                                        File "/root/axlearn/common/utils.py", line 653, in host_to_global_device_array
    device_arrays = jax.tree.map(put_to_devices, host_arrays)                                                              File "/opt/venv/lib/python3.10/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
  File "/root/axlearn/common/utils.py", line 637, in put_to_devices_fully_partitioned
    raise ValueError(f"({x.shape}) cannot be sharded across {len_local_devices} devices.")
ValueError: ((2, 8192)) cannot be sharded across 4 devices.

samos123 avatar Oct 31 '24 06:10 samos123

@samos123 please understand the physical branch logic, you shouldn't do global bs=128 over 256 chips, it is always 256. Logical batch is something you care more about, and physical batch can be auto configured when we use logical, but we didn't have that in place yet.

fwiw, you also cannot do model==16 is the num_heads, which is 8 for kv in GQA model. But that's a separate issue and you haven't hit that yet.

kelvin-zou avatar Oct 31 '24 16:10 kelvin-zou