Jlama icon indicating copy to clipboard operation
Jlama copied to clipboard

Different Embeddings from sentence-transformers/all-MiniLM-L6-v2 compared to Python

Open udaychandra opened this issue 1 month ago • 6 comments

First, thank you for the amazing work on Jlama! It's great to have native Java libraries for embeddings and LLMs.

Issue

We're getting different embedding values from Jlama compared to Python's sentence-transformers for the all-MiniLM-L6-v2 model, even though we've verified that tokenization is identical.

Java Code (Jlama)

var modelName = "sentence-transformers/all-MiniLM-L6-v2";
var workingDirectory = System.getProperty("user.home") + "/.jlama/models/";
var downloader = new Downloader(workingDirectory, modelName);
var modelPath = downloader.huggingFaceModel();

var model = ModelSupport.loadEmbeddingModel(modelPath, DType.F32, DType.F32);

String text = "This is a test document about machine learning";
float[] embedding = model.embed(text, Generator.PoolingType.AVG);

System.out.println("First 10 values:");
for (int i = 0; i < 10; i++) {
    System.out.println("  [" + i + "] = " + embedding[i]);
}

Java Output:

Magnitude: 1.0000001
[0] = -0.0009431843
[1] = 0.006532612
[2] = 0.070363656
[3] = 0.0154365115

Python Code (sentence-transformers)

from sentence_transformers import SentenceTransformer

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
text = "This is a test document about machine learning"
embedding = model.encode(text)

print("First 10 values:")
for i in range(10):
    print(f"  [{i}] = {embedding[i]}")

Python Output:

Magnitude: 1.0
[0] = -0.038466498255729675
[1] = 0.00013165567361284047
[2] = 0.01088548544794321
[3] = 0.040931958705186844

What We've Verified

  1. Tokenization is identical: Both produce the same token IDs: [101, 2023, 2003, 1037, 3231, 6254, 2055, 3698, 4083, 102]
  2. Same pooling strategy: Both use mean/average pooling (PoolingType.AVG in Java, pooling_mode_mean_tokens=True in Python)
  3. Same model source: Both download from HuggingFace sentence-transformers/all-MiniLM-L6-v2

The Problem

The actual embedding values are completely different (not just minor floating-point differences).

Questions

  1. Is the all-MiniLM-L6-v2 model fully supported/tested with Jlama?
  2. Are we missing any configuration or preprocessing steps?

Any guidance would be greatly appreciated!

udaychandra avatar Nov 01 '25 23:11 udaychandra

I looked this over quickly and I have a hunch. The KVcache and the batch forwarding might maybe the result non deterministic.

   public float[] embed(String input, PoolingType poolingType) {
        int[] encoded = Arrays.stream(tokenizer.encode(input)).mapToInt(Ints::checkedCast).toArray();

        Preconditions.checkArgument(encoded.length < c.contextLength);
        float[] outputEmbedding = new float[c.embeddingLength];

        try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getEphemeralKvBuffer()) {
            int promptLength = encoded.length;
            float avgp = 1.0f / promptLength;

Ill play with this and see what I can figure out.

edwardcapriolo avatar Nov 07 '25 12:11 edwardcapriolo

https://github.com/edwardcapriolo/deliverance/pull/13/files

I had not gotten to add embedding to my fork so I started on it. One thing I notice

   try (AbstractTensor r = batchForward(encoded, 0, kvMem)){
                if (poolingType == PoolingType.MODEL){
               
                    return outputEmbedding;
                }
                for (int i = 0; i < promptLength; i++) {
                    AbstractTensor output = r.slice(i);
                    // Pooling
                    for (int ii = 0; ii < config.embeddingLength; ii++) {
                        switch (poolingType) {
                            case AVG:
                                outputEmbedding[ii] += output.get(0, ii) * avgp;
                                break;
                            case MAX:
                                outputEmbedding[ii] = Math.max(outputEmbedding[ii], output.get(0, ii));
                                break;
                            case SUM:
                                outputEmbedding[ii] += output.get(0, ii);
                                break;
                        }
                    }
                }

The pooling implementation is looks fairly simple. I will take some time and deep dive the python code and try to understand what is different about them.

edwardcapriolo avatar Nov 09 '25 12:11 edwardcapriolo

Also this seems dubious

                        // BERT seems to use tanh for pooling rather than gelu

                        outputEmbedding[i] = ActivationFunction.eval(ActivationFunction.Type.TANH, pooled.get(0, i));
   
                    });
                    return outputEmbedding;

We could derive the activation function from the config

outputEmbedding[i] = config.activationFunction.eval( pooled.get(0, i));

I dont se why it is hardcoded here.

edwardcapriolo avatar Nov 09 '25 12:11 edwardcapriolo

I wanted to catch you up on what i have been working on. I started by adding some fractional logging all over the pipeline. This is still a bit trick as the input/output arrays can be large and printing them up to the console to compare does not make the most practical debugging.

I have went down a different road and started refactoring the classes like LayerNorm java to be more like the torch ones. https://docs.pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html . I will check them both in as python you can run by hand and as unit/itests to compare the results of each function in the stack.

I it a little tricky to tease out the myraid of libraries and what is expected, but I am progressing and learning as I go. Sorry to not be able to just hit with an "ah ha" I fixed it patch but Im relatively new to the ML/Torch side of things so I am learning as I go and comaring each step. But just letting you know.

https://github.com/edwardcapriolo/deliverance/commit/eb505c7b9c10984dbf733f0564403507635f8718

edwardcapriolo avatar Nov 12 '25 12:11 edwardcapriolo

Thank you brother. Super appreciate your help in debugging and potentially fixing this!

udaychandra avatar Nov 12 '25 20:11 udaychandra

@udaychandra I can use your help here as my understanding of fundamentals is a bit weak.

https://github.com/edwardcapriolo/deliverance/pull/new/layer

I decided to see if I could assert that LayerNorm java is close to LayerNorm pytorch. Now what I notice is for the 5/7 array the last 4 rows are perfect the 0th row is off. I dont know if this is the intent of the implementation as in the drawings of layer norm the left most slice is blue. But it looks suspicious to me. Maybe have a loot and tell me if I am crazy.

edwardcapriolo avatar Nov 13 '25 19:11 edwardcapriolo