diff --git a/haystack/preview/components/retrievers/memory.py b/haystack/preview/components/retrievers/memory.py index 5600b43e7..bc82454ca 100644 --- a/haystack/preview/components/retrievers/memory.py +++ b/haystack/preview/components/retrievers/memory.py @@ -173,7 +173,7 @@ class MemoryEmbeddingRetriever: @component.output_types(documents=List[List[Document]]) def run( self, - queries_embeddings: List[List[float]], + query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, scale_score: Optional[bool] = None, @@ -182,7 +182,7 @@ class MemoryEmbeddingRetriever: """ Run the MemoryEmbeddingRetriever on the given input data. - :param queries_embeddings: Embeddings of the queries. + :param query_embedding: Embedding of the query. :param filters: A dictionary with filters to narrow down the search space. :param top_k: The maximum number of documents to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. @@ -200,15 +200,12 @@ class MemoryEmbeddingRetriever: if return_embedding is None: return_embedding = self.return_embedding - docs = [] - for query_embedding in queries_embeddings: - docs.append( - self.document_store.embedding_retrieval( - query_embedding=query_embedding, - filters=filters, - top_k=top_k, - scale_score=scale_score, - return_embedding=return_embedding, - ) - ) + docs = self.document_store.embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + scale_score=scale_score, + return_embedding=return_embedding, + ) + return {"documents": docs} diff --git a/test/preview/components/retrievers/test_memory_retriever.py b/test/preview/components/retrievers/test_memory_retriever.py index 9e8361475..19ad0843b 100644 --- a/test/preview/components/retrievers/test_memory_retriever.py +++ b/test/preview/components/retrievers/test_memory_retriever.py @@ -190,14 +190,11 @@ class TestMemoryRetrievers: ds.write_documents(docs) retriever = MemoryEmbeddingRetriever(ds, top_k=top_k) - result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True) + result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True) assert "documents" in result - assert len(result["documents"]) == 2 - assert len(result["documents"][0]) == top_k - assert len(result["documents"][1]) == top_k - assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4] - assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0] + assert len(result["documents"]) == top_k + assert result["documents"][0].embedding == [1.0, 1.0, 1.0, 1.0] @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit @@ -244,22 +241,15 @@ class TestMemoryRetrievers: pipeline = Pipeline() pipeline.add_component("retriever", retriever) result: Dict[str, Any] = pipeline.run( - data={ - "retriever": { - "queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], - "return_embedding": True, - } - } + data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}} ) assert result assert "retriever" in result results_docs = result["retriever"]["documents"] assert results_docs - assert len(results_docs[0]) == top_k - assert len(results_docs[1]) == top_k - assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4] - assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0] + assert len(results_docs) == top_k + assert results_docs[0].embedding == [1.0, 1.0, 1.0, 1.0] @pytest.mark.integration @pytest.mark.parametrize(