alpa icon indicating copy to clipboard operation
alpa copied to clipboard

F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:227] Check failed: stream_device != nullptr (0 vs. nullptr)

Open chaokunyang opened this issue 1 year ago • 1 comments

Please describe the bug When trainning a MLP using alpa parallelize, it crashes at xla:

2023-03-22 15:40:13.983628: F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:227] Check failed: stream_device != nullptr (0 vs. nullptr)
*** Aborted at 1679470813 (unix time) try "date -d @1679470813" if you are using GNU date ***
PC: @                0x0 (unknown)
*** SIGABRT (@0x1f400015d55) received by PID 89429 (TID 0x7f2952b0e740) from PID 89429; stack trace: ***
    @     0x7f2950042d4f google::(anonymous namespace)::FailureSignalHandler()
    @     0x7f2952e4b9d0 __pthread_mutex_cond_lock_full
    @     0x7f2952b4cf35 (unknown)
    @     0x7f2952b368d7 (unknown)
    @     0x7f289b18b650 tsl::internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f2895e350bd _ZZN3xlaL27pybind11_init_xla_extensionERN8pybind117module_EENKUlRNS_10PjRtDeviceEE7_clES4_.isra.2397.cold.2543
    @     0x7f2896327783 _ZZN8pybind1112cpp_function10initializeIZN3xlaL27pybind11_init_xla_extensionERNS_7module_EEUlRNS2_10PjRtDeviceEE7_N3tsl6StatusEJS6_EJNS_4nameENS_9is_methodENS_7siblingEEEEvOT_PFT0_DpT1_EDpRKT2_ENUlRNS_6detail13function_callEE1_4_FUNESQ_
    @     0x7f28962f7ec0 pybind11::cpp_function::dispatcher()
    @     0x56365c425052 PyPickleBuffer_Release
    @     0x56365c41030b validate_stmts.cold
    @     0x56365c424dfd faulthandler_fatal_error
    @     0x56365c40befb obj2ast_stmt.cold
    @     0x56365c4062f1 append_ast_expr.cold
    @     0x56365c41793c unicode_isalpha.cold
    @     0x56365c407729 context_tp_dealloc.cold
    @     0x56365c4062f1 append_ast_expr.cold
    @     0x56365c41793c unicode_isalpha.cold
    @     0x56365c40853f PyAST_obj2mod.cold
    @     0x56365c4062f1 append_ast_expr.cold
    @     0x56365c41793c unicode_isalpha.cold
    @     0x56365c40befb obj2ast_stmt.cold
    @     0x56365c4062f1 append_ast_expr.cold
    @     0x56365c4b8e99 dictbytype
    @     0x56365c4b8e5b dictbytype
    @     0x56365c4d97f9 builtin_any
    @     0x56365c4d87f3 slot_tp_setattro
    @     0x56365c387f73 (unknown)
    @     0x56365c387a77 (unknown)
    @     0x56365c37afdd (unknown)
    @     0x56365c4ac679 ast_for_suite
    @     0x7f2952b38193 (unknown)
    @     0x56365c4ac57d ast_for_suite

Please describe the expected behavior

System information and environment

  • OS Platform and Distribution: Linux REPL7 docker
  • Python version: 3.8.16
  • CUDA version: 11.2
  • NCCL version: nccl-2.14.3.1
  • cupy version: 11.6.0
  • GPU model and memory: A10
  • Alpa version: master
  • TensorFlow version: not installed
  • JAX version: jaxlib-0.3.22.cuda112.cudnn810-cp38-cp38

To Reproduce Steps to reproduce the behavior:

  1. Execute following code in a ray cluster with A10 GPU installed.
  2. Got crash error

Screenshots image

Code snippet to reproduce the problem

import os
import logging
import ray

import alpa
from alpa.testing import assert_allclose
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import optax

logger = logging.getLogger(__name__)
logging.basicConfig(format=ray.ray_constants.LOGGER_FORMAT, level=logging.INFO)


state = None


def main(*args):
    logger.info(os.getpid())
    alpa.api.init("ray", 2, 1)
    logger.info("init alpa")

    class MLPModel(nn.Module):
        hidden_dim: int
        num_layers: int

        @nn.compact
        def __call__(self, x):
            for i in range(self.num_layers):
                if i % 2 == 0:
                    x = nn.Dense(features=self.hidden_dim * 4)(x)
                else:
                    x = nn.Dense(features=self.hidden_dim)(x)
                x = nn.relu(x)
            return x

    dim = 2048
    batch_size = 2048
    num_layers = 10

    # Generate ground truth W and b
    rngkey = jax.random.PRNGKey(0)
    k1, k2 = random.split(rngkey)
    W = random.normal(k1, (dim, dim))
    b = random.normal(k2, (dim,))

    # Generate the training data
    ksample, knoise = random.split(k1)
    x = random.normal(ksample, (batch_size, dim))
    y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))

    # Initialize a train state, which includes the model paramter and optimizer state.
    model = MLPModel(hidden_dim=dim, num_layers=num_layers)
    params = model.init(rngkey, x)
    tx = optax.adam(learning_rate=1e-3)
    global state
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

    # Define the training function and execute one step
    def train_step(state, batch):
        def loss_func(params):
            out = state.apply_fn(params, batch["x"])
            loss = jnp.mean((out - batch["y"]) ** 2)
            return loss

        grads = jax.grad(loss_func)(state.params)
        new_state = state.apply_gradients(grads=grads)
        return new_state

    batch = {"x": x, "y": y}
    expected_state = train_step(state, batch)

    @alpa.parallelize
    def alpa_train_step(state, batch):
        def loss_func(params):
            out = state.apply_fn(params, batch["x"])
            loss = jnp.mean((out - batch["y"]) ** 2)
            return loss

        grads = jax.grad(loss_func)(state.params)
        new_state = state.apply_gradients(grads=grads)
        return new_state

    # Test correctness
    actual_state = alpa_train_step(state, batch)
    assert_allclose(expected_state.params, actual_state.params, atol=5e-3)

    logger.info("Input parameter type: %s", type(state.params["params"]["Dense_0"]["kernel"]))
    logger.info("Output parameter type: %s", type(actual_state.params["params"]["Dense_0"]["kernel"]))

    state = actual_state  # We need this assignment because the original `state` is "donated" and freed.

    def sync_func():
        jax.local_devices()[0].synchronize_all_activity()

    logger.info("preshard_dynamic_args")
    # Benchmark parallel execution with alpa
    # We distribute arguments in advance for the benchmarking purpose.
    state, batch = alpa_train_step.preshard_dynamic_args(state, batch)
    logger.info("finished preshard_dynamic_args")

    def alpa_execution():
        global state
        logger.info(f"state, batch {state, batch}")
        state = alpa_train_step(state, batch)

    from alpa.util import benchmark_func
    alpa_costs = benchmark_func(alpa_execution, sync_func, warmup=1, number=1, repeat=1) * 1e3
    logger.info(f"Alpa execution time.   Mean: {np.mean(alpa_costs):.2f} ms, Std: {np.std(alpa_costs):.2f} ms")

Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.

chaokunyang avatar Mar 22 '23 07:03 chaokunyang

我感觉这个问题主要是因为devicemesh.py中 有这么一行代码:

update_jax_platform("cpu")

到这benchmark的sync函数调用 jax.local_devices()[0],的时候不是用的gpu而是用的cpu

Lssyes avatar Jan 13 '24 10:01 Lssyes