2024-05-09 15:40:36 +02:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2023-11-20 10:56:56 +01:00
|
|
|
|
2024-05-09 15:40:36 +02:00
|
|
|
from haystack import Document, Pipeline
|
|
|
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
|
2024-01-08 22:06:27 +01:00
|
|
|
from haystack.components.joiners.document_joiner import DocumentJoiner
|
2024-05-09 15:40:36 +02:00
|
|
|
from haystack.components.rankers import TransformersSimilarityRanker
|
|
|
|
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
2024-01-10 21:20:42 +01:00
|
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
2023-11-20 10:56:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
def test_hybrid_doc_search_pipeline(tmp_path):
|
|
|
|
# Create the pipeline
|
|
|
|
document_store = InMemoryDocumentStore()
|
|
|
|
hybrid_pipeline = Pipeline()
|
|
|
|
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
|
|
|
|
hybrid_pipeline.add_component(
|
2024-01-12 15:30:17 +01:00
|
|
|
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
|
2023-11-20 10:56:56 +01:00
|
|
|
)
|
|
|
|
hybrid_pipeline.add_component(
|
|
|
|
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
|
|
|
|
)
|
|
|
|
hybrid_pipeline.add_component(instance=DocumentJoiner(), name="joiner")
|
|
|
|
hybrid_pipeline.add_component(instance=TransformersSimilarityRanker(top_k=20), name="ranker")
|
|
|
|
|
|
|
|
hybrid_pipeline.connect("bm25_retriever", "joiner")
|
|
|
|
hybrid_pipeline.connect("text_embedder", "embedding_retriever")
|
|
|
|
hybrid_pipeline.connect("embedding_retriever", "joiner")
|
|
|
|
hybrid_pipeline.connect("joiner", "ranker")
|
|
|
|
|
2024-01-17 15:09:27 +01:00
|
|
|
# Serialize the pipeline to YAML
|
|
|
|
with open(tmp_path / "test_hybrid_doc_search_pipeline.yaml", "w") as f:
|
|
|
|
hybrid_pipeline.dump(f)
|
2023-11-20 10:56:56 +01:00
|
|
|
|
|
|
|
# Load the pipeline back
|
2024-01-17 15:09:27 +01:00
|
|
|
with open(tmp_path / "test_hybrid_doc_search_pipeline.yaml", "r") as f:
|
|
|
|
hybrid_pipeline = Pipeline.load(f)
|
2023-11-20 10:56:56 +01:00
|
|
|
|
|
|
|
# Populate the document store
|
|
|
|
documents = [
|
|
|
|
Document(content="My name is Jean and I live in Paris."),
|
|
|
|
Document(content="My name is Mark and I live in Berlin."),
|
|
|
|
Document(content="My name is Mario and I live in the capital of Italy."),
|
|
|
|
Document(content="My name is Giorgio and I live in Rome."),
|
|
|
|
]
|
2024-01-15 14:20:02 +01:00
|
|
|
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
doc_embedder.warm_up()
|
|
|
|
embedded_documents = doc_embedder.run(documents=documents)["documents"]
|
|
|
|
hybrid_pipeline.get_component("embedding_retriever").document_store.write_documents(embedded_documents)
|
2023-11-20 10:56:56 +01:00
|
|
|
|
|
|
|
query = "Who lives in Rome?"
|
|
|
|
result = hybrid_pipeline.run(
|
|
|
|
{"bm25_retriever": {"query": query}, "text_embedder": {"text": query}, "ranker": {"query": query}}
|
|
|
|
)
|
|
|
|
assert result["ranker"]["documents"][0].content == "My name is Giorgio and I live in Rome."
|
|
|
|
assert result["ranker"]["documents"][1].content == "My name is Mario and I live in the capital of Italy."
|