slot-attention-pytorch icon indicating copy to clipboard operation
slot-attention-pytorch copied to clipboard

Sampling operation prevents gradients from the back-propagation

Open hu-my opened this issue 1 year ago • 0 comments

Thanks for your implementation of Slot Attention module. However, I found that the sampling operation (in Line 40 at model.py) prevents gradients from the back-propagation. During training, the gradients of slot_mu and slot_sigma will be zero, which means the two variable will not change. I think the reparameterization trick is needed to make the sampling operation differentiable.

hu-my avatar Apr 11 '23 06:04 hu-my