alpa
alpa copied to clipboard
Clip by Global Norm causes Pipeshard Parallel crash
Please describe the bug Hi, This code works with other methods, but crashes when Pipeshard Parallel is used.
Please describe the expected behavior The model compiles without crashing.
System information and environment
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): RHEL 8.5, Docker image: cuda_11.8.0-cudnn8-devel-ubuntu22.04
- Python version: 3.9.6
- CUDA version: 11.8
- NCCL version: (2, 10, 3)?
- cupy version: cupy-cuda11x==11.5.0
- GPU model and memory: 20x Nvidia A100 80GB or 2x Nvidia A100 80GB
- Alpa version: This git commit (It's one older than the latest posted on Feb 13th, 2023) https://github.com/alpa-projects/alpa/commit/c164f3e6d94b500bb2fc23cb85e9a3236e15b59a
- TensorFlow version: tensorflow==2.8.0
- JAX version: jax==3.22.0
To Reproduce Steps to reproduce the behavior: (I am starting this using a SLURM script)
-
Run this py file with ray (It's a modified included test file)
-
If you comment out line 75, it will work again.
global_norm_clipper(kwargs.get("clip_by_global_norm", 1.0)), ## Comment this line to make it work
2b. If you switch the method from PipehardParallel to ShardParallel, it will work again
# method = alpa.ShardParallel() ## Uncomment this line to make it work
- See error
Screenshots
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::trace: 4.55 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::jaxpr operations: 0.13 s
alpa.pipeline_parallel.stage_construction.cluster_layers_and_slice_mesh(): num_devices: 20, num_stages: 20
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::stage construction: 0.17 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::apply grad: 0.24 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::shard stages: 26.39 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::launch meshes: 1.07 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::runtime emitter: 41.77 s
2023-02-13 05:48:11,529 ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=1087988, ip=172.16.130.105, repr=<alpa.device_mesh.MeshHostWorker object at 0x7f15952b5cd0>)
File "/home/sblouir/alpa/alpa/device_mesh.py", line 413, in create_and_set_cross_mesh_communicators
g.create_and_set_xla_communicators(devices, key)
File "/home/sblouir/alpa/alpa/collective/collective_group/nccl_collective_group.py", line 463, in create_and_set_xla_communicators
self.xla_comm_group.nccl_create_communicators(actual_world_size,
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/alpa_nccl_group_base.cc:141: NCCL operation ncclGroupEnd() failed: internal error
2023-02-13 05:48:11,541 ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=3878494, ip=172.16.130.108, repr=<alpa.device_mesh.MeshHostWorker object at 0x7fba34210d90>)
This prints out a lot of times, with different workers
Code snippet to reproduce the problem
"""Utilities for testing."""
from functools import partial
import unittest
from collections.abc import Iterable
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves
from jax.experimental.maps import FrozenDict as FrozenDictJax
import numpy as np
import optax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict as FrozenDictFlax
import alpa
from alpa.api import init, shutdown, parallelize, value_and_grad
from alpa.model.bert_model import BertConfig, FlaxBertLayer
from alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState
from alpa.parallel_method import PipeshardParallel
from alpa.pipeline_parallel.layer_construction import (AutoLayerOption,
ManualLayerOption)
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.stage_construction import (UniformStageOption,
StageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption
from typing import Any, Callable, Optional, Union
from typing import NamedTuple, Optional, Tuple, Callable
from optax._src import base
from optax._src import clipping
from optax._src import combine
from optax._src import factorized
from optax._src import transform
import functools
import numpy as np
from optax._src import base
import optax
import chex
from optax._src import linear_algebra
from optax._src import base
ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]
## Code from Optax's Adafactor optimizer
def crashing_optimizer(learning_rate:float = None, *args, **kwargs,) -> base.GradientTransformation:
def global_norm_clipper(max_norm: float) -> base.GradientTransformation:
def init_fn(params):
del params
return base.EmptyState()
def update_fn(updates, state, params=None):
del params
g_norm = linear_algebra.global_norm(updates)
g_norm = jnp.maximum(max_norm, g_norm)
updates = jax.tree_util.tree_map(lambda t: (t / g_norm) * max_norm, updates)
return updates, state
return base.GradientTransformation(init_fn, update_fn)
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True):
m = -1 if flip_sign else 1
if callable(learning_rate):
return transform.scale_by_schedule(lambda count: m * learning_rate(count))
return transform.scale(m * learning_rate)
tx = [
_scale_by_learning_rate(learning_rate, flip_sign=False),
global_norm_clipper(kwargs.get("clip_by_global_norm", 1.0)), ## Comment this line to make it work
transform.scale(-1),
]
return combine.chain(*tx)
class BasicModel(nn.Module):
num_layers:int
@nn.compact
def __call__(self, inputs=None, labels=None, attention_mask=None, loss_mask=None, target_input_ids=None, *args, **kwargs):
x = nn.Embed(256, 768)(inputs)
x = nn.LayerNorm()(x)
for _ in range(self.num_layers):
x = nn.Dense(16)(x)
x = nn.Dense(768)(x)
y = nn.Embed(256, 768)(labels)
y = nn.LayerNorm()(y)
loss = (y-x)**2
return jnp.sum(loss)
def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,
hidden_size, num_heads,
clip_by_global_norm, use_dynamic_scale,
add_manual_pipeline_marker):
rngkey = jax.random.PRNGKey(0)
inputs = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
labels = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
target_input_ids = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
batch = {
"inputs": inputs,
"attention_mask": attention_mask,
"labels":labels,
"loss_mask":loss_mask,
"target_input_ids":target_input_ids,
}
model = BasicModel(num_layers=num_layers,)
params = model.init(rngkey, **batch,)
# tx = optax.adam(learning_rate=1e-2)
tx = crashing_optimizer(learning_rate = 0.01, clip_by_global_norm=1.0,)
if use_dynamic_scale:
use_master_copy = False
dynamic_scale = DynamicScale()
else:
dynamic_scale = None
use_master_copy = False
state = TrainState.create(apply_fn=model.apply,
params=params,
tx=tx,
dynamic_scale=dynamic_scale,
use_master_copy=use_master_copy)
def train_step(state, batch):
def loss_func(params):
loss = state.apply_fn(params, **batch,)
return loss
dynamic_scale = state.dynamic_scale
if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(loss_func)
dynamic_scale, is_fin, val, grads = grad_fn(state.params)
else:
grad_fn = value_and_grad(loss_func)
val, grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
if dynamic_scale:
new_state = new_state.replace(
opt_state=jax.tree_map(partial(jnp.where, is_fin),
new_state.opt_state, state.opt_state),
params=jax.tree_map(partial(jnp.where, is_fin),
new_state.params, state.params),
master_copy=jax.tree_map(partial(jnp.where,
is_fin), new_state.master_copy,
state.master_copy),
dynamic_scale=dynamic_scale)
return new_state, val
return state, batch, train_step
class PipelineBasicTest(unittest.TestCase):
def setUp(self):
init(cluster="ray")
def tearDown(self):
shutdown()
def run_n_layer_bert(self,
num_layers=alpa.get_global_num_devices(),
batch_size=16,
seq_len=256,
hidden_size=512,
num_heads=512 // 64,
use_remat=False,
clip_by_global_norm=False,
use_dynamic_scale=False,
inject_train_step=None,
manual_pipeline_layer=True,
stage_option: Optional[StageOption] = None,
as_option: Optional[AutoShardingOption] = None,
do_numerical_test: bool = True):
method = PipeshardParallel(
num_micro_batches=4,
default_auto_sharding_option=as_option or AutoShardingOption(),
layer_option=ManualLayerOption(remat_layer=use_remat)
if manual_pipeline_layer else AutoLayerOption(
layer_num=num_layers,
remat_mode="coarse_grained_remat" if use_remat else "none"),
stage_option=stage_option or UniformStageOption())
# Init model
state, batch, train_step = get_bert_layer_train_state_and_step(
batch_size=batch_size,
seq_len=seq_len,
num_layers=num_layers,
hidden_size=hidden_size,
num_heads=num_heads,
clip_by_global_norm=clip_by_global_norm,
use_dynamic_scale=use_dynamic_scale,
add_manual_pipeline_marker=manual_pipeline_layer)
if inject_train_step is not None:
assert isinstance(inject_train_step, Callable)
train_step = inject_train_step
# Compile
serial_train_step = train_step
parallel_train_step = parallelize(train_step, method=method)
executable = parallel_train_step.get_executable(state, batch)
for _ in range(100):
state, loss = parallel_train_step(state, batch)
print(f" loss: {loss}")
# Run correctnesss test
if do_numerical_test:
expected_new_state = None
actual_new_state = None
for i in range(1):
if i > 0:
state = expected_new_state
expected_new_state, expected_val = serial_train_step(
state, batch)
if i > 0:
state = actual_new_state
actual_new_state, actual_val = parallel_train_step(state, batch)
assert_allclose(expected_new_state.params,
actual_new_state.params, 1e-3, 1.5e-3)
assert_allclose(expected_val, actual_val, 1e-3, 1e-3)
hlo_text = executable.get_hlo_text()
return hlo_text
if __name__ == "__main__":
t = PipelineBasicTest()
t.setUp()
x = t.run_n_layer_bert(
manual_pipeline_layer=False,
do_numerical_test=False,
)
print(f"*" * 60,)
print(f" x: {x}")
print(f"*" * 60,)
t.tearDown()
Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem. Been hunting this one down for awhile... Please let me know if any more information can help.
I have also encountered a similar problem. May I inquire if you have resolved it?