Issue when training a MOE model
I modified some of the Pythia models to use LLaMAMoE. However, it didn't run properly.
I attached the the full log below, but in short, the line
token_idx, expert_idx = torch.where(mask)
will error out because NotImplementedError: Could not run 'aten::nonzero' with arguments from the 'Meta' backend. "
For reference, I followed the instructions in readme and created an Anaconda env with python 3.10. I can run a normal Pythia model without any issues.
Validating ...
Estimated TFLOPs: 0.54
Traceback (most recent call last):
File "/scratch/admin/lit-gpt/temp.py", line 382, in <module>
CLI(setup)
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/jsonargparse/_cli.py", line 96, in CLI
return _run_component(components, cfg_init)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/jsonargparse/_cli.py", line 181, in _run_component
return component(**cfg)
^^^^^^^^^^^^^^^^
File "/scratch/admin/lit-gpt/temp.py", line 117, in setup
fabric.launch(main, resume=resume)
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 834, in launch
return self._wrap_and_launch(function, self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 920, in _wrap_and_launch
return to_run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 925, in _wrap_with_setup
return to_run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/admin/lit-gpt/temp.py", line 173, in main
train(fabric, state, train_dataloader, val_dataloader)
File "/scratch/admin/lit-gpt/temp.py", line 263, in train
measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/lightning/fabric/utilities/throughput.py", line 303, in measure_flops
loss_fn(forward_fn()).backward()
^^^^^^^^^^^^
File "/scratch/admin/lit-gpt/temp.py", line 261, in <lambda>
forward_fn = lambda: meta_model(x)
^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/admin/lit-gpt/lit_gpt/model.py", line 91, in forward
x = block(x, cos, sin, mask, input_pos)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/admin/lit-gpt/lit_gpt/model.py", line 158, in forward
x = self.mlp(n_2) + h + x
^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/admin/lit-gpt/lit_gpt/model.py", line 393, in forward
token_idx, expert_idx = torch.where(mask)
^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/utils/_device.py", line 77, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/utils/flop_counter.py", line 413, in __torch_dispatch__
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/miniconda/envs/lit/lib/python3.11/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
NotImplementedError: Could not run 'aten::nonzero' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::nonzero' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].
CPU: registered at aten/src/ATen/RegisterCPU.cpp:31188 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:44143 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:16838 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_0.cpp:16725 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/BatchRulesDynamic.cpp:66 [kernel]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]
Thanks for the report. Unfortunately, PyTorch doesn't support this, so we cannot measure the flops used by Mixtral the way we do.
For the moment, you can avoid the error by setting measured_flops = 11996892435054592 * micro_batch_size. The downside is that flops and mfu won't be correct.