candle icon indicating copy to clipboard operation
candle copied to clipboard

BERT Safetensors variable mismatch

Open Christof23 opened this issue 1 year ago • 2 comments
trafficstars

Hi, I was running the BERT example code and noticed that some of the variables weren't correctly aligning with the current Safetensors obtained via:

let repo: ApiRepo = api.model("bert-base-uncased".to_string());
let weights_path: PathBuf = repo.get("model.safetensors")?

For example the model spec in candle-transformers/src/models/bert.rs results in: Error: TensorNotFound("embeddings.word_embeddings.weight").

The Safetensors version prepends all variables with bert and uses the older gamma/beta notation. This issue has also been noted here.

I think the problem is in layer_norm which doesn't expect gamma and beta but weight and bias:

let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?;
let bias = if config.affine {
    Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?)
} else {
    None
};

The Safetensor variables are as follows:

bert.embeddings.LayerNorm.beta
bert.embeddings.LayerNorm.gamma
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.word_embeddings.weight
bert.encoder.layer.0.attention.output.LayerNorm.beta
bert.encoder.layer.0.attention.output.LayerNorm.gamma
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.output.LayerNorm.beta
bert.encoder.layer.0.output.LayerNorm.gamma
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.1.attention.output.LayerNorm.beta
bert.encoder.layer.1.attention.output.LayerNorm.gamma
bert.encoder.layer.1.attention.output.dense.bias
bert.encoder.layer.1.attention.output.dense.weight
bert.encoder.layer.1.attention.self.key.bias
bert.encoder.layer.1.attention.self.key.weight
bert.encoder.layer.1.attention.self.query.bias
bert.encoder.layer.1.attention.self.query.weight
bert.encoder.layer.1.attention.self.value.bias
bert.encoder.layer.1.attention.self.value.weight
bert.encoder.layer.1.intermediate.dense.bias
bert.encoder.layer.1.intermediate.dense.weight
bert.encoder.layer.1.output.LayerNorm.beta
bert.encoder.layer.1.output.LayerNorm.gamma
bert.encoder.layer.1.output.dense.bias
bert.encoder.layer.1.output.dense.weight
bert.encoder.layer.10.attention.output.LayerNorm.beta
bert.encoder.layer.10.attention.output.LayerNorm.gamma
bert.encoder.layer.10.attention.output.dense.bias
bert.encoder.layer.10.attention.output.dense.weight
bert.encoder.layer.10.attention.self.key.bias
bert.encoder.layer.10.attention.self.key.weight
bert.encoder.layer.10.attention.self.query.bias
bert.encoder.layer.10.attention.self.query.weight
bert.encoder.layer.10.attention.self.value.bias
bert.encoder.layer.10.attention.self.value.weight
bert.encoder.layer.10.intermediate.dense.bias
bert.encoder.layer.10.intermediate.dense.weight
bert.encoder.layer.10.output.LayerNorm.beta
bert.encoder.layer.10.output.LayerNorm.gamma
bert.encoder.layer.10.output.dense.bias
bert.encoder.layer.10.output.dense.weight
bert.encoder.layer.11.attention.output.LayerNorm.beta
bert.encoder.layer.11.attention.output.LayerNorm.gamma
bert.encoder.layer.11.attention.output.dense.bias
bert.encoder.layer.11.attention.output.dense.weight
bert.encoder.layer.11.attention.self.key.bias
bert.encoder.layer.11.attention.self.key.weight
bert.encoder.layer.11.attention.self.query.bias
bert.encoder.layer.11.attention.self.query.weight
bert.encoder.layer.11.attention.self.value.bias
bert.encoder.layer.11.attention.self.value.weight
bert.encoder.layer.11.intermediate.dense.bias
bert.encoder.layer.11.intermediate.dense.weight
bert.encoder.layer.11.output.LayerNorm.beta
bert.encoder.layer.11.output.LayerNorm.gamma
bert.encoder.layer.11.output.dense.bias
bert.encoder.layer.11.output.dense.weight
bert.encoder.layer.2.attention.output.LayerNorm.beta
bert.encoder.layer.2.attention.output.LayerNorm.gamma
bert.encoder.layer.2.attention.output.dense.bias
bert.encoder.layer.2.attention.output.dense.weight
bert.encoder.layer.2.attention.self.key.bias
bert.encoder.layer.2.attention.self.key.weight
bert.encoder.layer.2.attention.self.query.bias
bert.encoder.layer.2.attention.self.query.weight
bert.encoder.layer.2.attention.self.value.bias
bert.encoder.layer.2.attention.self.value.weight
bert.encoder.layer.2.intermediate.dense.bias
bert.encoder.layer.2.intermediate.dense.weight
bert.encoder.layer.2.output.LayerNorm.beta
bert.encoder.layer.2.output.LayerNorm.gamma
bert.encoder.layer.2.output.dense.bias
bert.encoder.layer.2.output.dense.weight
bert.encoder.layer.3.attention.output.LayerNorm.beta
bert.encoder.layer.3.attention.output.LayerNorm.gamma
bert.encoder.layer.3.attention.output.dense.bias
bert.encoder.layer.3.attention.output.dense.weight
bert.encoder.layer.3.attention.self.key.bias
bert.encoder.layer.3.attention.self.key.weight
bert.encoder.layer.3.attention.self.query.bias
bert.encoder.layer.3.attention.self.query.weight
bert.encoder.layer.3.attention.self.value.bias
bert.encoder.layer.3.attention.self.value.weight
bert.encoder.layer.3.intermediate.dense.bias
bert.encoder.layer.3.intermediate.dense.weight
bert.encoder.layer.3.output.LayerNorm.beta
bert.encoder.layer.3.output.LayerNorm.gamma
bert.encoder.layer.3.output.dense.bias
bert.encoder.layer.3.output.dense.weight
bert.encoder.layer.4.attention.output.LayerNorm.beta
bert.encoder.layer.4.attention.output.LayerNorm.gamma
bert.encoder.layer.4.attention.output.dense.bias
bert.encoder.layer.4.attention.output.dense.weight
bert.encoder.layer.4.attention.self.key.bias
bert.encoder.layer.4.attention.self.key.weight
bert.encoder.layer.4.attention.self.query.bias
bert.encoder.layer.4.attention.self.query.weight
bert.encoder.layer.4.attention.self.value.bias
bert.encoder.layer.4.attention.self.value.weight
bert.encoder.layer.4.intermediate.dense.bias
bert.encoder.layer.4.intermediate.dense.weight
bert.encoder.layer.4.output.LayerNorm.beta
bert.encoder.layer.4.output.LayerNorm.gamma
bert.encoder.layer.4.output.dense.bias
bert.encoder.layer.4.output.dense.weight
bert.encoder.layer.5.attention.output.LayerNorm.beta
bert.encoder.layer.5.attention.output.LayerNorm.gamma
bert.encoder.layer.5.attention.output.dense.bias
bert.encoder.layer.5.attention.output.dense.weight
bert.encoder.layer.5.attention.self.key.bias
bert.encoder.layer.5.attention.self.key.weight
bert.encoder.layer.5.attention.self.query.bias
bert.encoder.layer.5.attention.self.query.weight
bert.encoder.layer.5.attention.self.value.bias
bert.encoder.layer.5.attention.self.value.weight
bert.encoder.layer.5.intermediate.dense.bias
bert.encoder.layer.5.intermediate.dense.weight
bert.encoder.layer.5.output.LayerNorm.beta
bert.encoder.layer.5.output.LayerNorm.gamma
bert.encoder.layer.5.output.dense.bias
bert.encoder.layer.5.output.dense.weight
bert.encoder.layer.6.attention.output.LayerNorm.beta
bert.encoder.layer.6.attention.output.LayerNorm.gamma
bert.encoder.layer.6.attention.output.dense.bias
bert.encoder.layer.6.attention.output.dense.weight
bert.encoder.layer.6.attention.self.key.bias
bert.encoder.layer.6.attention.self.key.weight
bert.encoder.layer.6.attention.self.query.bias
bert.encoder.layer.6.attention.self.query.weight
bert.encoder.layer.6.attention.self.value.bias
bert.encoder.layer.6.attention.self.value.weight
bert.encoder.layer.6.intermediate.dense.bias
bert.encoder.layer.6.intermediate.dense.weight
bert.encoder.layer.6.output.LayerNorm.beta
bert.encoder.layer.6.output.LayerNorm.gamma
bert.encoder.layer.6.output.dense.bias
bert.encoder.layer.6.output.dense.weight
bert.encoder.layer.7.attention.output.LayerNorm.beta
bert.encoder.layer.7.attention.output.LayerNorm.gamma
bert.encoder.layer.7.attention.output.dense.bias
bert.encoder.layer.7.attention.output.dense.weight
bert.encoder.layer.7.attention.self.key.bias
bert.encoder.layer.7.attention.self.key.weight
bert.encoder.layer.7.attention.self.query.bias
bert.encoder.layer.7.attention.self.query.weight
bert.encoder.layer.7.attention.self.value.bias
bert.encoder.layer.7.attention.self.value.weight
bert.encoder.layer.7.intermediate.dense.bias
bert.encoder.layer.7.intermediate.dense.weight
bert.encoder.layer.7.output.LayerNorm.beta
bert.encoder.layer.7.output.LayerNorm.gamma
bert.encoder.layer.7.output.dense.bias
bert.encoder.layer.7.output.dense.weight
bert.encoder.layer.8.attention.output.LayerNorm.beta
bert.encoder.layer.8.attention.output.LayerNorm.gamma
bert.encoder.layer.8.attention.output.dense.bias
bert.encoder.layer.8.attention.output.dense.weight
bert.encoder.layer.8.attention.self.key.bias
bert.encoder.layer.8.attention.self.key.weight
bert.encoder.layer.8.attention.self.query.bias
bert.encoder.layer.8.attention.self.query.weight
bert.encoder.layer.8.attention.self.value.bias
bert.encoder.layer.8.attention.self.value.weight
bert.encoder.layer.8.intermediate.dense.bias
bert.encoder.layer.8.intermediate.dense.weight
bert.encoder.layer.8.output.LayerNorm.beta
bert.encoder.layer.8.output.LayerNorm.gamma
bert.encoder.layer.8.output.dense.bias
bert.encoder.layer.8.output.dense.weight
bert.encoder.layer.9.attention.output.LayerNorm.beta
bert.encoder.layer.9.attention.output.LayerNorm.gamma
bert.encoder.layer.9.attention.output.dense.bias
bert.encoder.layer.9.attention.output.dense.weight
bert.encoder.layer.9.attention.self.key.bias
bert.encoder.layer.9.attention.self.key.weight
bert.encoder.layer.9.attention.self.query.bias
bert.encoder.layer.9.attention.self.query.weight
bert.encoder.layer.9.attention.self.value.bias
bert.encoder.layer.9.attention.self.value.weight
bert.encoder.layer.9.intermediate.dense.bias
bert.encoder.layer.9.intermediate.dense.weight
bert.encoder.layer.9.output.LayerNorm.beta
bert.encoder.layer.9.output.LayerNorm.gamma
bert.encoder.layer.9.output.dense.bias
bert.encoder.layer.9.output.dense.weight
bert.pooler.dense.bias
bert.pooler.dense.weight
cls.predictions.bias
cls.predictions.transform.LayerNorm.beta
cls.predictions.transform.LayerNorm.gamma
cls.predictions.transform.dense.bias
cls.predictions.transform.dense.weight
cls.seq_relationship.bias
cls.seq_relationship.weight

Christof23 avatar Mar 20 '24 14:03 Christof23

This PR addresses the above issue but not sure if updates to layer_norm are appropriate https://github.com/huggingface/candle/pull/1888

Christof23 avatar Mar 20 '24 15:03 Christof23

I'm running into the same issue -- I followed the instructions in the Candle reference guide to see how to run a HuggingFace model in Candle and I was surprised to see that the steps they recommend (loading the bert-base-uncased model into the BertModel struct) result in an error.

vrama628 avatar Apr 06 '24 13:04 vrama628