fix: hybrid pipeline e2e test (#6740)

* fix hybrid pipeline e2e test

* warmup

* write to the right docstore
This commit is contained in:
ZanSara 2024-01-15 14:20:02 +01:00 committed by GitHub
parent 8eba053dbc
commit b236ea49e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,10 +1,11 @@
import json import json
from haystack import Pipeline, Document from haystack import Pipeline, Document
from haystack.components.embedders import SentenceTransformersTextEmbedder from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.rankers import TransformersSimilarityRanker from haystack.components.rankers import TransformersSimilarityRanker
from haystack.components.joiners.document_joiner import DocumentJoiner from haystack.components.joiners.document_joiner import DocumentJoiner
from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
@ -47,6 +48,10 @@ def test_hybrid_doc_search_pipeline(tmp_path):
Document(content="My name is Giorgio and I live in Rome."), Document(content="My name is Giorgio and I live in Rome."),
] ]
hybrid_pipeline.get_component("bm25_retriever").document_store.write_documents(documents) hybrid_pipeline.get_component("bm25_retriever").document_store.write_documents(documents)
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder.warm_up()
embedded_documents = doc_embedder.run(documents=documents)["documents"]
hybrid_pipeline.get_component("embedding_retriever").document_store.write_documents(embedded_documents)
query = "Who lives in Rome?" query = "Who lives in Rome?"
result = hybrid_pipeline.run( result = hybrid_pipeline.run(