mergekit icon indicating copy to clipboard operation
mergekit copied to clipboard

Replacing MoE layers in every other transformer block

Open spliew opened this issue 10 months ago • 0 comments

Hi guys,

Instead of replacing all all FFN sub-blocks by MoE layers, I am attempting to replace MoE layers only in every other transformer block (like what the Gshard paper is doing) using the mixtral_moe.py script. Naively, I inserted an if layer_idx % everyotherlayer == 0 statement before the tensor_name.replace statement. Please have a look at the code below. However, it does not work as expected as I still get the MoE layer in every block upon inspecting the created model.

Does anyone have any idea what goes wrong with this approach?

# modified from https://github.com/arcee-ai/mergekit/blob/d9e0685b60163efc0b8626838d16742f7276a98d/mergekit/scripts/mixtral_moe.py#L335
    for layer_idx in range(base_cfg.num_hidden_layers):
        
        for weight_info in MISTRAL_INFO.layer_weights(index=layer_idx, config=base_cfg):
            tensor_name = weight_info.name

            if ".mlp." in tensor_name:
                if layer_idx % everyotherlayer == 0:   # INSERTED IF STATEMENT
                    
                    for moe_index, expert in enumerate(config.experts):
                        expert_name = tensor_name.replace(
                            ".mlp.gate_proj", f".block_sparse_moe.experts.{moe_index}.w1"
                        )
                        expert_name = expert_name.replace(
                            ".mlp.down_proj", f".block_sparse_moe.experts.{moe_index}.w2"
                        )
                        expert_name = expert_name.replace(
                            ".mlp.up_proj", f".block_sparse_moe.experts.{moe_index}.w3"
                        )
                        expert_loader = loaders.get(expert.model_ref)
                        tensor = expert_loader.get_tensor(
                            tensor_name, aliases=weight_info.aliases
                        )
                        if expert.noise_scale:
                            tensor += torch.randn_like(tensor) * expert.noise_scale
                        writer.save_tensor(
                            expert_name, tensor.to(dtype=out_dtype), clone=True
                        )
                else:
                    expert_name = tensor_name.replace(
                        ".mlp.gate_proj", f".mlp.gate_proj"
                    )
                    expert_loader = loaders.get(config.experts[0].model_ref)
                    tensor = expert_loader.get_tensor(
                        tensor_name, aliases=weight_info.aliases
                    )
                    if expert.noise_scale:
                        tensor += torch.randn_like(tensor) * expert.noise_scale
                    writer.save_tensor(
                        expert_name, tensor.to(dtype=out_dtype), clone=True
                    )
                
                continue
                
            writer.save_tensor(
                tensor_name,
                base_loader.get_tensor(tensor_name, aliases=weight_info.aliases).to(
                    dtype=out_dtype
                ),
            )```

spliew avatar Apr 16 '24 13:04 spliew