s4
s4 copied to clipboard
Correct way to deal with padded inputs?
Hello Everyone, I'm looking to use the standalone S4D replacement layer here as a drop-in replacement to the attention mechanism in a transformer model.
I'm wondering what is the best way to deal with a padded inputs is. Are there any examples of dealing with padded inputs within the codebase that I can take a look at?
We just 0-pad the inputs. Actually, it might be better if you mask out the inputs inside every S4 layer. The following snippet is from the latest model, which you can copy to replace the beginning of forward
function of the S4 layer.
def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
"""
u: (B H L) if self.transposed else (B L H)
state: (H N) never needed unless you know what you're doing
Returns: same shape as u
"""
if not self.transposed: u = u.transpose(-1, -2)
L = u.size(-1)
# Mask out padding tokens
# TODO handle option for mask - instead of lengths, which assumes suffix padding
if isinstance(lengths, int):
if lengths != L:
lengths = torch.tensor(lengths, dtype=torch.long, device=u.device)
else:
lengths = None
if lengths is not None:
assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)]
mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.)
u = u * mask
I'll leave the issue open and point to the reference when the code in this repo is updated.
@albertfgu Hello, I have some questions: I got a worse output when I 0-padded the inputs as you mentioned in NER task. And if I set bidirection to be true as well, whether the code of convolution part needs to be correspondingly revised to get a correction output? Thank you!
I got a worse output when I 0-padded the inputs as you mentioned in NER task.
Worse performance compared to what? What's the alternative to 0-padding?
And if I set bidirection to be true as well, whether the code of convolution part needs to be correspondingly revised to get a correction output?
You shouldn't need to change the code. You may want to set the length masks as specified above. That logic has also been pushed to the v3 branch which will get merged into master soon: https://github.com/HazyResearch/state-spaces/blob/v3/src/models/s4/s4.py#L1477
def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask """ u: (B H L) if self.transposed else (B L H) state: (H N) never needed unless you know what you're doing
Returns: same shape as u
"""
if not self.transposed: u = u.transpose(-1, -2)
L = u.size(-1)
# Mask out padding tokens
# TODO handle option for mask - instead of lengths, which assumes suffix padding
if isinstance(lengths, int):
if lengths != L:
lengths = torch.tensor(lengths, dtype=torch.long, device=u.device)
else:
lengths = None
if lengths is not None:
assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)]
mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.)
u = u * mask
I just masked the inputs as you gave in the code, and the output was worse than that of no maksing. Theoretically, for the sequence labelling task, the final outputs of both masked and unmasked inputs should be identical if s4 is unidirectional. This is strange.
Right, they should be identical for a unidirectional model. Are you sure the lengths tensor is passed in correctly?
I am sure the lengths tensor is correct. I am wondering whether you have any plan of running some NLU task such as NER by using S4? I have tried to pretrain a model based on bidirectional s4 like bert, and the accuracy of intent classification task on it is as high as bert. The mentioned-above experiments were implemented without masking though they actually needed. Now I face with two problems: 1) as I mentioned, the output would be worse if I mask the inputs as their actual lengths; and 2) the inferring time on short texts, such as 32 of sequence length, is nearly 2 times longer than bert. Do you have any idea about these two problems?
I have no plans at the moment, but I know there are other groups doing BERT-style pretraining with S4.
- I'm unable to help here without knowing more details. It sounds like there's just a bug somewhere. Personally I might start by setting up a pipeline where the mask definitely should not affect the output, and if it is, then trace the output through a 1 layer model and see where it's going wrong.
- The advantages of S4 will naturally be diminished on shorter sequences. There's a lot of room on the table for more efficient hardware-aware implementations which will hopefully be implemented as interest grows