About adapting Tutel in fairseq
Hi,
Thanks for the nice project.
I've been trying to integrate Tutel into a former version of fairseq (which doesn't support MoE and Tutel). And I modified the source code following the https://github.com/microsoft/Tutel/blob/main/tutel/examples/fairseq_moe/fairseq_patch.diff. But without using the system.cache() to read loss.
And in my model, I chose: self.moe = moe_layer( gate_type={'type': 'top', 'k': 2, 'capacity_factor': 0.0, 'fp32_gate': True, 'gate_noise': 1.0}, model_dim=self.embed_dim, experts={ 'num_experts_per_device': num_experts_per_device, 'type': 'ffn', 'hidden_size_per_expert': args.decoder_ffn_embed_dim, 'activation_fn': lambda x: self.activation_fn(x) }, scan_expert_func = lambda name, param: setattr(param, 'expert', True), # The mask is only compatible with Fairseq based on legacy_ddp parallel_type = "data", )
However, I encountered this: FloatingPointError: Fatal error: gradients are inconsistent between workers. Try --ddp-backend=legacy_ddp. Or are you mixing up different generation of GPUs in training? grad_norm across the workers: rank 0 = 120.00565186 rank 1 = 115.86860352'
I have already used a legacy_ddp backend. But this issue is always with me, when I tried other open-cource code for MoE. Would you help me here?
Many thanks, Z
May I ask the reason for removing system.cache()?
By the way, the patch was intended for a very old Fairseq checkpoint. Since it is impractical to keep patches up-to-date with all frequently-changing projects (e.g. Fairseq, Megatron), we will soon provide native end-to-end (E2E) training support within Tutel only. Please stay tuned for updates from Tutel. Thank you!
@ghostplant
Thanks for the quick response. I removed system.cache(), because I added the aux_loss alongside with return x, attn, aux_loss in the transformer layer and stacked the aux_loss from each layer in the transformer decoder.. Is this the problem?
Removing the cache will lead to an inability to recall each balance loss generated during the forwarding process when calculating the loss. As a result, training may become increasingly imbalanced, which could eventually cause the training to fail.
However, the modification of system.cache and the inconsistency between the above-mentioned ranks are two unrelated, independent issues. Regarding the inconsistency you mentioned earlier between rank 0 and rank 1, the expected phenomena are as follows:
(1) If the parameter is a shared parameter (i.e., does NOT have the skip_allreduce flag), then different ranks should have exactly the same values. If your grad_norm discrepancy occurs within this set of parameters, it indicates that there is an issue somewhere. (2) If the parameter is a non-shared parameter (i.e., it is marked with the skip_allreduce flag), then different ranks should NOT have the same values (in other words, if the grad_norms are identical, that would signal a problem). If your grad_norm discrepancy occurs within this set of parameters, it is actually expected behavior.
For example, an expected grad_norm situation would look like this:
<gpu-1> <gpu-2>
qkv_prog: 1.2 qkv_prog: 1.2 (expected to be identical)
gate.wg: -0.6 gate.wg: -0.6 (expected to be identical)
batched_fc1_w: -3.2 batched_fc1_w: 0.71 (expected to be different)
batched_fc2_w: 2.6 batched_fc2_w: 1.77 (expected to be different)
So, could you further confirm whether the grad_norm value gap you mentioned above happens to shared parameters or non-shared parameters?
Hi, thanks for the very detailed explanation! I very appreciate your help.
I think I have made it work by setting 'skip_allreduce' to True, and also inequivalent_tokens=True in forward (due to the nature of my data, the number of samples per GPU will not be the same).
But in the non-shared parameter situation (i.e., it is marked with the skip_allreduce flag), I noticed that the weights and grad_norms are as expected, different across GPUs.
However, I still remain confused about the parallel_type.
I set parallel_type = "data" , and basically had num_experts_per_device = int(args.moe_expert_num / num_gpus). In my current case, moe_expert_num=8, num_gpus=2, so num_experts_per_device=4.
In my understanding, if parallel_type is 'data', all experts will be replicated across GPUs, but it seems not to be the case here. But I noticed that self.num_global_experts is 8, and self.num_local_experts is 4. Am I understanding it incorrectly?
Nop. in data type, non-shared parameters are in Zero2 style, so they are still unique and independent in gradients.
Can I also ask about the checkpoints saving when training on multiple gpus? I'm checking https://github.com/microsoft/Tutel/blob/main/doc/CHECKPOINT.md and also some issues. It seems it's needed to save the checkpoints for different ranks, and convert them into a single checkpoint when inferencing on 1 gpu.
However, saving the checkpoint on different ranks does not seem to be straightforward in fairseq. Would you point out some directions for me? I understand it's impractical to keep up with all the different fairseq versions, but I would appreciate it if you could give me some idea.
Many thanks Z
Hi, unless you want to change the training GPU environments, you don't really need to do the conversion.
Assume your model is 20GB for shared parameter and 800GB for non-shared parameter, using 8 GPU to perform the whole training.
Then each GPU will produce a checkpoint of size 20GB + (800GB / 8) via torch.save(model.state_dict(), checkpoint_path), so eventually 8 GPU training will produce 8 checkpoint files in total. Consider 1 host may have more than 1 GPU, so need to ensure these 8 checkpoint_path are separated by rank name to avoid checkpoint overwritten.
When loading the model, no checkpoint conversion is needed if the GPUs count unchanged. (see https://github.com/microsoft/Tutel/blob/main/tutel/examples/helloworld.py#L104C30-L104C58)
Thanks loads for the help!!! I will try figuring out how to save checkpoints for different ranks on Faiseq.
Hi, I'm reopening this issue because I encountered some problems:
As you mentioned before
<gpu-1> <gpu-2>
qkv_prog: 1.2 qkv_prog: 1.2 (expected to be identical)
gate.wg: -0.6 gate.wg: -0.6 (expected to be identical)
batched_fc1_w: -3.2 batched_fc1_w: 0.71 (expected to be different)
batched_fc2_w: 2.6 batched_fc2_w: 1.77 (expected to be different)
I found out that the weights and grad_norm in batched_fc1_w and batched_fc2_w are identical in ranks 1 and 2. Even when parallel_type = 'model'..
I have tried to debug them for a long time, but still remain clueless.
Many thanks, Z
Hello, we found that the fairseq_moe instruction is too old while official fairseq also stops maintaining and the dataset link doesn't work as well, so we're going to remove this example in the next PR. Instead, we add a new GPT end-to-end training example here which is much simpler (https://github.com/microsoft/Tutel/tree/main/tutel/examples/modded-nanogpt-moe).
Since Tutel will maintain its own end-to-end training based on that, I strongly suggest not stick to the fairseq patch. If it is a must, there are 2 solutions to solve your issue within the fairseq repo.
(1) In your case, whether values of expert parameters are identical even before the first training step? If so, seems like you need to prevent expert parameters to broadcast to all GPUs, following: https://github.com/microsoft/Tutel/blob/main/tutel/examples/modded-nanogpt-moe/train_gpt_v0.py#L378-L379, while fairseq seems to broadcast all parameter values regardless non-shared, which results in them to be identical before the training start.
(2) If (1) isn't the case, i.e. the values of expert parameters turn to be identical after the training step. It indicates that all-reduce are improperly applied on all parameter values whatever they are shared or non-shared. For this case, ensure to skip all_reduce for all non-shared parameters, following: https://github.com/microsoft/Tutel/blob/main/tutel/examples/modded-nanogpt-moe/train_gpt_v0.py#L523