litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

How to pretrain moe model?

Open win10ogod opened this issue 1 year ago • 5 comments

How to pretrain moe model?

win10ogod avatar Jan 12 '24 05:01 win10ogod

Have a look at this pretraining tutorial: https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/pretrain_tinyllama.md

The only MoE model we support is Mixtral. You would need to replace tinyllama with it.

carmocca avatar Jan 12 '24 09:01 carmocca

Have a look at this pretraining tutorial: https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/pretrain_tinyllama.md

The only MoE model we support is Mixtral. You would need to replace tinyllama with it.

Like this?

 dict(
       name="tiny-llama-1.1b-MOE{}",
       hf_config=dict(org="TinyLlama-MOE", name="TinyLlama-1.1B-MOE{}"),
       block_size=4096,
       vocab_size=32000,
       padding_multiple=64,
       n_layer=11,
       n_query_groups=4,
       n_head=32,
       n_embd=4096,
       rotary_percentage=1.0,
       parallel_residual=False,
       bias=False,
       _norm_class="RMSNorm",  # original TinyLlama uses FusedRMSNorm
       norm_eps=1e-5,
       intermediate_size=5632,
       _mlp_class="LLaMAMoE",
       n_expert=8,
       n_expert_per_token=2,
   )

win10ogod avatar Jan 12 '24 10:01 win10ogod

Have a look at this pretraining tutorial: https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/pretrain_tinyllama.md

The only MoE model we support is Mixtral. You would need to replace tinyllama with it.

@carmocca C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\transformers\utils\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( Using bfloat16 Automatic Mixed Precision (AMP) {'model_name': 'Mixtral-8', 'name': 'Mixtral-8', 'save_interval': 10, 'eval_interval': 1000, 'eval_iters': 100, 'log_interval': 1, 'learning_rate': 0.0006, 'batch_size': 10, 'micro_batch_size': 5, 'gradient_accumulation_steps': 2, 'max_iters': 6000, 'weight_decay': 0.1, 'beta1': 0.9, 'beta2': 0.95, 'grad_clip': 1.0, 'decay_lr': True, 'warmup_iters': 2000, 'lr_decay_iters': 6000, 'min_lr': 6e-05} Seed set to 1337 Loading model with {'name': 'Mixtral-8', 'hf_config': {'org': 'mistralai', 'name': 'Mixtral-8'}, 'block_size': 768, 'vocab_size': 32000, 'padding_multiple': 512, 'padded_vocab_size': 32000, 'n_layer': 4, 'n_head': 32, 'n_embd': 768, 'rotary_percentage': 1.0, 'parallel_residual': False, 'bias': False, 'lm_head_bias': False, 'n_query_groups': 8, 'shared_attention_norm': False, '_norm_class': 'RMSNorm', 'norm_eps': 1e-05, '_mlp_class': 'LLaMAMoE', 'gelu_approximate': 'none', 'intermediate_size': 14336, 'rope_condense_ratio': 1, 'rope_base': 1000000, 'n_expert': 8, 'n_expert_per_token': 2, 'head_size': 24, 'rope_n_elem': 24} Time to instantiate model: 0.25 seconds. Total parameters 1,112,046,336 Validating ... C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\transformers\utils\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\transformers\utils\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( D:\AItest\lit-gpt-moellama\lit_gpt\model.py:238: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:308.) y = torch.nn.functional.scaled_dot_product_attention( Estimated TFLOPs: 25.80 Traceback (most recent call last): File "D:\AItest\lit-gpt-moellama\pretrain\openwebtext.py", line 250, in CLI(setup) File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\jsonargparse_cli.py", line 96, in CLI return _run_component(components, cfg_init) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\jsonargparse_cli.py", line 181, in _run_component return component(**cfg) ^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\pretrain\openwebtext.py", line 70, in setup fabric.launch(main, resume=resume) File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\fabric\fabric.py", line 839, in launch return self._wrap_and_launch(function, self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\fabric\fabric.py", line 925, in _wrap_and_launch return to_run(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\fabric\fabric.py", line 930, in _wrap_with_setup return to_run(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\pretrain\openwebtext.py", line 109, in main train(fabric, state, train_dataloader, val_dataloader) File "D:\AItest\lit-gpt-moellama\pretrain\openwebtext.py", line 131, in train measured_flops = measure_flops(meta_model, forward_fn, loss_fn) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\fabric\utilities\throughput.py", line 304, in measure_flops loss_fn(forward_fn()).backward() ^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\pretrain\openwebtext.py", line 129, in forward_fn = lambda: meta_model(x) ^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\lit_gpt\model.py", line 93, in forward x = block(x, cos, sin, mask, input_pos) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\lit_gpt\model.py", line 168, in forward x = self.mlp(self.norm_2(x)) + x ^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1561, in call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\lit_gpt\model.py", line 317, in forward token_idx, expert_idx = torch.where(mask) ^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils_device.py", line 77, in torch_function return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils\flop_counter.py", line 449, in torch_dispatch out = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch_ops.py", line 571, in call return self._op(*args, **(kwargs or {})) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

win10ogod avatar Jan 14 '24 11:01 win10ogod

Have a look at this pretraining tutorial: https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/pretrain_tinyllama.md The only MoE model we support is Mixtral. You would need to replace tinyllama with it.

@carmocca C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1561, in call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\lit_gpt\model.py", line 317, in forward token_idx, expert_idx = torch.where(mask) ^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils_device.py", line 77, in torch_function return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils\flop_counter.py", line 449, in torch_dispatch out = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch_ops.py", line 571, in call return self._op(*args, **(kwargs or {})) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

hi @win10ogod, we encounter the same issue, did you manage to solve it?

codeplay avatar Apr 17 '24 01:04 codeplay

Have a look at this pretraining tutorial: https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/pretrain_tinyllama.md The only MoE model we support is Mixtral. You would need to replace tinyllama with it.

@carmocca C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1561, in call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\AItest\lit-gpt-moellama\lit_gpt\model.py", line 317, in forward token_idx, expert_idx = torch.where(mask) ^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils_device.py", line 77, in torch_function return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils\flop_counter.py", line 449, in torch_dispatch out = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch_ops.py", line 571, in call return self._op(*args, **(kwargs or {})) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

hi @win10ogod, we encounter the same issue, did you manage to solve it?

I didn't resolve this issue.

win10ogod avatar Apr 17 '24 03:04 win10ogod