lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

max_pool2d:Shape propagation error for input dimensions used by resnet

Open kiya00 opened this issue 1 year ago • 2 comments

🐛 Bug

import torch
import torch.nn.functional as F
import thunder

a = torch.randn(1, 64, 112, 112).cuda().requires_grad_()
def func(a):
    return F.max_pool2d(a, 3, 2, 1, 1, False, False)  # t79: "cuda:0 f32[1, 64, 56, 56]"
cfunc = thunder.jit(func)
b = cfunc(a)

print(thunder.last_traces(cfunc)[-1].output[0]['output'].shape)
print(b.shape)

Outputs:

(56, 56)
torch.Size([1, 64, 56, 56])

The output shape of the trace is wrong, but it runs successfully

Trace:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[1, 64, 112, 112]"
  (t0, t1) = max_pool2d_with_indices(a, 3, 2, 1, 1, False)
  return {'output': t0, 'flat_args': [a], 'flat_output': (t0,)}, ((a, t1), (False, 3, 2, 1, 1))

cc @apaz-cli

kiya00 avatar May 06 '24 09:05 kiya00

OK, looks like max_pool_with_indices comes from https://github.com/Lightning-AI/lightning-thunder/pull/163. max_pool without indices has a well-tested meta-function, and it could be re-used here.

nikitaved avatar May 06 '24 19:05 nikitaved

triage review -- we should test that the metadata thunder produces is consistent with the actual output, too

mruberry avatar May 13 '24 19:05 mruberry

OK, looks like max_pool_with_indices comes from #163. max_pool without indices has a well-tested meta-function, and it could be re-used here.

Oops. sorry about that. Looks like I forgot to attach the non-pooling dimensions at the beginning. https://github.com/Lightning-AI/lightning-thunder/pull/163/files#diff-3c6a6ca64f7cd3508bcd348612f5aadea83a0506e521c1c8a232553f047d2321R1268

I'll do a quick fix for this one.

jjsjann123 avatar May 29 '24 09:05 jjsjann123