functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Fixed issue with total_weight in nll_loss_forward_decomposition

Open vfdev-5 opened this issue 3 years ago • 2 comments

Description:

@Chillee catched that total_weight output is wrong for nll_loss_forward_decomposition C++ implementation:

import torch
from torch import tensor
from torch._decomp import decomposition_table
from functorch import vmap
aten = torch.ops.aten

args = (
    tensor([[-4.8270, -7.5824, -0.6047], [-1.5412, -1.9719, -4.1460]], dtype=torch.float64, requires_grad=True),
    tensor([1, 1]),
    None,
    0,
    -100
)
ref_out = aten.nll_loss_forward(*args)

decomp_out = vmap(
    aten.nll_loss_forward.default,
    in_dims=(0, None, None, None, None)
)(args[0].unsqueeze(0), args[1], args[2], args[3], args[4])

torch.testing.assert_close(ref_out[0].unsqueeze(0), decomp_out[0])
torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])

Before this PR:

Traceback (most recent call last):
  File "/tmp/fth/repro_nll_loss_issue.py", line 73, in <module>
    torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])
  File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1317, in assert_close
    assert_equal(
  File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1086, in assert_equal
    raise error_metas[0].to_error()
AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2.0 at index (0,) (up to 1e-07 allowed)
Greatest relative difference: 1.0 at index (0,) (up to 1e-07 allowed)

PR is tested with (as right now there is no way to check total_weight on CI)

for ignore_value in [-100, 0, 2]:
    for reduction in [0, 1, 2]:
        args = (
            tensor([[-4.8270, -7.5824, -0.6047], [-1.5412, -1.9719, -4.1460]], dtype=torch.float64, requires_grad=True),
            tensor([1, 1]),
            None,
            reduction,
            ignore_value
        )
        ref_out = aten.nll_loss_forward(*args)

        decomp_out = vmap(
            aten.nll_loss_forward.default,
            in_dims=(0, None, None, None, None)
        )(args[0].unsqueeze(0), args[1], args[2], args[3], args[4])

        print(ignore_value, reduction, ref_out, decomp_out)
        torch.testing.assert_close(ref_out[0].unsqueeze(0), decomp_out[0])
        torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])

vfdev-5 avatar May 22 '22 21:05 vfdev-5

@Chillee sure as we discussed that elsewhere, I'll be adding nll_loss_forward to pytorch core for my next task.

EDIT: coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later

vfdev-5 avatar May 27 '22 10:05 vfdev-5

coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later

Haha, I've run into this a couple times when porting ops into Python :)

Chillee avatar May 31 '22 18:05 Chillee