jax icon indicating copy to clipboard operation
jax copied to clipboard

Latency Hiding Scheduler leads to x5 memory usage if used without jax.lax.scan

Open qGentry opened this issue 1 year ago • 1 comments

Description

Hi, we're training large (300B, 60 layers) mixture of experts transformer on a 1000+ GPU. We have some non-uniformity in layers so we can't use jax.lax.scan directly to stack layers together - instead, we just call each layer independently. Model doesn't have completely random structure, it is like (3 layers with same structure, 1 with another) repeated 15 times (to achieve 60 layers in total) We would benefit a LOT from overlapping computations & communications but when we try to enable latency hiding scheduler --xla_gpu_enable_latency_hiding_scheduler, this leads to increased in memory usage by a factor of 4-5 (from 50Gb per GPU to 200-250Gb per GPU, which is completely unusable). My guess is that compiler doesn't reuse buffers for async comms in this case for different layers.

We've tested also variant with jax.lax.scan and uniform layers, it seemed to work okay from memory usage point of view - only 20-25% overhead from latency hiding scheduler.

Is this a known problem? Is these any workaround?

System info (python version, jaxlib version, accelerator, etc.)

tested on 0.4.25/0.4.26, 1000+ H100 GPU

qGentry avatar Apr 15 '24 13:04 qGentry

Here is some toy repro tested on JAX 0.4.34

import flax.linen as nn
import jax
import jax.ad_checkpoint
import jax.numpy as jnp
import numpy as np
from flax.linen.linear import default_kernel_init

EMB_DIM = 8192
HID_DIM = 8192

BS = 32
SEQ_LEN = 4096
N_LAYERS = 32

SCAN = False

CHECKPOINT_POLICY = jax.checkpoint_policies.save_and_offload_only_these_names(
    names_which_can_be_saved=[],
    names_which_can_be_offloaded=[],
    offload_src="device",
    offload_dst="pinned_host",
)


mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(8, 1), ("data", "model"))
input_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data", None)
)
target_sharding = jax.sharding.NamedSharding(
    mesh,
    jax.sharding.PartitionSpec(
        "data",
    ),
)
rules = (
    ("batch", "data"),
    ("embedding", None),
    ("hidden", "model"),
    ("q_sequence", "model"),
)


class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x_residual = x
        h = nn.Dense(
            HID_DIM,
            kernel_init=nn.with_logical_partitioning(
                default_kernel_init,
                ("embedding", "hidden"),
            ),
            use_bias=False,
        )(x)
        h = nn.relu(h)
        
        x = nn.Dense(
            EMB_DIM,
            kernel_init=nn.with_logical_partitioning(
                default_kernel_init,
                ("hidden", "embedding"),
            ),
            use_bias=False,
        )(h)
        x = x_residual + x
        # Sequence parallelism
        x = nn.with_logical_constraint(x, ("batch", "q_sequence", None))
        return x


class Output(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=1,
            kernel_init=nn.with_logical_partitioning(
                default_kernel_init,
                ("hidden", None),
            ),
            use_bias=False,
        )(x)[..., 0]
        x = jnp.mean(x, axis=1)
        return x


class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        def apply_module(block, block_input, _):
            block_output = block(block_input)
            return block_output, None

        apply_module = nn.remat(
            apply_module,
            policy=CHECKPOINT_POLICY,
            prevent_cse=False,
        )

        if SCAN:
            x, _ = nn.scan(
                apply_module,
                variable_axes={"params": 0},
                split_rngs={"params": True},
                length=N_LAYERS,
                metadata_params={nn.PARTITION_NAME: "layers"},
            )(MLP(), x, None)
        else:
            for i in range(N_LAYERS):
                x = MLP(name=f"block_{i}")(x)

        preds = Output()(x)
        return preds


def loss_fn(preds, target):
    return jnp.mean((preds - target) ** 2)


def calc_loss(params, inputs, target):
    preds = Model().apply(params, inputs)
    loss = loss_fn(preds, target)
    return loss


def train_step(params, inputs, target):
    loss, grads = jax.value_and_grad(calc_loss)(params, inputs, target)
    params = jax.tree_util.tree_map(lambda p, g: p - 1e-8 * g, params, grads)
    return params, loss


def unbox_logically_partioned(tree, apply_constraint: bool = True):
    return jax.tree_util.tree_map(
        lambda leaf: (
            leaf.unbox(apply_constraint=apply_constraint)
            if isinstance(leaf, nn.LogicallyPartitioned)
            else leaf
        ),
        tree,
        is_leaf=lambda node: isinstance(node, nn.LogicallyPartitioned),
    )


def get_gpu_memory_usage() -> dict[str, float]:
    if jax.default_backend() != "gpu":
        return {}
    num_devices = jax.local_device_count("gpu")
    gpu_memory_usage = []
    for i in range(num_devices):
        memory_stats = jax.local_devices()[i].memory_stats()
        gpu_memory_usage.append(
            memory_stats["peak_bytes_in_use"] / memory_stats["bytes_limit"] * 100
        )
    return {f"GPU{i}": val for i, val in enumerate(gpu_memory_usage)}


with mesh, nn.logical_axis_rules(rules):
    fake_inputs = jnp.empty((BS, SEQ_LEN, EMB_DIM))
    fake_inputs = jax.device_put(fake_inputs, input_sharding)
    fake_target = jnp.empty((BS,))
    fake_target = jax.device_put(fake_target, target_sharding)

    params = Model().init(jax.random.PRNGKey(0), fake_inputs)
    params = unbox_logically_partioned(params)

    train_step_fn = (
        jax.jit(
            train_step,
            in_shardings=(
                jax.tree_util.tree_map(lambda x: x.sharding, params),
                input_sharding,
                target_sharding,
            ),
            out_shardings=(
                jax.tree_util.tree_map(lambda x: x.sharding, params),
                jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
            ),
            donate_argnums=(0,),
        )
        .lower(params, fake_inputs, fake_target)
        .compile()
    )
    with open("compiled.txt", "w") as f:
        f.write(train_step_fn.as_text())

    memory_analysis = train_step_fn.memory_analysis()
    print(
        f"Total size device = {memory_analysis.temp_size_in_bytes / 1024 / 1024 / 1024} GB, "  # noqa E501
        f"weights = {memory_analysis.argument_size_in_bytes / 1024 / 1024 / 1024} GB, "
        f"total: {(memory_analysis.argument_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024 / 1024 / 1024} GB"
    )

    for i in range(10):
        inputs = jax.random.normal(jax.random.PRNGKey(i), (BS, SEQ_LEN, EMB_DIM))
        inputs = jax.device_put(inputs, input_sharding)

        target = jax.random.normal(jax.random.PRNGKey(0), (BS,))
        target = jax.device_put(target, target_sharding)

        if i == 3:
            jax.tree_map(lambda x: x.block_until_ready(), params)
            jax.profiler.start_trace("./profile", create_perfetto_trace=True)
        params, loss = train_step_fn(params, inputs, target)
        if i == 3:
            jax.tree_map(lambda x: x.block_until_ready(), params)
            jax.profiler.stop_trace()
        print(loss)
        print(get_gpu_memory_usage())

Latency hiding scheduler enabled, XLA_FLAGS: --xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= --xla_gpu_enable_latency_hiding_scheduler=true

SCAN=False (hits OOM):

2024-10-17 10:47:03.275923: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 53.46GiB (57397316392 bytes) by rematerialization; only reduced to 70.25GiB (75430461976 bytes), down from 75.50GiB (81067606600 bytes) originally
Total size device = 54.500054121017456 GB, weights = 16.500030532479286 GB, total: 71.00008465349674 GB

SCAN=True:

Total size device = 34.5000324845314 GB, weights = 16.500030532479286 GB, total: 51.00006301701069 GB

Latency hiding scheduler disabled, XLA_FLAGS=--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer=

SCAN=False:

Total size device = 37.50002360343933 GB, weights = 16.500030532479286 GB, total: 54.00005413591862 GB

SCAN=True

Total size device = 35.00000220537186 GB, weights = 16.500030532479286 GB, total: 51.50003273785114 GB

qGentry avatar Oct 17 '24 10:10 qGentry

@qGentry Using JAX 0.4.35 XLA_FLAGS="--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= " and SCAN=False, I'm seeing a failure.

*** Check failure stack trace: ***
    @     0x7f26b3b96dc4  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x7f26b3b96c34  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7f26b3b971e9  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f26ac056141  xla::PjRtStreamExecutorLoadedExecutable::Execute()
    @     0x7f26abfa5d71  pjrt::PJRT_LoadedExecutable_Execute()
    @     0x7f26bbd699fc  xla::PjRtCApiLoadedExecutable::Execute()
    @     0x7f26c1ab0a25  xla::ifrt::PjRtLoadedExecutable::Execute()
    @     0x7f26c1250c49  xla::(anonymous namespace)::ExecuteShardedOnLocalDevicesInternal<>()
    @     0x7f26c12528ee  xla::PyLoadedExecutable::ExecuteSharded()
    @     0x7f26bbc2fc55  xla::ValueOrThrowWrapper<>::operator()()
    @     0x7f26bbc2fabd  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7f26c1a86eb8  nanobind::detail::nb_func_vectorcall_complex()
    @     0x56227ff6aabb  _PyEval_EvalFrameDefault
Aborted (core dumped)

Any chance you have other flags or env variables set?

sfvaroglu avatar Oct 30 '24 20:10 sfvaroglu

@qGentry Can you please set XLA_CLIENT_MEM_FRACTION=0.95 and use --xla_gpu_copy_insertion_use_region_analysis in addition to your existing flags and report back if it resolves the issue?

sfvaroglu avatar Nov 04 '24 18:11 sfvaroglu

@qGentry xla_gpu_memory_limit_slop_factor flag could also help in this case. The default value is 95, so you can experiment with lower values (90, 80, 70, etc.). You can find more info about this flag at https://github.com/jax-ml/jax/blob/main/docs/gpu_performance_tips.md. Let me know if you see any issues.

sfvaroglu avatar Dec 16 '24 17:12 sfvaroglu