alpa icon indicating copy to clipboard operation
alpa copied to clipboard

Clip by Global Norm causes Pipeshard Parallel crash

Open samblouir opened this issue 2 years ago • 1 comments

Please describe the bug Hi, This code works with other methods, but crashes when Pipeshard Parallel is used.

Please describe the expected behavior The model compiles without crashing.

System information and environment

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): RHEL 8.5, Docker image: cuda_11.8.0-cudnn8-devel-ubuntu22.04
  • Python version: 3.9.6
  • CUDA version: 11.8
  • NCCL version: (2, 10, 3)?
  • cupy version: cupy-cuda11x==11.5.0
  • GPU model and memory: 20x Nvidia A100 80GB or 2x Nvidia A100 80GB
  • Alpa version: This git commit (It's one older than the latest posted on Feb 13th, 2023) https://github.com/alpa-projects/alpa/commit/c164f3e6d94b500bb2fc23cb85e9a3236e15b59a
  • TensorFlow version: tensorflow==2.8.0
  • JAX version: jax==3.22.0

To Reproduce Steps to reproduce the behavior: (I am starting this using a SLURM script)

  1. Run this py file with ray (It's a modified included test file)

  2. If you comment out line 75, it will work again.

global_norm_clipper(kwargs.get("clip_by_global_norm", 1.0)), ## Comment this line to make it work

2b. If you switch the method from PipehardParallel to ShardParallel, it will work again

# method = alpa.ShardParallel() ## Uncomment this line to make it work

  1. See error

Screenshots

alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::trace: 4.55 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::jaxpr operations: 0.13 s
alpa.pipeline_parallel.stage_construction.cluster_layers_and_slice_mesh():  num_devices: 20,  num_stages: 20
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::stage construction: 0.17 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::apply grad: 0.24 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::shard stages: 26.39 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::launch meshes: 1.07 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::runtime emitter: 41.77 s
2023-02-13 05:48:11,529	ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=1087988, ip=172.16.130.105, repr=<alpa.device_mesh.MeshHostWorker object at 0x7f15952b5cd0>)
  File "/home/sblouir/alpa/alpa/device_mesh.py", line 413, in create_and_set_cross_mesh_communicators
    g.create_and_set_xla_communicators(devices, key)
  File "/home/sblouir/alpa/alpa/collective/collective_group/nccl_collective_group.py", line 463, in create_and_set_xla_communicators
    self.xla_comm_group.nccl_create_communicators(actual_world_size,
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/alpa_nccl_group_base.cc:141: NCCL operation ncclGroupEnd() failed: internal error
2023-02-13 05:48:11,541	ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=3878494, ip=172.16.130.108, repr=<alpa.device_mesh.MeshHostWorker object at 0x7fba34210d90>)

This prints out a lot of times, with different workers

Code snippet to reproduce the problem

"""Utilities for testing."""
from functools import partial
import unittest
from collections.abc import Iterable
from typing import Callable, Optional

import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves
from jax.experimental.maps import FrozenDict as FrozenDictJax
import numpy as np
import optax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict as FrozenDictFlax

import alpa
from alpa.api import init, shutdown, parallelize, value_and_grad
from alpa.model.bert_model import BertConfig, FlaxBertLayer
from alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState
from alpa.parallel_method import PipeshardParallel
from alpa.pipeline_parallel.layer_construction import (AutoLayerOption,
													   ManualLayerOption)
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.stage_construction import (UniformStageOption,
													   StageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption


from typing import Any, Callable, Optional, Union
from typing import NamedTuple, Optional, Tuple, Callable
from optax._src import base
from optax._src import clipping
from optax._src import combine
from optax._src import factorized
from optax._src import transform
import functools
import numpy as np
from optax._src import base
import optax
import chex
from optax._src import linear_algebra
from optax._src import base

ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]



## Code from Optax's Adafactor optimizer
def crashing_optimizer(learning_rate:float = None, *args, **kwargs,) -> base.GradientTransformation:

	def global_norm_clipper(max_norm: float) -> base.GradientTransformation:
		def init_fn(params):
			del params
			return base.EmptyState()

		def update_fn(updates, state, params=None):
			del params
			g_norm = linear_algebra.global_norm(updates)
			g_norm = jnp.maximum(max_norm, g_norm)
			updates = jax.tree_util.tree_map(lambda t: (t / g_norm) * max_norm, updates)
			return updates, state

		return base.GradientTransformation(init_fn, update_fn)
	
	
	def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True):
		m = -1 if flip_sign else 1
		if callable(learning_rate):
			return transform.scale_by_schedule(lambda count: m * learning_rate(count))
		return transform.scale(m * learning_rate)

	tx = [
			_scale_by_learning_rate(learning_rate, flip_sign=False),
			global_norm_clipper(kwargs.get("clip_by_global_norm", 1.0)), ## Comment this line to make it work
			transform.scale(-1),
		]
	return combine.chain(*tx)

class BasicModel(nn.Module):
	num_layers:int
	@nn.compact
	def __call__(self, inputs=None, labels=None, attention_mask=None, loss_mask=None, target_input_ids=None, *args, **kwargs):
		x = nn.Embed(256, 768)(inputs)
		x = nn.LayerNorm()(x)
		for _ in range(self.num_layers):
			x = nn.Dense(16)(x)
		x = nn.Dense(768)(x)

		y = nn.Embed(256, 768)(labels)
		y = nn.LayerNorm()(y)

		loss = (y-x)**2
		return jnp.sum(loss)

def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,
										hidden_size, num_heads,
										clip_by_global_norm, use_dynamic_scale,
										add_manual_pipeline_marker):
	rngkey = jax.random.PRNGKey(0)
	inputs = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
	labels = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
	attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
	loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
	target_input_ids = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)

	batch = {
		"inputs": inputs, 
		"attention_mask": attention_mask,
		"labels":labels, 
		"loss_mask":loss_mask, 
		"target_input_ids":target_input_ids, 
		}

	model = BasicModel(num_layers=num_layers,)
	params = model.init(rngkey, **batch,)

	# tx = optax.adam(learning_rate=1e-2)
	tx = crashing_optimizer(learning_rate = 0.01, clip_by_global_norm=1.0,)

	if use_dynamic_scale:
		use_master_copy = False
		dynamic_scale = DynamicScale()
	else:
		dynamic_scale = None
		use_master_copy = False

	state = TrainState.create(apply_fn=model.apply,
							  params=params,
							  tx=tx,
							  dynamic_scale=dynamic_scale,
							  use_master_copy=use_master_copy)

	def train_step(state, batch):

		def loss_func(params):
			loss = state.apply_fn(params, **batch,)
			return loss

		dynamic_scale = state.dynamic_scale
		if dynamic_scale:
			grad_fn = dynamic_scale.value_and_grad(loss_func)
			dynamic_scale, is_fin, val, grads = grad_fn(state.params)
		else:
			grad_fn = value_and_grad(loss_func)
			val, grads = grad_fn(state.params)

		new_state = state.apply_gradients(grads=grads)

		if dynamic_scale:
			new_state = new_state.replace(
				opt_state=jax.tree_map(partial(jnp.where, is_fin),
									   new_state.opt_state, state.opt_state),
				params=jax.tree_map(partial(jnp.where, is_fin),
									new_state.params, state.params),
				master_copy=jax.tree_map(partial(jnp.where,
												 is_fin), new_state.master_copy,
										 state.master_copy),
				dynamic_scale=dynamic_scale)
		return new_state, val

	return state, batch, train_step


class PipelineBasicTest(unittest.TestCase):

	def setUp(self):
		init(cluster="ray")

	def tearDown(self):
		shutdown()

	def run_n_layer_bert(self,
						 num_layers=alpa.get_global_num_devices(),
						 batch_size=16,
						 seq_len=256,
						 hidden_size=512,
						 num_heads=512 // 64,
						 use_remat=False,
						 clip_by_global_norm=False,
						 use_dynamic_scale=False,
						 inject_train_step=None,
						 manual_pipeline_layer=True,
						 stage_option: Optional[StageOption] = None,
						 as_option: Optional[AutoShardingOption] = None,
						 do_numerical_test: bool = True):
		method = PipeshardParallel(
			num_micro_batches=4,
			default_auto_sharding_option=as_option or AutoShardingOption(),
			layer_option=ManualLayerOption(remat_layer=use_remat)
			if manual_pipeline_layer else AutoLayerOption(
				layer_num=num_layers,
				remat_mode="coarse_grained_remat" if use_remat else "none"),
			stage_option=stage_option or UniformStageOption())

		# Init model
		state, batch, train_step = get_bert_layer_train_state_and_step(
			batch_size=batch_size,
			seq_len=seq_len,
			num_layers=num_layers,
			hidden_size=hidden_size,
			num_heads=num_heads,
			clip_by_global_norm=clip_by_global_norm,
			use_dynamic_scale=use_dynamic_scale,
			add_manual_pipeline_marker=manual_pipeline_layer)
		if inject_train_step is not None:
			assert isinstance(inject_train_step, Callable)
			train_step = inject_train_step

		# Compile
		serial_train_step = train_step
		parallel_train_step = parallelize(train_step, method=method)
		executable = parallel_train_step.get_executable(state, batch)

		for _ in range(100):
			state, loss = parallel_train_step(state, batch)
			print(f"  loss: {loss}")

		# Run correctnesss test
		if do_numerical_test:
			expected_new_state = None
			actual_new_state = None
			for i in range(1):
				if i > 0:
					state = expected_new_state
				expected_new_state, expected_val = serial_train_step(
					state, batch)

				if i > 0:
					state = actual_new_state

				actual_new_state, actual_val = parallel_train_step(state, batch)

				assert_allclose(expected_new_state.params,
								actual_new_state.params, 1e-3, 1.5e-3)
				assert_allclose(expected_val, actual_val, 1e-3, 1e-3)

		hlo_text = executable.get_hlo_text()
		return hlo_text

if __name__ == "__main__":
	t = PipelineBasicTest()
	t.setUp()
	x = t.run_n_layer_bert(
		manual_pipeline_layer=False,
		do_numerical_test=False,
	)
	print(f"*" * 60,)
	print(f"  x: {x}")
	print(f"*" * 60,)
	t.tearDown()

Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem. Been hunting this one down for awhile... Please let me know if any more information can help.

samblouir avatar Feb 13 '23 11:02 samblouir

I have also encountered a similar problem. May I inquire if you have resolved it?

AryaLiut avatar Aug 28 '23 09:08 AryaLiut