torchdrug icon indicating copy to clipboard operation
torchdrug copied to clipboard

Un-standardization of predict method in PropertyPrediction

Open kanojikajino opened this issue 2 years ago • 2 comments

As far as I understand, PropertyPrediction task trains a predictor against the standardized target values (for example, the loss is computed against the standardized target values at this line), and therefore, the method predict in PropertyPrediction is designed to output the standardized prediction, whose mean equals 0 and std equals 1. I felt this is a bit confusing for users, because the output of predict has to be calibrated by users after calling predict, which could be forgotten and then could lead to suboptimal predictive performance.

To this end, I would like to ask,

  1. whether my understanding on predict is correct,
  2. the reason why the output of predict has to be standardized, and
  3. if there is no solid reason on it, whether I can fix it so that the output of predict is non-standardized, which will include destructive changes on APIs.

kanojikajino avatar Jul 06 '22 02:07 kanojikajino

Hi! Thanks for your advise!

Yes, your understanding on predict is correct. The output of predict was originally designed for the forward function, where we define the loss function. However, we think that following your suggestions to make the output of predict non-strandardized would be more friendly to users.

Feel free to open a pull request for this issue! We will include this change in next release version.

Oxer11 avatar Jul 13 '22 01:07 Oxer11

Hi, I created a pull request #111, fixing the issue. I would appreciate it if you could review it and merge it if appropriate.

kanojikajino avatar Jul 14 '22 12:07 kanojikajino