mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 20:03:38 +00:00
make MemoryEmbeddingRetriever act in non-batch mode (#5809)
This commit is contained in:
parent
1a212420b7
commit
1c69070db6
@ -173,7 +173,7 @@ class MemoryEmbeddingRetriever:
|
||||
@component.output_types(documents=List[List[Document]])
|
||||
def run(
|
||||
self,
|
||||
queries_embeddings: List[List[float]],
|
||||
query_embedding: List[float],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
@ -182,7 +182,7 @@ class MemoryEmbeddingRetriever:
|
||||
"""
|
||||
Run the MemoryEmbeddingRetriever on the given input data.
|
||||
|
||||
:param queries_embeddings: Embeddings of the queries.
|
||||
:param query_embedding: Embedding of the query.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
||||
@ -200,15 +200,12 @@ class MemoryEmbeddingRetriever:
|
||||
if return_embedding is None:
|
||||
return_embedding = self.return_embedding
|
||||
|
||||
docs = []
|
||||
for query_embedding in queries_embeddings:
|
||||
docs.append(
|
||||
self.document_store.embedding_retrieval(
|
||||
query_embedding=query_embedding,
|
||||
filters=filters,
|
||||
top_k=top_k,
|
||||
scale_score=scale_score,
|
||||
return_embedding=return_embedding,
|
||||
)
|
||||
)
|
||||
docs = self.document_store.embedding_retrieval(
|
||||
query_embedding=query_embedding,
|
||||
filters=filters,
|
||||
top_k=top_k,
|
||||
scale_score=scale_score,
|
||||
return_embedding=return_embedding,
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
|
@ -190,14 +190,11 @@ class TestMemoryRetrievers:
|
||||
ds.write_documents(docs)
|
||||
|
||||
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True)
|
||||
result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
|
||||
|
||||
assert "documents" in result
|
||||
assert len(result["documents"]) == 2
|
||||
assert len(result["documents"][0]) == top_k
|
||||
assert len(result["documents"][1]) == top_k
|
||||
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
|
||||
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||
assert len(result["documents"]) == top_k
|
||||
assert result["documents"][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||
@pytest.mark.unit
|
||||
@ -244,22 +241,15 @@ class TestMemoryRetrievers:
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", retriever)
|
||||
result: Dict[str, Any] = pipeline.run(
|
||||
data={
|
||||
"retriever": {
|
||||
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
|
||||
"return_embedding": True,
|
||||
}
|
||||
}
|
||||
data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
|
||||
)
|
||||
|
||||
assert result
|
||||
assert "retriever" in result
|
||||
results_docs = result["retriever"]["documents"]
|
||||
assert results_docs
|
||||
assert len(results_docs[0]) == top_k
|
||||
assert len(results_docs[1]) == top_k
|
||||
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
|
||||
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||
assert len(results_docs) == top_k
|
||||
assert results_docs[0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
x
Reference in New Issue
Block a user