graphstorm
graphstorm copied to clipboard
Clarify/make consistent expected prediction output shape for different loss functions
For classification tasks, cross entropy will have a prediction shape with number of columns that's equal to the number of classes. So for binary tasks we have 2 columns in the output, where col2 = 1-col1.
Focal loss instead will produce one column that is the positive score, which carries the same information while being more space-efficient.
We need to inform users and possibly enforce a consistent behavior for num_clases
Re-opening for PR to mark closed for good