mergekit
mergekit copied to clipboard
Replacing MoE layers in every other transformer block
Hi guys,
Instead of replacing all all FFN sub-blocks by MoE layers, I am attempting to replace MoE layers only in every other transformer block (like what the Gshard paper is doing) using the mixtral_moe.py
script.
Naively, I inserted an if layer_idx % everyotherlayer == 0
statement before the tensor_name.replace
statement.
Please have a look at the code below.
However, it does not work as expected as I still get the MoE layer in every block upon inspecting the created model.
Does anyone have any idea what goes wrong with this approach?
# modified from https://github.com/arcee-ai/mergekit/blob/d9e0685b60163efc0b8626838d16742f7276a98d/mergekit/scripts/mixtral_moe.py#L335
for layer_idx in range(base_cfg.num_hidden_layers):
for weight_info in MISTRAL_INFO.layer_weights(index=layer_idx, config=base_cfg):
tensor_name = weight_info.name
if ".mlp." in tensor_name:
if layer_idx % everyotherlayer == 0: # INSERTED IF STATEMENT
for moe_index, expert in enumerate(config.experts):
expert_name = tensor_name.replace(
".mlp.gate_proj", f".block_sparse_moe.experts.{moe_index}.w1"
)
expert_name = expert_name.replace(
".mlp.down_proj", f".block_sparse_moe.experts.{moe_index}.w2"
)
expert_name = expert_name.replace(
".mlp.up_proj", f".block_sparse_moe.experts.{moe_index}.w3"
)
expert_loader = loaders.get(expert.model_ref)
tensor = expert_loader.get_tensor(
tensor_name, aliases=weight_info.aliases
)
if expert.noise_scale:
tensor += torch.randn_like(tensor) * expert.noise_scale
writer.save_tensor(
expert_name, tensor.to(dtype=out_dtype), clone=True
)
else:
expert_name = tensor_name.replace(
".mlp.gate_proj", f".mlp.gate_proj"
)
expert_loader = loaders.get(config.experts[0].model_ref)
tensor = expert_loader.get_tensor(
tensor_name, aliases=weight_info.aliases
)
if expert.noise_scale:
tensor += torch.randn_like(tensor) * expert.noise_scale
writer.save_tensor(
expert_name, tensor.to(dtype=out_dtype), clone=True
)
continue
writer.save_tensor(
tensor_name,
base_loader.get_tensor(tensor_name, aliases=weight_info.aliases).to(
dtype=out_dtype
),
)```