jax icon indicating copy to clipboard operation
jax copied to clipboard

Implementing PReLu

Open MythicArrow opened this issue 6 months ago • 23 comments

Implementing the Parametric ReLU from the well-known paper "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification", which was proposed in ICCV 2015. This function introduces a parameter named "a" that is learnable, and it allows the function to adapt during training, potentially improving model accuracy and convergence compared to standard ReLU or Leaky ReLU functions. Denoted as: f(x) = x if x>= 0 | ax if x < 0 ArXiv link: https://arxiv.org/abs/1502.01852#

MythicArrow avatar May 31 '25 20:05 MythicArrow

Could you please provide feedback on my implementation to confirm that it has been implemented correctly?

MythicArrow avatar May 31 '25 20:05 MythicArrow

Hi - thanks for the contribution! It looks like this would be a good contribution to jax.nn, but there are a number of changes we'd have to make: mainly, the a value should be an explicit parameter of the function, otherwise we wouldn't be able to differentiate with respect to it. With this in mind, the function should not take init or rng or num_parameters as arguments. Also, the implementation should probably modeled after that of the existing relu. We would also need to add tests for the new function, in tests/nn_test.py. Is this something you'd like to work on?

jakevdp avatar Jun 02 '25 18:06 jakevdp

Yes of course sir, I would.

MythicArrow avatar Jun 02 '25 21:06 MythicArrow

So would you willing to give me some time to work on it? Currently, I am a bit busy.

MythicArrow avatar Jun 02 '25 21:06 MythicArrow

Would like to hear your feedback!

MythicArrow avatar Jun 02 '25 21:06 MythicArrow

I will also add tests later.

MythicArrow avatar Jun 02 '25 21:06 MythicArrow

Ok the pr combination has been done.

MythicArrow avatar Jun 03 '25 20:06 MythicArrow

Hello sir, have you been busy recently?

MythicArrow avatar Jun 08 '25 11:06 MythicArrow

@MythicArrow It seems that prelu is equivalent to leaky_relu?

DanisNone avatar Jun 10 '25 19:06 DanisNone

Yeah it looks similar but there is a significant difference between them. Leaky ReLu contains "a" as a fixed constant but when it comes to PReLu its "a" is a learnable slope parameter, which is trained through backpropagation.

MythicArrow avatar Jun 10 '25 19:06 MythicArrow

You can pass a JAX array as the second argument to leaky_relu, and JAX will have no issues computing gradients through it.

DanisNone avatar Jun 10 '25 19:06 DanisNone

Yeah indeed @DanisNone – it looks like this is the same as leaky_relu. Sorry @MythicArrow, I should have noticed that earlier.

jakevdp avatar Jun 10 '25 20:06 jakevdp

I think given that, this PR can probably be closed.

jakevdp avatar Jun 10 '25 20:06 jakevdp

But the PReLu's "a" is learnable not a constant.

MythicArrow avatar Jun 10 '25 20:06 MythicArrow

I can make it eligible for being trained through backpropagation.

MythicArrow avatar Jun 10 '25 20:06 MythicArrow

But the PReLu's "a" is learnable not a constant.

That is true of leaky_relu as well.

jakevdp avatar Jun 10 '25 20:06 jakevdp

But the PReLu's "a" is learnable not a constant.

That is true of leaky_relu as well.

Oh ok

MythicArrow avatar Jun 10 '25 20:06 MythicArrow

Yeah indeed @DanisNone – it looks like this is the same as leaky_relu. Sorry @MythicArrow, I should have noticed that earlier.

No problem sir.

MythicArrow avatar Jun 10 '25 20:06 MythicArrow

Ok then I will close this pr.

MythicArrow avatar Jun 10 '25 20:06 MythicArrow

I checked the LeakyReLu's proportions and I see that it has a fixed value of "a" differing from the learnable one in the PReLu so would you merge the pr if I had edited the implementation in a way that it would have made the "a" learnable and trainable through backpropagation instead of the fixed one in LeakyReLu?

MythicArrow avatar Jun 15 '25 22:06 MythicArrow

You mention LeakyReLu – are you talking about the function in the stax example library? https://github.com/jax-ml/jax/blob/9678a764e83b5fab3a877348c6ccc503d9808145/jax/example_libraries/stax.py#L158 If so, then we likely wouldn't accept such a contribution, because we are not adding new features to stax.

If you're thinking of jax.nn.leaky_relu, then I can't say I understand your request here. The negative_slope is a parameter just like any other, and you can take the gradient with respect to it, which means its value can be optimized given an appropriate loss function. That makes me think it qualifies as a "learnable parameter" – is there something more you're looking for?

jakevdp avatar Jun 16 '25 04:06 jakevdp

I had checked the original paper of LeakyReLu and its "a" was a fixed constant like 0.01 and wasn't learned. But in this case the jax's LeakyReLu func is like PReLu so I think there is no need for implementing PReLu.

MythicArrow avatar Jun 18 '25 11:06 MythicArrow

Sir, I have researched the PReLu and found out that it indeed has a learnable parameter which is included in the model's weights. So if I define it as a parameter then it will be automatically learnable differing from the LeakyReLu. Does JAX allow to define 'a' as a model parameter?

MythicArrow avatar Dec 11 '25 08:12 MythicArrow