neural-search icon indicating copy to clipboard operation
neural-search copied to clipboard

Add support for asymmetric embedding models

Open br3no opened this issue 3 months ago • 21 comments

Description

This PR adds support for asymmetric embedding models such as https://huggingface.co/intfloat/multilingual-e5-small to the neural-search plugin.

It builds on the work done in https://github.com/opensearch-project/ml-commons/issues/1799.

Asymmetric embedding models behave differently when embedding passages and queries. For that end, the model must "know" on inference time, what kind of data it is embedding.

The changes are:

1. src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

The processor signals it is embedding passages, by passing the new AsymmetricTextEmbeddingParameters using the content type EmbeddingContentType.PASSAGE.

2. src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

Analogously, the query builder uses EmbeddingContentType.QUERY.

3. src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

Here is where most of the work was done. The class has been extended in a backwards-compatible way with inference methods that allow one to pass MLAlgoParams objects. Usage of AsymmetricTextEmbeddingParameters (which implements MLAlgoParams) is mandatory for asymmetric models. At the same time symmetric models do not accept them.

The only way to know whether a model is asymmetric or symmetric is by reading its model configuration (if the models' configuration contains a passage_prefix and/or a query_prefix, they are asymmetric, otherwise they are symmetric).

The src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java class deals with this, keeping the complexity in one place and not requiring any API change to the neural-search plugin (as proposed in #620). When calling the inference methods, clients (such as the TextEmbeddingProcessor) may pass the AsymmetricTextEmbeddingParameters object without caring if the model they are using is symmetric or asymmetric. The accessor class will first read the model's configuration (by calling the getModel API of the mlClient) and deal appropriately.

To avoid adding this extra roundtrip to every inference call, the asymmetry information is kept in a cache in memory.

Issues Resolved

#620

Check List

  • [x] New functionality includes testing.
    • [x] All tests pass
  • [ ] New functionality has been documented.
    • [x] New functionality has javadoc added
  • [x] Commits are signed as per the DCO using --signoff

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. For more information on following Developer Certificate of Origin and signing off your commits, please check here.

br3no avatar Apr 25 '24 19:04 br3no