TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[BUG] Assertion failed: t.data.dptr != nullptr. Input x is not allocated!

Open alexdremov opened this issue 1 year ago • 3 comments

While running RMS norm, I got the following exception:

/workspace/TransformerEngine/transformer_engine/common/transformer_engine.cpp:39 in function CheckInputTensor: Assertion failed: t.data.dptr != nullptr. Input x is not allocated!
  File "/usr/local/lib/python3.11/dist-packages/transformer_engine/pytorch/module/rmsnorm.py", line 50, in forward
    rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

However, clearly, all input tensors are allocated — I verified this with debugger

alexdremov avatar Jun 17 '24 13:06 alexdremov

Can you provide a minimal reproducer? The following runs for me:

import torch
import transformer_engine.pytorch as te

# Options
batch_size = 128
hidden_size = 128
dtype = torch.float32
device = torch.device("cuda")

# TE module
layer = te.RMSNorm(hidden_size, params_dtype=dtype, device=device)

# Synthetic data
x = torch.randn([batch_size, hidden_size], dtype=dtype, device=device, requires_grad=True)

# Forward and backward pass
y = layer(x)
y.sum().backward()

timmoon10 avatar Jun 17 '24 20:06 timmoon10

Can you provide a minimal reproducer? The following runs for me:

import torch
import transformer_engine.pytorch as te

# Options
batch_size = 128
hidden_size = 128
dtype = torch.float32
device = torch.device("cuda")

# TE module
layer = te.RMSNorm(hidden_size, params_dtype=dtype, device=device)

# Synthetic data
x = torch.randn([batch_size, hidden_size], dtype=dtype, device=device, requires_grad=True)

# Forward and backward pass
y = layer(x)
y.sum().backward()

Hey! This appeared when I tried to use Fp8Tensor. I'll try to write a minimal example, but this could be rather hard

alexdremov avatar Jul 02 '24 08:07 alexdremov

Most of our kernels don't handle Float8Tensor directly. Also, our RMSNorm kernel doesn't support FP8 input at the moment, just FP8 output. As a quick fix, you could manually cast your Float8Tensor to higher precision with:

x = Float8Tensor(...)
y = layer(x.from_float8())

timmoon10 avatar Jul 03 '24 23:07 timmoon10