Different Embeddings from sentence-transformers/all-MiniLM-L6-v2 compared to Python
First, thank you for the amazing work on Jlama! It's great to have native Java libraries for embeddings and LLMs.
Issue
We're getting different embedding values from Jlama compared to Python's sentence-transformers for the all-MiniLM-L6-v2 model, even though we've verified that tokenization is identical.
Java Code (Jlama)
var modelName = "sentence-transformers/all-MiniLM-L6-v2";
var workingDirectory = System.getProperty("user.home") + "/.jlama/models/";
var downloader = new Downloader(workingDirectory, modelName);
var modelPath = downloader.huggingFaceModel();
var model = ModelSupport.loadEmbeddingModel(modelPath, DType.F32, DType.F32);
String text = "This is a test document about machine learning";
float[] embedding = model.embed(text, Generator.PoolingType.AVG);
System.out.println("First 10 values:");
for (int i = 0; i < 10; i++) {
System.out.println(" [" + i + "] = " + embedding[i]);
}
Java Output:
Magnitude: 1.0000001
[0] = -0.0009431843
[1] = 0.006532612
[2] = 0.070363656
[3] = 0.0154365115
Python Code (sentence-transformers)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
text = "This is a test document about machine learning"
embedding = model.encode(text)
print("First 10 values:")
for i in range(10):
print(f" [{i}] = {embedding[i]}")
Python Output:
Magnitude: 1.0
[0] = -0.038466498255729675
[1] = 0.00013165567361284047
[2] = 0.01088548544794321
[3] = 0.040931958705186844
What We've Verified
- Tokenization is identical: Both produce the same token IDs:
[101, 2023, 2003, 1037, 3231, 6254, 2055, 3698, 4083, 102] - Same pooling strategy: Both use mean/average pooling (
PoolingType.AVGin Java,pooling_mode_mean_tokens=Truein Python) - Same model source: Both download from HuggingFace
sentence-transformers/all-MiniLM-L6-v2
The Problem
The actual embedding values are completely different (not just minor floating-point differences).
Questions
- Is the
all-MiniLM-L6-v2model fully supported/tested with Jlama? - Are we missing any configuration or preprocessing steps?
Any guidance would be greatly appreciated!
I looked this over quickly and I have a hunch. The KVcache and the batch forwarding might maybe the result non deterministic.
public float[] embed(String input, PoolingType poolingType) {
int[] encoded = Arrays.stream(tokenizer.encode(input)).mapToInt(Ints::checkedCast).toArray();
Preconditions.checkArgument(encoded.length < c.contextLength);
float[] outputEmbedding = new float[c.embeddingLength];
try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getEphemeralKvBuffer()) {
int promptLength = encoded.length;
float avgp = 1.0f / promptLength;
Ill play with this and see what I can figure out.
https://github.com/edwardcapriolo/deliverance/pull/13/files
I had not gotten to add embedding to my fork so I started on it. One thing I notice
try (AbstractTensor r = batchForward(encoded, 0, kvMem)){
if (poolingType == PoolingType.MODEL){
return outputEmbedding;
}
for (int i = 0; i < promptLength; i++) {
AbstractTensor output = r.slice(i);
// Pooling
for (int ii = 0; ii < config.embeddingLength; ii++) {
switch (poolingType) {
case AVG:
outputEmbedding[ii] += output.get(0, ii) * avgp;
break;
case MAX:
outputEmbedding[ii] = Math.max(outputEmbedding[ii], output.get(0, ii));
break;
case SUM:
outputEmbedding[ii] += output.get(0, ii);
break;
}
}
}
The pooling implementation is looks fairly simple. I will take some time and deep dive the python code and try to understand what is different about them.
Also this seems dubious
// BERT seems to use tanh for pooling rather than gelu
outputEmbedding[i] = ActivationFunction.eval(ActivationFunction.Type.TANH, pooled.get(0, i));
});
return outputEmbedding;
We could derive the activation function from the config
outputEmbedding[i] = config.activationFunction.eval( pooled.get(0, i));
I dont se why it is hardcoded here.
I wanted to catch you up on what i have been working on. I started by adding some fractional logging all over the pipeline. This is still a bit trick as the input/output arrays can be large and printing them up to the console to compare does not make the most practical debugging.
I have went down a different road and started refactoring the classes like LayerNorm java to be more like the torch ones. https://docs.pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html . I will check them both in as python you can run by hand and as unit/itests to compare the results of each function in the stack.
I it a little tricky to tease out the myraid of libraries and what is expected, but I am progressing and learning as I go. Sorry to not be able to just hit with an "ah ha" I fixed it patch but Im relatively new to the ML/Torch side of things so I am learning as I go and comaring each step. But just letting you know.
https://github.com/edwardcapriolo/deliverance/commit/eb505c7b9c10984dbf733f0564403507635f8718
Thank you brother. Super appreciate your help in debugging and potentially fixing this!
@udaychandra I can use your help here as my understanding of fundamentals is a bit weak.
https://github.com/edwardcapriolo/deliverance/pull/new/layer
I decided to see if I could assert that LayerNorm java is close to LayerNorm pytorch. Now what I notice is for the 5/7 array the last 4 rows are perfect the 0th row is off. I dont know if this is the intent of the implementation as in the drawings of layer norm the left most slice is blue. But it looks suspicious to me. Maybe have a loot and tell me if I am crazy.