make MemoryEmbeddingRetriever act in non-batch mode (#5809)

This commit is contained in:
Stefano Fiorucci 2023-09-14 15:37:20 +02:00 committed by GitHub
parent 1a212420b7
commit 1c69070db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 29 deletions

View File

@ -173,7 +173,7 @@ class MemoryEmbeddingRetriever:
@component.output_types(documents=List[List[Document]]) @component.output_types(documents=List[List[Document]])
def run( def run(
self, self,
queries_embeddings: List[List[float]], query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None, filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
scale_score: Optional[bool] = None, scale_score: Optional[bool] = None,
@ -182,7 +182,7 @@ class MemoryEmbeddingRetriever:
""" """
Run the MemoryEmbeddingRetriever on the given input data. Run the MemoryEmbeddingRetriever on the given input data.
:param queries_embeddings: Embeddings of the queries. :param query_embedding: Embedding of the query.
:param filters: A dictionary with filters to narrow down the search space. :param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return. :param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the scores of the retrieved documents or not. :param scale_score: Whether to scale the scores of the retrieved documents or not.
@ -200,15 +200,12 @@ class MemoryEmbeddingRetriever:
if return_embedding is None: if return_embedding is None:
return_embedding = self.return_embedding return_embedding = self.return_embedding
docs = [] docs = self.document_store.embedding_retrieval(
for query_embedding in queries_embeddings: query_embedding=query_embedding,
docs.append( filters=filters,
self.document_store.embedding_retrieval( top_k=top_k,
query_embedding=query_embedding, scale_score=scale_score,
filters=filters, return_embedding=return_embedding,
top_k=top_k, )
scale_score=scale_score,
return_embedding=return_embedding,
)
)
return {"documents": docs} return {"documents": docs}

View File

@ -190,14 +190,11 @@ class TestMemoryRetrievers:
ds.write_documents(docs) ds.write_documents(docs)
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k) retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True) result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
assert "documents" in result assert "documents" in result
assert len(result["documents"]) == 2 assert len(result["documents"]) == top_k
assert len(result["documents"][0]) == top_k assert result["documents"][0].embedding == [1.0, 1.0, 1.0, 1.0]
assert len(result["documents"][1]) == top_k
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
@pytest.mark.unit @pytest.mark.unit
@ -244,22 +241,15 @@ class TestMemoryRetrievers:
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_component("retriever", retriever) pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run( result: Dict[str, Any] = pipeline.run(
data={ data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
"retriever": {
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
"return_embedding": True,
}
}
) )
assert result assert result
assert "retriever" in result assert "retriever" in result
results_docs = result["retriever"]["documents"] results_docs = result["retriever"]["documents"]
assert results_docs assert results_docs
assert len(results_docs[0]) == top_k assert len(results_docs) == top_k
assert len(results_docs[1]) == top_k assert results_docs[0].embedding == [1.0, 1.0, 1.0, 1.0]
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize( @pytest.mark.parametrize(