max_pool2d:Shape propagation error for input dimensions used by resnet
🐛 Bug
import torch
import torch.nn.functional as F
import thunder
a = torch.randn(1, 64, 112, 112).cuda().requires_grad_()
def func(a):
return F.max_pool2d(a, 3, 2, 1, 1, False, False) # t79: "cuda:0 f32[1, 64, 56, 56]"
cfunc = thunder.jit(func)
b = cfunc(a)
print(thunder.last_traces(cfunc)[-1].output[0]['output'].shape)
print(b.shape)
Outputs:
(56, 56)
torch.Size([1, 64, 56, 56])
The output shape of the trace is wrong, but it runs successfully
Trace:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
# a: "cuda:0 f32[1, 64, 112, 112]"
(t0, t1) = max_pool2d_with_indices(a, 3, 2, 1, 1, False)
return {'output': t0, 'flat_args': [a], 'flat_output': (t0,)}, ((a, t1), (False, 3, 2, 1, 1))
cc @apaz-cli
OK, looks like max_pool_with_indices comes from https://github.com/Lightning-AI/lightning-thunder/pull/163. max_pool without indices has a well-tested meta-function, and it could be re-used here.
triage review -- we should test that the metadata thunder produces is consistent with the actual output, too
OK, looks like
max_pool_with_indicescomes from #163. max_pool without indices has a well-tested meta-function, and it could be re-used here.
Oops. sorry about that. Looks like I forgot to attach the non-pooling dimensions at the beginning. https://github.com/Lightning-AI/lightning-thunder/pull/163/files#diff-3c6a6ca64f7cd3508bcd348612f5aadea83a0506e521c1c8a232553f047d2321R1268
I'll do a quick fix for this one.