haystack/test/benchmarks/retriever_simplified.py

86 lines
3.1 KiB
Python
Raw Normal View History

2021-06-04 11:05:18 +02:00
"""
This script performs the same query benchmarking as `retriever.py` but with less of the loops that iterate
over all the parameters so that it is easier to inspect what is happening
"""
from haystack.document_store import MilvusDocumentStore, FAISSDocumentStore
from haystack.retriever import DensePassageRetriever
from retriever import prepare_data
import datetime
from pprint import pprint
from milvus import IndexType
2021-06-04 11:05:18 +02:00
from utils import get_document_store
2021-06-04 11:05:18 +02:00
def benchmark_querying(index_type, n_docs=100_000, similarity="dot_product"):
doc_index = "document"
label_index = "label"
docs, labels = prepare_data(
data_dir="data/",
filename_gold="nq2squad-dev.json",
filename_negative="psgs_w100_minus_gold_100k.tsv",
remote_url="https://ext-haystack-retriever-eval.s3-eu-west-1.amazonaws.com/",
embeddings_filenames=["wikipedia_passages_100k.pkl"],
embeddings_dir="embeddings/",
n_docs=n_docs,
add_precomputed=True
)
2021-06-04 11:05:18 +02:00
doc_store = get_document_store(
document_store_type=index_type,
similarity=similarity
)
# if index_type == "milvus_flat":
# doc_store = MilvusDocumentStore(index=doc_index, similarity=similarity)
# elif index_type == "milvus_hnsw":
# index_param = {"M": 64, "efConstruction": 80}
# search_param = {"ef": 20}
# doc_store = MilvusDocumentStore(
# index=doc_index,
# index_type=IndexType.HNSW,
# index_param=index_param,
# search_param=search_param,
# similarity=similarity
# )
doc_store.write_documents(documents=docs, index=doc_index)
doc_store.write_labels(labels=labels, index=label_index)
retriever = DensePassageRetriever(
document_store=doc_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
use_fast_tokenizers=True
)
raw_results = retriever.eval(label_index=label_index, doc_index=doc_index)
results = {
"n_queries": raw_results["n_questions"],
"retrieve_time": raw_results["retrieve_time"],
"queries_per_second": raw_results["n_questions"] / raw_results["retrieve_time"],
"seconds_per_query": raw_results["retrieve_time"] / raw_results["n_questions"],
"recall": raw_results["recall"] * 100,
"map": raw_results["map"] * 100,
"top_k": raw_results["top_k"],
"date_time": datetime.datetime.now(),
"error": None
}
pprint(results)
doc_store.delete_all_documents(index=doc_index)
doc_store.delete_all_documents(index=label_index)
if __name__ == "__main__":
similarity = "l2"
2021-06-04 11:05:18 +02:00
n_docs = 1000
2021-06-04 11:05:18 +02:00
benchmark_querying(index_type="milvus_flat", similarity=similarity, n_docs=n_docs)
benchmark_querying(index_type="milvus_hnsw", similarity=similarity, n_docs=n_docs)
benchmark_querying(index_type="faiss_flat", similarity=similarity, n_docs=n_docs)
benchmark_querying(index_type="faiss_hnsw", similarity=similarity, n_docs=n_docs)