transformers
transformers copied to clipboard
Conv1D doesn't output token-wise results consistently.
System Info
Hi, I recently observed from huggingface's GPT2 that
(1) the output (logits y1, ..., yN) from using a sequence with N tokens (say x1, ..., xN)
(2) the output (logits z1, ..., zM) from using the earlier part of the above sequence (say x1, ..., xM)
are not perfectly matched (y1!=z1,..., yM!=zM) during inference (so when causal mask is applied). I tried to figure out why this happened and realized that this is related to how Conv1D
's forward
module is implemented: https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L100-L104
Thing is, we internally use addmm
(say b + [x1, ..., xN]*W), which doesn't give you consistent row-wise outputs (say b + [x1, ..., xM]*W) although they should be the same theoretically.
I generated an example and proposed a way to resolve the issue below:
import torch
torch.manual_seed(0)
torch.cuda.manual_seed(0)
input_dim = 786
feature_dim = 2304
x1 = torch.randn((1, 38, input_dim), device='cuda') # (B, N, Fi) where N is the number of tokens in a sequence.
x2 = x1[:, :10] # (B, M, Fi) where M=10 is to gather the early M tokens from the sequence.
b = torch.randn((feature_dim,), device='cuda') # biases
w = torch.randn((input_dim, feature_dim), device='cuda') # weights
def addmm(x, b, w):
x = x.view(-1, x.size(-1))
return torch.addmm(b, x, w)
def addbmm(x, b, w): # (B, N, Fi), (Fi, Fh), (Fh)
batch_size, seq_len = x.size(0), x.size(1) # B, N
x = x.view(batch_size * seq_len, 1, x.size(-1)) # (B * N, 1, Fi)
# (1, Fi, Fh).expand ( (B * N, Fi, Fh) ) --> (B * N, Fi, Fh)
w = w.unsqueeze(0).expand((batch_size * seq_len,) + w.size())
return torch.matmul(x, w).add(b).view(batch_size * seq_len, -1) # (B * N, -1)
print("result (addmm):\n", addmm(x1, b, w)[:10] == addmm(x2, b, w))
print("result (addbmm):\n", addbmm(x1, b, w)[:10] == addbmm(x2, b, w))
The 1st function addmm
is the one from huggingface's Conv1D
, and the 2nd function addbmm
is what I implemented to avoid numerical error. For the printend outputs, we ideally have to get True
values always, but this is not the case of addmm
.
result (addmm):
tensor([[False, False, False, ..., False, True, True],
[ True, True, False, ..., False, False, True],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, True, ..., False, False, False]], device='cuda:0')
result (addbmm):
tensor([[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
...,
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True]], device='cuda:0')
Intuitively, I enforced batched matmul computation by explicitly creating a batch dimension for tensors, which leads to explicit row-wise computations and ends up with ideal results.
Thus, I think forward()
part of Conv1D
(https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L100-L104) should be updated as
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = x.view(x.size()[:-1].numel(), 1, x.size(-1))
weight = self.weight.unsqueeze(0).expand((x.size()[:-1].numel(),) + w.size())
x = torch.matmul(x, weight).add(self.bias)
x = x.view(size_out)
return x
Who can help?
@ArthurZucker @younesbelkada
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [X] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
I provided an example above.
Expected behavior
After fixing the bug, the earlier partial logit outputs shouldn't be affected by the future tokens.
Hey! Wow that's interesting. Two parts of answer:
- Very cool. We can use
torch.testing.assert_allclose
to checkout the max differences, and indeed I have the following outputs:
In [73]: torch.testing.assert_allclose(addmm(x1, b, w)[:10], addbmm(x2, b, w))
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[73], line 1
----> 1 torch.testing.assert_allclose(addmm(x1, b, w)[:10], addbmm(x2, b, w))
File /opt/conda/envs/py39/lib/python3.9/site-packages/torch/testing/_deprecated.py:32, in warn_deprecated.<locals>.outer_wrapper.<locals>.inner_wrapper(*args, **kwargs)
30 @functools.wraps(fn)
31 def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
---> 32 return_value = fn(*args, **kwargs)
33 tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
34 msg = (head + tail).strip()
File /opt/conda/envs/py39/lib/python3.9/site-packages/torch/testing/_deprecated.py:80, in assert_allclose(actual, expected, rtol, atol, equal_nan, msg)
77 if rtol is None and atol is None:
78 rtol, atol = _get_default_rtol_and_atol(actual, expected)
---> 80 torch.testing.assert_close(
81 actual,
82 expected,
83 rtol=rtol,
84 atol=atol,
85 equal_nan=equal_nan,
86 check_device=True,
87 check_dtype=False,
88 check_stride=False,
89 msg=msg or None,
90 )
[... skipping hidden 1 frame]
File /opt/conda/envs/py39/lib/python3.9/site-packages/torch/testing/_comparison.py:1093, in assert_equal(actual, expected, pair_types, sequence_types, mapping_types, msg, **options)
1090 return
1092 # TODO: compose all metas into one AssertionError
-> 1093 raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 9 / 23040 (0.0%)
Greatest absolute difference: 4.00543212890625e-05 at index (2, 952) (up to 1e-05 allowed)
Greatest relative difference: 0.0080592538321523 at index (8, 1875) (up to 0.0001 allowed)
So the outputs match up to 1e-2, which is not that great. Your fix is indeed good in terms of precision as torch.testing.assert_allclose(addbmm(x1, b, w)[:10], addbmm(x2, b, w))
is True.
- My concern is: is this faster or slower in terms of computation? Is
torch.addnm
more optimised (and requires less calls to different views) thus faster. Would the fix break Onnx tracing? And most importantly, is this backward compatible? If it is indeed a fix, meaning that this will bring our logits closer to what they were from the original logits, we might consider this as a potential good change, but the other concerns are still there! The problem is that GPT2 is an old model, it's very hard to change it (especially something as fundamental as the Conv).
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.