xformers icon indicating copy to clipboard operation
xformers copied to clipboard

Significant performance drop in training

Open Fan-Yixuan opened this issue 7 months ago • 4 comments

❓ Questions and Help

I'm new to xformers. I need to use Transformer Encoders to train on a dataset with a very large variation in sample lengths. My original code was:

tokens = [tokens[mapper == cnt, :] for cnt in range(bs)]
tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=1e3)
masks = (tokens[:, :, 0] == 1e3)

for _, layer in enumerate(self.inter_layers):
    tokens = layer(tokens, src_key_padding_mask=masks)

where mapper records which sample is each token comes from, layer is pytorch transformer encoder layer I changed it into xformers:

tokens = [tokens[mapper == cnt, :].unsqueeze(0) for cnt in range(bs)]
sample_len = [token.shape[1] for token in tokens]
attn_bias, tokens = fmha.BlockDiagonalMask.from_tensor_list(tokens)

for _, layer in enumerate(self.inter_layers):
    tokens = layer(tokens, attn_bias=attn_bias)

tokens = torch.split(tokens.squeeze(0), sample_len, dim=0)
tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=1e3)

here we use this for attention layer:

def forward(self, x: Tensor, attn_bias: Optional[Tensor] = None) -> Tensor:
    _, length, _ = x.shape
    x = torch._C._nn.linear(x, self.in_proj_weight, self.in_proj_bias)
    x = x.unflatten(-1, (3, self.embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2)
    q, k, v = x[0], x[1], x[2]
    if not self.training:
        dropout_p = 0.0
    else:
        dropout_p = self.dropout
    q = q.reshape(1, length, self.num_heads, self.head_dim)
    k = k.reshape(1, length, self.num_heads, self.head_dim)
    v = v.reshape(1, length, self.num_heads, self.head_dim)

    x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p)
    x = x.reshape(1, length, self.embed_dim)
    x = self.out_proj(x)
    return x

Once training starts (x coordinate of the figure below is training steps), the original training loss curve and the xformers-version training loss curve are shown below. (yellow for original, green for xformers). Nothing else is changed. image

My env: python 3.9, PyTorch 2.1.1, cuda 11.8, latest xformers. Similar phenomenons are observed in other envs and different gpu cards, whether single card training or ddp.

Please help me with this! Thanks a lot in advance! padding is really wasting my training time!!

Fan-Yixuan avatar Dec 02 '23 04:12 Fan-Yixuan

Hi, Which version of xformers are you using / which GPU? Can you report the output of python -m xformers.info. We had a bug in the earlier versions of xFormers, where enabling dropout could cause bad numerics (so I would try with dropout disabled just in case).

Also worth noting, you can replace the following:

    x = x.unflatten(-1, (3, self.embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2)
    q, k, v = x[0], x[1], x[2]
    q = q.reshape(1, length, self.num_heads, self.head_dim)
    k = k.reshape(1, length, self.num_heads, self.head_dim)
    v = v.reshape(1, length, self.num_heads, self.head_dim)

With something like that, which is going to be a bit more efficient for the BW pass:

    x = x.reshape(1, length, 3, self.num_heads, self.head_dim)
    q, k, v = xops.unbind(x, 2)

danthe3rd avatar Dec 02 '23 09:12 danthe3rd

Hi @danthe3rd, Thanks for the comment, my env:

xFormers 0.0.23.dev703
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.decoderF:               available
[email protected]: available
[email protected]: available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        unavailable
memory_efficient_attention.tritonflashattB:        unavailable
memory_efficient_attention.triton_splitKF:         available
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
pytorch.version:                                   2.1.1
pytorch.cuda:                                      available
gpu.compute_capability:                            7.5
gpu.name:                                          NVIDIA GeForce RTX 2080 Ti
build.info:                                        available
build.cuda_version:                                1108
build.python_version:                              3.9.18
build.torch_version:                               2.1.1
build.env.TORCH_CUDA_ARCH_LIST:                    5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX 9.0
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   conda-main
build.nvcc_version:                                11.8.89
source.privacy:                                    open source

and disable dropout really works! Thanks a lot and How can I fix this and can use dropout again.

Fan-Yixuan avatar Dec 02 '23 09:12 Fan-Yixuan

Hi @danthe3rd, disabling dropout really worked, but how can I use dropout again? Thanks a lot!

Fan-Yixuan avatar Dec 05 '23 07:12 Fan-Yixuan

Hey! I am seeing a very similar problem where the loss starts going up with xformers whenever dropout > 0. Everything is good when dropout == 0.0. Additionally, things are also good when forcing MemoryEfficientAttentionFlashAttentionOp dispatch, even when dropout > 0.

So I guess this is a bug in the Cutlass kernel?

vinamarora8 avatar Mar 05 '24 02:03 vinamarora8