lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Remove thunder.compile

Open IvanYashchuk opened this issue 1 year ago • 6 comments

I removed the majority of thunder.compile and make_callable_legacy usage in tests. ~There are a few left still.~

No test uses thunder.compile -> we can remove it completely.

Ref. https://github.com/Lightning-AI/lightning-thunder/issues/198, https://github.com/Lightning-AI/lightning-thunder/issues/794

IvanYashchuk avatar Jul 23 '24 11:07 IvanYashchuk

Current failures:

=========================== short test summary info ============================
FAILED thunder/tests/test_grad.py::test_vjp_correctness_getitem_torch_cpu_thunder.dtypes.float64 - RuntimeError: Expected '[tensor([[ 1.2818,  3.2108, -5.9048,  5.2651,  4.6372],
        [ 8.4662, -7.6512,  4.1605, -5.7554,  0.0265],
        [ 2.3152, -6.4835, -4.5604,  7.7270, -6.9754],
        [ 7.9316, -1.2834,  8.9820, -7.3240, -6.4750],
        [ 2.3197, -2.5032, -1.1486, -2.6956, -4.2437]], dtype=torch.float64,
       requires_grad=True), slice(1, 3, 1), slice(2, 4, 2)] to be equal to 'slice(1, 3, 1)
FAILED thunder/tests/test_core.py::test_cse_torch_cpu_None - AssertionError: assert ['t4', 't4', ...', 't17', ...] == ['t4', 't4', ...', 't16', ...]
  
  At index 3 diff: 't7' != 't14'
  
  Full diff:
    [
        't4',
        't4',
        't6',
  -     't14',
  ?       ^^
  +     't7',
  ?       ^
  -     't15',
        't16',
        't17',
  +     't18',
    ]
= 2 failed

IvanYashchuk avatar Jul 23 '24 12:07 IvanYashchuk

thunder/tests/test_core.py::test_cse_torch_cpu_None - AssertionError: assert ['t4', 't4', ...', 't17', ...] == ['t4', 't4', ...', 't16', ...] This test passes locally.

IvanYashchuk avatar Jul 23 '24 16:07 IvanYashchuk

More failures in the GPU CI:

=========================== short test summary info ============================
FAILED thunder/tests/test_grad.py::test_vjp_correctness_convolution_nvfuser_cuda_thunder.dtypes.float64 - Failed: Timeout >240.0s
FAILED thunder/tests/test_grad.py::test_vjp_correctness_baddbmm_torch_cuda_thunder.dtypes.float64 - thunder.core.interpreter.InterpreterError: Encountered exception Failed: Timeout >240.0s while tracing <function _make_differentiable_wrapper.<locals>.wrapper at 0x7fa7903d6f80>:
FAILED thunder/tests/test_grad.py::test_vjp_correctness_adaptive_avg_pool2d_torch_cuda_thunder.dtypes.float64 - NotImplementedError: VJP for torch.nn.functional.adaptive_avg_pool2d is not implemented
FAILED thunder/tests/test_grad.py::test_vjp_correctness_baddbmm_nvfuser_cuda_thunder.dtypes.float64 - thunder.core.interpreter.InterpreterError: Encountered exception Failed: Timeout >240.0s while tracing <function _make_differentiable_wrapper.<locals>.wrapper at 0x7f8b42dace50>:
FAILED thunder/tests/test_grad.py::test_vjp_correctness_convolution_torch_cuda_thunder.dtypes.float64 - thunder.core.interpreter.InterpreterError: Encountered exception Failed: Timeout >240.0s while tracing <function _make_differentiable_wrapper.<locals>.wrapper at 0x7f7c8eb24f70>:
FAILED thunder/tests/test_core.py::test_cse_nvfuser_cuda_None - AssertionError: assert ['t4', 't4', ...', 't17', ...] == ['t4', 't4', ...', 't16', ...]
  
  At index 3 diff: 't7' != 't14'
  
  Full diff:
    [
        't4',
        't4',
        't6',
  -     't14',
  ?       ^^
  +     't7',
  ?       ^
  -     't15',
        't16',
        't17',
  +     't18',
    ]
FAILED thunder/tests/test_core.py::test_traceback - AssertionError: assert 'torch.neg' in '        return func(*args, **kwargs)'
 +  where '        return func(*args, **kwargs)' = str(<_pytest._code.source.Source object at 0x7f96b8570520>)
 +    where <_pytest._code.source.Source object at 0x7f96b8570520> = <TracebackEntry /usr/local/lib/python3.10/dist-packages/torch/utils/_device.py:78>.statement
FAILED thunder/tests/test_core.py::test_cse_torch_cuda_None - AssertionError: assert ['t4', 't4', ...', 't17', ...] == ['t4', 't4', ...', 't16', ...]
  
  At index 3 diff: 't7' != 't14'
  
  Full diff:
    [
        't4',
        't4',
        't6',
  -     't14',
  ?       ^^
  +     't7',
  ?       ^
  -     't15',
        't16',
        't17',
  +     't18',
    ]
= 8 failed

IvanYashchuk avatar Jul 23 '24 16:07 IvanYashchuk

I wonder if the cse is flaky because it is ordering dependent somehow (maybe which of the common subexpression) is eliminated.

The timeouts could be due to the congestion, it seems to get better now and I'll rerun them.

t-vi avatar Jul 23 '24 17:07 t-vi

The timeouts seem to persist. :(

t-vi avatar Jul 23 '24 18:07 t-vi

The only failed case(https://github.com/Lightning-AI/lightning-thunder/pull/837#issuecomment-2245732546) I can reproduce locally is the FAILED thunder/tests/test_grad.py::test_vjp_correctness_adaptive_avg_pool2d_torch_cuda_thunder.dtypes.float64 - NotImplementedError: VJP for torch.nn.functional.adaptive_avg_pool2d is not implemented, the reason is because this op only uses torchex.grad_transform. when using initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v), there's no CompiledData passed in, thus no gradfn is found. https://github.com/Lightning-AI/lightning-thunder/blob/fa307a5656d2b7000a07a489f3d434aa7ba8cfd6/thunder/core/transforms.py#L1410-L1425

Even if we pass in the compiledData like:

trc_cd = CompileData(
        fn=vjp(f),
        executors_list=executor.executors_list(),
        disable_torch_autograd_support=True,
        disable_preprocessing=True,
    )
    with compile_data_and_stats(trc_cd, None):
        initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v)
        print(initial_trace_vjp_f)

The initial_trace_vjp_f has some naming issue(t5 = torch.torch.ops.aten._adaptive_avg_pool2d_backward(t4, a)) related to https://github.com/Lightning-AI/lightning-thunder/blob/fa307a5656d2b7000a07a489f3d434aa7ba8cfd6/thunder/core/symbol.py#L584

@torch.no_grad()
@no_autocast
def adaptive_avg_pool2d_bwd_wrapper(a, output_size):
  # a: "cuda:0 f64[3, 3, 3]"
  # output_size: "int 5"
  t1 = torch.nn.functional.adaptive_avg_pool2d(a, output_size)  # t1: "cuda:0 f64[3, 5, 5]"
    # t1 = ltorch.adaptive_avg_pool2d(a, output_size)  # t1: "cuda:0 f64[3, 5, 5]"
  t4 = prims.get_grad(t1)  # t4: "cuda:0 f64[3, 5, 5]"
  t5 = torch.torch.ops.aten._adaptive_avg_pool2d_backward(t4, a)  # t5: "cuda:0 f64[3, 3, 3]"
    # t5 = ltorch.adaptive_avg_pool2d_backward(t4, a)  # t5: "cuda:0 f64[3, 3, 3]"
  prims.put_grad(a, t5)
  return t1

Normally for this kind of torch.ops.aten. case it goes into https://github.com/Lightning-AI/lightning-thunder/blob/fa307a5656d2b7000a07a489f3d434aa7ba8cfd6/thunder/core/symbol.py#L586-L587 but we don't have _call_ctx when call trace() func cc: @IvanYashchuk

kiya00 avatar Aug 01 '24 15:08 kiya00

Closing in favor of https://github.com/Lightning-AI/lightning-thunder/pull/1114.

IvanYashchuk avatar Sep 07 '24 08:09 IvanYashchuk