jax
jax copied to clipboard
Latency Hiding Scheduler leads to x5 memory usage if used without jax.lax.scan
Description
Hi, we're training large (300B, 60 layers) mixture of experts transformer on a 1000+ GPU. We have some non-uniformity in layers so we can't use jax.lax.scan directly to stack layers together - instead, we just call each layer independently.
Model doesn't have completely random structure, it is like (3 layers with same structure, 1 with another) repeated 15 times (to achieve 60 layers in total)
We would benefit a LOT from overlapping computations & communications but when we try to enable latency hiding scheduler --xla_gpu_enable_latency_hiding_scheduler, this leads to increased in memory usage by a factor of 4-5 (from 50Gb per GPU to 200-250Gb per GPU, which is completely unusable).
My guess is that compiler doesn't reuse buffers for async comms in this case for different layers.
We've tested also variant with jax.lax.scan and uniform layers, it seemed to work okay from memory usage point of view - only 20-25% overhead from latency hiding scheduler.
Is this a known problem? Is these any workaround?
System info (python version, jaxlib version, accelerator, etc.)
tested on 0.4.25/0.4.26, 1000+ H100 GPU
Here is some toy repro tested on JAX 0.4.34
import flax.linen as nn
import jax
import jax.ad_checkpoint
import jax.numpy as jnp
import numpy as np
from flax.linen.linear import default_kernel_init
EMB_DIM = 8192
HID_DIM = 8192
BS = 32
SEQ_LEN = 4096
N_LAYERS = 32
SCAN = False
CHECKPOINT_POLICY = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[],
offload_src="device",
offload_dst="pinned_host",
)
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(8, 1), ("data", "model"))
input_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data", None)
)
target_sharding = jax.sharding.NamedSharding(
mesh,
jax.sharding.PartitionSpec(
"data",
),
)
rules = (
("batch", "data"),
("embedding", None),
("hidden", "model"),
("q_sequence", "model"),
)
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x_residual = x
h = nn.Dense(
HID_DIM,
kernel_init=nn.with_logical_partitioning(
default_kernel_init,
("embedding", "hidden"),
),
use_bias=False,
)(x)
h = nn.relu(h)
x = nn.Dense(
EMB_DIM,
kernel_init=nn.with_logical_partitioning(
default_kernel_init,
("hidden", "embedding"),
),
use_bias=False,
)(h)
x = x_residual + x
# Sequence parallelism
x = nn.with_logical_constraint(x, ("batch", "q_sequence", None))
return x
class Output(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(
features=1,
kernel_init=nn.with_logical_partitioning(
default_kernel_init,
("hidden", None),
),
use_bias=False,
)(x)[..., 0]
x = jnp.mean(x, axis=1)
return x
class Model(nn.Module):
@nn.compact
def __call__(self, x):
def apply_module(block, block_input, _):
block_output = block(block_input)
return block_output, None
apply_module = nn.remat(
apply_module,
policy=CHECKPOINT_POLICY,
prevent_cse=False,
)
if SCAN:
x, _ = nn.scan(
apply_module,
variable_axes={"params": 0},
split_rngs={"params": True},
length=N_LAYERS,
metadata_params={nn.PARTITION_NAME: "layers"},
)(MLP(), x, None)
else:
for i in range(N_LAYERS):
x = MLP(name=f"block_{i}")(x)
preds = Output()(x)
return preds
def loss_fn(preds, target):
return jnp.mean((preds - target) ** 2)
def calc_loss(params, inputs, target):
preds = Model().apply(params, inputs)
loss = loss_fn(preds, target)
return loss
def train_step(params, inputs, target):
loss, grads = jax.value_and_grad(calc_loss)(params, inputs, target)
params = jax.tree_util.tree_map(lambda p, g: p - 1e-8 * g, params, grads)
return params, loss
def unbox_logically_partioned(tree, apply_constraint: bool = True):
return jax.tree_util.tree_map(
lambda leaf: (
leaf.unbox(apply_constraint=apply_constraint)
if isinstance(leaf, nn.LogicallyPartitioned)
else leaf
),
tree,
is_leaf=lambda node: isinstance(node, nn.LogicallyPartitioned),
)
def get_gpu_memory_usage() -> dict[str, float]:
if jax.default_backend() != "gpu":
return {}
num_devices = jax.local_device_count("gpu")
gpu_memory_usage = []
for i in range(num_devices):
memory_stats = jax.local_devices()[i].memory_stats()
gpu_memory_usage.append(
memory_stats["peak_bytes_in_use"] / memory_stats["bytes_limit"] * 100
)
return {f"GPU{i}": val for i, val in enumerate(gpu_memory_usage)}
with mesh, nn.logical_axis_rules(rules):
fake_inputs = jnp.empty((BS, SEQ_LEN, EMB_DIM))
fake_inputs = jax.device_put(fake_inputs, input_sharding)
fake_target = jnp.empty((BS,))
fake_target = jax.device_put(fake_target, target_sharding)
params = Model().init(jax.random.PRNGKey(0), fake_inputs)
params = unbox_logically_partioned(params)
train_step_fn = (
jax.jit(
train_step,
in_shardings=(
jax.tree_util.tree_map(lambda x: x.sharding, params),
input_sharding,
target_sharding,
),
out_shardings=(
jax.tree_util.tree_map(lambda x: x.sharding, params),
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
),
donate_argnums=(0,),
)
.lower(params, fake_inputs, fake_target)
.compile()
)
with open("compiled.txt", "w") as f:
f.write(train_step_fn.as_text())
memory_analysis = train_step_fn.memory_analysis()
print(
f"Total size device = {memory_analysis.temp_size_in_bytes / 1024 / 1024 / 1024} GB, " # noqa E501
f"weights = {memory_analysis.argument_size_in_bytes / 1024 / 1024 / 1024} GB, "
f"total: {(memory_analysis.argument_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024 / 1024 / 1024} GB"
)
for i in range(10):
inputs = jax.random.normal(jax.random.PRNGKey(i), (BS, SEQ_LEN, EMB_DIM))
inputs = jax.device_put(inputs, input_sharding)
target = jax.random.normal(jax.random.PRNGKey(0), (BS,))
target = jax.device_put(target, target_sharding)
if i == 3:
jax.tree_map(lambda x: x.block_until_ready(), params)
jax.profiler.start_trace("./profile", create_perfetto_trace=True)
params, loss = train_step_fn(params, inputs, target)
if i == 3:
jax.tree_map(lambda x: x.block_until_ready(), params)
jax.profiler.stop_trace()
print(loss)
print(get_gpu_memory_usage())
Latency hiding scheduler enabled, XLA_FLAGS: --xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= --xla_gpu_enable_latency_hiding_scheduler=true
SCAN=False (hits OOM):
2024-10-17 10:47:03.275923: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 53.46GiB (57397316392 bytes) by rematerialization; only reduced to 70.25GiB (75430461976 bytes), down from 75.50GiB (81067606600 bytes) originally
Total size device = 54.500054121017456 GB, weights = 16.500030532479286 GB, total: 71.00008465349674 GB
SCAN=True:
Total size device = 34.5000324845314 GB, weights = 16.500030532479286 GB, total: 51.00006301701069 GB
Latency hiding scheduler disabled, XLA_FLAGS=--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer=
SCAN=False:
Total size device = 37.50002360343933 GB, weights = 16.500030532479286 GB, total: 54.00005413591862 GB
SCAN=True
Total size device = 35.00000220537186 GB, weights = 16.500030532479286 GB, total: 51.50003273785114 GB
@qGentry Using JAX 0.4.35 XLA_FLAGS="--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= " and SCAN=False, I'm seeing a failure.
*** Check failure stack trace: ***
@ 0x7f26b3b96dc4 absl::lts_20230802::log_internal::LogMessage::SendToLog()
@ 0x7f26b3b96c34 absl::lts_20230802::log_internal::LogMessage::Flush()
@ 0x7f26b3b971e9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f26ac056141 xla::PjRtStreamExecutorLoadedExecutable::Execute()
@ 0x7f26abfa5d71 pjrt::PJRT_LoadedExecutable_Execute()
@ 0x7f26bbd699fc xla::PjRtCApiLoadedExecutable::Execute()
@ 0x7f26c1ab0a25 xla::ifrt::PjRtLoadedExecutable::Execute()
@ 0x7f26c1250c49 xla::(anonymous namespace)::ExecuteShardedOnLocalDevicesInternal<>()
@ 0x7f26c12528ee xla::PyLoadedExecutable::ExecuteSharded()
@ 0x7f26bbc2fc55 xla::ValueOrThrowWrapper<>::operator()()
@ 0x7f26bbc2fabd nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x7f26c1a86eb8 nanobind::detail::nb_func_vectorcall_complex()
@ 0x56227ff6aabb _PyEval_EvalFrameDefault
Aborted (core dumped)
Any chance you have other flags or env variables set?
@qGentry Can you please set XLA_CLIENT_MEM_FRACTION=0.95 and use --xla_gpu_copy_insertion_use_region_analysis in addition to your existing flags and report back if it resolves the issue?
@qGentry xla_gpu_memory_limit_slop_factor flag could also help in this case. The default value is 95, so you can experiment with lower values (90, 80, 70, etc.). You can find more info about this flag at https://github.com/jax-ml/jax/blob/main/docs/gpu_performance_tips.md. Let me know if you see any issues.