tianshou icon indicating copy to clipboard operation
tianshou copied to clipboard

How to use custom loss ?

Open duburcqa opened this issue 4 years ago • 8 comments

I would like to add the following extra term to the loss function, || y_{pred} - y_{ref} ||_2^2 where y_{pred} is the action sampled by the distribution, and y_{ref} can be computed by the actor.

What is the best way to do it using your framework ? The point is being able to take advantage of the analytical gradient computation.

The only way I can think of is to overwrite the whole learn method of the policy (i.e. PPO algorithm), but it feels inconvenient just to add an extra line of code...

Thank you in advance,

Best, Alexis

duburcqa avatar May 28 '20 15:05 duburcqa

That's a good question. Currently, you can either inherit the policy class (as you mentioned) or change the original framework's code to meet your expectations.

It can be discussed further. Some existing frameworks (like RLlib) modularized the loss function part. But in my opinion, this could be inconvenient for further development. Since the loss function is highly customizable, making the abstraction of the loss function will double the code complexity.

Trinkle23897 avatar May 29 '20 00:05 Trinkle23897

Ok, I got your point and I agree with you. But what about adding a loss_fn in the abstract base class for policies, that is basically doing nothing by default but can be overridden by the user? Because I really don't like the idea to have to overwrite 'learn' itself since it is a major source of error.

In this case, it is not a custom loss strictly speaking, but rather additional component to the original loss function (regularization), that may depend on the actor. So that it only consists in an extra function call before calling backward. I don't know if doing so is usual or not.

duburcqa avatar May 29 '20 06:05 duburcqa

@Trinkle23897 Up !

duburcqa avatar Jun 30 '20 08:06 duburcqa

@Trinkle23897 Up !

I have no time after #106 before this Friday...Many things to do

Trinkle23897 avatar Jun 30 '20 08:06 Trinkle23897

No problem ! I can do it ! But what do you think about the idea ?

duburcqa avatar Jun 30 '20 08:06 duburcqa

I think that add loss_fn is okay, but what's its input?

Trinkle23897 avatar Jun 30 '20 10:06 Trinkle23897

@duburcqa It's a great idea to make it easier with a customized loss. I wondered if you have made any progress on that. Thanks!

oldcricket avatar Sep 10 '21 19:09 oldcricket

The loss is an integral part of the algorithm, so maybe inheriting and overriding is better than allowing users to pass custom losses. It's a central design question, I don't see it being necessary for the 1.0.0 release, but would keep the issue open

MischaPanch avatar Oct 14 '23 14:10 MischaPanch