spring-ai icon indicating copy to clipboard operation
spring-ai copied to clipboard

Incorrect SQL Query Generation for IN and NOT IN Filter Expressions in PGVector Store

Open muthuishere opened this issue 6 months ago • 0 comments

Bug description The PgVectorFilterExpressionConverter class in org.springframework.ai.vectorstore does not correctly handle the conversion of IN and NOT IN filter expressions for SQL queries targeting PostgreSQL JSON data types. PostgreSQL JSON does not natively support IN or NOT IN as filter expressions, which leads to BadSqlGrammarException when such queries are executed.

Environment

  • Spring AI version: 1.0.0-M1
  • Java version: 17

Steps to reproduce

  1. Create a SearchRequest using the IN filter expression:

FilterExpressionBuilder b = new FilterExpressionBuilder();
Filter.Expression  expression = b.in("country", List.of("BG", "NL")).build();

SearchRequest searchRequest = SearchRequest.query("The World").withFilterExpression(expression).withTopK(5).withSimilarityThresholdAll();
  1. Execute the search which triggers the query to the vector store.
  2. Observe the resulting BadSqlGrammarException due to incorrect SQL generation.

Expected behavior The query should successfully filter records based on departments 'HR' or 'IT' without throwing SQL grammar errors. The IN and NOT IN expressions should be converted to a format that is compatible with PostgreSQL's JSON capabilities.

Minimal Complete Reproducible example (a test in PgVectorStoreIT)

@Test
public void searchWithInFilter() {

 String distanceType="COSINE_DISTANCE";
 contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType)
	.run(context -> {

		VectorStore vectorStore = context.getBean(VectorStore.class);

		var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
				Map.of("country", "BG", "year", 2020, "foo bar 1", "bar.foo"));
		var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
				Map.of("country", "NL"));
		var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner",
				Map.of("country", "BG", "year", 2023));

		vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2));

		FilterExpressionBuilder b = new FilterExpressionBuilder();

		Filter.Expression  expression = b.in("country", List.of("BG", "NL")).build();

		SearchRequest searchRequest = SearchRequest.query("The World").withFilterExpression(expression).withTopK(5).withSimilarityThresholdAll();

		List<Document> results = vectorStore.similaritySearch(searchRequest);

		assertThat(results).hasSize(3);


		// Remove all documents from the store
		dropTable(context);
	});
	}

Error Log:

org.springframework.jdbc.BadSqlGrammarException: PreparedStatementCallback; bad SQL grammar [SELECT *, embedding <=> ? AS distance FROM public.vector_store WHERE embedding <=> ? < ?  AND metadata::jsonb @@ '$.country in ["BG","NL"]'::jsonpath  ORDER BY distance LIMIT ? ]

Proposed Solution Update the PgVectorFilterExpressionConverter class to handle IN and NOT IN expressions separately, ensuring they are converted to a series of logical OR conditions for IN and NOT logical OR conditions for NOT IN. Example modifications include:

@Override
protected void doExpression(Expression expression, StringBuilder context) {
    if (expression.type() == Expression.Type.IN) {
        handleIn(expression, context);
    } else if (expression.type() == Expression.Type.NIN) {
        handleNotIn(expression, context);
    } else {
        this.convertOperand(expression.left(), context);
        context.append(getOperationSymbol(expression));
        this.convertOperand(expression.right(), context);
    }
}

private void handleIn(Expression expression, StringBuilder context) {
context.append(" (");
 Filter.Value right = (Filter.Value) expression.right();

 // Assuming right() returns the collection of elements , as its always a list
 List<Object>  values = (List) right.value();
 for (int i = 0; i < values.size(); i++) {
	this.convertOperand(expression.left(), context);
	context.append(" == ");
	this.doSingleValue(values.get(i), context);

	if (i < values.size() - 1) {
		context.append(" || ");
	}
 }
 context.append(") ");
}
// SELECT *, embedding <=> ? AS distance FROM public.vector_store WHERE embedding <=> ? < ?  AND metadata::jsonb @@ ' ($.country == "BG" || $.country == "NL") '::jsonpath  ORDER BY distance LIMIT ? 


private void handleNotIn(Expression expression, StringBuilder context) {
 context.append(" !(");
 Filter.Value right = (Filter.Value) expression.right();

 // Assuming right() returns the collection of elements , as its always a list
 List<Object>  values = (List) right.value();
 for (int i = 0; i < values.size(); i++) {
	this.convertOperand(expression.left(), context);
	context.append(" == ");
	this.doSingleValue(values.get(i), context);

	if (i < values.size() - 1) {
		context.append(" || ");
	}
 }
 context.append(") ");
}

// SELECT *, embedding <=> ? AS distance FROM public.vector_store WHERE embedding <=> ? < ?  AND metadata::jsonb @@ '! ($.country == "BG" || $.country == "NL") '::jsonpath  ORDER BY distance LIMIT ? 

I have already completed the fix and added tests to ensure things are fine , let me know can i raise a PR

muthuishere avatar Aug 07 '24 07:08 muthuishere