Integrate sentence transformers into benchmarks (#843)

* Integrate sentence transformers into benchmarks

* Add doc store asserts

* switch data downloads from s3 client to https. add license info

* Fix mypy, revert config

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Timo Moeller 2021-04-09 17:24:16 +02:00 committed by GitHub
parent d38c07e0ee
commit 837dea4e6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 24 deletions

View File

@ -462,7 +462,7 @@ that are most relevant to the query.
#### embed
```python
| embed(texts: Union[List[str], str]) -> List[np.ndarray]
| embed(texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]
```
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
@ -496,7 +496,7 @@ Embeddings, one per input queries
#### embed\_passages
```python
| embed_passages(docs: List[Document]) -> List[np.ndarray]
| embed_passages(docs: List[Document]) -> Union[List[str], List[List[str]]]
```
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()

View File

@ -524,7 +524,7 @@ class EmbeddingRetriever(BaseRetriever):
top_k=top_k, index=index)
return documents
def embed(self, texts: Union[List[str], str]) -> List[np.ndarray]:
def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
"""
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
@ -543,9 +543,9 @@ class EmbeddingRetriever(BaseRetriever):
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
emb = [(r["vec"]) for r in emb]
elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings
# texts can be a list of strings or a list of [title, text]
# get back list of numpy embedding vectors
emb = self.embedding_model.encode(texts)
emb = self.embedding_model.encode(texts, batch_size=200, show_progress_bar=False)
emb = [r for r in emb]
return emb
@ -558,13 +558,15 @@ class EmbeddingRetriever(BaseRetriever):
"""
return self.embed(texts)
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_passages(self, docs: List[Document]) -> Union[List[str], List[List[str]]]:
"""
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()
:param docs: List of documents to embed
:return: Embeddings, one per input passage
"""
texts = [d.text for d in docs]
return self.embed(texts)
if self.model_format == "sentence_transformers":
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.text] for d in docs] # type: ignore
else:
passages = [d.text for d in docs] # type: ignore
return self.embed(passages)

View File

@ -6,6 +6,10 @@
"elastic",
"elasticsearch"
],
[
"sentence_transformers",
"elasticsearch"
],
[
"dpr",
"elasticsearch"
@ -41,7 +45,7 @@
}
},
"filenames": {
"data_s3_url": "s3://ext-haystack-retriever-eval/",
"data_s3_url": "https://ext-haystack-retriever-eval.s3-eu-west-1.amazonaws.com/",
"data_dir": "../../data/retriever/",
"filename_gold": "nq2squad-dev.json",
"filenames_negative": {
@ -55,4 +59,4 @@
"100000": "wikipedia_passages_100k.pkl",
"1000000": "wikipedia_passages_1m.pkl"}
}
}
}

View File

@ -1,7 +1,7 @@
import pandas as pd
from pathlib import Path
from time import perf_counter
from utils import get_document_store, get_retriever, index_to_doc_store, load_config
from utils import get_document_store, get_retriever, index_to_doc_store, load_config, download_from_url
from haystack.preprocessor.utils import eval_data_from_json
from haystack.document_store.faiss import FAISSDocumentStore
@ -15,7 +15,7 @@ import random
import traceback
import os
import requests
from farm.file_utils import download_from_s3
from farm.file_utils import http_get
import json
from results_to_json import retriever as retriever_json
from templates import RETRIEVER_TEMPLATE, RETRIEVER_MAP_TEMPLATE, RETRIEVER_SPEED_TEMPLATE
@ -50,7 +50,7 @@ def benchmark_indexing(n_docs_options, retriever_doc_stores, data_dir, filename_
docs, _ = prepare_data(data_dir=data_dir,
filename_gold=filename_gold,
filename_negative=filename_negative,
data_s3_url=data_s3_url,
remote_url=data_s3_url,
embeddings_filenames=embeddings_filenames,
embeddings_dir=embeddings_dir,
n_docs=n_docs)
@ -134,7 +134,7 @@ def benchmark_querying(n_docs_options,
for retriever_name, doc_store_name in retriever_doc_stores:
try:
logger.info(f"##### Start querying run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ")
if retriever_name == "elastic":
if retriever_name in ["elastic", "sentence_transformers"]:
similarity = "cosine"
else:
similarity = "dot_product"
@ -145,7 +145,7 @@ def benchmark_querying(n_docs_options,
docs, labels = prepare_data(data_dir=data_dir,
filename_gold=filename_gold,
filename_negative=filename_negative,
data_s3_url=data_s3_url,
remote_url=data_s3_url,
embeddings_filenames=embeddings_filenames,
embeddings_dir=embeddings_dir,
n_docs=n_docs,
@ -254,7 +254,7 @@ def add_precomputed_embeddings(embeddings_dir, embeddings_filenames, docs):
return ret
def prepare_data(data_dir, filename_gold, filename_negative, data_s3_url, embeddings_filenames, embeddings_dir, n_docs=None, n_queries=None, add_precomputed=False):
def prepare_data(data_dir, filename_gold, filename_negative, remote_url, embeddings_filenames, embeddings_dir, n_docs=None, n_queries=None, add_precomputed=False):
"""
filename_gold points to a squad format file.
filename_negative points to a csv file where the first column is doc_id and second is document text.
@ -262,11 +262,11 @@ def prepare_data(data_dir, filename_gold, filename_negative, data_s3_url, embed
"""
logging.getLogger("farm").setLevel(logging.INFO)
download_from_s3(data_s3_url + filename_gold, cache_dir=data_dir)
download_from_s3(data_s3_url + filename_negative, cache_dir=data_dir)
download_from_url(remote_url + filename_gold, filepath=data_dir + filename_gold)
download_from_url(remote_url + filename_negative, filepath=data_dir + filename_negative)
if add_precomputed:
for embedding_filename in embeddings_filenames:
download_from_s3(data_s3_url + str(embeddings_dir) + embedding_filename, cache_dir=data_dir)
download_from_url(remote_url + str(embeddings_dir) + embedding_filename, filepath=data_dir + str(embeddings_dir) + embedding_filename)
logging.getLogger("farm").setLevel(logging.WARN)
gold_docs, labels = eval_data_from_json(data_dir + filename_gold)

View File

@ -1,3 +1,9 @@
# The benchmarks use
# - a variant of the Natural Questions Dataset (https://ai.google.com/research/NaturalQuestions) from Google Research
# licensed under CC BY-SA 3.0 (https://creativecommons.org/licenses/by-sa/3.0/)
# - the SQuAD 2.0 Dataset (https://rajpurkar.github.io/SQuAD-explorer/) from Rajpurkar et al.
# licensed under CC BY-SA 4.0 (https://creativecommons.org/licenses/by-sa/4.0/legalcode)
from retriever import benchmark_indexing, benchmark_querying
from reader import benchmark_reader
from utils import load_config

View File

@ -4,14 +4,16 @@ from haystack.document_store.memory import InMemoryDocumentStore
from haystack.document_store.elasticsearch import Elasticsearch, ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.retriever.sparse import ElasticsearchRetriever, TfidfRetriever
from haystack.retriever.dense import DensePassageRetriever
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from farm.file_utils import http_get
import logging
import subprocess
import time
import json
from typing import Union
from pathlib import Path
logger = logging.getLogger(__name__)
@ -29,6 +31,7 @@ def get_document_store(document_store_type, similarity='dot_product'):
if os.path.exists("haystack_test.db"):
os.remove("haystack_test.db")
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
assert document_store.get_document_count() == 0
elif document_store_type == "memory":
document_store = InMemoryDocumentStore()
elif document_store_type == "elasticsearch":
@ -36,6 +39,7 @@ def get_document_store(document_store_type, similarity='dot_product'):
client = Elasticsearch()
client.indices.delete(index='haystack_test*', ignore=[404])
document_store = ElasticsearchDocumentStore(index="eval_document", similarity=similarity, timeout=3000)
assert document_store.get_document_count(index="eval_document") == 0
elif document_store_type in("faiss_flat", "faiss_hnsw"):
if document_store_type == "faiss_flat":
index_type = "Flat"
@ -55,10 +59,10 @@ def get_document_store(document_store_type, similarity='dot_product'):
document_store = FAISSDocumentStore(sql_url="postgresql://postgres:password@localhost:5432/haystack",
faiss_index_factory_str=index_type,
similarity=similarity)
assert document_store.get_document_count() == 0
else:
raise Exception(f"No document store fixture for '{document_store_type}'")
assert document_store.get_document_count() == 0
return document_store
def get_retriever(retriever_name, doc_store):
@ -72,6 +76,11 @@ def get_retriever(retriever_name, doc_store):
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
use_fast_tokenizers=False)
if retriever_name == "sentence_transformers":
return EmbeddingRetriever(document_store=doc_store,
embedding_model="nq-distilbert-base-v1",
use_gpu=True,
model_format="sentence_transformers")
def get_reader(reader_name, reader_type, max_seq_len=384):
reader_class = None
@ -87,7 +96,7 @@ def index_to_doc_store(doc_store, docs, retriever, labels=None):
doc_store.write_labels(labels, index=label_index)
# these lines are not run if the docs.embedding field is already populated with precomputed embeddings
# See the prepare_data() fn in the retriever benchmark script
elif callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
if callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
doc_store.update_embeddings(retriever, index=doc_index)
def load_config(config_filename, ci):
@ -107,3 +116,25 @@ def load_config(config_filename, ci):
return params, filenames
def download_from_url(url: str, filepath:Union[str, Path]):
"""
Download from a url to a local file. Skip already existing files.
:param url: Url
:param filepath: local path where the url content shall be stored
:return: local path of the downloaded file
"""
logger.info(f"Downloading {url}")
# Create local folder
folder, filename = os.path.split(filepath)
if not os.path.exists(folder):
os.makedirs(folder)
# Download file if not present locally
if os.path.exists(filepath):
logger.info(f"Skipping {url} (exists locally)")
else:
logger.info(f"Downloading {url} to {filepath} ")
with open(filepath, "wb") as file:
http_get(url=url, temp_file=file)
return filepath