orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Sharded loading performance question

Open hr0nix opened this issue 2 years ago • 10 comments

I'm benchmarking loading a 65B sharded transformer model on multiple GPUs on the same host. The checkpoint itself is not sharded, but when the model is being loaded, a correct sharding is being supplied.

I've noticed several performance-related peculiarities that I would like to understand better, hopefully you can comment on them.

  • For some reason, when the weights of the model are in scan-friendly format (that is, transformer block weights are stacked along the layer dimension in a single tensor), loading on 8 GPUs takes ~4x longer than when the weights for different layers are stored separately. This is surprising as I would expect loading time to be either unaffected or smaller for scan-friendly format.
  • Loading scan-friendly model using 2 GPUs takes 2.5x less time than when using 8 GPUs. This is also weird as I would expect the total number of disk reads to be approximately the same in both cases.
  • When OCDBT is enabled, checkpoint loading seems to take ~5-10% longer. Is this to be expected?

Any clarifications or suggestions how to speed things up would be much appreciated.

hr0nix avatar Jul 27 '23 10:07 hr0nix

  • If you are not using OCDBT, it is expected that checkpoints with more parameters take longer to load. So for a model with stacked weights, it should be much faster to load typically..... approximately how many parameter directories are present?
  • The number of reads would be the same regardless of the number of devices, but moving arrays from host to device would take longer for more devices. The read should dominate the restore time though, so it's strange that this causes such a slowdown. Actually, what is the absolute time required to load?
  • Depending on the scale of the model, there could still be a few issues with OCDBT, though in general it is expected to be better or approximately equal in all cases. You could try removing the experimental_read_coalescing_threshold_bytes option here - that has helped things on some scales, so it may or may not help you.

cpgaffney1 avatar Jul 27 '23 20:07 cpgaffney1

Actually, what is the absolute time required to load?

Loading 65B model with bfloat16 weights on 8 GPUs takes ~10 min in scan-friendly mode, 2.5 min in normal mode. Reading this amount of data from disk is much faster, so the time's definitely not dominated by reads.

Maybe I'm doing the loading wrong, I don't know. Here's how I do it, perhaps you can spot a problem?

def logically_partitioned_tree_to_mesh_sharding(
    logically_partitioned_tree: PyTree,
    mesh: jax.sharding.Mesh,
    logical_rules: LogicalRules,
) -> PyTree:
    logical_partition_spec = nn.get_partition_spec(logically_partitioned_tree)
    return nn.logical_to_mesh_sharding(
        logical_partition_spec, mesh, rules=list(logical_rules.items())
    )

def create_checkpoint_manager(path: Path, create_path: bool = True):
    return ocp.CheckpointManager(
        directory=path,
        checkpointers={
            "state": ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
            "metadata": ocp.Checkpointer(ocp.JsonCheckpointHandler()),
        },
        options=ocp.CheckpointManagerOptions(create=create_path),
    )

def load_model_from_local_dir(
    weights: LocalDirWeightSourceConfig,
    sharding_config: ShardingConfig,
) -> tuple[ModelBase, PyTree, TokenizerBase]:
    checkpoint_manager = create_checkpoint_manager(path=weights.path, create_path=False)
    step = weights.step or checkpoint_manager.latest_step()
    if step is None:
        raise ValueError(f"No checkpoints found in {weights.path}")

    # First restore model metadata
    metadata_dict = checkpoint_manager.restore(step, items={"metadata": None})[
        "metadata"
    ]
    metadata = ModelMetadata(**metadata_dict)

    # Use metadata to create tokenizer and model
    tokenizer = create_tokenizer(metadata.tokenizer)
    model = create_model(
        metadata.model,
        vocab_size=tokenizer.vocab_size,
    )

    # Figure out parameter sharding induced by the model
    init_model_params_partial = functools.partial(
        init_model_params,
        model,
        jax.random.PRNGKey(
            0
        ),  # RNG key doesn't matter here as we will be overriding params anyway
    )
    abstract_model_params = jax.eval_shape(init_model_params_partial)
    device_mesh = create_device_mesh(sharding_config.mesh)
    param_sharding = logically_partitioned_tree_to_mesh_sharding(
        abstract_model_params, device_mesh, sharding_config.logical_rules
    )

    # Figure out the structure of the saved checkpoint
    saved_state_structure = checkpoint_manager.structure()["state"]

    # Lazy-load everything by default to avoid loading the whole optimizer state
    state_restore_args = jax.tree_util.tree_map(
        lambda _: ocp.type_handlers.RestoreArgs(lazy=True),
        saved_state_structure,
    )

    # Override restore args for params to load them eagerly
    # and with the correct partitioning
    state_restore_args["params"] = jax.tree_util.tree_map(
        lambda s: ocp.type_handlers.ArrayRestoreArgs(sharding=s),
        param_sharding,
    )

    # Actually load the params
    restored = checkpoint_manager.restore(
        checkpoint_manager.latest_step(),
        items={"state": state_restore_args},
        restore_kwargs={"state": {"restore_args": state_restore_args}},
    )

    return model, restored["state"]["params"], tokenizer

hr0nix avatar Jul 27 '23 21:07 hr0nix

It may have to do with loading to GPU, this is not something that is well tested. Could you provide a minimal repro, along with details about the environment you're using? Might make sense to have JAX team take a look at it.

cpgaffney1 avatar Jul 27 '23 23:07 cpgaffney1

Here's a repro that illustrates (on a smaller scale) that save/load in scan-friendly format is significantly slower: https://gist.github.com/hr0nix/ab08f3fc31d57d02472f190e30d7fe39

On my machine this script prints:

$ python3.10 ./repro.py
Scan: True
Saving checkpoint took 173.67339634895325 seconds
Loading checkpoint took 102.56669664382935 seconds
Scan: False
Saving checkpoint took 68.30835676193237 seconds
Loading checkpoint took 52.649762868881226 seconds

Environment info:

>>> jax.__version__
'0.4.14'
>>> orbax.checkpoint.__version__
'0.3.0'

hr0nix avatar Jul 29 '23 01:07 hr0nix

Hey @cpgaffney1, will someone be able to look at the repro in the foreseeable future? We'd really like to switch to scan-based models as they are jit-compiled 60x faster :-)

hr0nix avatar Jul 31 '23 11:07 hr0nix

We're looking into it as of now, I'll let you know when I have new info!

cpgaffney1 avatar Jul 31 '23 14:07 cpgaffney1

Hi @hr0nix, sorry for the delay on this, we're just a little stumped by the stacked vs. unstacked issue. We also don't have any tooling set up yet for running on GPUs, so we haven't been able to get a fully accurate reproduction. That said, we are still intending to take a look when we have a chance.

For your 5-10% regression when OCDBT is active, I've learned that removing the experimental_read_coalescing_threshold_bytes option should actually solve the issue. Have you had a chance to try that?

cpgaffney1 avatar Aug 10 '23 21:08 cpgaffney1

Is there a way to remove this option without modifying the orbax code?

hr0nix avatar Aug 11 '23 16:08 hr0nix

No, there isn't an option that can be set. Ideally Tensorstore could just optimize this setting away completely if we know in what contexts better performance is offered by removing it.

cpgaffney1 avatar Aug 14 '23 13:08 cpgaffney1

So we think the issue you're observing with checkpoints with stacked layers being much slower than an equivalent unstacked checkpoint may be attributed to how the JAX serialization library and Tensorstore interact. If previously we had n parameters, which are stacked into a single parameter, that parameter will now be saved as a single chunk. There are additional costs associated with creating a directory for each of those n parameters in the unstacked version, but in certain configurations, it's possible that the higher cost associated with not-parallelizing those n chunk writes outweighs the cost of directory creation.

Note that if the larger parameter is sharded into, say, 8 distinct chunks, JAX serialization will write n / 8 chunks in parallel. Ideally then, we would like to have a solution that subdivides large chunks, meaning that the chunk size will not be tied exactly to the shard size.

We'll be looking into how to do this optimization, since we don't want it to be a hyperparameter subject to tuning by the user. In the meantime, if you really need to address this yourself, you could override TypeHandler and replace the calls to jax.serialization with your own implementation, which would divide the current chunks being written in _write_array into smaller chunks.

cpgaffney1 avatar Aug 15 '23 19:08 cpgaffney1