Aapo Tanskanen
Aapo Tanskanen
Hi @JackCaoG @bdhirsh any update on this? I am having this same issue when trying to pretrain a fairseq wav2vec2 model with pytorch-xla on TPU VM. Fairseq wav2vec2 model has...
Hi, I just published ONNX version with scripts to do the ONNX conversion here: https://huggingface.co/aapot/bge-m3-onnx
> > Hi, I just published ONNX version with scripts to do the ONNX conversion here: https://huggingface.co/aapot/bge-m3-onnx > > Thanks for your work. It seems like a cpu version, right?...
@sanchit-gandhi yep this would be interesting to get working! Yes, I also tried gradient scaling like it was implemented in the PyTorch pretrain script (basically multiply gradients with (num devices...
Thanks for those pointers @sanchit-gandhi, sounds reasonable! I'll start digging into this soon, will keep you updated here.
Just a quick update so I finally had time to start the actual debugging. More info to follow soon.
Alright, here are some findings so far: Step 1: `mask_time_indices` and `sampled_negative_indices` were not same with the PT implementation. Fixed that by pretty much just copying functions for those from...
Continuing with the updates: Step 4. contrastive loss calculation is same with Flax and PT Step 5. diversity loss calculation looks to be same but I'll verify that later
Thanks for the tips @sanchit-gandhi! Actually I also had in mind to use pre-trained weights to compare model outputs that way, will try it soon. Will also check fairseq implementation...
After using pre-trained weights to continue pretraining for one more step with same input data, I think following is happening with model outputs: - `projected_states` has max difference of 0.276...