TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to xxx for platform CUDA

Open MoFHeka opened this issue 1 year ago • 4 comments

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to te_scaled_upper_triang_masked_softmax_forward for platform CUDA

from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer

When I use TE flax layer, all of them report no implementation bug.

Image: ghcr.io/nvidia/jax:maxtext-2024-07-17

MoFHeka avatar Jul 18 '24 02:07 MoFHeka

@denera Could you take a look?

ptrendx avatar Jul 19 '24 18:07 ptrendx

Hi @MoFHeka -- the JAX/XLA custom op for te_scaled_upper_triang_masked_softmax_forward is implemented here, exposed via PyBind11 here and registered with XLA for the CUDA platform here.

TE/Flax modules invoke this custom op via the scaled_upper_triang_softmax_fwd() API. Is that what you're trying to use in your application?

If this is not working for you, could you provide us a minimal reproducer along with some information about your platform like GPU type, CUDA driver version and CUDA Toolkit version?

denera avatar Jul 24 '24 00:07 denera

I tried to run python -c 'from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer' in ghcr.io/nvidia/jax:maxtext-2024-07-17 on H100 with drive 550.54.14 & cuda 12.4 and I can't reproduce the error.

zlsh80826 avatar Jul 24 '24 02:07 zlsh80826

Here is the simple demo code, nothing special. Could be the problem that Kube host machine cuda driver version(535) is too old?

from dataclasses import dataclass
from functools import partial
import functools
import os
import sys
import time

import flax.linen as fnn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import optax

import flax.linen as nn
import flax.struct
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import numpy as np
from flax.linen import partitioning as nn_partitioning
from flax.linen.linear import DotGeneralT, PrecisionLike

from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"

os.environ["XLA_FLAGS"] = """
    --xla_gpu_enable_triton_gemm=false
    --xla_gpu_graph_level=2
    --xla_gpu_enable_custom_fusions=true
    --xla_gpu_enable_address_computation_fusion=true
"""

@dataclass
class ModelConfig:
    """Configuration for the language models."""

    seq_len: int
    n_layers: int
    d_model: int
    num_heads: int
    ff_dim: int
    dropout: float

    batch_size: int
    learning_rate: float
    max_num_batch: int


class RandomDS:
    def __init__(self, batch_size: int, seq_len: int, use_jax=False):
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.use_jax = use_jax
        if self.use_jax:
            self.rng = jax.random.PRNGKey(1)

    def __iter__(self):
        if self.use_jax:
            batches = [
                jax.random.bits(self.rng, shape=(self.batch_size, 100), dtype=jnp.uint8)
                for start in range(0, self.batch_size)
            ]
        else:
            batches = [
                np.random.randint(low=0, high=16, size=(self.batch_size, 100), dtype=np.uint8)
                for start in range(0, self.batch_size)
            ]
        return iter(batches)


class TransformerLayer(fnn.Module):
    d_model: int
    num_heads: int
    ff_dim: int
    dropout: float

    def setup(self):
        self.mha = MultiHeadAttention(
            head_dim=self.d_model // self.num_heads,
            num_attention_heads=self.num_heads,
            input_layernorm=False,
            dtype=jnp.bfloat16,
        )
        self.layer_norm_1 = fnn.LayerNorm(epsilon=1e-5, dtype=jnp.bfloat16,)
        self.linear_1 = fnn.Dense(
            features=self.ff_dim,
            kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
            dtype=jnp.bfloat16,
            param_dtype=jnp.bfloat16
        )
        self.linear_2 = fnn.Dense(
            features=self.d_model,
            kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
            dtype=jnp.bfloat16,
            param_dtype=jnp.bfloat16
        )
        self.layer_norm_2 = fnn.LayerNorm(epsilon=1e-5, dtype=jnp.bfloat16)
        self.dropout_layer = fnn.Dropout(self.dropout, deterministic=False)

    def __call__(
        self, x: jnp.array, mask: jnp.array
    ) -> jnp.array:
        # "correct" type annotations for jax DeviceArrays are numpy ndarrays
        x = self.layer_norm_1(x)
        x = self.mha(inputs_q=x, inputs_kv=x, mask=mask)[0]
        x = x + self.dropout_layer(x)
        x = x + self.dropout_layer(self._ff_block(self.layer_norm_2(x)))
        return x

    def _ff_block(self, x):
        x = jnn.relu(self.linear_1(x))
        x = self.dropout_layer(x)
        x = self.linear_2(x)
        return x


class LM(fnn.Module):
    cfg: ModelConfig

    def setup(self):
        self.byte_embedding = fnn.Embed(
            num_embeddings=256,
            features=self.cfg.d_model,
            embedding_init=jnn.initializers.normal(),
            param_dtype=jnp.bfloat16
        )
        self.positional_encoding = self.param(
            "positional_encoding",
            jnn.initializers.normal(),
            (self.cfg.seq_len, self.cfg.d_model),
            dtype=jnp.bfloat16,
        )
        self.dropout_layer = fnn.Dropout(self.cfg.dropout, deterministic=False)

        self.transformer_layers = [
            TransformerLayer(
                self.cfg.d_model, self.cfg.num_heads, self.cfg.ff_dim, self.cfg.dropout
            )
            for _ in range(self.cfg.n_layers)
        ]
        self.prob_decoder = fnn.Dense(
            features=256,
            kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
            dtype=jnp.bfloat16,
            param_dtype=jnp.bfloat16
        )

    def __call__(self, text):
        x = self.byte_embedding(text)
        # Shift x right so causality isn't violated
        x = jnp.concatenate(
                [jnp.zeros([text.shape[0], 1, self.cfg.d_model], dtype=x.dtype), x[:, :-1, :]], axis=1
            )
        x = x + self.positional_encoding
        x = self.dropout_layer(x)

        mask = fnn.attention.make_causal_mask(text)
        for layer in self.transformer_layers:
            x = layer(x, mask=mask)

        return self.prob_decoder(x)

rng = jax.random.PRNGKey(1)
def compute_loss(params, model: LM, text):
    model_out = model.apply(params, text=text, rngs={"dropout": rng})
    one_hots = jnn.one_hot(text, 256)
    loss = optax.softmax_cross_entropy(model_out, one_hots)
    return loss


def setup_model(rng, cfg: ModelConfig):
    model = LM(cfg)

    rng_p, rng_d = jax.random.split(rng)
    params = model.init(
        {"params": rng_p, "dropout": rng_d}, jnp.zeros([cfg.batch_size, cfg.seq_len], dtype=jnp.uint8)
    )
    return params, model


def setup_optimizer(params, cfg: ModelConfig):
    optimizer = optax.adam(cfg.learning_rate)
    opt_state = optimizer.init(params)
    return optimizer, opt_state


def train_loop(
    model: LM, optimizer, opt_state, params, cfg: ModelConfig, datapath: str
):

    def run_train_step(params, opt_state, text_batch):
        loss, grad = jax.value_and_grad(lambda p: compute_loss(p, model, text=text_batch).mean())(params)
        updates, opt_state = optimizer.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    fast_train_step = jax.jit(run_train_step, donate_argnums=[0, 1])

    losses = []
    t = time.time()
    log_per = 20

    def multi_train_steps(state, data):
        for single_step_data in data:
            params, opt_state, loss = fast_train_step(params, opt_state, batch)
        return params, opt_state, loss
    
    dataset = list(RandomDS(cfg.batch_size, cfg.seq_len, use_jax=True))


    for idx, batch in enumerate(dataset):
        params, opt_state, loss = fast_train_step(params, opt_state, batch)
        if (idx + 1) % log_per == 0:
            break

    iter_num = 0
    t = time.time()
    for batch in dataset:
        params, opt_state, loss = fast_train_step(params, opt_state, batch)
        losses.append(loss)
        iter_num += 1
    time_elps = time.time() - t
    speed = iter_num * cfg.batch_size / time_elps
    print(
        f"At iter {iter_num}, loss: {np.mean(losses):.4f}, Speed: {int(speed):d}"
    )
    t = time.time()
    losses = []

    return params, opt_state


def setup_all(cfg: ModelConfig, rng=None):
    rng = jax.random.PRNGKey(1)
    params, model = setup_model(rng, cfg)
    optimizer, opt_state = setup_optimizer(params, cfg)

    return params, model, optimizer, opt_state


if __name__ == "__main__":
    cfg = ModelConfig(
        seq_len=100,
        n_layers=1,
        d_model=512,
        num_heads=2,
        ff_dim=1024,
        dropout=0.1,
        batch_size=128,
        learning_rate=1e-3,
        max_num_batch=5000,
    )

    params, model, optimizer, opt_state = setup_all(cfg)
    params, model, optimizer, opt_state = amp_policy.cast_to_compute((params, model, optimizer, opt_state))
    params, opt_state = train_loop(model, optimizer, opt_state, params, cfg)

MoFHeka avatar Jul 24 '24 07:07 MoFHeka