Tuana Celik d49e92e21c
ElasticsearchRetriever to BM25Retriever (#2423)
* change class names to bm25

* Update Documentation & Code Style

* Update Documentation & Code Style

* Update Documentation & Code Style

* Add back all_terms_must_match

* fix syntax

* Update Documentation & Code Style

* Update Documentation & Code Style

* Creating a wrapper for old ES retriever with deprecated wrapper

* Update Documentation & Code Style

* New method for deprecating old ESRetriever

* New attempt for deprecating the ESRetriever

* Reverting to the simplest solution - warning logged

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
2022-04-26 16:09:39 +02:00

186 lines
7.3 KiB
Python

import os
from haystack.document_stores.sql import SQLDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores.elasticsearch import Elasticsearch, ElasticsearchDocumentStore, OpenSearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.document_stores.milvus import MilvusDocumentStore
from haystack.nodes.retriever.sparse import BM25Retriever, TfidfRetriever
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever
from haystack.nodes.reader.farm import FARMReader
from haystack.nodes.reader.transformers import TransformersReader
from haystack.utils import launch_milvus, launch_es, launch_opensearch
from haystack.modeling.data_handler.processor import http_get
import logging
import subprocess
import time
import json
from typing import Union
from pathlib import Path
logger = logging.getLogger(__name__)
reader_models = [
"deepset/roberta-base-squad2",
"deepset/minilm-uncased-squad2",
"deepset/bert-base-cased-squad2",
"deepset/bert-large-uncased-whole-word-masking-squad2",
"deepset/xlm-roberta-large-squad2",
]
reader_types = ["farm"]
doc_index = "eval_document"
label_index = "label"
def get_document_store(document_store_type, similarity="dot_product", index="document"):
"""TODO This method is taken from test/conftest.py but maybe should be within Haystack.
Perhaps a class method of DocStore that just takes string for type of DocStore"""
if document_store_type == "sql":
if os.path.exists("haystack_test.db"):
os.remove("haystack_test.db")
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
assert document_store.get_document_count() == 0
elif document_store_type == "memory":
document_store = InMemoryDocumentStore()
elif document_store_type == "elasticsearch":
launch_es()
# make sure we start from a fresh index
client = Elasticsearch()
client.indices.delete(index="haystack_test*", ignore=[404])
document_store = ElasticsearchDocumentStore(index="eval_document", similarity=similarity, timeout=3000)
elif document_store_type in ("milvus_flat", "milvus_hnsw"):
launch_milvus()
if document_store_type == "milvus_flat":
index_type = "FLAT"
index_param = None
search_param = None
elif document_store_type == "milvus_hnsw":
index_type = "HNSW"
index_param = {"M": 64, "efConstruction": 80}
search_param = {"ef": 20}
document_store = MilvusDocumentStore(
similarity=similarity,
index_type=index_type,
index_param=index_param,
search_param=search_param,
index=index,
)
assert document_store.get_document_count(index="eval_document") == 0
elif document_store_type in ("faiss_flat", "faiss_hnsw"):
if document_store_type == "faiss_flat":
index_type = "Flat"
elif document_store_type == "faiss_hnsw":
index_type = "HNSW"
status = subprocess.run(["docker rm -f haystack-postgres"], shell=True)
time.sleep(1)
status = subprocess.run(
["docker run --name haystack-postgres -p 5432:5432 -e POSTGRES_PASSWORD=password -d postgres"], shell=True
)
time.sleep(6)
status = subprocess.run(
['docker exec haystack-postgres psql -U postgres -c "CREATE DATABASE haystack;"'], shell=True
)
time.sleep(1)
document_store = FAISSDocumentStore(
sql_url="postgresql://postgres:password@localhost:5432/haystack",
faiss_index_factory_str=index_type,
similarity=similarity,
index=index,
)
assert document_store.get_document_count() == 0
elif document_store_type in ("opensearch_flat", "opensearch_hnsw"):
launch_opensearch()
if document_store_type == "opensearch_flat":
index_type = "flat"
elif document_store_type == "opensearch_hnsw":
index_type = "hnsw"
document_store = OpenSearchDocumentStore(index_type=index_type, port=9201, timeout=3000)
else:
raise Exception(f"No document store fixture for '{document_store_type}'")
return document_store
def get_retriever(retriever_name, doc_store, devices):
if retriever_name == "elastic":
return BM25Retriever(doc_store)
if retriever_name == "tfidf":
return TfidfRetriever(doc_store)
if retriever_name == "dpr":
return DensePassageRetriever(
document_store=doc_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
use_fast_tokenizers=False,
devices=devices,
)
if retriever_name == "sentence_transformers":
return EmbeddingRetriever(
document_store=doc_store,
embedding_model="nq-distilbert-base-v1",
use_gpu=True,
model_format="sentence_transformers",
)
def get_reader(reader_name, reader_type, max_seq_len=384):
reader_class = None
if reader_type == "farm":
reader_class = FARMReader
elif reader_type == "transformers":
reader_class = TransformersReader
return reader_class(reader_name, top_k_per_candidate=4, max_seq_len=max_seq_len)
def index_to_doc_store(doc_store, docs, retriever, labels=None):
doc_store.write_documents(docs, doc_index)
if labels:
doc_store.write_labels(labels, index=label_index)
# these lines are not run if the docs.embedding field is already populated with precomputed embeddings
# See the prepare_data() fn in the retriever benchmark script
if callable(getattr(retriever, "embed_documents", None)) and docs[0].embedding is None:
doc_store.update_embeddings(retriever, index=doc_index)
def load_config(config_filename, ci):
conf = json.load(open(config_filename))
if ci:
params = conf["params"]["ci"]
else:
params = conf["params"]["full"]
filenames = conf["filenames"]
max_docs = max(params["n_docs_options"])
n_docs_keys = sorted([int(x) for x in list(filenames["embeddings_filenames"])])
for k in n_docs_keys:
if max_docs <= k:
filenames["embeddings_filenames"] = [filenames["embeddings_filenames"][str(k)]]
filenames["filename_negative"] = filenames["filenames_negative"][str(k)]
break
return params, filenames
def download_from_url(url: str, filepath: Union[str, Path]):
"""
Download from a url to a local file. Skip already existing files.
:param url: Url
:param filepath: local path where the url content shall be stored
:return: local path of the downloaded file
"""
logger.info(f"Downloading {url}")
# Create local folder
folder, filename = os.path.split(filepath)
if not os.path.exists(folder):
os.makedirs(folder)
# Download file if not present locally
if os.path.exists(filepath):
logger.info(f"Skipping {url} (exists locally)")
else:
logger.info(f"Downloading {url} to {filepath} ")
with open(filepath, "wb") as file:
http_get(url=url, temp_file=file)
return filepath