Implementing PReLu
Implementing the Parametric ReLU from the well-known paper "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", which was proposed in ICCV 2015. This function introduces a parameter named "a" that is learnable, and it allows the function to adapt during training, potentially improving model accuracy and convergence compared to standard ReLU or Leaky ReLU functions. Denoted as: f(x) = x if x>= 0 | ax if x < 0 ArXiv link: https://arxiv.org/abs/1502.01852#
Could you please provide feedback on my implementation to confirm that it has been implemented correctly?
Hi - thanks for the contribution! It looks like this would be a good contribution to jax.nn, but there are a number of changes we'd have to make: mainly, the a value should be an explicit parameter of the function, otherwise we wouldn't be able to differentiate with respect to it. With this in mind, the function should not take init or rng or num_parameters as arguments. Also, the implementation should probably modeled after that of the existing relu. We would also need to add tests for the new function, in tests/nn_test.py. Is this something you'd like to work on?
Yes of course sir, I would.
So would you willing to give me some time to work on it? Currently, I am a bit busy.
Would like to hear your feedback!
I will also add tests later.
Ok the pr combination has been done.
Hello sir, have you been busy recently?
@MythicArrow It seems that prelu is equivalent to leaky_relu?
Yeah it looks similar but there is a significant difference between them. Leaky ReLu contains "a" as a fixed constant but when it comes to PReLu its "a" is a learnable slope parameter, which is trained through backpropagation.
You can pass a JAX array as the second argument to leaky_relu, and JAX will have no issues computing gradients through it.
Yeah indeed @DanisNone – it looks like this is the same as leaky_relu. Sorry @MythicArrow, I should have noticed that earlier.
I think given that, this PR can probably be closed.
But the PReLu's "a" is learnable not a constant.
I can make it eligible for being trained through backpropagation.
But the PReLu's "a" is learnable not a constant.
That is true of leaky_relu as well.
But the PReLu's "a" is learnable not a constant.
That is true of
leaky_reluas well.
Oh ok
Yeah indeed @DanisNone – it looks like this is the same as
leaky_relu. Sorry @MythicArrow, I should have noticed that earlier.
No problem sir.
Ok then I will close this pr.
I checked the LeakyReLu's proportions and I see that it has a fixed value of "a" differing from the learnable one in the PReLu so would you merge the pr if I had edited the implementation in a way that it would have made the "a" learnable and trainable through backpropagation instead of the fixed one in LeakyReLu?
You mention LeakyReLu – are you talking about the function in the stax example library? https://github.com/jax-ml/jax/blob/9678a764e83b5fab3a877348c6ccc503d9808145/jax/example_libraries/stax.py#L158
If so, then we likely wouldn't accept such a contribution, because we are not adding new features to stax.
If you're thinking of jax.nn.leaky_relu, then I can't say I understand your request here. The negative_slope is a parameter just like any other, and you can take the gradient with respect to it, which means its value can be optimized given an appropriate loss function. That makes me think it qualifies as a "learnable parameter" – is there something more you're looking for?
I had checked the original paper of LeakyReLu and its "a" was a fixed constant like 0.01 and wasn't learned. But in this case the jax's LeakyReLu func is like PReLu so I think there is no need for implementing PReLu.
Sir, I have researched the PReLu and found out that it indeed has a learnable parameter which is included in the model's weights. So if I define it as a parameter then it will be automatically learnable differing from the LeakyReLu. Does JAX allow to define 'a' as a model parameter?