[breaking] force logit output of focal loss to conform to (N, 2) shape for binary classification
Issue #, if available:
https://github.com/awslabs/graphstorm/issues/1242
Description of changes:
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
The key problem of setting decoder_output_dim=2 is that we change the implementation of the decoder layer. Specifically, it's parameter size will be 2X larger and the parameters are impacted by the [:, 1].
Theoretically, the effective parameters of both decoder implementations are identical. The difference is, with the new implementation the output [:, 0] is meaningless. @thvasilo Can you set [:, 0] to all zeros when returning it for inference?