[WIP] Add Atlas - Retrieval Augmented Language Model
What does this PR do?
Implements Atlas: Few-shot Learning with Retrieval Augmented Language Model as mentioned here https://github.com/huggingface/transformers/issues/20503
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@patrickvonplaten, @lhoestq, @patil-suraj cc @patrick-s-h-lewis and @gizacard
This branch is very much a WIP currently, but for anyone interested here is roughly how I plan to structure things, aiming to roughly mesh the shape of the original implementation with Transformer's existing patterns. For the most part, I hope to make its usage as similar as possible to T5ForConditionalGeneration.
This is all new to me, so any feedback would be super helpful!
class AtlasConfig():
pass
class AtlasTrainer(Trainer):
pass
class AtlasPreTrainedModel(PreTrainedModel):
pass
class AtlasModel(AtlasPreTrainedModel):
def __init__(self, queryPassageEncoder, reader, retriever):
self.queryPassageEncoder = queryPassageEncoder # UntiedDualEncoder
self.reader = reader # FiD
self.retriever = retriever # HFIndexBase
class FiD(T5ForConditionalGeneration):
def __init__(self):
self.encoder = FiDStack()
self.decoder = FiDStack()
class FiDStack(T5Stack):
pass
class UntiedDualEncoder(torch.nn.Module):
def __init__(self, query_contriever, passage_contriever):
self.query_contriever = query_contriever
self.passage_contriever = passage_contriever
class Contriever(BertModel):
pass
class HFIndexBase():
pass
class AtlasRetriever:
def __init__(self, index):
self.index = index # HFIndexBase
The existing RAG implementation makes its sub-models easily swappable, however, the inputs and outputs expected by 'reader' model (the name given to the T5 encoder/decoder in the original implementation) here are non-standard due to the fusion-in-decoder mechanism, so I don't plan to make these models as easily swappable as I think that would complicate things unnecessarily.
As I'm not doing this, it seems it may be best practice to copy implementation (w/ "Copied from" comments) of models like the BertModel and T5ForConditionalGeneration rather than import - if that's the case I'll switch these across once the PR's almost ready.
There is some complexity here in how we make the model trainable E2E within Huggingface's patterns, which I haven't yet looked into deeply. I wonder whether a class AtlasTrainer(Trainer) would make sense, which can implement the various continuous re-indexing strategies described in the original paper.
Yes, please do use the approach of copying model code and adding # Copied from comments as it's more inline with the general approach in the library (RAG being an exception :-) )
cc @ArthurZucker
@ArthurZucker @ae99 let me know if you need help with anything - think this is a super cool addition!
@ArthurZucker @ae99 let me know if you need help with anything - think this is a super cool addition!
Hey @patrickvonplaten and @ArthurZucker! I think the general structure of this model is mostly in place. I'd love to get an early review on the PR from you just to check if things are looking ok and confirm the major things are roughly fitting patterns correctly.
I have a few temporary notebooks save_pretrained.ipynb, test_retriever.ipynb and test_model.ipynb in place of actual tests at the moment if you would like to get a sense of usage. Like RAG I have a dedicated retriever, but I've cut this down to mostly be a small wrapper around a dataset+index for now. Documentation and tests haven't been touched at all yet, and everything is very WIP still!
Hi @ae99, I would also like to contribute. Let me know if there is something I can help you with.
@ArthurZucker could you maybe take a look here? :-) Let me know if you need some help
@akashe feel free to give this PR a review as well if you'd like to help a bit :-)
Will review now 😉
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hey @ae99 , are you still working on the integration? If not then let me know, I would be happy to continue from where you left.
Hey @ae99 , are you still working on the integration? If not then let me know, I would be happy to continue from where you left.
Hey @akashe, that'd be perfect.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi! This one would be really relevant for something that we are working on in my org. What's the status on it? We may be able to chip in.
Hey! We have not really picked it up, if the community needs it we can probably come back to it, but I would advise to just put the model on the hub following this tutorial! 🤗