transformers
transformers copied to clipboard
Position embedding in the DETR model
System Info
According to the argument definition of the DetrDecoderLayer.forward()
specified here:
https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/detr/modeling_detr.py#L723-L728
The positional_embeddings
argument for the cross-attention should be assigned by the position_embeddings
variable instead of query_position_embeddings
.
https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/detr/modeling_detr.py#L757-L764
Is this an error in the argument definition or the code part?
Thank you!
Who can help?
@NielsRogge
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
It is from the transformers code.
Arguments definition: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/detr/modeling_detr.py#L723-L728
Cross-attention code: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/detr/modeling_detr.py#L757-L764
Expected behavior
Either:
- The
positional_embeddings
argument for the cross-attention should be assigned by theposition_embeddings
variable instead ofquery_position_embeddings
, or - Update the documentation of the argument to the correct one.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hey @NielsRogge , could you explain how to solve this issue? You just put the Goo first issue label on it but it's not clear what a contributor would have to do to fix it.
Hi @NielsRogge, I would like to take on this. As Sylvain suggested, could you offer some context on how to go with this? Thanks :)
Yeah I marked this as good first issue as someone could take a deeper dive into DETR's position embeddings.
Reading the paper for that could definitely be helpful. But the implementation is correct, it's probably internal variables/docstrings that need to be updated. From the paper:
Since the decoder is also permutation-invariant, the N input embeddings must be different to produce different results. These input embeddings are learnt positional encodings that we refer to as object queries, and similarly to the encoder, we add them to the input of each attention layer.
So the position_embeddings
argument of the cross-attention layer are exactly these input embeddings, often also called "content embeddings" or "object queries".
Then a bit later on in the paper they state:
There are two kinds of positional encodings in our model: spatial positional encodings and output positional encodings (object queries).
So the key_value_position_embeddings
arguments of the cross-attention layer refer to these spatial position encodings. These are added to the keys and values in the cross-attention operation.
So we could for clarity update the "position_embeddings" argument to "object_queries", and the "key_value_position_embeddings" argument to "spatial_position_embeddings"
Hello @daspartho @NielsRogge , wanted to inquire as to whether any progress was made on this? I'd like to take a look.
Hello @NielsRogge , I am currently working on this issue. I've read the article and I do understand what has to be changed. My question is if we only have to change the DetrDecoderLayer
class (in the respective forward
function mentioned above or al position_embeddings args have to change too.
I did some local tests too, and noted that changing only in the function forward i mentioned to object_queries
and spatial_position_embeddings
, many tests broke because of wrong arguments passed since names changed. In order to change these arguments, we need to change them in tests?
I looked up some tests, but I do think the problem is in the code itself, since classes related to that one would be passing arguments wrongly.
This is my first contribution to an open source project this size, and I'm really happy to do it. Thanks in advance.
Hey @NielsRogge is this issue still open? If yes can I take this?
Hey @hackpk I'm finishing touches in my PR to fix this Issue, so Idk...
That's great.I'll look for another issue then. Thanks.
No problem, good luck :D
@NielsRogge @amyeroberts I think this can be closed due to #24652