spacecutter
spacecutter copied to clipboard
Docs error?
Dear Torch friends
Perhaps I am missing something, but I think the docs of "reduction_" function is not correct. Shall it be?
def _reduction(loss: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Reduce loss
Parameters
----------
loss : torch.Tensor, [batch_size, 1]
Batch losses.
reduction : str
Method for reducing the loss. Options include 'elementwise_mean',
'none', and 'sum'.
Returns
-------
loss : torch.Tensor
Reduced loss.
"""
if reduction == 'elementwise_mean':
return loss.mean()
elif reduction == 'none':
return loss
elif reduction == 'sum':
return loss.sum()
else:
raise ValueError(f'{reduction} is not a valid reduction')
?
Yup, you're totally right. I should fix that...