litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Issue when training a MOE model

Open lfsszd opened this issue 1 year ago • 1 comments

I modified some of the Pythia models to use LLaMAMoE. However, it didn't run properly. I attached the the full log below, but in short, the line token_idx, expert_idx = torch.where(mask) will error out because NotImplementedError: Could not run 'aten::nonzero' with arguments from the 'Meta' backend. "

For reference, I followed the instructions in readme and created an Anaconda env with python 3.10. I can run a normal Pythia model without any issues.

Validating ...
Estimated TFLOPs: 0.54
Traceback (most recent call last):
  File "/scratch/admin/lit-gpt/temp.py", line 382, in <module>
    CLI(setup)
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
           ^^^^^^^^^^^^^^^^
  File "/scratch/admin/lit-gpt/temp.py", line 117, in setup
    fabric.launch(main, resume=resume)
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 834, in launch
    return self._wrap_and_launch(function, self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 920, in _wrap_and_launch
    return to_run(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 925, in _wrap_with_setup
    return to_run(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/admin/lit-gpt/temp.py", line 173, in main
    train(fabric, state, train_dataloader, val_dataloader)
  File "/scratch/admin/lit-gpt/temp.py", line 263, in train
    measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/utilities/throughput.py", line 303, in measure_flops
    loss_fn(forward_fn()).backward()
            ^^^^^^^^^^^^
  File "/scratch/admin/lit-gpt/temp.py", line 261, in <lambda>
    forward_fn = lambda: meta_model(x)
                         ^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/admin/lit-gpt/lit_gpt/model.py", line 91, in forward
    x = block(x, cos, sin, mask, input_pos)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/admin/lit-gpt/lit_gpt/model.py", line 158, in forward
    x = self.mlp(n_2) + h + x
        ^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/admin/lit-gpt/lit_gpt/model.py", line 393, in forward
    token_idx, expert_idx = torch.where(mask)
                            ^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/utils/_device.py", line 77, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/utils/flop_counter.py", line 413, in __torch_dispatch__
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})

NotImplementedError: Could not run 'aten::nonzero' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::nonzero' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:31188 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:44143 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_0.cpp:16725 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/BatchRulesDynamic.cpp:66 [kernel]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]

lfsszd avatar Feb 29 '24 00:02 lfsszd

Thanks for the report. Unfortunately, PyTorch doesn't support this, so we cannot measure the flops used by Mixtral the way we do.

For the moment, you can avoid the error by setting measured_flops = 11996892435054592 * micro_batch_size. The downside is that flops and mfu won't be correct.

carmocca avatar Feb 29 '24 01:02 carmocca