torchdrug
torchdrug copied to clipboard
Un-standardization of predict method in PropertyPrediction
As far as I understand, PropertyPrediction
task trains a predictor against the standardized target values (for example, the loss is computed against the standardized target values at this line), and therefore,
the method predict
in PropertyPrediction
is designed to output the standardized prediction, whose mean equals 0 and std equals 1. I felt this is a bit confusing for users, because the output of predict
has to be calibrated by users after calling predict
, which could be forgotten and then could lead to suboptimal predictive performance.
To this end, I would like to ask,
- whether my understanding on
predict
is correct, - the reason why the output of
predict
has to be standardized, and - if there is no solid reason on it, whether I can fix it so that the output of
predict
is non-standardized, which will include destructive changes on APIs.
Hi! Thanks for your advise!
Yes, your understanding on predict
is correct. The output of predict
was originally designed for the forward
function, where we define the loss function. However, we think that following your suggestions to make the output of predict
non-strandardized would be more friendly to users.
Feel free to open a pull request for this issue! We will include this change in next release version.
Hi, I created a pull request #111, fixing the issue. I would appreciate it if you could review it and merge it if appropriate.