openai-cookbook
openai-cookbook copied to clipboard
strings_ranked_by_relatedness example function throws spatial error
Identify the file to be fixed https://cookbook.openai.com/examples/question_answering_using_embeddings
Describe the problem
The provided strings_ranked_by_relatedness function doesn't work:
# search function
def strings_ranked_by_relatedness(
query: str,
df: pd.DataFrame,
relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
top_n: int = 100
) -> tuple[list[str], list[float]]:
"""Returns a list of strings and relatednesses, sorted from most related to least."""
query_embedding_response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=query,
)
query_embedding = query_embedding_response.data[0].embedding
strings_and_relatednesses = [
(row["text"], relatedness_fn(query_embedding, row["embedding"]))
for i, row in df.iterrows()
]
strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)
strings, relatednesses = zip(*strings_and_relatednesses)
return strings[:top_n], relatednesses[:top_n]
consistently throws ValueError: Input vector should be 1-D.:
relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/scipy/spatial/distance.py", line 694, in cosine
return correlation(u, v, w=w, centered=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/scipy/spatial/distance.py", line 626, in correlation
v = _validate_vector(v)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/scipy/spatial/distance.py", line 302, in _validate_vector
raise ValueError("Input vector should be 1-D.")
ValueError: Input vector should be 1-D.
Describe a solution GPT sayeth:
The issue appears to be that the row embeddings are not of the expected shape. Specifically, they are being interpreted as (1,) instead of the expected shape matching the query embedding (1536,). This suggests that the embeddings may be nested within another structure or not correctly formatted.
To address this, we should ensure that each embedding is correctly extracted and reshaped. Here's the updated code to handle this:
import pandas as pd
import numpy as np
from scipy import spatial
def strings_ranked_by_relatedness(
query: str,
df: pd.DataFrame,
relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
top_n: int = 100
) -> tuple[list[str], list[float]]:
"""Returns a list of strings and relatednesses, sorted from most related to least."""
query_embedding_response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=query,
)
query_embedding = np.array(query_embedding_response.data[0].embedding)
query_embedding = query_embedding.flatten() # Ensure the query embedding is 1-D
def process_embedding(embedding):
embedding = np.array(embedding)
if embedding.size == 0 or embedding.ndim != 1 or embedding.shape[0] != query_embedding.shape[0]: # Check for invalid embeddings
return None
return embedding
print(f"Query embedding shape: {query_embedding.shape}")
strings_and_relatednesses = []
for i, row in df.iterrows():
row_embedding = process_embedding(row["embedding"])
if row_embedding is None: # Skip invalid embeddings
print(f"Skipping row {i} due to invalid embedding shape: {row['embedding']}")
continue
print(f"Row {i} embedding shape: {row_embedding.shape}")
relatedness = relatedness_fn(query_embedding, row_embedding)
strings_and_relatednesses.append((row["text"], relatedness))
strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)
if strings_and_relatednesses:
strings, relatednesses = zip(*strings_and_relatednesses)
return strings[:top_n], relatednesses[:top_n]
else:
return [], []
# Assuming `client` and `EMBEDDING_MODEL` are defined elsewhere in your code
In this updated code:
The
process_embeddingfunction now checks that each embedding is not only non-empty but also 1-D and has the same length as the query embedding.If an embedding fails these checks, it is skipped, and a message is printed indicating why it was skipped.
This should handle cases where the embeddings have unexpected shapes and ensure that only valid embeddings are processed.
Screenshots If applicable, add screenshots to help explain your problem.
Additional context Add any other context about the problem here.
This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 10 days.
This issue was closed because it has been stalled for 10 days with no activity.