keras-nlp
keras-nlp copied to clipboard
[Help][BUG] `KeyError: 'lm_head.weight'` on loading llama 3.2
Trying to load llama-3.2 on TPU VM v3-8 via this:
device_mesh = keras.distribution.DeviceMesh((1, 8), ["batch", "model"], devices=keras.distribution.list_devices())
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ("model", None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output/kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear/kernel"] = ("model", None)
model_parallel = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
model = keras_nlp.models.Llama3CausalLM.from_preset("meta-llama/Llama-3.2-3B-Instruct")
but it throws this Error:
KeyError: 'lm_head.weight'
note: i get layout_map code from This Example. i don't know if problem is from layout_map or Llama3CausalLM