mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 13:53:16 +00:00
make MemoryEmbeddingRetriever act in non-batch mode (#5809)
This commit is contained in:
parent
1a212420b7
commit
1c69070db6
@ -173,7 +173,7 @@ class MemoryEmbeddingRetriever:
|
|||||||
@component.output_types(documents=List[List[Document]])
|
@component.output_types(documents=List[List[Document]])
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
queries_embeddings: List[List[float]],
|
query_embedding: List[float],
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
scale_score: Optional[bool] = None,
|
scale_score: Optional[bool] = None,
|
||||||
@ -182,7 +182,7 @@ class MemoryEmbeddingRetriever:
|
|||||||
"""
|
"""
|
||||||
Run the MemoryEmbeddingRetriever on the given input data.
|
Run the MemoryEmbeddingRetriever on the given input data.
|
||||||
|
|
||||||
:param queries_embeddings: Embeddings of the queries.
|
:param query_embedding: Embedding of the query.
|
||||||
:param filters: A dictionary with filters to narrow down the search space.
|
:param filters: A dictionary with filters to narrow down the search space.
|
||||||
:param top_k: The maximum number of documents to return.
|
:param top_k: The maximum number of documents to return.
|
||||||
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
||||||
@ -200,15 +200,12 @@ class MemoryEmbeddingRetriever:
|
|||||||
if return_embedding is None:
|
if return_embedding is None:
|
||||||
return_embedding = self.return_embedding
|
return_embedding = self.return_embedding
|
||||||
|
|
||||||
docs = []
|
docs = self.document_store.embedding_retrieval(
|
||||||
for query_embedding in queries_embeddings:
|
query_embedding=query_embedding,
|
||||||
docs.append(
|
filters=filters,
|
||||||
self.document_store.embedding_retrieval(
|
top_k=top_k,
|
||||||
query_embedding=query_embedding,
|
scale_score=scale_score,
|
||||||
filters=filters,
|
return_embedding=return_embedding,
|
||||||
top_k=top_k,
|
)
|
||||||
scale_score=scale_score,
|
|
||||||
return_embedding=return_embedding,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return {"documents": docs}
|
return {"documents": docs}
|
||||||
|
@ -190,14 +190,11 @@ class TestMemoryRetrievers:
|
|||||||
ds.write_documents(docs)
|
ds.write_documents(docs)
|
||||||
|
|
||||||
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
retriever = MemoryEmbeddingRetriever(ds, top_k=top_k)
|
||||||
result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True)
|
result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
|
||||||
|
|
||||||
assert "documents" in result
|
assert "documents" in result
|
||||||
assert len(result["documents"]) == 2
|
assert len(result["documents"]) == top_k
|
||||||
assert len(result["documents"][0]) == top_k
|
assert result["documents"][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||||
assert len(result["documents"][1]) == top_k
|
|
||||||
assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4]
|
|
||||||
assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
@pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever])
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -244,22 +241,15 @@ class TestMemoryRetrievers:
|
|||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_component("retriever", retriever)
|
pipeline.add_component("retriever", retriever)
|
||||||
result: Dict[str, Any] = pipeline.run(
|
result: Dict[str, Any] = pipeline.run(
|
||||||
data={
|
data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
|
||||||
"retriever": {
|
|
||||||
"queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]],
|
|
||||||
"return_embedding": True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result
|
assert result
|
||||||
assert "retriever" in result
|
assert "retriever" in result
|
||||||
results_docs = result["retriever"]["documents"]
|
results_docs = result["retriever"]["documents"]
|
||||||
assert results_docs
|
assert results_docs
|
||||||
assert len(results_docs[0]) == top_k
|
assert len(results_docs) == top_k
|
||||||
assert len(results_docs[1]) == top_k
|
assert results_docs[0].embedding == [1.0, 1.0, 1.0, 1.0]
|
||||||
assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4]
|
|
||||||
assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0]
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user