recurrent-memory-transformer-pytorch
recurrent-memory-transformer-pytorch copied to clipboard
Question: first read memories
During the first run, mems == None, and the model doesn't attend to any "read" tokens, as per:
https://github.com/lucidrains/recurrent-memory-transformer-pytorch/blob/3be7d43604c6921a7dbdc68f88c7f3c534f82d2a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py#L350-L355
Why not attend to read_memory_emb, and replace with
read_mem_length = mem_length
read_mems = repeat(self.read_memory_emb, 'm d -> b m d', b = batch)
if exists(read_memories):
read_mems += read_memories
@pfeatherstone ohh, yea i could add it and mask it out
is that the only thing keeping it from being onnx-able?
Pretty much. Basically nothing can be optional with tracing and ONNX. So both mems and mask must be specified
cool, will get it done later today!
@pfeatherstone decided to take the easy way out https://github.com/lucidrains/recurrent-memory-transformer-pytorch/commit/90de2ac64c1ce2d2ef90f3b63dbdcecf8af2a024 let me know if this does it
You still need to pass mems in the forward pass when tracing. So i think the fix would be to pass mems full of zeros and make the code handle that is if it were None
Then probably need some control flow in a torch.jit.script function which handles the different cases
@pfeatherstone have you tried this manually?
Yeah I have. I could submit a PR but I didn't know if passing zeros and allowing to attend to that was ok.
@pfeatherstone it is harmless to attend to it, as the 0s get summed to the read mem positional embeddings
however, decided to add a keyword argument on forward that can mask out the read memories if need be
@pfeatherstone all good?
Ah yes I need to try. Can it wait till Tuesday?
@pfeatherstone yea, just checking, take your time