Create time and performance benchmarks for all readers and retrievers (#339)

* add time and perf benchmark for es

* Add retriever benchmarking

* Add Reader benchmarking

* add nq to squad conversion

* add conversion stats

* clean benchmarks

* Add link to dataset

* Update imports

* add first support for neg psgs

* Refactor test

* set max_seq_len

* cleanup benchmark

* begin retriever speed benchmarking

* Add support for retriever query index benchmarking

* improve reader eval, retriever speed benchmarking

* improve retriever speed benchmarking

* Add retriever accuracy benchmark

* Add neg doc shuffling

* Add top_n

* 3x speedup of SQL. add postgres docker run. make shuffle neg a param. add more logging

* Add models to sweep

* add option for faiss index type

* remove unneeded line

* change faiss to faiss_flat

* begin automatic benchmark script

* remove existing postgres docker for benchmarking

* Add data processing scripts

* Remove shuffle in script bc data already shuffled

* switch hnsw setup from 256 to 128

* change es similarity to dot product by default

* Error includes stack trace

* Change ES default timeout

* remove delete_docs() from timing for indexing

* Add support for website export

* update website on push to benchmarks

* add complete benchmarks results

* new json format

* removed NaN as is not a valid json token

* fix benchmarking for faiss hnsw queries. do sql calls in update_embeddings() as batches

* update benchmarks for hnsw 128,20,80

* don't delete full index in delete_all_documents()

* update texts for charts

* update recall column for retriever

* change scale and add units to desc

* add units to legend

* add axis titles. update desc

* add html tags

Co-authored-by: deepset <deepset@Crenolape.localdomain>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
Co-authored-by: PiffPaffM <markuspaff.mp@gmail.com>
This commit is contained in:
Branden Chan 2020-10-12 13:34:42 +02:00 committed by GitHub
parent 8edeb844f7
commit 1cebcb7dda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1233 additions and 85 deletions

View File

@ -4,7 +4,7 @@ name: Deploy website
# events but only for the master branch
on:
push:
branches: [ master ]
branches: [ master, benchmarks ]
jobs:
# This workflow contains a single job called "build"

View File

@ -2,38 +2,18 @@
"chart_type": "BarChart",
"title": "Reader Performance",
"subtitle": "Time and Accuracy Benchmarks",
"description": "",
"description": "Performance benchmarks of different Readers that can be used off-the-shelf in Haystack. Some models are geared towards speed, while others are more performance-focused. Accuracy is measured as F1 score and speed as passages/sec (with passages of 384 tokens). Each Reader is benchmarked using the SQuAD v2.0 development set, which contains 11866 question answer pairs. When tokenized using the BERT tokenizer and split using a sliding window approach, these become 12350 passages that are passed into the model. We set <i>max_seq_len=384</i> and <i>doc_stride=128</i>. These benchmarking tests are run using an AWS p3.2xlarge instance with a Nvidia V100 GPU with this <a href='https://github.com/deepset-ai/haystack/blob/master/test/benchmarks/reader.py'>script</a>. Please note that we are using the FARMReader class rather than the TransformersReader class. Also, the F1 measure that is reported here is in fact calculated on token level, rather than word level as is done in the official SQuAD script.",
"bars": "horizontal",
"columns": [
"Model",
"Top 10 Accuracy",
"Time"
"F1",
"Speed (passages/sec)"
],
"data": [
{
"model": "RoBERTa",
"accuracy": 72.22222222222222,
"time": 56.033346766999784
},
{
"model": "MiniLM",
"accuracy": 38.88888888888889,
"time": 49.28621050500078
},
{
"model": "BERT base",
"accuracy": 31.48148148148148,
"time": 42.67899718899935
},
{
"model": "BERT large",
"accuracy": 53.70370370370371,
"time": 74.21550956300052
},
{
"model": "XLMR large",
"accuracy": 72.22222222222222,
"time": 76.56486266000047
}
{"F1": 80.67985794671885, "Model": "RoBERTa", "Speed": 92.3039712094936},
{"F1": 78.23306265318686, "Model": "MiniLM", "Speed": 98.62387044489223},
{"F1": 74.90271600053505, "Model": "BERT base", "Speed": 99.92750782409666},
{"F1": 82.64545708097472, "Model": "BERT large", "Speed": 39.529824033964466},
{"F1": 85.26275190954586, "Model": "XLM-RoBERTa", "Speed": 39.29142006004379}
]
}

View File

@ -0,0 +1,103 @@
{
"chart_type": "LineChart",
"title": "Retriever Accuracy",
"subtitle": "mAP at different number of docs",
"description": "Here you can see how the mean avg. precision (mAP) of the retriever decays as the number of documents increases. The set up is the same as the above querying benchmark except that a varying number of negative documents are used to fill the document store.",
"columns": [
"n_docs",
"BM25 / ElasticSearch",
"DPR / ElasticSearch",
"DPR / FAISS (flat)",
"DPR / FAISS (HSNW)"
],
"axis": [
{ "x": "Number of docs", "y": "mAP" }
],
"data": [
[
"model",
"n_docs",
"map"
],
[
"DPR / ElasticSearch",
1000,
0.929
],
[
"BM25 / ElasticSearch",
1000,
0.748
],
[
"BM25 / ElasticSearch",
10000,
0.6609999999999999
],
[
"DPR / ElasticSearch",
10000,
0.898
],
[
"DPR / ElasticSearch",
100000,
0.863
],
[
"BM25 / ElasticSearch",
100000,
0.56
],
[
"DPR / FAISS (flat)",
1000,
0.929
],
[
"DPR / FAISS (flat)",
10000,
0.898
],
[
"DPR / FAISS (flat)",
100000,
0.863
],
[
"DPR / FAISS (HSNW)",
1000,
0.929
],
[
"DPR / FAISS (HSNW)",
10000,
0.8959999999999999
],
[
"DPR / FAISS (HSNW)",
100000,
0.8490000000000001
],
[
"DPR / ElasticSearch",
500000,
0.805
],
[
"DPR / FAISS (flat)",
500000,
0.805
],
[
"BM25 / ElasticSearch",
500000,
0.452
],
[
"DPR / FAISS (HSNW)",
500000,
0.7659999999999999
]
]
}

View File

@ -1,37 +1,53 @@
{
"chart_type": "BarChart",
"title": "Retriever Performance",
"subtitle": "Time and Accuracy Benchmarks",
"description": "",
"bars": "horizontal",
"columns": [
"Model",
"Recall",
"Index Time",
"Query Time"
],
"series": {
"s0": "recall",
"s1": "time",
"s2": "time"
"chart_type": "BarChart",
"title": "Retriever Performance",
"subtitle": "Time and Accuracy Benchmarks",
"description": "Comparison of the speed and accuracy of different DocumentStore / Retriever combinations on 100k documents. <b>Indexing speed</b> (in docs/sec) refers to how quickly Documents can be inserted into a DocumentStore. <b>Querying speed</b> (in queries/sec) refers to the speed at which the system returns relevant Documents when presented with a query.\n\nThe dataset used is Wikipedia, split into 100 word passages (from <a href='https://github.com/facebookresearch/DPR/blob/master/data/download_data.py'>here</a>)). \n\nFor querying, we use the Natural Questions development set in combination with the wiki passages. The Document Store is populated with the 100 word passages in which the answer spans occur (i.e. gold passages) as well as a random selection of 100 word passages in which the answer spans do not occur (i.e. negative passages). We take a total of 100k gold and negative passages. Query and document embedding are generated by the <i>\"facebook/dpr-question_encoder-single-nq-base\"</i> and <i>\"facebook/dpr-ctx_encoder-single-nq-base\"</i> models. The retriever returns 10 candidates and both the recall and mAP scores are calculated on these 10.\n\nFor FAISS HNSW, we use <i>n_links=128</i>, <i>efSearch=20</i> and <i>efConstruction=80</i>. Both index and query benchmarks are performed on an AWS P3.2xlarge instance which is accelerated by an Nvidia V100 GPU.",
"bars": "horizontal",
"columns": [
"Model",
"Recall",
"Index Speed (docs/sec)",
"Query Speed (queries/sec)"
],
"series": {
"s0": "recall",
"s1": "time",
"s2": "time"
},
"axes": {
"label": "recall",
"time_side": "top",
"time_label": "seconds"
},
"data": [
{
"model": "DPR / ElasticSearch",
"n_docs": 100000,
"recall": 95.8,
"index_speed": 79.54165185,
"query_speed": 6.5360000000000005
},
"axes": {
"label": "recall",
"time_side": "top",
"time_label": "seconds"
{
"model": "BM25 / ElasticSearch",
"n_docs": 100000,
"recall": 71.7,
"index_speed": 476.9143596,
"query_speed": 162.996
},
"data": [
{
"model": "BM25 (Elasticsearch)",
"recall": 1.0,
"index_time": 3.868239733,
"query_time": 0.47588522500000074
},
{
"model": "DPR (FAISS)",
"recall": 0.9629629629629629,
"index_time": 18.181850872000723,
"query_time": 1.358759985992947
}
]
}
{
"model": "DPR / FAISS (flat)",
"n_docs": 100000,
"recall": 95.8,
"index_speed": 107.8662479,
"query_speed": 5.044
},
{
"model": "DPR / FAISS (HSNW)",
"n_docs": 100000,
"recall": 94.0,
"index_speed": 92.24548333,
"query_speed": 12.815
}
]
}

View File

@ -0,0 +1,105 @@
{
"chart_type": "LineChart",
"title": "Retriever Speed",
"subtitle": "Query Speed at different number of docs",
"description": "Here you can see how the query speed of different Retriever / DocumentStore combinations scale as the number of documents increases. The set up is the same as the above querying benchmark except that a varying number of negative documents are used to fill the document store.",
"columns": [
"n_docs",
"BM25 / ElasticSearch",
"DPR / ElasticSearch",
"DPR / FAISS (flat)",
"DPR / FAISS (HSNW)"
],
"axis": [
{ "x": "Number of docs", "y": "Docs/sec" }
],
"data":
[
[
"model",
"n_docs",
"query_speed"
],
[
"DPR / ElasticSearch",
1000,
40.802
],
[
"BM25 / ElasticSearch",
1000,
232.97799999999998
],
[
"BM25 / ElasticSearch",
10000,
167.81
],
[
"DPR / ElasticSearch",
10000,
27.006999999999998
],
[
"DPR / ElasticSearch",
100000,
6.5360000000000005
],
[
"BM25 / ElasticSearch",
100000,
162.996
],
[
"DPR / FAISS (flat)",
1000,
40.048
],
[
"DPR / FAISS (flat)",
10000,
23.976999999999997
],
[
"DPR / FAISS (flat)",
100000,
5.044
],
[
"DPR / FAISS (HSNW)",
1000,
37.884
],
[
"DPR / FAISS (HSNW)",
10000,
33.421
],
[
"DPR / FAISS (HSNW)",
100000,
12.815
],
[
"DPR / ElasticSearch",
500000,
1.514
],
[
"DPR / FAISS (flat)",
500000,
1.091
],
[
"BM25 / ElasticSearch",
500000,
95.491
],
[
"DPR / FAISS (HSNW)",
500000,
3.259
]
]
}

View File

@ -40,6 +40,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
create_index: bool = True,
update_existing_documents: bool = False,
refresh_type: str = "wait_for",
similarity="dot_product",
timeout=30,
):
"""
A DocumentStore using Elasticsearch to store and query the documents for our search.
@ -75,9 +77,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
- 'wait_for' => continue only after changes are visible (slow, but safe)
- 'false' => continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion)
More info at https://www.elastic.co/guide/en/elasticsearch/reference/6.8/docs-refresh.html
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
:param timeout: Number of seconds after which an ElasticSearch request times out.
"""
self.client = Elasticsearch(hosts=[{"host": host, "port": port}], http_auth=(username, password),
scheme=scheme, ca_certs=ca_certs, verify_certs=verify_certs)
scheme=scheme, ca_certs=ca_certs, verify_certs=verify_certs, timeout=timeout)
# configure mappings to ES fields that will be used for querying / displaying results
if type(search_fields) == str:
@ -102,6 +109,12 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.label_index: str = label_index
self.update_existing_documents = update_existing_documents
self.refresh_type = refresh_type
if similarity == "cosine":
self.similarity_fn_name = "cosineSimilarity"
elif similarity == "dot_product":
self.similarity_fn_name = "dotProduct"
else:
raise Exception("Invalid value for similarity in ElasticSearchDocumentStore constructor. Choose between \'cosine\' and \'dot_product\'")
def _create_document_index(self, index_name):
"""
@ -420,14 +433,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if not self.embedding_field:
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
else:
# +1 in cosine similarity to avoid negative numbers
# +1 in similarity to avoid negative numbers (for cosine sim)
body= {
"size": top_k,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": f"cosineSimilarity(params.query_vector,'{self.embedding_field}') + 1.0",
"source": f"{self.similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1.0",
"params": {
"query_vector": query_emb.tolist()
}

View File

@ -1,9 +1,10 @@
import logging
from pathlib import Path
from typing import Union, List, Optional, Dict
from tqdm import tqdm
import faiss
import numpy as np
import random
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
@ -70,10 +71,11 @@ class FAISSDocumentStore(SQLDocumentStore):
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
n_links = kwargs.get("n_links", 256)
n_links = kwargs.get("n_links", 128)
index = faiss.IndexHNSWFlat(vector_dim, n_links, metric_type)
index.hnsw.efSearch = kwargs.get("efSearch", 256)
index.hnsw.efConstruction = kwargs.get("efConstruction", 256)
index.hnsw.efSearch = kwargs.get("efSearch", 20)#20
index.hnsw.efConstruction = kwargs.get("efConstruction", 80)#80
logger.info(f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}")
else:
index = faiss.index_factory(vector_dim, index_factory, metric_type)
return index
@ -87,7 +89,8 @@ class FAISSDocumentStore(SQLDocumentStore):
:return:
"""
# vector index
self.faiss_index = self.faiss_index or self._create_new_index(vector_dim=self.vector_dim)
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
# doc + metadata index
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
@ -124,16 +127,22 @@ class FAISSDocumentStore(SQLDocumentStore):
self.faiss_index.reset()
index = index or self.index
documents = self.get_all_documents(index=index)
logger.info(f"Updating embeddings for {len(documents)} docs ...")
if len(documents) == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
self.faiss_index = None
return
logger.info(f"Updating embeddings for {len(documents)} docs...")
embeddings = retriever.embed_passages(documents) # type: ignore
assert len(documents) == len(embeddings)
for i, doc in enumerate(documents):
doc.embedding = embeddings[i]
vector_id_map = {}
for i in range(0, len(documents), self.index_buffer_size):
logger.info("Indexing embeddings and updating vectors_ids...")
for i in tqdm(range(0, len(documents), self.index_buffer_size)):
vector_id_map = {}
vector_id = self.faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
embeddings = np.array(embeddings, dtype="float32")
@ -142,8 +151,7 @@ class FAISSDocumentStore(SQLDocumentStore):
for doc in documents[i: i + self.index_buffer_size]:
vector_id_map[doc.id] = vector_id
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
self.update_vector_ids(vector_id_map, index=index)
def train_index(self, documents: Optional[Union[List[dict], List[Document]]], embeddings: Optional[np.array] = None):
"""
@ -193,12 +201,11 @@ class FAISSDocumentStore(SQLDocumentStore):
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
# assign query score to each document
#assign query score to each document
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
doc.probability = (doc.score + 1) / 2
return documents
def save(self, file_path: Union[str, Path]):

View File

View File

@ -3,6 +3,7 @@ import multiprocessing
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from collections import defaultdict
from time import perf_counter
import numpy as np
from farm.data_handler.data_silo import DataSilo
@ -453,8 +454,10 @@ class FARMReader(BaseReader):
# Convert input format for FARM
farm_input = [v for v in d.values()]
n_queries = len([y for x in farm_input for y in x["qas"]])
# Create DataLoader that can be passed to the Evaluator
tic = perf_counter()
indices = range(len(farm_input))
dataset, tensor_names = self.inferencer.processor.dataset_from_dicts(farm_input, indices=indices)
data_loader = NamedDataLoader(dataset=dataset, batch_size=self.inferencer.batch_size, tensor_names=tensor_names)
@ -462,10 +465,15 @@ class FARMReader(BaseReader):
evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)
eval_results = evaluator.eval(self.inferencer.model)
toc = perf_counter()
reader_time = toc - tic
results = {
"EM": eval_results[0]["EM"],
"f1": eval_results[0]["f1"],
"top_n_accuracy": eval_results[0]["top_n_accuracy"]
"top_n_accuracy": eval_results[0]["top_n_accuracy"],
"top_n": self.inferencer.model.prediction_heads[0].n_best,
"reader_time": reader_time,
"seconds_per_query": reader_time / n_queries
}
return results

View File

View File

@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import List
import logging
from time import perf_counter
from functools import wraps
from tqdm import tqdm
from haystack import Document
from haystack.document_store.base import BaseDocumentStore
@ -24,6 +27,18 @@ class BaseRetriever(ABC):
"""
pass
def timing(self, fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if "retrieve_time" not in self.__dict__:
self.retrieve_time = 0
tic = perf_counter()
ret = fn(*args, **kwargs)
toc = perf_counter()
self.retrieve_time += toc - tic
return ret
return wrapper
def eval(
self,
label_index: str = "label",
@ -55,6 +70,8 @@ class BaseRetriever(ABC):
# Extract all questions for evaluation
filters = {"origin": [label_origin]}
timed_retrieve = self.timing(self.retrieve)
labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
correct_retrievals = 0
@ -70,9 +87,10 @@ class BaseRetriever(ABC):
question_label_dict[label.question] = deduplicated_doc_ids
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
logger.info("Performing eval queries...")
if open_domain:
for question, gold_answers in question_label_dict.items():
retrieved_docs = self.retrieve(question, top_k=top_k, index=doc_index)
for question, gold_answers in tqdm(question_label_dict.items()):
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
# check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs):
for gold_answer in gold_answers:
@ -82,8 +100,8 @@ class BaseRetriever(ABC):
break
# Option 2: Strict evaluation by document ids that are listed in the labels
else:
for question, gold_ids in question_label_dict.items():
retrieved_docs = self.retrieve(question, top_k=top_k, index=doc_index)
for question, gold_ids in tqdm(question_label_dict.items()):
retrieved_docs = timed_retrieve(question, top_k=top_k, index=doc_index)
# check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs):
for gold_id in gold_ids:
@ -99,4 +117,4 @@ class BaseRetriever(ABC):
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
f" the top-{top_k} candidate passages selected by the retriever."))
return {"recall": recall, "map": mean_avg_precision}
return {"recall": recall, "map": mean_avg_precision, "retrieve_time": self.retrieve_time, "n_questions": number_of_questions, "top_k": top_k}

View File

@ -0,0 +1,52 @@
import pickle
from pathlib import Path
from tqdm import tqdm
import json
n_passages = 1_000_000
embeddings_dir = Path("embeddings")
embeddings_filenames = [f"wikipedia_passages_{i}.pkl" for i in range(50)]
neg_passages_filename = "psgs_w100_minus_gold.tsv"
gold_passages_filename = "nq2squad-dev.json"
# Extract gold passage ids
passage_ids = []
gold_data = json.load(open(gold_passages_filename))["data"]
for d in gold_data:
for p in d["paragraphs"]:
passage_ids.append(str(p["passage_id"]))
print("gold_ids")
print(len(passage_ids))
print()
# Extract neg passage ids
with open(neg_passages_filename) as f:
f.readline() # Ignore column headers
for _ in range(n_passages - len(passage_ids)):
l = f.readline()
passage_ids.append(str(l.split()[0]))
assert len(passage_ids) == len(set(passage_ids))
assert set([type(x) for x in passage_ids]) == {str}
passage_ids = set(passage_ids)
print("all_ids")
print(len(passage_ids))
print()
# Gather vectors for passages
ret = []
for ef in tqdm(embeddings_filenames):
curr = pickle.load(open(embeddings_dir / ef, "rb"))
for i, vec in curr:
if i in passage_ids:
ret.append((i, vec))
print("n_vectors")
print(len(ret))
print()
# Write vectors to file
with open(f"wikipedia_passages_{n_passages}.pkl", "wb") as f:
pickle.dump(ret, f)

View File

@ -0,0 +1,20 @@
import json
from tqdm import tqdm
import time
import random
random.seed(42)
lines = []
with open("psgs_w100_minus_gold_unshuffled.tsv") as f:
f.readline() # Remove column header
lines = [l for l in tqdm(f)]
tic = time.perf_counter()
random.shuffle(lines)
toc = time.perf_counter()
t = toc - tic
print(t)
with open("psgs_w100_minus_gold.tsv", "w") as f:
f.write("id\ttext\title\n")
for l in tqdm(lines):
f.write(l)

View File

@ -0,0 +1,281 @@
#!/usr/bin/python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
DEEPSET DOCSTRING:
A modified version of the script from here:
https://github.com/google/retrieval-qa-eval/blob/master/nq_to_squad.py
Edits have been made by deepset in order to create a dev set for Haystack benchmarking.
Input should be the official NQ dev set (v1.0-simplified-nq-dev-all.jsonl.gz)
Expected numbers are:
Converted 7830 NQ records into 5678 SQuAD records.
Removed samples: yes/no: 177 multi_short: 648 non_para 1192 long_ans_only: 130 errors: 5
Removed annotations: long_answer: 4610 short_answer: 953 no_answer: ~1006
where:
multi_short - annotations where there are multiple disjoint short answers
non_para - where the annotation occurs in an html element that is not a paragraph
ORIGINAL DOCSTRING:
Convert the Natural Questions dataset into SQuAD JSON format.
To use this utility, first follow the directions at the URL below to download
the complete training dataset.
https://ai.google.com/research/NaturalQuestions/download
Next, run this program, specifying the data you wish to convert. For instance,
the invocation:
python nq_to_squad.py\
--data_pattern=/usr/local/data/tnq/v1.0/train/*.gz\
--output_file=/usr/local/data/tnq/v1.0/train.json
will process all training data and write the results into `train.json`. This
file can, in turn, be provided to squad_eval.py using the --squad argument.
"""
import argparse
import glob
import gzip
import json
import logging
import os
import re
# Dropped samples
n_yn = 0
n_ms = 0
n_non_p = 0
n_long_ans_only = 0
n_error = 0
# Dropped annotations
n_long_ans = 0
n_no_ans = 0
n_short = 0
def clean_text(start_token, end_token, doc_tokens, doc_bytes,
ignore_final_whitespace=True):
"""Remove HTML tags from a text span and reconstruct proper spacing."""
text = ""
for index in range(start_token, end_token):
token = doc_tokens[index]
if token["html_token"]:
continue
text += token["token"]
# Add a single space between two tokens iff there is at least one
# whitespace character between them (outside of an HTML tag). For example:
#
# token1 token2 ==> Add space.
# token1</B> <B>token2 ==> Add space.
# token1</A>token2 ==> No space.
# token1<A href="..." title="...">token2 ==> No space.
# token1<SUP>2</SUP>token2 ==> No space.
next_token = token
last_index = end_token if ignore_final_whitespace else end_token + 1
for next_token in doc_tokens[index + 1:last_index]:
if not next_token["html_token"]:
break
chars = (doc_bytes[token["end_byte"]:next_token["start_byte"]]
.decode("utf-8"))
# Since some HTML tags are missing from the token list, we count '<' and
# '>' to detect if we're inside a tag.
unclosed_brackets = 0
for char in chars:
if char == "<":
unclosed_brackets += 1
elif char == ">":
unclosed_brackets -= 1
elif unclosed_brackets == 0 and re.match(r"\s", char):
# Add a single space after this token.
text += " "
break
return text
def get_anno_type(annotation):
long_answer = annotation["long_answer"]
short_answers = annotation["short_answers"]
yes_no_answer = annotation["yes_no_answer"]
if len(short_answers) > 1:
return "multi_short"
elif yes_no_answer != "NONE":
return yes_no_answer
elif len(short_answers) == 1:
return "short_answer"
elif len(short_answers) == 0:
if long_answer["start_token"] == -1:
return "no_answer"
else:
return "long_answer"
def reduce_annotations(anno_types, answers):
"""
In cases where there is annotator disagreement, this fn picks either only the short_answers or only the no_answers,
depending on which is more numerous, with a bias towards picking short_answers.
Note: By this stage, all long_answer annotations and all samples with yes/no answer have been removed.
This leaves just no_answer and short_answers"""
for at in set(anno_types):
assert at in ("no_answer", "short_answer")
if anno_types.count("short_answer") >= anno_types.count("no_answer"):
majority = "short_answer"
is_impossible = False
else:
majority = "no_answer"
is_impossible = True
answers = [a for at, a in zip(anno_types, answers) if at == majority]
reduction = len(anno_types) - len(answers)
assert reduction < 3
if not is_impossible:
global n_no_ans
n_no_ans += reduction
else:
global n_short
n_short += reduction
answers = []
return answers, is_impossible
def nq_to_squad(record):
"""Convert a Natural Questions record to SQuAD format."""
doc_bytes = record["document_html"].encode("utf-8")
doc_tokens = record["document_tokens"]
question_text = record["question_text"]
question_text = question_text[0].upper() + question_text[1:] + "?"
answers = []
anno_types = []
for annotation in record["annotations"]:
anno_type = get_anno_type(annotation)
long_answer = annotation["long_answer"]
short_answers = annotation["short_answers"]
if anno_type.lower() in ["yes", "no"]:
global n_yn
n_yn += 1
return
# Skip examples that don't have exactly one short answer.
# Note: Consider including multi-span short answers.
if anno_type == "multi_short":
global n_ms
n_ms += 1
return
elif anno_type == "short_answer":
short_answer = short_answers[0]
# Skip examples corresponding to HTML blocks other than <P>.
long_answer_html_tag = doc_tokens[long_answer["start_token"]]["token"]
if long_answer_html_tag != "<P>":
global n_non_p
n_non_p += 1
return
answer = clean_text(
short_answer["start_token"], short_answer["end_token"], doc_tokens,
doc_bytes)
before_answer = clean_text(
0, short_answer["start_token"], doc_tokens,
doc_bytes, ignore_final_whitespace=False)
elif anno_type == "no_answer":
answer = ""
before_answer = ""
# Throw out long answer annotations
elif anno_type == "long_answer":
global n_long_ans
n_long_ans += 1
continue
anno_types.append(anno_type)
answer = {"answer_start": len(before_answer),
"text": answer}
answers.append(answer)
if len(answers) == 0:
global n_long_ans_only
n_long_ans_only += 1
return
answers, is_impossible = reduce_annotations(anno_types, answers)
paragraph = clean_text(
0, len(doc_tokens), doc_tokens,
doc_bytes)
return {"title": record["document_title"],
"paragraphs":
[{"context": paragraph,
"qas": [{"answers": answers,
"id": record["example_id"],
"question": question_text,
"is_impossible": is_impossible}]}]}
def main():
parser = argparse.ArgumentParser(
description="Convert the Natural Questions to SQuAD JSON format.")
parser.add_argument("--data_pattern", dest="data_pattern",
help=("A file pattern to match the Natural Questions "
"dataset."),
metavar="PATTERN", required=True)
parser.add_argument("--version", dest="version",
help="The version label in the output file.",
metavar="LABEL", default="nq-train")
parser.add_argument("--output_file", dest="output_file",
help="The name of the SQuAD JSON formatted output file.",
metavar="FILE", default="nq_as_squad.json")
args = parser.parse_args()
root = logging.getLogger()
root.setLevel(logging.DEBUG)
records = 0
nq_as_squad = {"version": args.version, "data": []}
for file in sorted(glob.iglob(args.data_pattern)):
logging.info("opening %s", file)
with gzip.GzipFile(file, "r") as f:
for line in f:
records += 1
nq_record = json.loads(line)
try:
squad_record = nq_to_squad(nq_record)
except:
squad_record = None
global n_error
n_error += 1
if squad_record:
nq_as_squad["data"].append(squad_record)
if records % 100 == 0:
logging.info("processed %s records", records)
print("Converted %s NQ records into %s SQuAD records." %
(records, len(nq_as_squad["data"])))
print(f"Removed samples: yes/no: {n_yn} multi_short: {n_ms} non_para {n_non_p} long_ans_only: {n_long_ans_only} errors: {n_error}")
print(f"Removed annotations: long_answer: {n_long_ans} short_answer: {n_short} no_answer: ~{n_no_ans}")
with open(args.output_file, "w") as f:
json.dump(nq_as_squad, f, indent=4)
if __name__ == "__main__":
main()

54
test/benchmarks/reader.py Normal file
View File

@ -0,0 +1,54 @@
from utils import get_document_store, index_to_doc_store, get_reader
from haystack.preprocessor.utils import eval_data_from_file
from pathlib import Path
import pandas as pd
reader_models = ["deepset/roberta-base-squad2", "deepset/minilm-uncased-squad2",
"deepset/bert-base-cased-squad2", "deepset/bert-large-uncased-whole-word-masking-squad2",
"deepset/xlm-roberta-large-squad2", "distilbert-base-uncased-distilled-squad"]
reader_types = ["farm"]
data_dir = Path("../../data/squad20")
filename = "dev-v2.0.json"
# Note that this number is approximate - it was calculated using Bert Base Cased
# This number could vary when using a different tokenizer
n_passages = 12350
doc_index = "eval_document"
label_index = "label"
def benchmark_reader():
reader_results = []
doc_store = get_document_store("elasticsearch")
docs, labels = eval_data_from_file(data_dir/filename)
index_to_doc_store(doc_store, docs, None, labels)
for reader_name in reader_models:
for reader_type in reader_types:
try:
reader = get_reader(reader_name, reader_type)
results = reader.eval(document_store=doc_store,
doc_index=doc_index,
label_index=label_index,
device="cuda")
# print(results)
results["passages_per_second"] = n_passages / results["reader_time"]
results["reader"] = reader_name
results["error"] = ""
reader_results.append(results)
except Exception as e:
results = {'EM': 0.,
'f1': 0.,
'top_n_accuracy': 0.,
'top_n': 0,
'reader_time': 0.,
"passages_per_second": 0.,
"seconds_per_query": 0.,
'reader': reader_name,
"error": e}
reader_results.append(results)
reader_df = pd.DataFrame.from_records(reader_results)
reader_df.to_csv("reader_results.csv")
if __name__ == "__main__":
benchmark_reader()

View File

@ -0,0 +1,6 @@
,EM,f1,top_n_accuracy,top_n,reader_time,seconds_per_query,passages_per_second,reader,error
0,0.7589752233271532,0.8067985794671885,0.9671329849991572,5,133.79706027999998,0.011275666634080564,92.30397120949361,deepset/roberta-base-squad2,
1,0.7359683128265633,0.7823306265318686,0.9714309792684982,5,125.22323393199997,0.010553112584864317,98.62387044489225,deepset/minilm-uncased-squad2,
2,0.700825889094893,0.7490271600053505,0.9585369964604753,5,123.58959278499992,0.010415438461570867,99.92750782409666,deepset/bert-base-cased-squad2,
3,0.7821506826226192,0.8264545708097472,0.9762346199224675,5,312.42233685099995,0.026329204184308102,39.529824033964466,deepset/bert-large-uncased-whole-word-masking-squad2,
4,0.8099612337771785,0.8526275190954586,0.9772459126917242,5,314.3179854819998,0.026488958830439897,39.29142006004379,deepset/xlm-roberta-large-squad2,
1 EM f1 top_n_accuracy top_n reader_time seconds_per_query passages_per_second reader error
2 0 0.7589752233271532 0.8067985794671885 0.9671329849991572 5 133.79706027999998 0.011275666634080564 92.30397120949361 deepset/roberta-base-squad2
3 1 0.7359683128265633 0.7823306265318686 0.9714309792684982 5 125.22323393199997 0.010553112584864317 98.62387044489225 deepset/minilm-uncased-squad2
4 2 0.700825889094893 0.7490271600053505 0.9585369964604753 5 123.58959278499992 0.010415438461570867 99.92750782409666 deepset/bert-base-cased-squad2
5 3 0.7821506826226192 0.8264545708097472 0.9762346199224675 5 312.42233685099995 0.026329204184308102 39.529824033964466 deepset/bert-large-uncased-whole-word-masking-squad2
6 4 0.8099612337771785 0.8526275190954586 0.9772459126917242 5 314.3179854819998 0.026488958830439897 39.29142006004379 deepset/xlm-roberta-large-squad2

View File

@ -0,0 +1,101 @@
import json
import pandas as pd
from pprint import pprint
def reader():
model_rename_map = {
'deepset/roberta-base-squad2': "RoBERTa",
'deepset/minilm-uncased-squad2': "MiniLM",
'deepset/bert-base-cased-squad2': "BERT base",
'deepset/bert-large-uncased-whole-word-masking-squad2': "BERT large",
'deepset/xlm-roberta-large-squad2': "XLM-RoBERTa",
}
column_name_map = {
"f1": "F1",
"passages_per_second": "Speed",
"reader": "Model"
}
df = pd.read_csv("reader_results.csv")
df = df[["f1", "passages_per_second", "reader"]]
df["reader"] = df["reader"].map(model_rename_map)
df = df[list(column_name_map)]
df = df.rename(columns=column_name_map)
ret = [dict(row) for i, row in df.iterrows()]
print("Reader overview")
print(json.dumps(ret, indent=2))
def retriever():
column_name_map = {
"model": "model",
"n_docs": "n_docs",
"docs_per_second": "index_speed",
"queries_per_second": "query_speed",
"map": "map"
}
name_cleaning = {
"dpr": "DPR",
"elastic": "BM25",
"elasticsearch": "ElasticSearch",
"faiss": "FAISS",
"faiss_flat": "FAISS (flat)",
"faiss_hnsw": "FAISS (HSNW)"
}
index = pd.read_csv("retriever_index_results.csv")
query = pd.read_csv("retriever_query_results.csv")
df = pd.merge(index, query,
how="right",
left_on=["retriever", "doc_store", "n_docs"],
right_on=["retriever", "doc_store", "n_docs"])
df["retriever"] = df["retriever"].map(name_cleaning)
df["doc_store"] = df["doc_store"].map(name_cleaning)
df["model"] = df["retriever"] + " / " + df["doc_store"]
df = df[list(column_name_map)]
df = df.rename(columns=column_name_map)
print("Retriever overview")
print(retriever_overview(df))
print("Retriever MAP")
print(retriever_map(df))
print("Retriever Speed")
print(retriever_speed(df))
def retriever_map(df):
columns = ["model", "n_docs", "map"]
df = df[columns]
ret = [list(row) for i, row in df.iterrows()]
ret = [columns] + ret
return json.dumps(ret, indent=4)
def retriever_speed(df):
columns = ["model", "n_docs", "query_speed"]
df = df[columns]
ret = [list(row) for i, row in df.iterrows()]
ret = [columns] + ret
return json.dumps(ret, indent=4)
def retriever_overview(df, chosen_n_docs=100_000):
df = df[df["n_docs"] == chosen_n_docs]
ret = [dict(row) for i, row in df.iterrows()]
return json.dumps(ret, indent=2)
if __name__ == "__main__":
reader()
retriever()

View File

@ -0,0 +1,224 @@
import pandas as pd
from pathlib import Path
from time import perf_counter
from utils import get_document_store, get_retriever, index_to_doc_store
from haystack.preprocessor.utils import eval_data_from_file
from haystack import Document
import pickle
import time
from tqdm import tqdm
import logging
import datetime
import random
import traceback
logger = logging.getLogger(__name__)
logging.getLogger("haystack.retriever.base").setLevel(logging.WARN)
logging.getLogger("elasticsearch").setLevel(logging.WARN)
es_similarity = "dot_product"
retriever_doc_stores = [
# ("elastic", "elasticsearch"),
# ("dpr", "elasticsearch"),
# ("dpr", "faiss_flat"),
("dpr", "faiss_hnsw")
]
n_docs_options = [
1000,
10000,
100000,
500000,
]
# If set to None, querying will be run on all queries
n_queries = None
data_dir = Path("../../data/retriever")
filename_gold = "nq2squad-dev.json" # Found at s3://ext-haystack-retriever-eval
filename_negative = "psgs_w100_minus_gold.tsv" # Found at s3://ext-haystack-retriever-eval
embeddings_dir = Path("embeddings")
embeddings_filenames = [f"wikipedia_passages_1m.pkl"] # Found at s3://ext-haystack-retriever-eval
doc_index = "eval_document"
label_index = "label"
seed = 42
random.seed(42)
def prepare_data(data_dir, filename_gold, filename_negative, n_docs=None, n_queries=None, add_precomputed=False):
"""
filename_gold points to a squad format file.
filename_negative points to a csv file where the first column is doc_id and second is document text.
If add_precomputed is True, this fn will look in the embeddings files for precomputed embeddings to add to each Document
"""
gold_docs, labels = eval_data_from_file(data_dir / filename_gold)
# Reduce number of docs
gold_docs = gold_docs[:n_docs]
# Remove labels whose gold docs have been removed
doc_ids = [x.id for x in gold_docs]
labels = [x for x in labels if x.document_id in doc_ids]
# Filter labels down to n_queries
selected_queries = list(set(f"{x.document_id} | {x.question}" for x in labels))
selected_queries = selected_queries[:n_queries]
labels = [x for x in labels if f"{x.document_id} | {x.question}" in selected_queries]
n_neg_docs = max(0, n_docs - len(gold_docs))
neg_docs = prepare_negative_passages(data_dir, filename_negative, n_neg_docs)
docs = gold_docs + neg_docs
if add_precomputed:
docs = add_precomputed_embeddings(data_dir / embeddings_dir, embeddings_filenames, docs)
return docs, labels
def prepare_negative_passages(data_dir, filename_negative, n_docs):
if n_docs == 0:
return []
with open(data_dir / filename_negative) as f:
lines = []
_ = f.readline() # Skip column titles line
for _ in range(n_docs):
lines.append(f.readline()[:-1])
docs = []
for l in lines[:n_docs]:
id, text, title = l.split("\t")
d = {"text": text,
"meta": {"passage_id": int(id),
"title": title}}
d = Document(**d)
docs.append(d)
return docs
def benchmark_indexing():
retriever_results = []
for n_docs in n_docs_options:
for retriever_name, doc_store_name in retriever_doc_stores:
doc_store = get_document_store(doc_store_name, es_similarity=es_similarity)
retriever = get_retriever(retriever_name, doc_store)
docs, _ = prepare_data(data_dir, filename_gold, filename_negative, n_docs=n_docs)
tic = perf_counter()
index_to_doc_store(doc_store, docs, retriever)
toc = perf_counter()
indexing_time = toc - tic
print(indexing_time)
retriever_results.append({
"retriever": retriever_name,
"doc_store": doc_store_name,
"n_docs": n_docs,
"indexing_time": indexing_time,
"docs_per_second": n_docs / indexing_time,
"date_time": datetime.datetime.now()})
retriever_df = pd.DataFrame.from_records(retriever_results)
retriever_df = retriever_df.sort_values(by="retriever").sort_values(by="doc_store")
retriever_df.to_csv("retriever_index_results.csv")
doc_store.delete_all_documents(index=doc_index)
doc_store.delete_all_documents(index=label_index)
time.sleep(10)
del doc_store
del retriever
def benchmark_querying():
""" Benchmark the time it takes to perform querying. Doc embeddings are loaded from file."""
retriever_results = []
for n_docs in n_docs_options:
for retriever_name, doc_store_name in retriever_doc_stores:
try:
logger.info(f"##### Start run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ")
doc_store = get_document_store(doc_store_name, es_similarity=es_similarity)
retriever = get_retriever(retriever_name, doc_store)
add_precomputed = retriever_name in ["dpr"]
# For DPR, precomputed embeddings are loaded from file
docs, labels = prepare_data(data_dir,
filename_gold,
filename_negative,
n_docs=n_docs,
n_queries=n_queries,
add_precomputed=add_precomputed)
logger.info("Start indexing...")
index_to_doc_store(doc_store, docs, retriever, labels)
logger.info("Start queries...")
raw_results = retriever.eval()
results = {
"retriever": retriever_name,
"doc_store": doc_store_name,
"n_docs": n_docs,
"n_queries": raw_results["n_questions"],
"retrieve_time": raw_results["retrieve_time"],
"queries_per_second": raw_results["n_questions"] / raw_results["retrieve_time"],
"seconds_per_query": raw_results["retrieve_time"]/ raw_results["n_questions"],
"recall": raw_results["recall"],
"map": raw_results["map"],
"top_k": raw_results["top_k"],
"date_time": datetime.datetime.now(),
"error": None
}
doc_store.delete_all_documents()
time.sleep(5)
del doc_store
del retriever
except Exception as e:
tb = traceback.format_exc()
results = {
"retriever": retriever_name,
"doc_store": doc_store_name,
"n_docs": n_docs,
"n_queries": 0,
"retrieve_time": 0.,
"queries_per_second": 0.,
"seconds_per_query": 0.,
"recall": 0.,
"map": 0.,
"top_k": 0,
"date_time": datetime.datetime.now(),
"error": str(tb)
}
logger.info(results)
retriever_results.append(results)
retriever_df = pd.DataFrame.from_records(retriever_results)
retriever_df = retriever_df.sort_values(by="retriever").sort_values(by="doc_store")
retriever_df.to_csv("retriever_query_results.csv")
def add_precomputed_embeddings(embeddings_dir, embeddings_filenames, docs):
ret = []
id_to_doc = {x.meta["passage_id"]: x for x in docs}
for ef in embeddings_filenames:
logger.info(f"Adding precomputed embeddings from {embeddings_dir / ef}")
filename = embeddings_dir / ef
embeds = pickle.load(open(filename, "rb"))
for i, vec in embeds:
if int(i) in id_to_doc:
curr = id_to_doc[int(i)]
curr.embedding = vec
ret.append(curr)
# In the official DPR repo, there are only 20594995 precomputed embeddings for 21015324 wikipedia passages
# If there isn't an embedding for a given doc, we remove it here
ret = [x for x in ret if x.embedding is not None]
logger.info(f"Embeddings loaded for {len(ret)}/{len(docs)} docs")
return ret
if __name__ == "__main__":
# benchmark_indexing()
benchmark_querying()

View File

@ -0,0 +1,13 @@
retriever,doc_store,n_docs,indexing_time,docs_per_second,date_time,Notes
dpr,elasticsearch,1000,14.16526405,70.59522482,2020-10-08 10:30:56,
elastic,elasticsearch,1000,5.805040058,172.2640998,2020-10-08 10:30:25,
elastic,elasticsearch,10000,22.56448254,443.1743553,2020-10-08 13:01:09,
dpr,elasticsearch,10000,126.2442168,79.21154929,2020-10-08 13:03:32,
dpr,elasticsearch,100000,1257.202958,79.54165185,2020-10-08 13:28:16,
elastic,elasticsearch,100000,209.681252,476.9143596,2020-10-08 13:07:05,
dpr,faiss_flat,1000,8.223732258,121.5992895,44112.24392,
dpr,faiss_flat,10000,89.72649358,111.4498026,44112.24663,
dpr,faiss_flat,100000,927.0740565,107.8662479,44112.56656,
dpr,faiss_hnsw,1000,8.86507699,112.8021788,44113.37262,"hnsw 128,20,80"
dpr,faiss_hnsw,10000,100.1804832,99.81984193,44113.37413,"hnsw 128,20,80"
dpr,faiss_hnsw,100000,1084.063917,92.24548333,44113.38721,"hnsw 128,20,80"
1 retriever doc_store n_docs indexing_time docs_per_second date_time Notes
2 dpr elasticsearch 1000 14.16526405 70.59522482 2020-10-08 10:30:56
3 elastic elasticsearch 1000 5.805040058 172.2640998 2020-10-08 10:30:25
4 elastic elasticsearch 10000 22.56448254 443.1743553 2020-10-08 13:01:09
5 dpr elasticsearch 10000 126.2442168 79.21154929 2020-10-08 13:03:32
6 dpr elasticsearch 100000 1257.202958 79.54165185 2020-10-08 13:28:16
7 elastic elasticsearch 100000 209.681252 476.9143596 2020-10-08 13:07:05
8 dpr faiss_flat 1000 8.223732258 121.5992895 44112.24392
9 dpr faiss_flat 10000 89.72649358 111.4498026 44112.24663
10 dpr faiss_flat 100000 927.0740565 107.8662479 44112.56656
11 dpr faiss_hnsw 1000 8.86507699 112.8021788 44113.37262 hnsw 128,20,80
12 dpr faiss_hnsw 10000 100.1804832 99.81984193 44113.37413 hnsw 128,20,80
13 dpr faiss_hnsw 100000 1084.063917 92.24548333 44113.38721 hnsw 128,20,80

View File

@ -0,0 +1,17 @@
retriever,doc_store,n_docs,n_queries,retrieve_time,queries_per_second,seconds_per_query,recall,map,top_k,date_time,error,name,note
dpr,elasticsearch,1000,1085,26.592,40.802,0.025,0.991,0.929,10,2020-10-07 15:06:57,,dpr-elasticsearch,
dpr,elasticsearch,10000,5791,214.425,27.007,0.037,0.975,0.898,10,2020-10-07 15:11:35,,dpr-elasticsearch,
dpr,elasticsearch,100000,5791,886.045,6.536,0.153,0.958,0.863,10,2020-10-07 15:30:52,,dpr-elasticsearch,
dpr,elasticsearch,500000,5791,3824.624,1.514,0.660,0.930,0.805,10,2020-10-07 17:44:02,,dpr-elasticsearch,
dpr,faiss_flat,1000,1085,27.092,40.048,0.025,0.991,0.929,10,2020-10-07 13:06:35,,dpr-faiss_flat,
dpr,faiss_flat,10000,5791,241.524,23.977,0.042,0.975,0.898,10,2020-10-07 13:17:21,,dpr-faiss_flat,
dpr,faiss_flat,100000,5791,1148.181,5.044,0.198,0.958,0.863,10,2020-10-07 14:04:51,,dpr-faiss_flat,
dpr,faiss_flat,500000,5791,5308.016,1.091,0.917,0.930,0.805,10,2020-10-08 10:01:32,,dpr-faiss_flat,
elastic,elasticsearch,1000,1085,4.657,232.978,0.004,0.891,0.748,10,2020-10-07 13:04:47,,elastic-elasticsearch,
elastic,elasticsearch,10000,5791,34.509,167.810,0.006,0.811,0.661,10,2020-10-07 13:07:52,,elastic-elasticsearch,
elastic,elasticsearch,100000,5791,35.529,162.996,0.006,0.717,0.560,10,2020-10-07 13:21:48,,elastic-elasticsearch,
elastic,elasticsearch,500000,5791,60.645,95.491,0.010,0.624,0.452,10,2020-10-07 16:14:52,,elastic-elasticsearch,
dpr,faiss_hnsw,1000,1085,28.640,37.884,0.026,0.991,0.929,10,2020-10-09 07:19:29,,dpr-faiss_hnsw,"128,20,80"
dpr,faiss_hnsw,10000,5791,173.272,33.421,0.030,0.972,0.896,10,2020-10-09 07:23:28,,dpr-faiss_hnsw,"128,20,80"
dpr,faiss_hnsw,100000,5791,451.884,12.815,0.078,0.940,0.849,10,2020-10-09 07:37:56,,dpr-faiss_hnsw,"128,20,80"
dpr,faiss_hnsw,500000,5791,1777.023,3.259,0.307,0.882,0.766,10,2020-10-09,,dpr-faiss_hnsw,"128,20,80"
1 retriever doc_store n_docs n_queries retrieve_time queries_per_second seconds_per_query recall map top_k date_time error name note
2 dpr elasticsearch 1000 1085 26.592 40.802 0.025 0.991 0.929 10 2020-10-07 15:06:57 dpr-elasticsearch
3 dpr elasticsearch 10000 5791 214.425 27.007 0.037 0.975 0.898 10 2020-10-07 15:11:35 dpr-elasticsearch
4 dpr elasticsearch 100000 5791 886.045 6.536 0.153 0.958 0.863 10 2020-10-07 15:30:52 dpr-elasticsearch
5 dpr elasticsearch 500000 5791 3824.624 1.514 0.660 0.930 0.805 10 2020-10-07 17:44:02 dpr-elasticsearch
6 dpr faiss_flat 1000 1085 27.092 40.048 0.025 0.991 0.929 10 2020-10-07 13:06:35 dpr-faiss_flat
7 dpr faiss_flat 10000 5791 241.524 23.977 0.042 0.975 0.898 10 2020-10-07 13:17:21 dpr-faiss_flat
8 dpr faiss_flat 100000 5791 1148.181 5.044 0.198 0.958 0.863 10 2020-10-07 14:04:51 dpr-faiss_flat
9 dpr faiss_flat 500000 5791 5308.016 1.091 0.917 0.930 0.805 10 2020-10-08 10:01:32 dpr-faiss_flat
10 elastic elasticsearch 1000 1085 4.657 232.978 0.004 0.891 0.748 10 2020-10-07 13:04:47 elastic-elasticsearch
11 elastic elasticsearch 10000 5791 34.509 167.810 0.006 0.811 0.661 10 2020-10-07 13:07:52 elastic-elasticsearch
12 elastic elasticsearch 100000 5791 35.529 162.996 0.006 0.717 0.560 10 2020-10-07 13:21:48 elastic-elasticsearch
13 elastic elasticsearch 500000 5791 60.645 95.491 0.010 0.624 0.452 10 2020-10-07 16:14:52 elastic-elasticsearch
14 dpr faiss_hnsw 1000 1085 28.640 37.884 0.026 0.991 0.929 10 2020-10-09 07:19:29 dpr-faiss_hnsw 128,20,80
15 dpr faiss_hnsw 10000 5791 173.272 33.421 0.030 0.972 0.896 10 2020-10-09 07:23:28 dpr-faiss_hnsw 128,20,80
16 dpr faiss_hnsw 100000 5791 451.884 12.815 0.078 0.940 0.849 10 2020-10-09 07:37:56 dpr-faiss_hnsw 128,20,80
17 dpr faiss_hnsw 500000 5791 1777.023 3.259 0.307 0.882 0.766 10 2020-10-09 dpr-faiss_hnsw 128,20,80

24
test/benchmarks/run.py Normal file
View File

@ -0,0 +1,24 @@
from retriever import benchmark_indexing, benchmark_querying
from reader import benchmark_reader
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--reader', default=False, action="store_true",
help='Perform Reader benchmarks')
parser.add_argument('--retriever_index', default=False, action="store_true",
help='Perform Retriever indexing benchmarks')
parser.add_argument('--retriever_query', default=False, action="store_true",
help='Perform Retriever querying benchmarks')
parser.add_argument('--ci', default=False, action="store_true",
help='Perform a smaller subset of benchmarks that are quicker to run')
args = parser.parse_args()
if args.retriever_index:
benchmark_indexing(ci)
if args.retriever_query:
benchmark_querying(ci)
if args.retriever_reader:
benchmark_reader(ci)

106
test/benchmarks/utils.py Normal file
View File

@ -0,0 +1,106 @@
import os
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.document_store.memory import InMemoryDocumentStore
from haystack.document_store.elasticsearch import Elasticsearch, ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.retriever.sparse import ElasticsearchRetriever, TfidfRetriever
from haystack.retriever.dense import DensePassageRetriever
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from time import perf_counter
import pandas as pd
import json
import logging
import subprocess
import time
from pathlib import Path
logger = logging.getLogger(__name__)
reader_models = ["deepset/roberta-base-squad2", "deepset/minilm-uncased-squad2", "deepset/bert-base-cased-squad2", "deepset/bert-large-uncased-whole-word-masking-squad2", "deepset/xlm-roberta-large-squad2"]
reader_types = ["farm"]
data_dir_reader = Path("../../data/squad20")
filename_reader = "dev-v2.0.json"
doc_index = "eval_document"
label_index = "label"
def get_document_store(document_store_type, es_similarity='cosine'):
""" TODO This method is taken from test/conftest.py but maybe should be within Haystack.
Perhaps a class method of DocStore that just takes string for type of DocStore"""
if document_store_type == "sql":
if os.path.exists("haystack_test.db"):
os.remove("haystack_test.db")
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
elif document_store_type == "memory":
document_store = InMemoryDocumentStore()
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
client.indices.delete(index='haystack_test*', ignore=[404])
document_store = ElasticsearchDocumentStore(index="eval_document", similarity=es_similarity)
elif document_store_type in("faiss_flat", "faiss_hnsw"):
if document_store_type == "faiss_flat":
index_type = "Flat"
elif document_store_type == "faiss_hnsw":
index_type = "HNSW"
#TEMP FIX for issue with deleting docs
# status = subprocess.run(
# ['docker rm -f haystack-postgres'],
# shell=True)
# time.sleep(3)
# try:
# document_store = FAISSDocumentStore(sql_url="postgresql://postgres:password@localhost:5432/haystack",
# faiss_index_factory_str=index_type)
# except:
# Launch a postgres instance & create empty DB
# logger.info("Didn't find Postgres. Start a new instance...")
status = subprocess.run(
['docker rm -f haystack-postgres'],
shell=True)
time.sleep(1)
status = subprocess.run(
['docker run --name haystack-postgres -p 5432:5432 -e POSTGRES_PASSWORD=password -d postgres'],
shell=True)
time.sleep(3)
status = subprocess.run(
['docker exec -it haystack-postgres psql -U postgres -c "CREATE DATABASE haystack;"'], shell=True)
time.sleep(1)
document_store = FAISSDocumentStore(sql_url="postgresql://postgres:password@localhost:5432/haystack",
faiss_index_factory_str=index_type)
else:
raise Exception(f"No document store fixture for '{document_store_type}'")
return document_store
def get_retriever(retriever_name, doc_store):
if retriever_name == "elastic":
return ElasticsearchRetriever(doc_store)
if retriever_name == "tfidf":
return TfidfRetriever(doc_store)
if retriever_name == "dpr":
return DensePassageRetriever(document_store=doc_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True)
def get_reader(reader_name, reader_type, max_seq_len=384):
reader_class = None
if reader_type == "farm":
reader_class = FARMReader
elif reader_type == "transformers":
reader_class = TransformersReader
return reader_class(reader_name, top_k_per_candidate=4, max_seq_len=max_seq_len)
def index_to_doc_store(doc_store, docs, retriever, labels=None):
doc_store.write_documents(docs, doc_index)
if labels:
doc_store.write_labels(labels, index=label_index)
# these lines are not run if the docs.embedding field is already populated with precomputed embeddings
# See the prepare_data() fn in the retriever benchmark script
elif callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
doc_store.update_embeddings(retriever, index=doc_index)