allennlp
allennlp copied to clipboard
HotFlip invalid_replacement_indices and different shapes bug
Describe the bug
The first "bug" in the Hotflip
class
def __init__(self, ...):
...
self.invalid_replacement_indices: List[int] = []
for i in self.vocab._index_to_token[self.namespace]:
if not self.vocab._index_to_token[self.namespace][i].isalnum():
self.invalid_replacement_indices.append(i)
isalnum()
is probably not the best way to find invalid_replacement_indices
, because in many cases tokens contain _
symbol (for example in BPE encoding)
The second bug is related to using token indexers with start_tokens
and end_tokens
. In my case, I use the following dataset reader:
"dataset_reader": {
"type": "text_classification_json",
"token_indexers":
"tokens": {
"type": "single_id",
"start_tokens": ["<START>"],
"end_tokens": ["<END>"],
"token_min_padding_length": 5
},
"tokenizer": {
"type": "just_spaces"
}
}
In this case, the grad
and text_field.tokens
will be of different shapes and this will cause text_field.tokens[index_of_token_to_flip] = new_token
to fail when index_of_token_to_flip >= len(text_field.tokens)
.
def attack_from_json(self, ...):
...
for instance in original_instances:
text_field = instance[input_field_to_attack]
...
grad = grads[grad_input_field][0]
...
while True:
text_field.tokens[index_of_token_to_flip] = new_token
And the third bug is related to using token_min_padding_length > 0
parameter. It will also cause a shape mismatch as shown above.
To Reproduce
To reproduce the error train any classifier with the dataset_reader
I provided above and try to use HotFlip.attack_from_json()
method.
You've brought up three issues:
- For this one, you're right that the current heuristic is not ideal. Really, though, there are a few other constraints to get this exactly right - you don't want a continuation wordpiece as your first token, for instance. A PR to improve the handling of these cases would be welcome.
2 and 3 both break the TextFieldEmbedder API. If you have an indexer that adds tokens, your embedder is supposed to remove them. The Hotflip API is written against the TextField APIs, so things that break the TextField APIs are going to break the Hotflip API also, and I'm not sure there's much that we can do about it, except recommend that you copy the code and customize it for your specific application. If you don't have to worry about covering arbitrary TextField inputs, you can be a lot more flexible in your logic.
If I really wanted to fix this, I'd probably add some API method in TokenIndexer
that tells which tokens were added. Then Hotflip can remove those tokens' gradients before doing its thing. I don't see another way around this issue. I'm also not sure whether it's worth it to add that API method.