haystack/examples/basic_qa_pipeline.py
Nicola Procopio c102b152dc
fix: Run update_embeddings in examples (#6008)
* added hybrid search example

Added an example about hybrid search for faq pipeline on covid dataset

* formatted with back formatter

* renamed document

* fixed

* fixed typos

* added test

added test for hybrid search

* fixed withespaces

* removed test for hybrid search

* fixed pylint

* commented logging

* updated hybrid search example

* release notes

* Update hybrid_search_faq_pipeline.py-815df846dca7e872.yaml

* Update hybrid_search_faq_pipeline.py

* mention hybrid search example in release notes

* reduce installed dependencies in examples test workflow

* do not install cuda dependencies

* skip models if API key not set; delete document indices

* skip models if API key not set; delete document indices

* skip models if API key not set; delete document indices

* keep roberta-base model and inference extra

* pylint

* disable pylint no-logging-basicconfig rule

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
2023-10-10 16:38:52 +02:00

80 lines
3.1 KiB
Python

# Disable pylint errors for logging basicConfig
# pylint: disable=no-logging-basicconfig
import logging
from pathlib import Path
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.nodes import BM25Retriever, FARMReader
from haystack.nodes.file_classifier import FileTypeClassifier
from haystack.nodes.file_converter import TextConverter
from haystack.nodes.preprocessor import PreProcessor
from haystack.pipelines import Pipeline
from haystack.utils import fetch_archive_from_http, launch_es, print_answers
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
def basic_qa_pipeline():
# Initialize a DocumentStore
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="example-document")
# fetch, pre-process and write documents
doc_dir = "data/basic_qa_pipeline"
s3_url = "https://core-engineering.s3.eu-central-1.amazonaws.com/public/scripts/wiki_gameofthrones_txt1.zip"
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
file_paths = [p for p in Path(doc_dir).glob("**/*")]
files_metadata = [{"name": path.name} for path in file_paths]
# Indexing Pipeline
indexing_pipeline = Pipeline()
# Makes sure the file is a TXT file (FileTypeClassifier node)
classifier = FileTypeClassifier()
indexing_pipeline.add_node(classifier, name="Classifier", inputs=["File"])
# Converts a file into text and performs basic cleaning (TextConverter node)
text_converter = TextConverter(remove_numeric_tables=True)
indexing_pipeline.add_node(text_converter, name="Text_converter", inputs=["Classifier.output_1"])
# - Pre-processes the text by performing splits and adding metadata to the text (Preprocessor node)
preprocessor = PreProcessor(
clean_whitespace=True,
clean_empty_lines=True,
split_length=100,
split_overlap=50,
split_respect_sentence_boundary=True,
)
indexing_pipeline.add_node(preprocessor, name="Preprocessor", inputs=["Text_converter"])
# - Writes the resulting documents into the document store
indexing_pipeline.add_node(document_store, name="Document_Store", inputs=["Preprocessor"])
# Then we run it with the documents and their metadata as input
indexing_pipeline.run(file_paths=file_paths, meta=files_metadata)
# Initialize Retriever & Reader
retriever = BM25Retriever(document_store=document_store)
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
# Query Pipeline
pipeline = Pipeline()
pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
prediction = pipeline.run(
query="Who is the father of Arya Stark?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
)
print_answers(prediction, details="minimum")
# Remove the index once we're done to save space
document_store.delete_index(index="example-document")
return prediction
if __name__ == "__main__":
launch_es()
basic_qa_pipeline()