Fixed issue with total_weight in nll_loss_forward_decomposition
Description:
@Chillee catched that total_weight output is wrong for nll_loss_forward_decomposition C++ implementation:
import torch
from torch import tensor
from torch._decomp import decomposition_table
from functorch import vmap
aten = torch.ops.aten
args = (
tensor([[-4.8270, -7.5824, -0.6047], [-1.5412, -1.9719, -4.1460]], dtype=torch.float64, requires_grad=True),
tensor([1, 1]),
None,
0,
-100
)
ref_out = aten.nll_loss_forward(*args)
decomp_out = vmap(
aten.nll_loss_forward.default,
in_dims=(0, None, None, None, None)
)(args[0].unsqueeze(0), args[1], args[2], args[3], args[4])
torch.testing.assert_close(ref_out[0].unsqueeze(0), decomp_out[0])
torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])
Before this PR:
Traceback (most recent call last):
File "/tmp/fth/repro_nll_loss_issue.py", line 73, in <module>
torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])
File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1317, in assert_close
assert_equal(
File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1086, in assert_equal
raise error_metas[0].to_error()
AssertionError: Tensor-likes are not close!
Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2.0 at index (0,) (up to 1e-07 allowed)
Greatest relative difference: 1.0 at index (0,) (up to 1e-07 allowed)
PR is tested with (as right now there is no way to check total_weight on CI)
for ignore_value in [-100, 0, 2]:
for reduction in [0, 1, 2]:
args = (
tensor([[-4.8270, -7.5824, -0.6047], [-1.5412, -1.9719, -4.1460]], dtype=torch.float64, requires_grad=True),
tensor([1, 1]),
None,
reduction,
ignore_value
)
ref_out = aten.nll_loss_forward(*args)
decomp_out = vmap(
aten.nll_loss_forward.default,
in_dims=(0, None, None, None, None)
)(args[0].unsqueeze(0), args[1], args[2], args[3], args[4])
print(ignore_value, reduction, ref_out, decomp_out)
torch.testing.assert_close(ref_out[0].unsqueeze(0), decomp_out[0])
torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])
@Chillee sure as we discussed that elsewhere, I'll be adding nll_loss_forward to pytorch core for my next task.
EDIT: coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later
coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later
Haha, I've run into this a couple times when porting ops into Python :)