make MemoryEmbeddingRetriever act in non-batch mode (#5809)

This commit is contained in:
Stefano Fiorucci 2023-09-14 15:37:20 +02:00 committed by GitHub
parent 1a212420b7
commit 1c69070db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 29 deletions

View File

@ -173,7 +173,7 @@ class MemoryEmbeddingRetriever:
@component.output_types(documents=List[List[Document]])
def run(
self,
queries_embeddings: List[List[float]],
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
scale_score: Optional[bool] = None,
@ -182,7 +182,7 @@ class MemoryEmbeddingRetriever:
"""
Run the MemoryEmbeddingRetriever on the given input data.
:param queries_embeddings: Embeddings of the queries.
:param query_embedding: Embedding of the query.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the scores of the retrieved documents or not.
@ -200,15 +200,12 @@ class MemoryEmbeddingRetriever:
if return_embedding is None:
return_embedding = self.return_embedding
docs = []
for query_embedding in queries_embeddings:
docs.append(
self.document_store.embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
scale_score=scale_score,
return_embedding=return_embedding,
)
)
docs = self.document_store.embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
scale_score=scale_score,
return_embedding=return_embedding,
)
return {"documents": docs}

View File

@ -190,14 +190,11 @@ class TestMemoryRetrievers:
ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True)
result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
assert "documents" in result
assert len(result["documents"]) == 2
assert len(result["documents"][0]) == top_k
assert len(result["documents"][1]) == top_k
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
assert len(result["documents"]) == top_k
assert result["documents"][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit
@ -244,22 +241,15 @@ class TestMemoryRetrievers:
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run(
data={
"retriever": {
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
"return_embedding": True,
}
}
data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
)
assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert len(results_docs[0]) == top_k
assert len(results_docs[1]) == top_k
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
assert len(results_docs) == top_k
assert results_docs[0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.integration
@pytest.mark.parametrize(