slot-attention-pytorch
slot-attention-pytorch copied to clipboard
Sampling operation prevents gradients from the back-propagation
Thanks for your implementation of Slot Attention module. However, I found that the sampling operation (in Line 40 at model.py) prevents gradients from the back-propagation. During training, the gradients of slot_mu and slot_sigma will be zero, which means the two variable will not change. I think the reparameterization trick is needed to make the sampling operation differentiable.