keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

[Help][BUG] `KeyError: 'lm_head.weight'` on loading llama 3.2

Open steveepreston opened this issue 1 year ago • 4 comments

Trying to load llama-3.2 on TPU VM v3-8 via this:

device_mesh = keras.distribution.DeviceMesh((1, 8), ["batch", "model"], devices=keras.distribution.list_devices())
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ("model", None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output/kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear/kernel"] = ("model", None)
model_parallel = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)


model = keras_nlp.models.Llama3CausalLM.from_preset("meta-llama/Llama-3.2-3B-Instruct")

but it throws this Error:

KeyError: 'lm_head.weight'

note: i get layout_map code from This Example. i don't know if problem is from layout_map or Llama3CausalLM

steveepreston avatar Oct 13 '24 13:10 steveepreston