NanoRange icon indicating copy to clipboard operation
NanoRange copied to clipboard

[Question]Tutorial 3(Pytorch, JAX) Test accuracy for the model with sigmoid activation function

Open sy-eng opened this issue 2 years ago • 2 comments

Thank you for your great tutorials!

I have a question about your comment on the test accuracy for the model with sigmoid activation function. (It is under cell 17 for pytorch and under cell 18 for JAX)

You mentioned the result with sigmoid is very poor and, coincidentally, the model for JAX is trained, but is it because the model for pytorch is not trained well and is the result for JAX is correct?

I re-trained the model for pytorch and I found the training stops at epoch 8, because the result of epoch 1 is better than epoch 2-8. This means the saved model is the result of epoch 1.

I changed "patient" variable from 7 to 50 and I got a similar result with JAX.

Thank you.

sy-eng avatar Oct 08 '23 11:10 sy-eng

Hi, the sigmoid model is indeed a fun one to play around in this tutorial. :) I had tried a couple of trainings in PyTorch with 50 epochs and noticed that some start learning suddenly, but many continued to fail even after longer trainings. You need to be a bit lucky that the gradients don't cancel each other out too much in the early layers and actually start learning. In JAX, the sigmoid networks tend to go slightly more stably to the learning regime. At the same time, when you optimize the initialization, add some normalization or use Adam, the MLP also trains relatively good with sigmoid activation functions. Nonetheless, the idea of the sigmoid training was to show that one shouldn't use sigmoid as the main hidden activation function in a network, since it brings several drawbacks. So I would recommend using other activation functions than trying to over-optimize the sigmoid network :)

phlippe avatar Oct 08 '23 14:10 phlippe

Thank you for your comment.

I had tried a couple of trainings in PyTorch with 50 epochs and noticed that some start learning suddenly, but many continued to fail even after longer trainings.

I ran a code shown below and all test accuracies were higher than 75%... Did many models with sigmoid really fail to learn?

for i in range(50): print(f"Training BaseNetwork with {i} ") set_seed(i) act_fn = Sigmoid() net_actfn = BaseNetwork(act_fn=act_fn).to(device) train_model(net_actfn, f"FashionMNIST_sigmoid_{i}", overwrite=False, patience=50)

It is true the learning start suddenly.

So I would recommend using other activation functions than trying to over-optimize the sigmoid network :)

I agree this.

Thank you.

sy-eng avatar Oct 09 '23 12:10 sy-eng