alpa
alpa copied to clipboard
F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:227] Check failed: stream_device != nullptr (0 vs. nullptr)
Please describe the bug
When trainning a MLP using alpa parallelize, it crashes at xla
:
2023-03-22 15:40:13.983628: F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:227] Check failed: stream_device != nullptr (0 vs. nullptr)
*** Aborted at 1679470813 (unix time) try "date -d @1679470813" if you are using GNU date ***
PC: @ 0x0 (unknown)
*** SIGABRT (@0x1f400015d55) received by PID 89429 (TID 0x7f2952b0e740) from PID 89429; stack trace: ***
@ 0x7f2950042d4f google::(anonymous namespace)::FailureSignalHandler()
@ 0x7f2952e4b9d0 __pthread_mutex_cond_lock_full
@ 0x7f2952b4cf35 (unknown)
@ 0x7f2952b368d7 (unknown)
@ 0x7f289b18b650 tsl::internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f2895e350bd _ZZN3xlaL27pybind11_init_xla_extensionERN8pybind117module_EENKUlRNS_10PjRtDeviceEE7_clES4_.isra.2397.cold.2543
@ 0x7f2896327783 _ZZN8pybind1112cpp_function10initializeIZN3xlaL27pybind11_init_xla_extensionERNS_7module_EEUlRNS2_10PjRtDeviceEE7_N3tsl6StatusEJS6_EJNS_4nameENS_9is_methodENS_7siblingEEEEvOT_PFT0_DpT1_EDpRKT2_ENUlRNS_6detail13function_callEE1_4_FUNESQ_
@ 0x7f28962f7ec0 pybind11::cpp_function::dispatcher()
@ 0x56365c425052 PyPickleBuffer_Release
@ 0x56365c41030b validate_stmts.cold
@ 0x56365c424dfd faulthandler_fatal_error
@ 0x56365c40befb obj2ast_stmt.cold
@ 0x56365c4062f1 append_ast_expr.cold
@ 0x56365c41793c unicode_isalpha.cold
@ 0x56365c407729 context_tp_dealloc.cold
@ 0x56365c4062f1 append_ast_expr.cold
@ 0x56365c41793c unicode_isalpha.cold
@ 0x56365c40853f PyAST_obj2mod.cold
@ 0x56365c4062f1 append_ast_expr.cold
@ 0x56365c41793c unicode_isalpha.cold
@ 0x56365c40befb obj2ast_stmt.cold
@ 0x56365c4062f1 append_ast_expr.cold
@ 0x56365c4b8e99 dictbytype
@ 0x56365c4b8e5b dictbytype
@ 0x56365c4d97f9 builtin_any
@ 0x56365c4d87f3 slot_tp_setattro
@ 0x56365c387f73 (unknown)
@ 0x56365c387a77 (unknown)
@ 0x56365c37afdd (unknown)
@ 0x56365c4ac679 ast_for_suite
@ 0x7f2952b38193 (unknown)
@ 0x56365c4ac57d ast_for_suite
Please describe the expected behavior
System information and environment
- OS Platform and Distribution: Linux REPL7 docker
- Python version: 3.8.16
- CUDA version: 11.2
- NCCL version: nccl-2.14.3.1
- cupy version: 11.6.0
- GPU model and memory: A10
- Alpa version: master
- TensorFlow version: not installed
- JAX version: jaxlib-0.3.22.cuda112.cudnn810-cp38-cp38
To Reproduce Steps to reproduce the behavior:
- Execute following code in a ray cluster with A10 GPU installed.
- Got crash error
Screenshots
Code snippet to reproduce the problem
import os
import logging
import ray
import alpa
from alpa.testing import assert_allclose
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import optax
logger = logging.getLogger(__name__)
logging.basicConfig(format=ray.ray_constants.LOGGER_FORMAT, level=logging.INFO)
state = None
def main(*args):
logger.info(os.getpid())
alpa.api.init("ray", 2, 1)
logger.info("init alpa")
class MLPModel(nn.Module):
hidden_dim: int
num_layers: int
@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
if i % 2 == 0:
x = nn.Dense(features=self.hidden_dim * 4)(x)
else:
x = nn.Dense(features=self.hidden_dim)(x)
x = nn.relu(x)
return x
dim = 2048
batch_size = 2048
num_layers = 10
# Generate ground truth W and b
rngkey = jax.random.PRNGKey(0)
k1, k2 = random.split(rngkey)
W = random.normal(k1, (dim, dim))
b = random.normal(k2, (dim,))
# Generate the training data
ksample, knoise = random.split(k1)
x = random.normal(ksample, (batch_size, dim))
y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))
# Initialize a train state, which includes the model paramter and optimizer state.
model = MLPModel(hidden_dim=dim, num_layers=num_layers)
params = model.init(rngkey, x)
tx = optax.adam(learning_rate=1e-3)
global state
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# Define the training function and execute one step
def train_step(state, batch):
def loss_func(params):
out = state.apply_fn(params, batch["x"])
loss = jnp.mean((out - batch["y"]) ** 2)
return loss
grads = jax.grad(loss_func)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
batch = {"x": x, "y": y}
expected_state = train_step(state, batch)
@alpa.parallelize
def alpa_train_step(state, batch):
def loss_func(params):
out = state.apply_fn(params, batch["x"])
loss = jnp.mean((out - batch["y"]) ** 2)
return loss
grads = jax.grad(loss_func)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
# Test correctness
actual_state = alpa_train_step(state, batch)
assert_allclose(expected_state.params, actual_state.params, atol=5e-3)
logger.info("Input parameter type: %s", type(state.params["params"]["Dense_0"]["kernel"]))
logger.info("Output parameter type: %s", type(actual_state.params["params"]["Dense_0"]["kernel"]))
state = actual_state # We need this assignment because the original `state` is "donated" and freed.
def sync_func():
jax.local_devices()[0].synchronize_all_activity()
logger.info("preshard_dynamic_args")
# Benchmark parallel execution with alpa
# We distribute arguments in advance for the benchmarking purpose.
state, batch = alpa_train_step.preshard_dynamic_args(state, batch)
logger.info("finished preshard_dynamic_args")
def alpa_execution():
global state
logger.info(f"state, batch {state, batch}")
state = alpa_train_step(state, batch)
from alpa.util import benchmark_func
alpa_costs = benchmark_func(alpa_execution, sync_func, warmup=1, number=1, repeat=1) * 1e3
logger.info(f"Alpa execution time. Mean: {np.mean(alpa_costs):.2f} ms, Std: {np.std(alpa_costs):.2f} ms")
Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
我感觉这个问题主要是因为devicemesh.py中 有这么一行代码:
update_jax_platform("cpu")
到这benchmark的sync函数调用 jax.local_devices()[0]
,的时候不是用的gpu而是用的cpu