Request to include new updates to the TensorFlow colab
Hi,
Thanks for your hard work in porting the weights from TF. It has been helpful for the students who use PyTorch. I just updated the checkpoint and model to be the exact one we used in the paper. Also added some code to handle new Keras variable naming. Would it be possible for you to update the PyTorch port as well? It should result in a model that counts better.
Thanks, Debidatta
Hi @debidatta, thank you for the update. Yes I had it in mind but never got to do it.
Is the architecture the exact same as the old model, and it's just a matter of parameters naming? Or is there any additional parameter/logic that needs to be added to the port?
Yes a couple of changes are needed based on this commit
Architecturally, pos_encoding1 and pos_encoding2 here would now have 512 channels and fc_layer size for period prediction here will have size 64.
During forward pass you don't need to add +1 anymore here. Optionally, you can keep 32 channels in the last axis after the forward pass is done here to be the same as the paper.
Thank you that's super useful!