| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2022-09-20 10:22:08 +02:00
										 |  |  | from haystack.document_stores import SQLDocumentStore | 
					
						
							|  |  |  | from haystack.document_stores import InMemoryDocumentStore | 
					
						
							|  |  |  | from haystack.document_stores import ElasticsearchDocumentStore, OpenSearchDocumentStore | 
					
						
							|  |  |  | from haystack.document_stores.elasticsearch import Elasticsearch | 
					
						
							|  |  |  | from haystack.document_stores import FAISSDocumentStore | 
					
						
							|  |  |  | from haystack.nodes import BM25Retriever, TfidfRetriever | 
					
						
							|  |  |  | from haystack.nodes import DensePassageRetriever, EmbeddingRetriever | 
					
						
							|  |  |  | from haystack.nodes import FARMReader | 
					
						
							|  |  |  | from haystack.nodes import TransformersReader | 
					
						
							| 
									
										
										
										
											2023-05-19 16:37:38 +02:00
										 |  |  | from haystack.utils import launch_es, launch_opensearch | 
					
						
							| 
									
										
										
										
											2021-09-28 16:34:24 +02:00
										 |  |  | from haystack.modeling.data_handler.processor import http_get | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | import logging | 
					
						
							|  |  |  | import subprocess | 
					
						
							|  |  |  | import time | 
					
						
							| 
									
										
										
										
											2020-10-21 17:59:44 +02:00
										 |  |  | import json | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  | from typing import Union | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | reader_models = [ | 
					
						
							|  |  |  |     "deepset/roberta-base-squad2", | 
					
						
							|  |  |  |     "deepset/minilm-uncased-squad2", | 
					
						
							|  |  |  |     "deepset/bert-base-cased-squad2", | 
					
						
							|  |  |  |     "deepset/bert-large-uncased-whole-word-masking-squad2", | 
					
						
							|  |  |  |     "deepset/xlm-roberta-large-squad2", | 
					
						
							|  |  |  | ] | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | reader_types = ["farm"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | doc_index = "eval_document" | 
					
						
							|  |  |  | label_index = "label" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def get_document_store(document_store_type, similarity="dot_product", index="document"): | 
					
						
							|  |  |  |     """TODO This method is taken from test/conftest.py but maybe should be within Haystack.
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |     Perhaps a class method of DocStore that just takes string for type of DocStore"""
 | 
					
						
							|  |  |  |     if document_store_type == "sql": | 
					
						
							|  |  |  |         if os.path.exists("haystack_test.db"): | 
					
						
							|  |  |  |             os.remove("haystack_test.db") | 
					
						
							|  |  |  |         document_store = SQLDocumentStore(url="sqlite:///haystack_test.db") | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |         assert document_store.get_document_count() == 0 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |     elif document_store_type == "memory": | 
					
						
							|  |  |  |         document_store = InMemoryDocumentStore() | 
					
						
							|  |  |  |     elif document_store_type == "elasticsearch": | 
					
						
							| 
									
										
										
										
											2021-07-26 10:52:52 +02:00
										 |  |  |         launch_es() | 
					
						
							| 
									
										
										
										
											2022-09-20 10:22:08 +02:00
										 |  |  |         time.sleep(5) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |         # make sure we start from a fresh index | 
					
						
							|  |  |  |         client = Elasticsearch() | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         client.indices.delete(index="haystack_test*", ignore=[404]) | 
					
						
							| 
									
										
										
										
											2021-02-03 11:45:18 +01:00
										 |  |  |         document_store = ElasticsearchDocumentStore(index="eval_document", similarity=similarity, timeout=3000) | 
					
						
							| 
									
										
										
										
											2021-07-26 10:52:52 +02:00
										 |  |  |     elif document_store_type in ("faiss_flat", "faiss_hnsw"): | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |         if document_store_type == "faiss_flat": | 
					
						
							|  |  |  |             index_type = "Flat" | 
					
						
							|  |  |  |         elif document_store_type == "faiss_hnsw": | 
					
						
							|  |  |  |             index_type = "HNSW" | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         status = subprocess.run(["docker rm -f haystack-postgres"], shell=True) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |         time.sleep(1) | 
					
						
							|  |  |  |         status = subprocess.run( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |             ["docker run --name haystack-postgres -p 5432:5432 -e POSTGRES_PASSWORD=password -d postgres"], shell=True | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-02-03 11:45:18 +01:00
										 |  |  |         time.sleep(6) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |         status = subprocess.run( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |             ['docker exec haystack-postgres psql -U postgres -c "CREATE DATABASE haystack;"'], shell=True | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |         time.sleep(1) | 
					
						
							| 
									
										
										
										
											2021-06-04 11:05:18 +02:00
										 |  |  |         document_store = FAISSDocumentStore( | 
					
						
							|  |  |  |             sql_url="postgresql://postgres:password@localhost:5432/haystack", | 
					
						
							|  |  |  |             faiss_index_factory_str=index_type, | 
					
						
							|  |  |  |             similarity=similarity, | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |             index=index, | 
					
						
							| 
									
										
										
										
											2021-06-04 11:05:18 +02:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |         assert document_store.get_document_count() == 0 | 
					
						
							| 
									
										
										
										
											2021-07-26 10:52:52 +02:00
										 |  |  |     elif document_store_type in ("opensearch_flat", "opensearch_hnsw"): | 
					
						
							| 
									
										
										
										
											2022-11-28 14:36:45 +01:00
										 |  |  |         launch_opensearch(local_port=9201) | 
					
						
							| 
									
										
										
										
											2021-07-26 10:52:52 +02:00
										 |  |  |         if document_store_type == "opensearch_flat": | 
					
						
							|  |  |  |             index_type = "flat" | 
					
						
							|  |  |  |         elif document_store_type == "opensearch_hnsw": | 
					
						
							|  |  |  |             index_type = "hnsw" | 
					
						
							| 
									
										
										
										
											2021-10-19 14:40:53 +02:00
										 |  |  |         document_store = OpenSearchDocumentStore(index_type=index_type, port=9201, timeout=3000) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |     else: | 
					
						
							|  |  |  |         raise Exception(f"No document store fixture for '{document_store_type}'") | 
					
						
							|  |  |  |     return document_store | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-10 13:25:02 +02:00
										 |  |  | def get_retriever(retriever_name, doc_store, devices): | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |     if retriever_name == "elastic": | 
					
						
							| 
									
										
										
										
											2022-04-26 16:09:39 +02:00
										 |  |  |         return BM25Retriever(doc_store) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |     if retriever_name == "tfidf": | 
					
						
							|  |  |  |         return TfidfRetriever(doc_store) | 
					
						
							|  |  |  |     if retriever_name == "dpr": | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         return DensePassageRetriever( | 
					
						
							|  |  |  |             document_store=doc_store, | 
					
						
							|  |  |  |             query_embedding_model="facebook/dpr-question_encoder-single-nq-base", | 
					
						
							|  |  |  |             passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", | 
					
						
							|  |  |  |             use_gpu=True, | 
					
						
							|  |  |  |             use_fast_tokenizers=False, | 
					
						
							|  |  |  |             devices=devices, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |     if retriever_name == "sentence_transformers": | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         return EmbeddingRetriever( | 
					
						
							|  |  |  |             document_store=doc_store, | 
					
						
							|  |  |  |             embedding_model="nq-distilbert-base-v1", | 
					
						
							|  |  |  |             use_gpu=True, | 
					
						
							|  |  |  |             model_format="sentence_transformers", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | def get_reader(reader_name, reader_type, max_seq_len=384): | 
					
						
							|  |  |  |     reader_class = None | 
					
						
							|  |  |  |     if reader_type == "farm": | 
					
						
							|  |  |  |         reader_class = FARMReader | 
					
						
							|  |  |  |     elif reader_type == "transformers": | 
					
						
							|  |  |  |         reader_class = TransformersReader | 
					
						
							|  |  |  |     return reader_class(reader_name, top_k_per_candidate=4, max_seq_len=max_seq_len) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | def index_to_doc_store(doc_store, docs, retriever, labels=None): | 
					
						
							|  |  |  |     doc_store.write_documents(docs, doc_index) | 
					
						
							|  |  |  |     if labels: | 
					
						
							|  |  |  |         doc_store.write_labels(labels, index=label_index) | 
					
						
							|  |  |  |     # these lines are not run if the docs.embedding field is already populated with precomputed embeddings | 
					
						
							|  |  |  |     # See the prepare_data() fn in the retriever benchmark script | 
					
						
							| 
									
										
										
										
											2021-10-28 12:17:56 +02:00
										 |  |  |     if callable(getattr(retriever, "embed_documents", None)) and docs[0].embedding is None: | 
					
						
							| 
									
										
										
										
											2022-09-20 10:22:08 +02:00
										 |  |  |         doc_store.update_embeddings(retriever, index=doc_index, batch_size=200) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-15 18:12:17 +02:00
										 |  |  | def load_config(config_filename, ci): | 
					
						
							|  |  |  |     conf = json.load(open(config_filename)) | 
					
						
							|  |  |  |     if ci: | 
					
						
							|  |  |  |         params = conf["params"]["ci"] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         params = conf["params"]["full"] | 
					
						
							|  |  |  |     filenames = conf["filenames"] | 
					
						
							|  |  |  |     max_docs = max(params["n_docs_options"]) | 
					
						
							|  |  |  |     n_docs_keys = sorted([int(x) for x in list(filenames["embeddings_filenames"])]) | 
					
						
							|  |  |  |     for k in n_docs_keys: | 
					
						
							|  |  |  |         if max_docs <= k: | 
					
						
							|  |  |  |             filenames["embeddings_filenames"] = [filenames["embeddings_filenames"][str(k)]] | 
					
						
							|  |  |  |             filenames["filename_negative"] = filenames["filenames_negative"][str(k)] | 
					
						
							|  |  |  |             break | 
					
						
							|  |  |  |     return params, filenames | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | def download_from_url(url: str, filepath: Union[str, Path]): | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Download from a url to a local file. Skip already existing files. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :param url: Url | 
					
						
							|  |  |  |     :param filepath: local path where the url content shall be stored | 
					
						
							|  |  |  |     :return: local path of the downloaded file | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-19 18:18:32 +02:00
										 |  |  |     logger.info("Downloading %s", url) | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |     # Create local folder | 
					
						
							|  |  |  |     folder, filename = os.path.split(filepath) | 
					
						
							|  |  |  |     if not os.path.exists(folder): | 
					
						
							|  |  |  |         os.makedirs(folder) | 
					
						
							|  |  |  |     # Download file if not present locally | 
					
						
							|  |  |  |     if os.path.exists(filepath): | 
					
						
							| 
									
										
										
										
											2022-09-19 18:18:32 +02:00
										 |  |  |         logger.info("Skipping %s (exists locally)", url) | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2022-09-19 18:18:32 +02:00
										 |  |  |         logger.info("Downloading %s to %s", filepath) | 
					
						
							| 
									
										
										
										
											2021-04-09 17:24:16 +02:00
										 |  |  |         with open(filepath, "wb") as file: | 
					
						
							|  |  |  |             http_get(url=url, temp_file=file) | 
					
						
							| 
									
										
										
										
											2021-09-10 13:25:02 +02:00
										 |  |  |     return filepath |