memorizing-transformers-pytorch
memorizing-transformers-pytorch copied to clipboard
Any interesting results?
Hey! Cool repo. I like all the knn+lm methods Did you do some runs yet? Anything interesting to report?
:wave: hello Romain! no not yet, i still need to build out the modular forgetting system
didn't you start your new job?
Ok great, I'll follow up on the progress :)
Indeed I started the new job, pretty interesting!
I hope you don't mind me stalking this project, but I tried this out on enwik8 (https://github.com/igor0/memorizing-transformers-pytorch/commit/d302feee0c3d9655a92c392850c4ec5d86bff77c). I basically just ported the enwik8 training loop from another one of @lucidrains's projects.
The initial finding is that with KNN memories, the training loop is pretty slow, so often (but not always) I'll sample 0% GPU usage. Disabling the KNN memories makes the training loop go much faster (>10x compared to with KNN). So the KNN code may need some optimization, but I don't understand it well enough yet to suggest something constructive.
Edit: Ah, that was with knn_use_gpu=False, I missed that. With knn_use_gpu=True, I seem to get a hang. On the GPU, it's faiss.index_cpu_to_all_gpus(self.index) that's hanging for me, endlessly chewing up CPU cycles. Just FYI.
@igor0 ah thanks for trying it! i'll hook up the enwik8 training myself next week and start profiling and see what's going on :) still want to polish up the pluggable memory expiring strategy (account for memory creation time as well as last retrieved time)
I ended up having two issues with KNN on the GPU. Here are the findings so far.
1. the wheel package faiss-gpu hangs for me on A100 and A10 With the faiss-gpu package installed by pip, I always get a hang in index_cpu_to_all_gpus(). I opened an issue here: https://github.com/kyamagu/faiss-wheels/issues/54. I would guess that the faiss-gpu wheel isn't compatible with CUDA 11, so the A100/A10 GPUs don't work.
Using conda rather than pip to install faiss-gpu seems to work for me.
2. " remove_ids not implemented for this type of index" As far as I can tell, remove_ids is not supported for any GPU indexes. One possible solution may be to simulate a sliding window with two GPU indexes, so that we always completely clear an index, instead of removing entries one-by-one. The fancier expiry will get more complicated and will require some type of manual compaction, at least if you want to run faiss on the GPU.
there is little benefit to using faiss gpu
however if knn operations are slow here, it's likely because a flat index is used
What type of index to use then? One problem is that we don't know the distribution of the keys upfront, and the clustering approaches require that. Furthermore, the distribution of keys changes over time. So, you could keep recomputing the clusters periodically. I'm sure that's doable, but another thing to sort out and tune.
IMO a flat index is a reasonable place to start. And a flat index on a GPU would perform much better than a flat index on a CPU.
If faiss doesn't let us implement a flat index with the ability to replace entries, then we could implement our own sliding window mechanism, or just avoid faiss for now and simply implement the memory directly as a PyTorch tensor. That could be one straightforward solution.
Hmm, yeah, maybe this won't be smooth sailing
There is also another library out there called TorchPQ that may fit the bill of running on GPU and have removing of ids. But it is relatively young library still, so prob not without a few rough edges. I'll take a closer look next week, thanks for prematurely trying this out!
https://github.com/DeMoriarty/TorchPQ
FlatContainer in TorchPQ looks promising as a potential flat GPU index (to avoid the challenges with clustering): https://github.com/DeMoriarty/TorchPQ/blob/main/torchpq/container/FlatContainer.py
It seems like FlatContainer::set_data_by_address()
can arbitrarily overwrite records in the flat container. That would be more efficient than FlatContainer::remove()
because remove()
needs to copy a lot of data around. Not sure how much that will matter in the end, but always good to avoid copying when possible.
Could you describe how often these operations are done within memorizing transformers:
- Adding an embedding (and how many at once)
- searching
- removing an embedding
An example of index that works without training (although it's not obvious that's a good property at this point) is IndexHNSW
What I meant by "there is little benefit in using faiss GPU" is that faiss indices are usually very fast (search is done in less than 1ms) on CPU, and it's not faster on GPU. The only time it's better to use GPU is if you need to query with a huge batch of embeddings (let's say 1M)
But the choice of index should be done based on how often you need to search/add/train/remove and how many vectors you have. So if you give more information on that, i can advise
@rom1504 thanks for offering your expertise! so basically this paper is adding embeddings at a rate of 512 tokens per training step. To compound on the problem, they are doing separate indexed memories per batch, which is why I have to instantiate the number of faiss indices equal to the batch size. Searching is done also every training step (after the first), with a top k of 32, and removal of embeddings starts after it hits some capacity limit (in the paper, they had 2048, and then scales the memory size up to 16000) 2048 would mean the removing of ids start on the 5th step. So basically high rates of adding, removing, searching.
The author told me what they were doing within google is running each batch on 1 TPU core, and thus able to assign it its own index.
Flat would be fine, but it also negates the paper's main selling point, which is that fetching from approximate knn memories should benefit attention net greatly. Hopefully it isn't the case that it "does not work in practice" due to engineering obstacles
even in the worst case, I think the lessons from this paper can be carried away to some other architecture (say if one were to generalize https://github.com/lucidrains/HTM-pytorch) 1. storing l2normed - key / values (cosine sim attention) as memories for stability 2. memories need not be differentiable 3. approximate knn is fine 4. one only needs one or two layers of long term memory fetching at most (placed at the middle of the attention net)
ok interesting Removing from an index is usually slow, so I would not remove. Instead I would replace the remove operation by adding removed indices to a mask. (and when doing the search you search with an higher K value and apply the mask, higher K affects minimally the search speed)
And maybe you rebuild the index from scratch every 1000 steps to save on memory if needed.
about add/search, I would start by trying simply using IndexHNSWFlat, I believe it will work well enough. It's a little slow at adding (maybe 10ms for a batch of 512), but search is basically instant (0.1ms)
import faiss
index = faiss.IndexHNSWFlat(dimension, 15, faiss.METRIC_INNER_PRODUCT)
index.add(faiss.rand((512, dimension)))
@rom1504 thanks for the suggestion :+1:
on trying it out, seeing add_with_ids not implemented for this type of index
, which would make custom memory expiration strategies unapproachable
The problem is that for each training sample, we need to:
- search seq_len entries
- remove seq_len entries
- add seq_len entries
So, adds : removals : searches are 1 : 1 : 1, or alternatively you can think of it as searches : replacements being 1 : 1. So, we are doing the removals in order to create space for the add. Masking out the elements doesn't really solve the problem for us because it doesn't open up new space.
One solution is to have two indexes: current
and previous
. We add to current
, and once current
fills up, clear previous
, current
becomes previous
, and previous
becomes current
. So basically, we are only adding. In this scenario, current
and previous
don't need to support any fancy operations beyond add() and clear(), so they can probably be either flat or HNSW.
Another solution is to use a single flat index that supports replacement: i.e., we can efficiently replace some entries with other entries. Faiss doesn't seem to support this, but you can either implement something from scratch (just represent the memory as a tensor), or use the other library that @lucidrains mentioned.
@igor0 yea, i feel like if we were to go with flat for faiss, there would be no benefit. the whole point is to sell the approach of using approximate knn for long term, non differentiable memory - at least, that was what excited me about the paper initially
maybe it would be best to forget about faiss and scann, and just try to roll something with deepmind's HTM (although i think more thought needs to be put into how to generalize HTM to more than just a depth of 1 hierarchy) - or the alternative is to just forget about this repository and focus on https://github.com/lucidrains/routing-transformer and make sure it supports recurrence and that the routing attention can act on a set of non differentiable memories
on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable
I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array)
you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap
Masking out the elements doesn't really solve the problem for us because it doesn't open up new space
why is opening new spaces needed? The number of embedding you add at every iteration is pretty small, so the memory use will be limited until you do a few thousands steps. At this point you can rebuild the index
on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable
I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array)
you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap
ohh ok, maybe it could still work then, since faiss can support a ridiculous number of vectors - realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained
i guess the other issue is frequent retraining of indices, since each batch will be maintaining its own index, and everytime a new document comes along, it needs to be cleared and retrained
screenshot below for clarity, just imagine the batch dimension being around 32 - 64
The number of embedding you add at every iteration is pretty small, so the memory use will be limited until you do a few thousands steps. At this point you can rebuild the index
OK, you can add elements to the index until it reaches 2 x intended_size
and then compact it down to intended_size
. During the compaction, you can choose whatever criteria for eviction, and simply create a new index. That's an alternative to having two indices (previous
and current
).
i guess the other issue is frequent retraining of indices
I don't think HNSW indices need to be trained, so training wouldn't be an issue. That's also why we shouldn't use clustering-based indices, at least at training time. (The constraints are a bit different at inference time, but let's focus on training time for now.)
realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained
If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8?
indeed hnsw requires no training, that's why I was suggesting it!
if your total number of vector is below 1M, hnsw is quite fine if your dimension is 1024, 1M vectors mean 4GB in ram which is quite reasonable
beyond 10M vectors, you'll have to think a bit more, but with some smart eviction and retraining only every N (like 1000) steps, it should be ok
indeed hnsw requires no training, that's why I was suggesting it!
TIL!
realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained
If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8?
yea that's true, but it would be nice if the work is extended (RL for example)
just to give an idea of how much faiss can scale, https://rom1504.github.io/clip-retrieval/ 's backend is currently holding an index of 5B embeddings of dimension 768. It uses 200MB of ram for the index thanks to faiss's memmapping feature. The index is 800GB. The search time is 400ms because it's big and on disk. (if using sharding and in-memory, the search time could be < 1ms)
so for millions of embeddings, everything should be fine
ok, let me meditate on this for an hour or two (before reckless execution), since it would require a big refactor of the KNN Memory class
thank you both, and hope you are having a great weekend :)